diff --git a/cpp/cmake_modules/DefineOptions.cmake b/cpp/cmake_modules/DefineOptions.cmake index 05fc14bbc72..ec1e0b6352a 100644 --- a/cpp/cmake_modules/DefineOptions.cmake +++ b/cpp/cmake_modules/DefineOptions.cmake @@ -391,6 +391,10 @@ if("${CMAKE_SOURCE_DIR}" STREQUAL "${CMAKE_CURRENT_SOURCE_DIR}") define_option(ARROW_WITH_ZLIB "Build with zlib compression" OFF) define_option(ARROW_WITH_ZSTD "Build with zstd compression" OFF) + define_option(ARROW_WITH_UCX + "Build with UCX transport for Arrow Flight;(only used if ARROW_FLIGHT is ON)" + OFF) + define_option(ARROW_WITH_UTF8PROC "Build with support for Unicode properties using the utf8proc library;(only used if ARROW_COMPUTE is ON or ARROW_GANDIVA is ON)" ON) diff --git a/cpp/src/arrow/CMakeLists.txt b/cpp/src/arrow/CMakeLists.txt index e9e826097b3..b6f1e2481fa 100644 --- a/cpp/src/arrow/CMakeLists.txt +++ b/cpp/src/arrow/CMakeLists.txt @@ -747,6 +747,10 @@ endif() if(ARROW_FLIGHT) add_subdirectory(flight) + + if(ARROW_WITH_UCX) + add_subdirectory(flight/transport/ucx) + endif() endif() if(ARROW_FLIGHT_SQL) diff --git a/cpp/src/arrow/flight/CMakeLists.txt b/cpp/src/arrow/flight/CMakeLists.txt index 7447e675e08..f9d135654b4 100644 --- a/cpp/src/arrow/flight/CMakeLists.txt +++ b/cpp/src/arrow/flight/CMakeLists.txt @@ -313,4 +313,14 @@ if(ARROW_BUILD_BENCHMARKS) add_dependencies(arrow-flight-benchmark arrow-flight-perf-server) add_dependencies(arrow_flight arrow-flight-benchmark) + + if(ARROW_WITH_UCX) + if(ARROW_FLIGHT_TEST_LINKAGE STREQUAL "static") + target_link_libraries(arrow-flight-benchmark arrow_flight_transport_ucx_static) + target_link_libraries(arrow-flight-perf-server arrow_flight_transport_ucx_static) + else() + target_link_libraries(arrow-flight-benchmark arrow_flight_transport_ucx_shared) + target_link_libraries(arrow-flight-perf-server arrow_flight_transport_ucx_shared) + endif() + endif() endif(ARROW_BUILD_BENCHMARKS) diff --git a/cpp/src/arrow/flight/flight_benchmark.cc b/cpp/src/arrow/flight/flight_benchmark.cc index 872c67c80b7..fa0cc9a3d53 100644 --- a/cpp/src/arrow/flight/flight_benchmark.cc +++ b/cpp/src/arrow/flight/flight_benchmark.cc @@ -40,12 +40,20 @@ #include "arrow/flight/test_util.h" #ifdef ARROW_CUDA +#include #include "arrow/gpu/cuda_api.h" #endif +#ifdef ARROW_WITH_UCX +#include "arrow/flight/transport/ucx/ucx.h" +#endif DEFINE_bool(cuda, false, "Allocate results in CUDA memory"); DEFINE_string(transport, "grpc", - "The network transport to use. Supported: \"grpc\" (default)."); + "The network transport to use. Supported: \"grpc\" (default)" +#ifdef ARROW_WITH_UCX + ", \"ucx\"" +#endif // ARROW_WITH_UCX + "."); DEFINE_string(server_host, "", "An existing performance server to benchmark against (leave blank to spawn " "one automatically)"); @@ -497,6 +505,21 @@ int main(int argc, char** argv) { options.disable_server_verification = true; } } + } else if (FLAGS_transport == "ucx") { +#ifdef ARROW_WITH_UCX + arrow::flight::transport::ucx::InitializeFlightUcx(); + if (FLAGS_test_unix || !FLAGS_server_unix.empty()) { + std::cerr << "Transport does not support domain sockets: " << FLAGS_transport + << std::endl; + return EXIT_FAILURE; + } + ARROW_CHECK_OK(arrow::flight::Location::Parse("ucx://" + FLAGS_server_host + ":" + + std::to_string(FLAGS_server_port)) + .Value(&location)); +#else + std::cerr << "Not built with transport: " << FLAGS_transport << std::endl; + return EXIT_FAILURE; +#endif } else { std::cerr << "Unknown transport: " << FLAGS_transport << std::endl; return EXIT_FAILURE; @@ -514,6 +537,16 @@ int main(int argc, char** argv) { ABORT_NOT_OK(arrow::cuda::CudaDeviceManager::Instance().Value(&manager)); ABORT_NOT_OK(manager->GetDevice(0).Value(&device)); call_options.memory_manager = device->default_memory_manager(); + + // Needed to prevent UCX warning + // cuda_md.c:162 UCX ERROR cuMemGetAddressRange(0x7f2ab5dc0000) error: invalid + // device context + std::shared_ptr context; + ABORT_NOT_OK(device->GetContext().Value(&context)); + auto cuda_status = cuCtxPushCurrent(reinterpret_cast(context->handle())); + if (cuda_status != CUDA_SUCCESS) { + ARROW_LOG(WARNING) << "CUDA error " << cuda_status; + } #else std::cerr << "-cuda requires that Arrow is built with ARROW_CUDA" << std::endl; return 1; diff --git a/cpp/src/arrow/flight/perf_server.cc b/cpp/src/arrow/flight/perf_server.cc index cc42ffedd68..37e3ec4d771 100644 --- a/cpp/src/arrow/flight/perf_server.cc +++ b/cpp/src/arrow/flight/perf_server.cc @@ -19,6 +19,7 @@ #include #include +#include #include #include #include @@ -43,10 +44,17 @@ #ifdef ARROW_CUDA #include "arrow/gpu/cuda_api.h" #endif +#ifdef ARROW_WITH_UCX +#include "arrow/flight/transport/ucx/ucx.h" +#endif DEFINE_bool(cuda, false, "Allocate results in CUDA memory"); DEFINE_string(transport, "grpc", - "The network transport to use. Supported: \"grpc\" (default)."); + "The network transport to use. Supported: \"grpc\" (default)" +#ifdef ARROW_WITH_UCX + ", \"ucx\"" +#endif // ARROW_WITH_UCX + "."); DEFINE_string(server_host, "localhost", "Host where the server is running on"); DEFINE_int32(port, 31337, "Server port to listen on"); DEFINE_string(server_unix, "", "Unix socket path where the server is running on"); @@ -97,7 +105,7 @@ class PerfDataStream : public FlightDataStream { if (records_sent_ >= total_records_) { // Signal that iteration is over payload.ipc_message.metadata = nullptr; - return Status::OK(); + return payload; } if (verify_) { @@ -274,6 +282,29 @@ int main(int argc, char** argv) { ARROW_CHECK_OK(arrow::flight::Location::ForGrpcUnix(FLAGS_server_unix) .Value(&connect_location)); } + } else if (FLAGS_transport == "ucx") { +#ifdef ARROW_WITH_UCX + arrow::flight::transport::ucx::InitializeFlightUcx(); + if (FLAGS_server_unix.empty()) { + if (!FLAGS_cert_file.empty() || !FLAGS_key_file.empty()) { + std::cerr << "Transport does not support TLS: " << FLAGS_transport << std::endl; + return EXIT_FAILURE; + } + ARROW_CHECK_OK(arrow::flight::Location::Parse("ucx://" + FLAGS_server_host + ":" + + std::to_string(FLAGS_port)) + .Value(&bind_location)); + ARROW_CHECK_OK(arrow::flight::Location::Parse("ucx://" + FLAGS_server_host + ":" + + std::to_string(FLAGS_port)) + .Value(&connect_location)); + } else { + std::cerr << "Transport does not support domain sockets: " << FLAGS_transport + << std::endl; + return EXIT_FAILURE; + } +#else + std::cerr << "Not built with transport: " << FLAGS_transport << std::endl; + return EXIT_FAILURE; +#endif } else { std::cerr << "Unknown transport: " << FLAGS_transport << std::endl; return EXIT_FAILURE; @@ -308,6 +339,7 @@ int main(int argc, char** argv) { // Exit with a clean error code (0) on SIGTERM ARROW_CHECK_OK(g_server->SetShutdownOnSignals({SIGTERM})); std::cout << "Server transport: " << FLAGS_transport << std::endl; + std::cout << "Server location: " << connect_location.ToString() << std::endl; if (FLAGS_server_unix.empty()) { std::cout << "Server host: " << FLAGS_server_host << std::endl; std::cout << "Server port: " << FLAGS_port << std::endl; diff --git a/cpp/src/arrow/flight/test_definitions.cc b/cpp/src/arrow/flight/test_definitions.cc index 2cfac641446..1ec06a1f004 100644 --- a/cpp/src/arrow/flight/test_definitions.cc +++ b/cpp/src/arrow/flight/test_definitions.cc @@ -45,7 +45,7 @@ using arrow::internal::checked_cast; void ConnectivityTest::TestGetPort() { std::unique_ptr server = ExampleTestServer(); - ASSERT_OK_AND_ASSIGN(auto location, Location::ForScheme(transport(), "localhost", 0)); + ASSERT_OK_AND_ASSIGN(auto location, Location::ForScheme(transport(), "127.0.0.1", 0)); FlightServerOptions options(location); ASSERT_OK(server->Init(options)); ASSERT_GT(server->port(), 0); @@ -53,7 +53,7 @@ void ConnectivityTest::TestGetPort() { void ConnectivityTest::TestBuilderHook() { std::unique_ptr server = ExampleTestServer(); - ASSERT_OK_AND_ASSIGN(auto location, Location::ForScheme(transport(), "localhost", 0)); + ASSERT_OK_AND_ASSIGN(auto location, Location::ForScheme(transport(), "127.0.0.1", 0)); FlightServerOptions options(location); bool builder_hook_run = false; options.builder_hook = [&builder_hook_run](void* builder) { @@ -68,7 +68,7 @@ void ConnectivityTest::TestBuilderHook() { void ConnectivityTest::TestShutdown() { // Regression test for ARROW-15181 constexpr int kIterations = 10; - ASSERT_OK_AND_ASSIGN(auto location, Location::ForScheme(transport(), "localhost", 0)); + ASSERT_OK_AND_ASSIGN(auto location, Location::ForScheme(transport(), "127.0.0.1", 0)); for (int i = 0; i < kIterations; i++) { std::unique_ptr server = ExampleTestServer(); @@ -84,7 +84,7 @@ void ConnectivityTest::TestShutdown() { void ConnectivityTest::TestShutdownWithDeadline() { std::unique_ptr server = ExampleTestServer(); - ASSERT_OK_AND_ASSIGN(auto location, Location::ForScheme(transport(), "localhost", 0)); + ASSERT_OK_AND_ASSIGN(auto location, Location::ForScheme(transport(), "127.0.0.1", 0)); FlightServerOptions options(location); ASSERT_OK(server->Init(options)); ASSERT_GT(server->port(), 0); @@ -96,13 +96,13 @@ void ConnectivityTest::TestShutdownWithDeadline() { } void ConnectivityTest::TestBrokenConnection() { std::unique_ptr server = ExampleTestServer(); - ASSERT_OK_AND_ASSIGN(auto location, Location::ForScheme(transport(), "localhost", 0)); + ASSERT_OK_AND_ASSIGN(auto location, Location::ForScheme(transport(), "127.0.0.1", 0)); FlightServerOptions options(location); ASSERT_OK(server->Init(options)); std::unique_ptr client; ASSERT_OK_AND_ASSIGN(location, - Location::ForScheme(transport(), "localhost", server->port())); + Location::ForScheme(transport(), "127.0.0.1", server->port())); ASSERT_OK_AND_ASSIGN(client, FlightClient::Connect(location)); ASSERT_OK(server->Shutdown()); @@ -117,7 +117,7 @@ void ConnectivityTest::TestBrokenConnection() { void DataTest::SetUp() { server_ = ExampleTestServer(); - ASSERT_OK_AND_ASSIGN(auto location, Location::ForScheme(transport(), "localhost", 0)); + ASSERT_OK_AND_ASSIGN(auto location, Location::ForScheme(transport(), "127.0.0.1", 0)); FlightServerOptions options(location); ASSERT_OK(server_->Init(options)); @@ -129,7 +129,7 @@ void DataTest::TearDown() { } Status DataTest::ConnectClient() { ARROW_ASSIGN_OR_RAISE(auto location, - Location::ForScheme(transport(), "localhost", server_->port())); + Location::ForScheme(transport(), "127.0.0.1", server_->port())); ARROW_ASSIGN_OR_RAISE(client_, FlightClient::Connect(location)); return Status::OK(); } @@ -638,7 +638,7 @@ class DoPutTestServer : public FlightServerBase { }; void DoPutTest::SetUp() { - ASSERT_OK_AND_ASSIGN(auto location, Location::ForScheme(transport(), "localhost", 0)); + ASSERT_OK_AND_ASSIGN(auto location, Location::ForScheme(transport(), "127.0.0.1", 0)); ASSERT_OK(MakeServer( location, &server_, &client_, [](FlightServerOptions* options) { return Status::OK(); }, @@ -766,7 +766,7 @@ void DoPutTest::TestLargeBatch() { void DoPutTest::TestSizeLimit() { const int64_t size_limit = 4096; ASSERT_OK_AND_ASSIGN(auto location, - Location::ForScheme(transport(), "localhost", server_->port())); + Location::ForScheme(transport(), "127.0.0.1", server_->port())); auto client_options = FlightClientOptions::Defaults(); client_options.write_size_limit_bytes = size_limit; ASSERT_OK_AND_ASSIGN(auto client, FlightClient::Connect(location, client_options)); @@ -866,7 +866,7 @@ Status AppMetadataTestServer::DoPut(const ServerCallContext& context, } void AppMetadataTest::SetUp() { - ASSERT_OK_AND_ASSIGN(auto location, Location::ForScheme(transport(), "localhost", 0)); + ASSERT_OK_AND_ASSIGN(auto location, Location::ForScheme(transport(), "127.0.0.1", 0)); ASSERT_OK(MakeServer( location, &server_, &client_, [](FlightServerOptions* options) { return Status::OK(); }, @@ -1045,7 +1045,7 @@ class IpcOptionsTestServer : public FlightServerBase { }; void IpcOptionsTest::SetUp() { - ASSERT_OK_AND_ASSIGN(auto location, Location::ForScheme(transport(), "localhost", 0)); + ASSERT_OK_AND_ASSIGN(auto location, Location::ForScheme(transport(), "127.0.0.1", 0)); ASSERT_OK(MakeServer( location, &server_, &client_, [](FlightServerOptions* options) { return Status::OK(); }, @@ -1241,7 +1241,7 @@ void CudaDataTest::SetUp() { impl_->device = std::move(device); impl_->context = std::move(context); - ASSERT_OK_AND_ASSIGN(auto location, Location::ForScheme(transport(), "localhost", 0)); + ASSERT_OK_AND_ASSIGN(auto location, Location::ForScheme(transport(), "127.0.0.1", 0)); ASSERT_OK(MakeServer( location, &server_, &client_, [this](FlightServerOptions* options) { diff --git a/cpp/src/arrow/flight/test_util.h b/cpp/src/arrow/flight/test_util.h index 5320b958d58..d5b774b4a37 100644 --- a/cpp/src/arrow/flight/test_util.h +++ b/cpp/src/arrow/flight/test_util.h @@ -113,7 +113,7 @@ Status MakeServer(const Location& location, std::unique_ptr* s RETURN_NOT_OK(make_server_options(&server_options)); RETURN_NOT_OK((*server)->Init(server_options)); std::string uri = - location.scheme() + "://localhost:" + std::to_string((*server)->port()); + location.scheme() + "://127.0.0.1:" + std::to_string((*server)->port()); ARROW_ASSIGN_OR_RAISE(auto real_location, Location::Parse(uri)); FlightClientOptions client_options = FlightClientOptions::Defaults(); RETURN_NOT_OK(make_client_options(&client_options)); diff --git a/cpp/src/arrow/flight/transport/ucx/CMakeLists.txt b/cpp/src/arrow/flight/transport/ucx/CMakeLists.txt new file mode 100644 index 00000000000..6e315b68d6c --- /dev/null +++ b/cpp/src/arrow/flight/transport/ucx/CMakeLists.txt @@ -0,0 +1,77 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +add_custom_target(arrow_flight_transport_ucx) +arrow_install_all_headers("arrow/flight/transport/ucx") + +find_package(PkgConfig REQUIRED) +pkg_check_modules(UCX REQUIRED IMPORTED_TARGET ucx) + +set(ARROW_FLIGHT_TRANSPORT_UCX_SRCS + ucx_client.cc + ucx_server.cc + ucx.cc + ucx_internal.cc + util_internal.cc) +set(ARROW_FLIGHT_TRANSPORT_UCX_LINK_LIBS) + +include_directories(SYSTEM ${UCX_INCLUDE_DIRS}) +list(APPEND ARROW_FLIGHT_TRANSPORT_UCX_LINK_LIBS PkgConfig::UCX) + +add_arrow_lib(arrow_flight_transport_ucx + # CMAKE_PACKAGE_NAME + # ArrowFlightTransportUcx + # PKG_CONFIG_NAME + # arrow-flight-transport-ucx + SOURCES + ${ARROW_FLIGHT_TRANSPORT_UCX_SRCS} + PRECOMPILED_HEADERS + "$<$:arrow/flight/transport/ucx/pch.h>" + DEPENDENCIES + SHARED_LINK_FLAGS + ${ARROW_VERSION_SCRIPT_FLAGS} # Defined in cpp/arrow/CMakeLists.txt + SHARED_LINK_LIBS + arrow_shared + arrow_flight_shared + ${ARROW_FLIGHT_TRANSPORT_UCX_LINK_LIBS} + STATIC_LINK_LIBS + arrow_static + arrow_flight_static + ${ARROW_FLIGHT_TRANSPORT_UCX_LINK_LIBS}) + +if(ARROW_BUILD_TESTS) + if(ARROW_FLIGHT_TEST_LINKAGE STREQUAL "static") + set(ARROW_FLIGHT_UCX_TEST_LINK_LIBS + arrow_static + arrow_flight_static + arrow_flight_testing_static + arrow_flight_transport_ucx_static + ${ARROW_TEST_LINK_LIBS}) + else() + set(ARROW_FLIGHT_UCX_TEST_LINK_LIBS + arrow_shared + arrow_flight_shared + arrow_flight_testing_shared + arrow_flight_transport_ucx_shared + ${ARROW_TEST_LINK_LIBS}) + endif() + add_arrow_test(flight_transport_ucx_test + STATIC_LINK_LIBS + ${ARROW_FLIGHT_UCX_TEST_LINK_LIBS} + LABELS + "arrow_flight") +endif() diff --git a/cpp/src/arrow/flight/transport/ucx/flight_transport_ucx_test.cc b/cpp/src/arrow/flight/transport/ucx/flight_transport_ucx_test.cc new file mode 100644 index 00000000000..6a580af92fd --- /dev/null +++ b/cpp/src/arrow/flight/transport/ucx/flight_transport_ucx_test.cc @@ -0,0 +1,386 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#include +#include + +#include "arrow/array/array_base.h" +#include "arrow/flight/test_definitions.h" +#include "arrow/flight/test_util.h" +#include "arrow/flight/transport/ucx/ucx.h" +#include "arrow/table.h" +#include "arrow/testing/gtest_util.h" +#include "arrow/util/config.h" + +#ifdef UCP_API_VERSION +#error "UCX headers should not be in public API" +#endif + +#include "arrow/flight/transport/ucx/ucx_internal.h" + +#ifdef ARROW_CUDA +#include "arrow/gpu/cuda_api.h" +#endif + +namespace arrow { +namespace flight { + +class UcxEnvironment : public ::testing::Environment { + public: + void SetUp() override { transport::ucx::InitializeFlightUcx(); } +}; + +testing::Environment* const kUcxEnvironment = + testing::AddGlobalTestEnvironment(new UcxEnvironment()); + +//------------------------------------------------------------ +// Common transport tests + +class UcxConnectivityTest : public ConnectivityTest { + protected: + std::string transport() const override { return "ucx"; } +}; +ARROW_FLIGHT_TEST_CONNECTIVITY(UcxConnectivityTest); + +class UcxDataTest : public DataTest { + protected: + std::string transport() const override { return "ucx"; } +}; +ARROW_FLIGHT_TEST_DATA(UcxDataTest); + +class UcxDoPutTest : public DoPutTest { + protected: + std::string transport() const override { return "ucx"; } +}; +ARROW_FLIGHT_TEST_DO_PUT(UcxDoPutTest); + +class UcxAppMetadataTest : public AppMetadataTest { + protected: + std::string transport() const override { return "ucx"; } +}; +ARROW_FLIGHT_TEST_APP_METADATA(UcxAppMetadataTest); + +class UcxIpcOptionsTest : public IpcOptionsTest { + protected: + std::string transport() const override { return "ucx"; } +}; +ARROW_FLIGHT_TEST_IPC_OPTIONS(UcxIpcOptionsTest); + +class UcxCudaDataTest : public CudaDataTest { + protected: + std::string transport() const override { return "ucx"; } +}; +ARROW_FLIGHT_TEST_CUDA_DATA(UcxCudaDataTest); + +//------------------------------------------------------------ +// UCX internals tests + +constexpr std::initializer_list kStatusCodes = { + StatusCode::OK, + StatusCode::OutOfMemory, + StatusCode::KeyError, + StatusCode::TypeError, + StatusCode::Invalid, + StatusCode::IOError, + StatusCode::CapacityError, + StatusCode::IndexError, + StatusCode::Cancelled, + StatusCode::UnknownError, + StatusCode::NotImplemented, + StatusCode::SerializationError, + StatusCode::RError, + StatusCode::CodeGenError, + StatusCode::ExpressionValidationError, + StatusCode::ExecutionError, + StatusCode::AlreadyExists, +}; + +constexpr std::initializer_list kFlightStatusCodes = { + FlightStatusCode::Internal, FlightStatusCode::TimedOut, + FlightStatusCode::Cancelled, FlightStatusCode::Unauthenticated, + FlightStatusCode::Unauthorized, FlightStatusCode::Unavailable, + FlightStatusCode::Failed, +}; + +class TestStatusDetail : public StatusDetail { + public: + const char* type_id() const override { return "test-status-detail"; } + std::string ToString() const override { return "Custom status detail"; } +}; + +namespace transport { +namespace ucx { + +static constexpr std::initializer_list kFrameTypes = { + FrameType::kHeaders, FrameType::kBuffer, FrameType::kPayloadHeader, + FrameType::kPayloadBody, FrameType::kDisconnect, +}; + +TEST(FrameHeader, Basics) { + for (const auto frame_type : kFrameTypes) { + FrameHeader header; + ASSERT_OK(header.Set(frame_type, /*counter=*/42, /*body_size=*/65535)); + if (frame_type == FrameType::kDisconnect) { + ASSERT_RAISES(Cancelled, Frame::ParseHeader(header.data(), header.size())); + } else { + ASSERT_OK_AND_ASSIGN(auto frame, Frame::ParseHeader(header.data(), header.size())); + ASSERT_EQ(frame->type, frame_type); + ASSERT_EQ(frame->counter, 42); + ASSERT_EQ(frame->size, 65535); + } + } +} + +TEST(FrameHeader, FrameType) { + for (const auto frame_type : kFrameTypes) { + ASSERT_LE(static_cast(frame_type), static_cast(FrameType::kMaxFrameType)); + } +} + +TEST(HeadersFrame, Parse) { + const char* data = + ("\x00\x00\x00\x02\x00\x00\x00\x05\x00\x00\x00\x03x-foobar" + "\x00\x00\x00\x05\x00\x00\x00\x01x-bin\x01"); + constexpr int64_t size = 34; + + { + std::unique_ptr buffer( + new Buffer(reinterpret_cast(data), size)); + ASSERT_OK_AND_ASSIGN(auto headers, HeadersFrame::Parse(std::move(buffer))); + ASSERT_OK_AND_ASSIGN(auto foo, headers.Get("x-foo")); + ASSERT_EQ(foo, "bar"); + ASSERT_OK_AND_ASSIGN(auto bin, headers.Get("x-bin")); + ASSERT_EQ(bin, "\x01"); + } + { + std::unique_ptr buffer(new Buffer(reinterpret_cast(data), 3)); + EXPECT_RAISES_WITH_MESSAGE_THAT(Invalid, + ::testing::HasSubstr("expected number of headers"), + HeadersFrame::Parse(std::move(buffer))); + } + { + std::unique_ptr buffer(new Buffer(reinterpret_cast(data), 7)); + EXPECT_RAISES_WITH_MESSAGE_THAT(Invalid, + ::testing::HasSubstr("expected length of key 1"), + HeadersFrame::Parse(std::move(buffer))); + } + { + std::unique_ptr buffer( + new Buffer(reinterpret_cast(data), 10)); + EXPECT_RAISES_WITH_MESSAGE_THAT(Invalid, + ::testing::HasSubstr("expected length of value 1"), + HeadersFrame::Parse(std::move(buffer))); + } + { + std::unique_ptr buffer( + new Buffer(reinterpret_cast(data), 12)); + EXPECT_RAISES_WITH_MESSAGE_THAT( + Invalid, + ::testing::HasSubstr("expected key 1 to have length 5, but only 0 bytes remain"), + HeadersFrame::Parse(std::move(buffer))); + } + { + std::unique_ptr buffer( + new Buffer(reinterpret_cast(data), 17)); + EXPECT_RAISES_WITH_MESSAGE_THAT( + Invalid, + ::testing::HasSubstr( + "expected value 1 to have length 3, but only 0 bytes remain"), + HeadersFrame::Parse(std::move(buffer))); + } +} + +TEST(HeadersFrame, RoundTripStatus) { + for (const auto code : kStatusCodes) { + { + Status expected = code == StatusCode::OK ? Status() : Status(code, "foo"); + ASSERT_OK_AND_ASSIGN(auto headers, HeadersFrame::Make(expected, {})); + Status status; + ASSERT_OK(headers.GetStatus(&status)); + ASSERT_EQ(status, expected); + } + + if (code == StatusCode::OK) continue; + + // Attach a generic status detail + { + auto detail = std::make_shared(); + Status original(code, "foo", detail); + Status expected(code, "foo", + std::make_shared(FlightStatusCode::Internal, + detail->ToString())); + ASSERT_OK_AND_ASSIGN(auto headers, HeadersFrame::Make(expected, {})); + Status status; + ASSERT_OK(headers.GetStatus(&status)); + ASSERT_EQ(status, expected); + } + + // Attach a Flight status detail + for (const auto flight_code : kFlightStatusCodes) { + Status expected(code, "foo", + std::make_shared(flight_code, "extra")); + ASSERT_OK_AND_ASSIGN(auto headers, HeadersFrame::Make(expected, {})); + Status status; + ASSERT_OK(headers.GetStatus(&status)); + ASSERT_EQ(status, expected); + } + } +} +} // namespace ucx +} // namespace transport + +//------------------------------------------------------------ +// Ad-hoc UCX-specific tests + +class SimpleTestServer : public FlightServerBase { + public: + Status GetFlightInfo(const ServerCallContext& context, const FlightDescriptor& request, + std::unique_ptr* info) override { + if (request.path.size() > 0 && request.path[0] == "error") { + return status_; + } + auto examples = ExampleFlightInfo(); + info->reset(new FlightInfo(examples[0])); + return Status::OK(); + } + + Status DoGet(const ServerCallContext& context, const Ticket& request, + std::unique_ptr* data_stream) override { + RecordBatchVector batches; + RETURN_NOT_OK(ExampleIntBatches(&batches)); + auto batch_reader = std::make_shared(batches[0]->schema(), batches); + *data_stream = std::unique_ptr(new RecordBatchStream(batch_reader)); + return Status::OK(); + } + + void set_error_status(Status st) { status_ = std::move(st); } + + private: + Status status_; +}; + +class TestUcx : public ::testing::Test { + public: + void SetUp() { + ASSERT_OK_AND_ASSIGN(auto location, Location::ForScheme("ucx", "127.0.0.1", 0)); + ASSERT_OK(MakeServer( + location, &server_, &client_, + [](FlightServerOptions* options) { return Status::OK(); }, + [](FlightClientOptions* options) { return Status::OK(); })); + } + + void TearDown() { + ASSERT_OK(client_->Close()); + ASSERT_OK(server_->Shutdown()); + } + + protected: + std::unique_ptr client_; + std::unique_ptr server_; +}; + +TEST_F(TestUcx, GetFlightInfo) { + auto descriptor = FlightDescriptor::Path({"foo", "bar"}); + std::unique_ptr info; + ASSERT_OK_AND_ASSIGN(info, client_->GetFlightInfo(descriptor)); + // Test that we can reuse the connection + ASSERT_OK_AND_ASSIGN(info, client_->GetFlightInfo(descriptor)); +} + +TEST_F(TestUcx, SequentialClients) { + ASSERT_OK_AND_ASSIGN( + auto client2, + FlightClient::Connect(server_->location(), FlightClientOptions::Defaults())); + + Ticket ticket{"a"}; + + ASSERT_OK_AND_ASSIGN(auto stream1, client_->DoGet(ticket)); + ASSERT_OK_AND_ASSIGN(auto table1, stream1->ToTable()); + + ASSERT_OK_AND_ASSIGN(auto stream2, client2->DoGet(ticket)); + ASSERT_OK_AND_ASSIGN(auto table2, stream2->ToTable()); + + AssertTablesEqual(*table1, *table2); +} + +TEST_F(TestUcx, ConcurrentClients) { + ASSERT_OK_AND_ASSIGN( + auto client2, + FlightClient::Connect(server_->location(), FlightClientOptions::Defaults())); + + Ticket ticket{"a"}; + + ASSERT_OK_AND_ASSIGN(auto stream1, client_->DoGet(ticket)); + ASSERT_OK_AND_ASSIGN(auto stream2, client2->DoGet(ticket)); + + ASSERT_OK_AND_ASSIGN(auto table1, stream1->ToTable()); + ASSERT_OK_AND_ASSIGN(auto table2, stream2->ToTable()); + + AssertTablesEqual(*table1, *table2); +} + +TEST_F(TestUcx, Errors) { + auto descriptor = FlightDescriptor::Path({"error", "bar"}); + auto* server = reinterpret_cast(server_.get()); + for (const auto code : kStatusCodes) { + if (code == StatusCode::OK) continue; + + Status expected(code, "Error message"); + server->set_error_status(expected); + Status actual = client_->GetFlightInfo(descriptor).status(); + ASSERT_EQ(actual, expected); + + // Attach a generic status detail + { + auto detail = std::make_shared(); + server->set_error_status(Status(code, "foo", detail)); + Status expected(code, "foo", + std::make_shared(FlightStatusCode::Internal, + detail->ToString())); + Status actual = client_->GetFlightInfo(descriptor).status(); + ASSERT_EQ(actual, expected); + } + + // Attach a Flight status detail + for (const auto flight_code : kFlightStatusCodes) { + Status expected(code, "Error message", + std::make_shared(flight_code, "extra")); + server->set_error_status(expected); + Status actual = client_->GetFlightInfo(descriptor).status(); + ASSERT_EQ(actual, expected); + } + } +} + +TEST(TestUcxIpV6, DISABLED_IpV6Port) { + // Also, disabled in CI as machines lack an IPv6 interface + ASSERT_OK_AND_ASSIGN(auto location, Location::ForScheme("ucx", "[::1]", 0)); + + std::unique_ptr server(new SimpleTestServer()); + FlightServerOptions server_options(location); + ASSERT_OK(server->Init(server_options)); + + FlightClientOptions client_options = FlightClientOptions::Defaults(); + ASSERT_OK_AND_ASSIGN(auto client, + FlightClient::Connect(server->location(), client_options)); + + auto descriptor = FlightDescriptor::Path({"foo", "bar"}); + ASSERT_OK_AND_ASSIGN(auto info, client->GetFlightInfo(descriptor)); +} + +} // namespace flight +} // namespace arrow diff --git a/cpp/src/arrow/flight/transport/ucx/ucx.cc b/cpp/src/arrow/flight/transport/ucx/ucx.cc new file mode 100644 index 00000000000..0e3daf60213 --- /dev/null +++ b/cpp/src/arrow/flight/transport/ucx/ucx.cc @@ -0,0 +1,45 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#include "arrow/flight/transport/ucx/ucx.h" + +#include + +#include "arrow/flight/transport.h" +#include "arrow/flight/transport/ucx/ucx_internal.h" +#include "arrow/flight/transport_server.h" +#include "arrow/util/logging.h" + +namespace arrow { +namespace flight { +namespace transport { +namespace ucx { + +namespace { +std::once_flag kInitializeOnce; +} +void InitializeFlightUcx() { + std::call_once(kInitializeOnce, []() { + auto* registry = flight::internal::GetDefaultTransportRegistry(); + DCHECK_OK(registry->RegisterClient("ucx", MakeUcxClientImpl)); + DCHECK_OK(registry->RegisterServer("ucx", MakeUcxServerImpl)); + }); +} +} // namespace ucx +} // namespace transport +} // namespace flight +} // namespace arrow diff --git a/cpp/src/arrow/flight/transport/ucx/ucx.h b/cpp/src/arrow/flight/transport/ucx/ucx.h new file mode 100644 index 00000000000..dda2c83035c --- /dev/null +++ b/cpp/src/arrow/flight/transport/ucx/ucx.h @@ -0,0 +1,35 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +// Experimental UCX-based transport for Flight. + +#pragma once + +#include "arrow/flight/visibility.h" + +namespace arrow { +namespace flight { +namespace transport { +namespace ucx { + +ARROW_FLIGHT_EXPORT +void InitializeFlightUcx(); + +} // namespace ucx +} // namespace transport +} // namespace flight +} // namespace arrow diff --git a/cpp/src/arrow/flight/transport/ucx/ucx_client.cc b/cpp/src/arrow/flight/transport/ucx/ucx_client.cc new file mode 100644 index 00000000000..173132062e5 --- /dev/null +++ b/cpp/src/arrow/flight/transport/ucx/ucx_client.cc @@ -0,0 +1,733 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +/// The client-side implementation of a UCX-based transport for +/// Flight. +/// +/// Each UCX driver is used to support one call at a time. This gives +/// the greatest throughput for data plane methods, but is relatively +/// expensive in terms of other resources, both for the server and the +/// client. (UCX drivers have multiple threading modes: single-thread +/// access, serialized access, and multi-thread access. Testing found +/// that multi-thread access incurred high synchronization costs.) +/// Hence, for concurrent calls in a single client, we must maintain +/// multiple drivers, and so unlike gRPC, there is no real difference +/// between using one client concurrently and using multiple +/// independent clients. + +#include "arrow/flight/transport/ucx/ucx_internal.h" + +#include +#include +#include +#include + +#include +#include + +#include "arrow/buffer.h" +#include "arrow/flight/client.h" +#include "arrow/flight/transport.h" +#include "arrow/flight/transport/ucx/util_internal.h" +#include "arrow/result.h" +#include "arrow/status.h" +#include "arrow/util/logging.h" +#include "arrow/util/make_unique.h" +#include "arrow/util/uri.h" + +namespace arrow { +namespace flight { +namespace transport { +namespace ucx { + +namespace { +class UcxClientImpl; + +Status MergeStatuses(Status server_status, Status transport_status) { + if (server_status.ok()) { + if (transport_status.ok()) return server_status; + return transport_status; + } else if (transport_status.ok()) { + return server_status; + } + return Status::FromDetailAndArgs(server_status.code(), server_status.detail(), + server_status.message(), + ". Transport context: ", transport_status.ToString()); +} + +/// \brief An individual connection to the server. +class ClientConnection { + public: + ClientConnection() = default; + ARROW_DISALLOW_COPY_AND_ASSIGN(ClientConnection); + ARROW_DEFAULT_MOVE_AND_ASSIGN(ClientConnection); + ~ClientConnection() { DCHECK(!driver_) << "Connection was not closed!"; } + + Status Init(std::shared_ptr ucp_context, const arrow::internal::Uri& uri) { + auto status = InitImpl(std::move(ucp_context), uri); + // Clean up after-the-fact if we fail to initialize + if (!status.ok()) { + if (driver_) { + status = MergeStatuses(std::move(status), driver_->Close()); + driver_.reset(); + remote_endpoint_ = nullptr; + } + if (ucp_worker_) ucp_worker_.reset(); + } + return status; + } + + Status InitImpl(std::shared_ptr ucp_context, + const arrow::internal::Uri& uri) { + { + ucs_status_t status; + ucp_worker_params_t worker_params; + std::memset(&worker_params, 0, sizeof(worker_params)); + worker_params.field_mask = UCP_WORKER_PARAM_FIELD_THREAD_MODE; + worker_params.thread_mode = UCS_THREAD_MODE_SERIALIZED; + + ucp_worker_h ucp_worker; + status = ucp_worker_create(ucp_context->get(), &worker_params, &ucp_worker); + RETURN_NOT_OK(FromUcsStatus("ucp_worker_create", status)); + ucp_worker_.reset(new UcpWorker(std::move(ucp_context), ucp_worker)); + } + { + // Create endpoint for remote worker + struct sockaddr_storage connect_addr; + ARROW_ASSIGN_OR_RAISE(auto addrlen, UriToSockaddr(uri, &connect_addr)); + std::string peer; + ARROW_UNUSED(SockaddrToString(connect_addr).Value(&peer)); + ARROW_LOG(DEBUG) << "Connecting to " << peer; + + ucp_ep_params_t params; + params.field_mask = UCP_EP_PARAM_FIELD_FLAGS | UCP_EP_PARAM_FIELD_NAME | + UCP_EP_PARAM_FIELD_SOCK_ADDR; + params.flags = UCP_EP_PARAMS_FLAGS_CLIENT_SERVER; + params.name = "UcxClientImpl"; + params.sockaddr.addr = reinterpret_cast(&connect_addr); + params.sockaddr.addrlen = addrlen; + + auto status = ucp_ep_create(ucp_worker_->get(), ¶ms, &remote_endpoint_); + RETURN_NOT_OK(FromUcsStatus("ucp_ep_create", status)); + } + + driver_.reset(new UcpCallDriver(ucp_worker_, remote_endpoint_)); + ARROW_LOG(DEBUG) << "Connected to " << driver_->peer(); + + { + // Set up Active Message (AM) handler + ucp_am_handler_param_t handler_params; + handler_params.field_mask = UCP_AM_HANDLER_PARAM_FIELD_ID | + UCP_AM_HANDLER_PARAM_FIELD_CB | + UCP_AM_HANDLER_PARAM_FIELD_ARG; + handler_params.id = kUcpAmHandlerId; + handler_params.cb = HandleIncomingActiveMessage; + handler_params.arg = driver_.get(); + ucs_status_t status = + ucp_worker_set_am_recv_handler(ucp_worker_->get(), &handler_params); + RETURN_NOT_OK(FromUcsStatus("ucp_worker_set_am_recv_handler", status)); + } + + return Status::OK(); + } + + Status Close() { + if (!driver_) return Status::OK(); + + auto status = driver_->SendFrame(FrameType::kDisconnect, nullptr, 0); + const auto ucs_status = FlightUcxStatusDetail::Unwrap(status); + if (IsIgnorableDisconnectError(ucs_status)) { + status = Status::OK(); + } + status = MergeStatuses(std::move(status), driver_->Close()); + + driver_.reset(); + remote_endpoint_ = nullptr; + ucp_worker_.reset(); + return status; + } + + UcpCallDriver* driver() { + DCHECK(driver_); + return driver_.get(); + } + + private: + static ucs_status_t HandleIncomingActiveMessage(void* self, const void* header, + size_t header_length, void* data, + size_t data_length, + const ucp_am_recv_param_t* param) { + auto* driver = reinterpret_cast(self); + return driver->RecvActiveMessage(header, header_length, data, data_length, param); + } + + std::shared_ptr ucp_worker_; + ucp_ep_h remote_endpoint_; + std::unique_ptr driver_; +}; + +class UcxClientStream : public internal::ClientDataStream { + public: + UcxClientStream(UcxClientImpl* impl, ClientConnection conn) + : impl_(impl), + conn_(std::move(conn)), + driver_(conn_.driver()), + writes_done_(false), + finished_(false) {} + + protected: + Status DoFinish() override; + + UcxClientImpl* impl_; + ClientConnection conn_; + UcpCallDriver* driver_; + bool writes_done_; + bool finished_; + Status io_status_; + Status server_status_; +}; + +class GetClientStream : public UcxClientStream { + public: + GetClientStream(UcxClientImpl* impl, ClientConnection conn) + : UcxClientStream(impl, std::move(conn)) { + writes_done_ = true; + } + + bool ReadData(internal::FlightData* data) override { + if (finished_) return false; + + bool success = true; + io_status_ = ReadImpl(data).Value(&success); + + if (!io_status_.ok() || !success) { + finished_ = true; + } + return success; + } + + private: + ::arrow::Result ReadImpl(internal::FlightData* data) { + ARROW_ASSIGN_OR_RAISE(auto frame, driver_->ReadNextFrame()); + + if (frame->type == FrameType::kHeaders) { + // Trailers, stream is over + ARROW_ASSIGN_OR_RAISE(auto headers, HeadersFrame::Parse(std::move(frame->buffer))); + RETURN_NOT_OK(headers.GetStatus(&server_status_)); + return false; + } + + RETURN_NOT_OK(driver_->ExpectFrameType(*frame, FrameType::kPayloadHeader)); + PayloadHeaderFrame payload_header(std::move(frame->buffer)); + RETURN_NOT_OK(payload_header.ToFlightData(data)); + + // DoGet does not support metadata-only messages, so we can always + // assume we have an IPC payload + ARROW_ASSIGN_OR_RAISE(auto message, ipc::Message::Open(data->metadata, nullptr)); + + if (ipc::Message::HasBody(message->type())) { + ARROW_ASSIGN_OR_RAISE(frame, driver_->ReadNextFrame()); + RETURN_NOT_OK(driver_->ExpectFrameType(*frame, FrameType::kPayloadBody)); + data->body = std::move(frame->buffer); + } + return true; + } +}; + +class WriteClientStream : public UcxClientStream { + public: + WriteClientStream(UcxClientImpl* impl, ClientConnection conn) + : UcxClientStream(impl, std::move(conn)) { + std::thread t(&WriteClientStream::DriveWorker, this); + driver_thread_.swap(t); + } + arrow::Result WriteData(const FlightPayload& payload) override { + std::unique_lock guard(driver_mutex_); + if (finished_ || writes_done_) return Status::Invalid("Already done writing"); + outgoing_ = driver_->SendFlightPayload(payload); + working_cv_.notify_all(); + completed_cv_.wait(guard, [this] { return outgoing_.is_finished(); }); + + auto status = outgoing_.status(); + outgoing_ = Future<>(); + RETURN_NOT_OK(status); + return true; + } + Status WritesDone() override { + std::unique_lock guard(driver_mutex_); + if (!writes_done_) { + ARROW_ASSIGN_OR_RAISE(auto headers, HeadersFrame::Make({})); + outgoing_ = + driver_->SendFrameAsync(FrameType::kHeaders, std::move(headers).GetBuffer()); + working_cv_.notify_all(); + completed_cv_.wait(guard, [this] { return outgoing_.is_finished(); }); + + writes_done_ = true; + auto status = outgoing_.status(); + outgoing_ = Future<>(); + RETURN_NOT_OK(status); + } + return Status::OK(); + } + + protected: + void JoinThread() { + try { + driver_thread_.join(); + } catch (const std::system_error&) { + // Ignore + } + } + // Flight's API allows concurrent reads/writes, but the UCX driver + // here is single-threaded, so push all UCX work onto a single + // worker thread + void DriveWorker() { + while (true) { + { + std::unique_lock guard(driver_mutex_); + working_cv_.wait(guard, + [this] { return incoming_.is_valid() || outgoing_.is_valid(); }); + } + + while (true) { + std::unique_lock guard(driver_mutex_); + if (!incoming_.is_valid() && !outgoing_.is_valid()) break; + if (incoming_.is_valid() && incoming_.is_finished()) { + if (!incoming_.status().ok()) { + io_status_ = incoming_.status(); + finished_ = true; + } else { + HandleIncomingMessage(*incoming_.result()); + } + incoming_ = Future>(); + completed_cv_.notify_all(); + break; + } + if (outgoing_.is_valid() && outgoing_.is_finished()) { + completed_cv_.notify_all(); + break; + } + driver_->MakeProgress(); + } + if (finished_) return; + } + } + + virtual void HandleIncomingMessage(const std::shared_ptr& frame) {} + + std::mutex driver_mutex_; + std::thread driver_thread_; + std::condition_variable completed_cv_; + std::condition_variable working_cv_; + Future> incoming_; + Future<> outgoing_; +}; + +class PutClientStream : public WriteClientStream { + public: + using WriteClientStream::WriteClientStream; + bool ReadPutMetadata(std::shared_ptr* out) override { + std::unique_lock guard(driver_mutex_); + if (finished_) { + *out = nullptr; + guard.unlock(); + JoinThread(); + return false; + } + next_metadata_ = nullptr; + incoming_ = driver_->ReadFrameAsync(); + working_cv_.notify_all(); + completed_cv_.wait(guard, [this] { return next_metadata_ != nullptr || finished_; }); + + if (finished_) { + *out = nullptr; + guard.unlock(); + JoinThread(); + return false; + } + *out = std::move(next_metadata_); + return true; + } + + private: + void HandleIncomingMessage(const std::shared_ptr& frame) override { + // No lock here, since this is called from DriveWorker() which is + // holding the lock + if (frame->type == FrameType::kBuffer) { + next_metadata_ = std::move(frame->buffer); + } else if (frame->type == FrameType::kHeaders) { + // Trailers, stream is over + finished_ = true; + HeadersFrame headers; + io_status_ = HeadersFrame::Parse(std::move(frame->buffer)).Value(&headers); + if (!io_status_.ok()) { + finished_ = true; + return; + } + io_status_ = headers.GetStatus(&server_status_); + if (!io_status_.ok()) { + finished_ = true; + return; + } + } else { + finished_ = true; + io_status_ = + Status::IOError("Unexpected frame type ", static_cast(frame->type)); + } + } + std::shared_ptr next_metadata_; +}; + +class ExchangeClientStream : public WriteClientStream { + public: + ExchangeClientStream(UcxClientImpl* impl, ClientConnection conn) + : WriteClientStream(impl, std::move(conn)), read_state_(ReadState::kFinished) {} + + bool ReadData(internal::FlightData* data) override { + std::unique_lock guard(driver_mutex_); + if (finished_) { + guard.unlock(); + JoinThread(); + return false; + } + + // Drive the read loop here. (We can't recursively call + // ReadFrameAsync below since the internal mutex is not + // recursive.) + read_state_ = ReadState::kExpectHeader; + incoming_ = driver_->ReadFrameAsync(); + working_cv_.notify_all(); + completed_cv_.wait(guard, [this] { return read_state_ != ReadState::kExpectHeader; }); + if (read_state_ != ReadState::kFinished) { + incoming_ = driver_->ReadFrameAsync(); + working_cv_.notify_all(); + completed_cv_.wait(guard, [this] { return read_state_ == ReadState::kFinished; }); + } + + if (finished_) { + guard.unlock(); + JoinThread(); + return false; + } + *data = std::move(next_data_); + return true; + } + + private: + enum class ReadState { + kFinished, + kExpectHeader, + kExpectBody, + }; + + std::string DebugExpectingString() { + switch (read_state_) { + case ReadState::kFinished: + return "(not expecting a frame)"; + case ReadState::kExpectHeader: + return "payload header frame"; + case ReadState::kExpectBody: + return "payload body frame"; + } + return "(unknown or invalid state)"; + } + + void HandleIncomingMessage(const std::shared_ptr& frame) override { + // No lock here, since this is called from MakeProgress() + // which is called under the lock already + if (frame->type == FrameType::kPayloadHeader) { + if (read_state_ != ReadState::kExpectHeader) { + finished_ = true; + io_status_ = Status::IOError("Got unexpected payload header frame, expected: ", + DebugExpectingString()); + return; + } + + PayloadHeaderFrame payload_header(std::move(frame->buffer)); + io_status_ = payload_header.ToFlightData(&next_data_); + if (!io_status_.ok()) { + finished_ = true; + return; + } + + if (next_data_.metadata) { + std::unique_ptr message; + io_status_ = ipc::Message::Open(next_data_.metadata, nullptr).Value(&message); + if (!io_status_.ok()) { + finished_ = true; + return; + } + if (ipc::Message::HasBody(message->type())) { + read_state_ = ReadState::kExpectBody; + return; + } + } + read_state_ = ReadState::kFinished; + } else if (frame->type == FrameType::kPayloadBody) { + next_data_.body = std::move(frame->buffer); + read_state_ = ReadState::kFinished; + } else if (frame->type == FrameType::kHeaders) { + // Trailers, stream is over + finished_ = true; + read_state_ = ReadState::kFinished; + HeadersFrame headers; + io_status_ = HeadersFrame::Parse(std::move(frame->buffer)).Value(&headers); + if (!io_status_.ok()) { + finished_ = true; + return; + } + io_status_ = headers.GetStatus(&server_status_); + if (!io_status_.ok()) { + finished_ = true; + return; + } + } else { + finished_ = true; + io_status_ = + Status::IOError("Unexpected frame type ", static_cast(frame->type)); + read_state_ = ReadState::kFinished; + } + } + + internal::FlightData next_data_; + ReadState read_state_; +}; + +class UcxClientImpl : public arrow::flight::internal::ClientTransport { + public: + UcxClientImpl() {} + + virtual ~UcxClientImpl() { + if (!ucp_context_) return; + auto status = Close(); + if (!status.ok()) { + ARROW_LOG(WARNING) << "UcxClientImpl errored in Close() in destructor: " + << status.ToString(); + } + } + + Status Init(const FlightClientOptions& options, const Location& location, + const arrow::internal::Uri& uri) override { + RETURN_NOT_OK(uri_.Parse(uri.ToString())); + { + ucp_config_t* ucp_config; + ucp_params_t ucp_params; + ucs_status_t status; + + status = ucp_config_read(nullptr, nullptr, &ucp_config); + RETURN_NOT_OK(FromUcsStatus("ucp_config_read", status)); + + // If location is IPv6, must adjust UCX config + // XXX: we assume locations always resolve to IPv6 or IPv4 but + // that is not necessarily true. + { + struct sockaddr_storage connect_addr; + RETURN_NOT_OK(UriToSockaddr(uri, &connect_addr)); + if (connect_addr.ss_family == AF_INET6) { + status = ucp_config_modify(ucp_config, "AF_PRIO", "inet6"); + RETURN_NOT_OK(FromUcsStatus("ucp_config_modify", status)); + } + } + + std::memset(&ucp_params, 0, sizeof(ucp_params)); + ucp_params.field_mask = UCP_PARAM_FIELD_FEATURES; + ucp_params.features = UCP_FEATURE_AM | UCP_FEATURE_WAKEUP; + + ucp_context_h ucp_context; + status = ucp_init(&ucp_params, ucp_config, &ucp_context); + ucp_config_release(ucp_config); + RETURN_NOT_OK(FromUcsStatus("ucp_init", status)); + ucp_context_.reset(new UcpContext(ucp_context)); + } + + RETURN_NOT_OK(MakeConnection()); + return Status::OK(); + } + + Status Close() override { + std::unique_lock connections_mutex_; + while (!connections_.empty()) { + RETURN_NOT_OK(connections_.front().Close()); + connections_.pop_front(); + } + return Status::OK(); + } + + Status GetFlightInfo(const FlightCallOptions& options, + const FlightDescriptor& descriptor, + std::unique_ptr* info) override { + ARROW_ASSIGN_OR_RAISE(auto connection, CheckoutConnection(options)); + UcpCallDriver* driver = connection.driver(); + + auto impl = [&]() { + RETURN_NOT_OK(driver->StartCall(kMethodGetFlightInfo)); + + ARROW_ASSIGN_OR_RAISE(std::string payload, descriptor.SerializeToString()); + + RETURN_NOT_OK(driver->SendFrame(FrameType::kBuffer, + reinterpret_cast(payload.data()), + static_cast(payload.size()))); + + ARROW_ASSIGN_OR_RAISE(auto incoming_message, driver->ReadNextFrame()); + if (incoming_message->type == FrameType::kBuffer) { + ARROW_ASSIGN_OR_RAISE( + *info, FlightInfo::Deserialize(util::string_view(*incoming_message->buffer))); + ARROW_ASSIGN_OR_RAISE(incoming_message, driver->ReadNextFrame()); + } + RETURN_NOT_OK(driver->ExpectFrameType(*incoming_message, FrameType::kHeaders)); + ARROW_ASSIGN_OR_RAISE(auto headers, + HeadersFrame::Parse(std::move(incoming_message->buffer))); + Status status; + RETURN_NOT_OK(headers.GetStatus(&status)); + return status; + }; + auto status = impl(); + return MergeStatuses(std::move(status), ReturnConnection(std::move(connection))); + } + + Status DoExchange(const FlightCallOptions& options, + std::unique_ptr* out) override { + ARROW_ASSIGN_OR_RAISE(auto connection, CheckoutConnection(options)); + UcpCallDriver* driver = connection.driver(); + + auto status = driver->StartCall(kMethodDoExchange); + if (ARROW_PREDICT_TRUE(status.ok())) { + *out = + arrow::internal::make_unique(this, std::move(connection)); + return Status::OK(); + } + return MergeStatuses(std::move(status), ReturnConnection(std::move(connection))); + } + + Status DoGet(const FlightCallOptions& options, const Ticket& ticket, + std::unique_ptr* stream) override { + ARROW_ASSIGN_OR_RAISE(auto connection, CheckoutConnection(options)); + UcpCallDriver* driver = connection.driver(); + + auto impl = [&]() { + RETURN_NOT_OK(driver->StartCall(kMethodDoGet)); + ARROW_ASSIGN_OR_RAISE(std::string payload, ticket.SerializeToString()); + RETURN_NOT_OK(driver->SendFrame(FrameType::kBuffer, + reinterpret_cast(payload.data()), + static_cast(payload.size()))); + *stream = + arrow::internal::make_unique(this, std::move(connection)); + return Status::OK(); + }; + + auto status = impl(); + if (ARROW_PREDICT_TRUE(status.ok())) return status; + return MergeStatuses(std::move(status), ReturnConnection(std::move(connection))); + } + + Status DoPut(const FlightCallOptions& options, + std::unique_ptr* out) override { + ARROW_ASSIGN_OR_RAISE(auto connection, CheckoutConnection(options)); + UcpCallDriver* driver = connection.driver(); + + auto status = driver->StartCall(kMethodDoPut); + if (ARROW_PREDICT_TRUE(status.ok())) { + *out = arrow::internal::make_unique(this, std::move(connection)); + return Status::OK(); + } + return MergeStatuses(std::move(status), ReturnConnection(std::move(connection))); + } + + Status DoAction(const FlightCallOptions& options, const Action& action, + std::unique_ptr* results) override { + // XXX: fake this for now to get the perf test to work + return Status::OK(); + } + + Status MakeConnection() { + ClientConnection conn; + RETURN_NOT_OK(conn.Init(ucp_context_, uri_)); + connections_.push_back(std::move(conn)); + return Status::OK(); + } + + arrow::Result CheckoutConnection(const FlightCallOptions& options) { + std::unique_lock connections_mutex_; + if (connections_.empty()) RETURN_NOT_OK(MakeConnection()); + ClientConnection conn = std::move(connections_.front()); + conn.driver()->set_memory_manager(options.memory_manager); + conn.driver()->set_read_memory_pool(options.read_options.memory_pool); + conn.driver()->set_write_memory_pool(options.write_options.memory_pool); + connections_.pop_front(); + return conn; + } + + Status ReturnConnection(ClientConnection conn) { + std::unique_lock connections_mutex_; + // TODO(ARROW-16127): for future improvement: reclaim clients + // asynchronously in the background (try to avoid issues like + // constantly opening/closing clients because the application is + // just barely over the limit of open connections) + if (connections_.size() >= kMaxOpenConnections) { + RETURN_NOT_OK(conn.Close()); + return Status::OK(); + } + connections_.push_back(std::move(conn)); + return Status::OK(); + } + + private: + static constexpr size_t kMaxOpenConnections = 3; + + arrow::internal::Uri uri_; + std::shared_ptr ucp_context_; + std::mutex connections_mutex_; + std::deque connections_; +}; + +Status UcxClientStream::DoFinish() { + RETURN_NOT_OK(WritesDone()); + if (!finished_) { + internal::FlightData message; + std::shared_ptr metadata; + while (ReadData(&message)) { + } + while (ReadPutMetadata(&metadata)) { + } + finished_ = true; + } + if (impl_) { + auto status = impl_->ReturnConnection(std::move(conn_)); + impl_ = nullptr; + driver_ = nullptr; + if (!status.ok()) { + if (io_status_.ok()) { + io_status_ = std::move(status); + } else { + io_status_ = Status::FromDetailAndArgs( + io_status_.code(), io_status_.detail(), io_status_.message(), + ". Transport context: ", status.ToString()); + } + } + } + return MergeStatuses(server_status_, io_status_); +} +} // namespace + +std::unique_ptr MakeUcxClientImpl() { + return arrow::internal::make_unique(); +} + +} // namespace ucx +} // namespace transport +} // namespace flight +} // namespace arrow diff --git a/cpp/src/arrow/flight/transport/ucx/ucx_internal.cc b/cpp/src/arrow/flight/transport/ucx/ucx_internal.cc new file mode 100644 index 00000000000..ab4cc323f4c --- /dev/null +++ b/cpp/src/arrow/flight/transport/ucx/ucx_internal.cc @@ -0,0 +1,1171 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#include "arrow/flight/transport/ucx/ucx_internal.h" + +#include +#include +#include +#include + +#include "arrow/buffer.h" +#include "arrow/flight/transport/ucx/util_internal.h" +#include "arrow/flight/types.h" +#include "arrow/util/base64.h" +#include "arrow/util/bit_util.h" +#include "arrow/util/logging.h" +#include "arrow/util/make_unique.h" +#include "arrow/util/uri.h" + +namespace arrow { +namespace flight { +namespace transport { +namespace ucx { + +// Defines to test different implementation strategies +// Enable the CONTIG path for CPU-only data +// #define ARROW_FLIGHT_UCX_SEND_CONTIG +// Enable ucp_mem_map in IOV path +// #define ARROW_FLIGHT_UCX_SEND_IOV_MAP + +constexpr char kHeaderMethod[] = ":method:"; + +namespace { +Status SizeToUInt32BytesBe(const int64_t in, uint8_t* out) { + if (ARROW_PREDICT_FALSE(in < 0)) { + return Status::Invalid("Length cannot be negative"); + } else if (ARROW_PREDICT_FALSE( + in > static_cast(std::numeric_limits::max()))) { + return Status::Invalid("Length cannot exceed uint32_t"); + } + UInt32ToBytesBe(static_cast(in), out); + return Status::OK(); +} +ucs_memory_type InferMemoryType(const Buffer& buffer) { + if (!buffer.is_cpu()) { + return UCS_MEMORY_TYPE_CUDA; + } + return UCS_MEMORY_TYPE_UNKNOWN; +} +void TryMapBuffer(ucp_context_h context, const void* buffer, const size_t size, + ucs_memory_type memory_type, ucp_mem_h* memh_p) { + ucp_mem_map_params_t map_param; + std::memset(&map_param, 0, sizeof(map_param)); + map_param.field_mask = UCP_MEM_MAP_PARAM_FIELD_ADDRESS | + UCP_MEM_MAP_PARAM_FIELD_LENGTH | + UCP_MEM_MAP_PARAM_FIELD_MEMORY_TYPE; + map_param.address = const_cast(buffer); + map_param.length = size; + map_param.memory_type = memory_type; + auto ucs_status = ucp_mem_map(context, &map_param, memh_p); + if (ucs_status != UCS_OK) { + *memh_p = nullptr; + ARROW_LOG(WARNING) << "Could not map memory: " + << FromUcsStatus("ucp_mem_map", ucs_status); + } +} +void TryMapBuffer(ucp_context_h context, const Buffer& buffer, ucp_mem_h* memh_p) { + TryMapBuffer(context, reinterpret_cast(buffer.address()), + static_cast(buffer.size()), InferMemoryType(buffer), memh_p); +} +void TryUnmapBuffer(ucp_context_h context, ucp_mem_h memh_p) { + if (memh_p) { + auto ucs_status = ucp_mem_unmap(context, memh_p); + if (ucs_status != UCS_OK) { + ARROW_LOG(WARNING) << "Could not unmap memory: " + << FromUcsStatus("ucp_mem_unmap", ucs_status); + } + } +} + +/// \brief Wrapper around a UCX zero copy buffer (a host memory DATA +/// buffer). +/// +/// Owns a reference to the associated worker to avoid undefined +/// behavior. +class UcxDataBuffer : public Buffer { + public: + explicit UcxDataBuffer(std::shared_ptr worker, void* data, size_t size) + : Buffer(reinterpret_cast(data), static_cast(size)), + worker_(std::move(worker)) {} + + ~UcxDataBuffer() { + ucp_am_data_release(worker_->get(), + const_cast(reinterpret_cast(data()))); + } + + private: + std::shared_ptr worker_; +}; +}; // namespace + +constexpr size_t FrameHeader::kFrameHeaderBytes; +constexpr uint8_t FrameHeader::kFrameVersion; + +Status FrameHeader::Set(FrameType frame_type, uint32_t counter, int64_t body_size) { + header[0] = kFrameVersion; + header[1] = static_cast(frame_type); + UInt32ToBytesBe(counter, header.data() + 4); + RETURN_NOT_OK(SizeToUInt32BytesBe(body_size, header.data() + 8)); + return Status::OK(); +} + +arrow::Result> Frame::ParseHeader(const void* header, + size_t header_length) { + if (header_length < FrameHeader::kFrameHeaderBytes) { + return Status::IOError("Header is too short, must be at least ", + FrameHeader::kFrameHeaderBytes, " bytes, got ", header_length); + } + + const uint8_t* frame_header = reinterpret_cast(header); + if (frame_header[0] != FrameHeader::kFrameVersion) { + return Status::IOError("Expected frame version ", + static_cast(FrameHeader::kFrameVersion), " but got ", + static_cast(frame_header[0])); + } else if (frame_header[1] > static_cast(FrameType::kMaxFrameType)) { + return Status::IOError("Unknown frame type ", static_cast(frame_header[1])); + } + + const FrameType frame_type = static_cast(frame_header[1]); + const uint32_t frame_counter = BytesToUInt32Be(frame_header + 4); + const uint32_t frame_size = BytesToUInt32Be(frame_header + 8); + + if (frame_type == FrameType::kDisconnect) { + return Status::Cancelled("Client initiated disconnect"); + } + + return std::make_shared(frame_type, frame_size, frame_counter, nullptr); +} + +arrow::Result HeadersFrame::Parse(std::unique_ptr buffer) { + HeadersFrame result; + const uint8_t* payload = buffer->data(); + const uint8_t* end = payload + buffer->size(); + if (ARROW_PREDICT_FALSE((end - payload) < 4)) { + return Status::Invalid("Buffer underflow, expected number of headers"); + } + const uint32_t num_headers = BytesToUInt32Be(payload); + payload += 4; + for (uint32_t i = 0; i < num_headers; i++) { + if (ARROW_PREDICT_FALSE((end - payload) < 4)) { + return Status::Invalid("Buffer underflow, expected length of key ", i + 1); + } + const uint32_t key_length = BytesToUInt32Be(payload); + payload += 4; + + if (ARROW_PREDICT_FALSE((end - payload) < 4)) { + return Status::Invalid("Buffer underflow, expected length of value ", i + 1); + } + const uint32_t value_length = BytesToUInt32Be(payload); + payload += 4; + + if (ARROW_PREDICT_FALSE((end - payload) < key_length)) { + return Status::Invalid("Buffer underflow, expected key ", i + 1, " to have length ", + key_length, ", but only ", (end - payload), " bytes remain"); + } + const util::string_view key(reinterpret_cast(payload), key_length); + payload += key_length; + + if (ARROW_PREDICT_FALSE((end - payload) < value_length)) { + return Status::Invalid("Buffer underflow, expected value ", i + 1, + " to have length ", value_length, ", but only ", + (end - payload), " bytes remain"); + } + const util::string_view value(reinterpret_cast(payload), value_length); + payload += value_length; + result.headers_.emplace_back(key, value); + } + + result.buffer_ = std::move(buffer); + return result; +} +arrow::Result HeadersFrame::Make( + const std::vector>& headers) { + int32_t total_length = 4 /* # of headers */; + for (const auto& header : headers) { + total_length += 4 /* key length */ + 4 /* value length */ + + header.first.size() /* key */ + header.second.size(); + } + + ARROW_ASSIGN_OR_RAISE(auto buffer, AllocateBuffer(total_length)); + uint8_t* payload = buffer->mutable_data(); + + RETURN_NOT_OK(SizeToUInt32BytesBe(headers.size(), payload)); + payload += 4; + for (const auto& header : headers) { + RETURN_NOT_OK(SizeToUInt32BytesBe(header.first.size(), payload)); + payload += 4; + RETURN_NOT_OK(SizeToUInt32BytesBe(header.second.size(), payload)); + payload += 4; + std::memcpy(payload, header.first.data(), header.first.size()); + payload += header.first.size(); + std::memcpy(payload, header.second.data(), header.second.size()); + payload += header.second.size(); + } + return Parse(std::move(buffer)); +} +arrow::Result HeadersFrame::Make( + const Status& status, + const std::vector>& headers) { + auto all_headers = headers; + all_headers.emplace_back(kHeaderStatusCode, + std::to_string(static_cast(status.code()))); + all_headers.emplace_back(kHeaderStatusMessage, status.message()); + if (status.detail()) { + auto fsd = FlightStatusDetail::UnwrapStatus(status); + if (fsd) { + all_headers.emplace_back(kHeaderStatusDetailCode, + std::to_string(static_cast(fsd->code()))); + all_headers.emplace_back(kHeaderStatusDetail, fsd->extra_info()); + } else { + all_headers.emplace_back(kHeaderStatusDetail, status.detail()->ToString()); + } + } + return Make(all_headers); +} + +arrow::Result HeadersFrame::Get(const std::string& key) { + for (const auto& pair : headers_) { + if (pair.first == key) return pair.second; + } + return Status::KeyError(key); +} + +Status HeadersFrame::GetStatus(Status* out) { + util::string_view code_str, message_str; + auto status = Get(kHeaderStatusCode).Value(&code_str); + if (!status.ok()) { + return Status::KeyError("Server did not send status code header ", kHeaderStatusCode); + } + + StatusCode status_code = StatusCode::OK; + auto code = std::strtol(code_str.data(), nullptr, /*base=*/10); + switch (code) { + case 0: + status_code = StatusCode::OK; + break; + case 1: + status_code = StatusCode::OutOfMemory; + break; + case 2: + status_code = StatusCode::KeyError; + break; + case 3: + status_code = StatusCode::TypeError; + break; + case 4: + status_code = StatusCode::Invalid; + break; + case 5: + status_code = StatusCode::IOError; + break; + case 6: + status_code = StatusCode::CapacityError; + break; + case 7: + status_code = StatusCode::IndexError; + break; + case 8: + status_code = StatusCode::Cancelled; + break; + case 9: + status_code = StatusCode::UnknownError; + break; + case 10: + status_code = StatusCode::NotImplemented; + break; + case 11: + status_code = StatusCode::SerializationError; + break; + case 13: + status_code = StatusCode::RError; + break; + case 40: + status_code = StatusCode::CodeGenError; + break; + case 41: + status_code = StatusCode::ExpressionValidationError; + break; + case 42: + status_code = StatusCode::ExecutionError; + break; + case 45: + status_code = StatusCode::AlreadyExists; + break; + default: + status_code = StatusCode::UnknownError; + break; + } + if (status_code == StatusCode::OK) { + *out = Status::OK(); + return Status::OK(); + } + + status = Get(kHeaderStatusMessage).Value(&message_str); + if (!status.ok()) { + *out = Status(status_code, "Server did not send status message header", nullptr); + return Status::OK(); + } + + util::string_view detail_code_str, detail_str; + FlightStatusCode detail_code = FlightStatusCode::Internal; + + if (Get(kHeaderStatusDetailCode).Value(&detail_code_str).ok()) { + auto detail_code_int = std::strtol(detail_code_str.data(), nullptr, /*base=*/10); + switch (detail_code_int) { + case 1: + detail_code = FlightStatusCode::TimedOut; + break; + case 2: + detail_code = FlightStatusCode::Cancelled; + break; + case 3: + detail_code = FlightStatusCode::Unauthenticated; + break; + case 4: + detail_code = FlightStatusCode::Unauthorized; + break; + case 5: + detail_code = FlightStatusCode::Unavailable; + break; + case 6: + detail_code = FlightStatusCode::Failed; + break; + case 0: + default: + detail_code = FlightStatusCode::Internal; + break; + } + } + ARROW_UNUSED(Get(kHeaderStatusDetail).Value(&detail_str)); + + std::shared_ptr detail = nullptr; + if (!detail_str.empty()) { + detail = std::make_shared(detail_code, std::string(detail_str)); + } + *out = Status(status_code, std::string(message_str), std::move(detail)); + return Status::OK(); +} + +namespace { +static constexpr uint32_t kMissingFieldSentinel = std::numeric_limits::max(); +static constexpr uint32_t kInt32Max = + static_cast(std::numeric_limits::max()); +arrow::Result PayloadHeaderFieldSize(const std::string& field, + const std::shared_ptr& data, + uint32_t* total_size) { + if (!data) return kMissingFieldSentinel; + if (data->size() > kInt32Max) { + return Status::Invalid(field, " must be less than 2 GiB, was: ", data->size()); + } + *total_size += static_cast(data->size()); + // Check for underflow + if (*total_size < 0) return Status::Invalid("Payload header must fit in a uint32_t"); + return static_cast(data->size()); +} +uint8_t* PackField(uint32_t size, const std::shared_ptr& data, uint8_t* out) { + UInt32ToBytesBe(size, out); + if (size != kMissingFieldSentinel) { + std::memcpy(out + 4, data->data(), size); + return out + 4 + size; + } else { + return out + 4; + } +} +} // namespace + +arrow::Result PayloadHeaderFrame::Make(const FlightPayload& payload, + MemoryPool* memory_pool) { + // Assemble all non-data fields here. Presumably this is much less + // than data size so we will pay the copy. + + // Structure per field: [4 byte length][data]. If a field is not + // present, UINT32_MAX is used as the sentinel (since 0-sized fields + // are acceptable) + uint32_t header_size = 12; + ARROW_ASSIGN_OR_RAISE( + const uint32_t descriptor_size, + PayloadHeaderFieldSize("descriptor", payload.descriptor, &header_size)); + ARROW_ASSIGN_OR_RAISE( + const uint32_t app_metadata_size, + PayloadHeaderFieldSize("app_metadata", payload.app_metadata, &header_size)); + ARROW_ASSIGN_OR_RAISE( + const uint32_t ipc_metadata_size, + PayloadHeaderFieldSize("ipc_message.metadata", payload.ipc_message.metadata, + &header_size)); + + ARROW_ASSIGN_OR_RAISE(auto header_buffer, AllocateBuffer(header_size, memory_pool)); + uint8_t* payload_header = header_buffer->mutable_data(); + + payload_header = PackField(descriptor_size, payload.descriptor, payload_header); + payload_header = PackField(app_metadata_size, payload.app_metadata, payload_header); + payload_header = + PackField(ipc_metadata_size, payload.ipc_message.metadata, payload_header); + + return PayloadHeaderFrame(std::move(header_buffer)); +} +Status PayloadHeaderFrame::ToFlightData(internal::FlightData* data) { + std::shared_ptr buffer = std::move(buffer_); + + // Unpack the descriptor + uint32_t offset = 0; + uint32_t size = BytesToUInt32Be(buffer->data()); + offset += 4; + if (size != kMissingFieldSentinel) { + if (static_cast(offset + size) > buffer->size()) { + return Status::Invalid("Buffer is too small: expected ", offset + size, + " bytes but have ", buffer->size()); + } + util::string_view desc(reinterpret_cast(buffer->data() + offset), size); + data->descriptor.reset(new FlightDescriptor()); + ARROW_ASSIGN_OR_RAISE(*data->descriptor, FlightDescriptor::Deserialize(desc)); + offset += size; + } else { + data->descriptor = nullptr; + } + + // Unpack app_metadata + size = BytesToUInt32Be(buffer->data() + offset); + offset += 4; + // While we properly handle zero-size vs nullptr metadata here, gRPC + // doesn't (Protobuf doesn't differentiate between the two) + if (size != kMissingFieldSentinel) { + if (static_cast(offset + size) > buffer->size()) { + return Status::Invalid("Buffer is too small: expected ", offset + size, + " bytes but have ", buffer->size()); + } + data->app_metadata = SliceBuffer(buffer, offset, size); + offset += size; + } else { + data->app_metadata = nullptr; + } + + // Unpack the IPC header + size = BytesToUInt32Be(buffer->data() + offset); + offset += 4; + if (size != kMissingFieldSentinel) { + if (static_cast(offset + size) > buffer->size()) { + return Status::Invalid("Buffer is too small: expected ", offset + size, + " bytes but have ", buffer->size()); + } + data->metadata = SliceBuffer(std::move(buffer), offset, size); + } else { + data->metadata = nullptr; + } + data->body = nullptr; + return Status::OK(); +} + +// pImpl the driver since async methods require a stable address +class UcpCallDriver::Impl { + public: +#if defined(ARROW_FLIGHT_UCX_SEND_CONTIG) + constexpr static bool kEnableContigSend = true; +#else + constexpr static bool kEnableContigSend = false; +#endif + + Impl(std::shared_ptr worker, ucp_ep_h endpoint) + : padding_bytes_({0, 0, 0, 0, 0, 0, 0, 0}), + worker_(std::move(worker)), + endpoint_(endpoint), + read_memory_pool_(default_memory_pool()), + write_memory_pool_(default_memory_pool()), + memory_manager_(CPUDevice::Instance()->default_memory_manager()), + name_("(unknown remote)"), + counter_(0) { +#if defined(ARROW_FLIGHT_UCX_SEND_IOV_MAP) + TryMapBuffer(worker_->context().get(), padding_bytes_.data(), padding_bytes_.size(), + UCS_MEMORY_TYPE_HOST, &padding_memh_p_); +#endif + + ucp_ep_attr_t attrs; + std::memset(&attrs, 0, sizeof(attrs)); + attrs.field_mask = + UCP_EP_ATTR_FIELD_LOCAL_SOCKADDR | UCP_EP_ATTR_FIELD_REMOTE_SOCKADDR; + if (ucp_ep_query(endpoint_, &attrs) == UCS_OK) { + std::string local_addr, remote_addr; + ARROW_UNUSED(SockaddrToString(attrs.local_sockaddr).Value(&local_addr)); + ARROW_UNUSED(SockaddrToString(attrs.remote_sockaddr).Value(&remote_addr)); + name_ = "local:" + local_addr + ";remote:" + remote_addr; + } + } + + ~Impl() { +#if defined(ARROW_FLIGHT_UCX_SEND_IOV_MAP) + TryUnmapBuffer(worker_->context().get(), padding_memh_p_); +#endif + } + + arrow::Result> ReadNextFrame() { + auto fut = ReadFrameAsync(); + while (!fut.is_finished()) MakeProgress(); + RETURN_NOT_OK(fut.status()); + return fut.MoveResult(); + } + + Future> ReadFrameAsync() { + RETURN_NOT_OK(CheckClosed()); + + std::unique_lock guard(frame_mutex_); + if (ARROW_PREDICT_FALSE(!status_.ok())) return status_; + + // Expected value of "counter" field in the frame header + const uint32_t counter_value = next_counter_++; + auto it = frames_.find(counter_value); + if (it != frames_.end()) { + // Message already delivered, return it + Future> fut = it->second; + frames_.erase(it); + return fut; + } + // Message not yet delivered, insert a future and wait + auto pair = frames_.insert({counter_value, Future>::Make()}); + DCHECK(pair.second); + return pair.first->second; + } + + Status SendFrame(FrameType frame_type, const uint8_t* data, const int64_t size) { + RETURN_NOT_OK(CheckClosed()); + + void* request = nullptr; + ucp_request_param_t request_param; + request_param.op_attr_mask = UCP_OP_ATTR_FIELD_FLAGS; + request_param.flags = UCP_AM_SEND_FLAG_REPLY; + + // Send frame header + FrameHeader header; + RETURN_NOT_OK(header.Set(frame_type, counter_++, size)); + if (size == 0) { + // UCX appears to crash on zero-byte payloads + request = ucp_am_send_nbx(endpoint_, kUcpAmHandlerId, header.data(), header.size(), + padding_bytes_.data(), + /*size=*/1, &request_param); + } else { + request = ucp_am_send_nbx(endpoint_, kUcpAmHandlerId, header.data(), header.size(), + data, size, &request_param); + } + RETURN_NOT_OK(CompleteRequestBlocking("ucp_am_send_nbx", request)); + + return Status::OK(); + } + + Future<> SendFrameAsync(FrameType frame_type, std::unique_ptr buffer) { + RETURN_NOT_OK(CheckClosed()); + + ucp_request_param_t request_param; + std::memset(&request_param, 0, sizeof(request_param)); + request_param.op_attr_mask = UCP_OP_ATTR_FIELD_CALLBACK | UCP_OP_ATTR_FIELD_DATATYPE | + UCP_OP_ATTR_FIELD_FLAGS | UCP_OP_ATTR_FIELD_USER_DATA; + request_param.cb.send = AmSendCallback; + request_param.datatype = ucp_dt_make_contig(1); + request_param.flags = UCP_AM_SEND_FLAG_REPLY; + + const int64_t size = buffer->size(); + if (size == 0) { + // UCX appears to crash on zero-byte payloads + ARROW_ASSIGN_OR_RAISE(auto buffer, AllocateBuffer(1, write_memory_pool_)); + } + + std::unique_ptr pending_send(new PendingContigSend()); + RETURN_NOT_OK(pending_send->header.Set(frame_type, counter_++, size)); + pending_send->ipc_message = std::move(buffer); + pending_send->driver = this; + pending_send->completed = Future<>::Make(); + pending_send->memh_p = nullptr; + + request_param.user_data = pending_send.release(); + { + auto* pending_send = reinterpret_cast(request_param.user_data); + + void* request = ucp_am_send_nbx( + endpoint_, kUcpAmHandlerId, pending_send->header.data(), + pending_send->header.size(), + reinterpret_cast(pending_send->ipc_message->mutable_data()), + static_cast(pending_send->ipc_message->size()), &request_param); + if (!request) { + // Request completed immediately + delete pending_send; + return Status::OK(); + } else if (UCS_PTR_IS_ERR(request)) { + delete pending_send; + return FromUcsStatus("ucp_am_send_nbx", UCS_PTR_STATUS(request)); + } + return pending_send->completed; + } + } + + Future<> SendFlightPayload(const FlightPayload& payload) { + static const int64_t kMaxBatchSize = std::numeric_limits::max(); + RETURN_NOT_OK(CheckClosed()); + + if (payload.ipc_message.body_length > kMaxBatchSize) { + return Status::Invalid("Cannot send record batches exceeding 2GiB yet"); + } + + { + ARROW_ASSIGN_OR_RAISE(auto frame, + PayloadHeaderFrame::Make(payload, write_memory_pool_)); + RETURN_NOT_OK(SendFrame(FrameType::kPayloadHeader, frame.data(), frame.size())); + } + + if (!ipc::Message::HasBody(payload.ipc_message.type)) { + return Status::OK(); + } + + // While IOV (scatter-gather) might seem like it avoids a memcpy, + // profiling shows that at least for the TCP/SHM/RDMA transports, + // UCX just does a memcpy internally. Furthermore, on the receiver + // side, a sender-side IOV send prevents optimizations based on + // mapped buffers (UCX will memcpy to the destination buffer + // regardless of whether it's mapped or not). + + // If all buffers are on the CPU, concatenate them ourselves and + // do a regular send to avoid this. Else, use IOV and let UCX + // figure out what to do. + + // Weirdness: UCX prefers TCP over shared memory for CONTIG? We + // can avoid this by setting UCX_RNDV_THRESH=inf, this will make + // UCX prefer shared memory again. However, we still want to avoid + // the CONTIG path when shared memory is available, because the + // total amount of time spent in memcpy is greater than using IOV + // and letting UCX handle it. + + // Consider: if we can figure out how to make IOV always as fast + // as CONTIG, we can just send the metadata fields as part of the + // IOV payload and avoid having to send two distinct messages. + + bool all_cpu = true; + int32_t total_buffers = 0; + for (const auto& buffer : payload.ipc_message.body_buffers) { + if (!buffer || buffer->size() == 0) continue; + all_cpu = all_cpu && buffer->is_cpu(); + total_buffers++; + + // Arrow IPC requires that we align buffers to 8 byte boundary + const auto remainder = static_cast( + bit_util::RoundUpToMultipleOf8(buffer->size()) - buffer->size()); + if (remainder) total_buffers++; + } + + ucp_request_param_t request_param; + std::memset(&request_param, 0, sizeof(request_param)); + request_param.op_attr_mask = UCP_OP_ATTR_FIELD_CALLBACK | UCP_OP_ATTR_FIELD_DATATYPE | + UCP_OP_ATTR_FIELD_FLAGS | UCP_OP_ATTR_FIELD_USER_DATA; + request_param.cb.send = AmSendCallback; + request_param.flags = UCP_AM_SEND_FLAG_REPLY; + + std::unique_ptr pending_send; + void* send_data = nullptr; + size_t send_size = 0; + + if (!all_cpu) { + request_param.op_attr_mask = + request_param.op_attr_mask | UCP_OP_ATTR_FIELD_MEMORY_TYPE; + // XXX: UCX doesn't appear to autodetect this correctly if we + // use UNKNOWN + request_param.memory_type = UCS_MEMORY_TYPE_CUDA; + } + + if (kEnableContigSend && all_cpu) { + // CONTIG - concatenate buffers into one before sending + + // TODO(ARROW-16126): this needs to be pipelined since it can be expensive. + // Preliminary profiling shows ~5% overhead just from mapping the buffer + // alone (on Infiniband; it seems to be trivial for shared memory) + request_param.datatype = ucp_dt_make_contig(1); + pending_send = arrow::internal::make_unique(); + auto* pending_contig = reinterpret_cast(pending_send.get()); + + const int64_t body_length = std::max(payload.ipc_message.body_length, 1); + ARROW_ASSIGN_OR_RAISE(pending_contig->ipc_message, + AllocateBuffer(body_length, write_memory_pool_)); + TryMapBuffer(worker_->context().get(), *pending_contig->ipc_message, + &pending_contig->memh_p); + + uint8_t* ipc_message = pending_contig->ipc_message->mutable_data(); + if (payload.ipc_message.body_length == 0) { + std::memset(ipc_message, '\0', 1); + } + + for (const auto& buffer : payload.ipc_message.body_buffers) { + if (!buffer || buffer->size() == 0) continue; + + std::memcpy(ipc_message, buffer->data(), buffer->size()); + ipc_message += buffer->size(); + + const auto remainder = static_cast( + bit_util::RoundUpToMultipleOf8(buffer->size()) - buffer->size()); + if (remainder) { + std::memset(ipc_message, 0, remainder); + ipc_message += remainder; + } + } + + send_data = reinterpret_cast(pending_contig->ipc_message->mutable_data()); + send_size = static_cast(pending_contig->ipc_message->size()); + } else { + // IOV - let UCX use scatter-gather path + request_param.datatype = UCP_DATATYPE_IOV; + pending_send = arrow::internal::make_unique(); + auto* pending_iov = reinterpret_cast(pending_send.get()); + + pending_iov->payload = payload; + pending_iov->iovs.resize(total_buffers); + ucp_dt_iov_t* iov = pending_iov->iovs.data(); +#if defined(ARROW_FLIGHT_UCX_SEND_IOV_MAP) + // XXX: this seems to have no benefits in tests so far + pending_iov->memh_ps.resize(total_buffers); + ucp_mem_h* memh_p = pending_iov->memh_ps.data(); +#endif + for (const auto& buffer : payload.ipc_message.body_buffers) { + if (!buffer || buffer->size() == 0) continue; + + iov->buffer = const_cast(reinterpret_cast(buffer->address())); + iov->length = buffer->size(); + ++iov; + +#if defined(ARROW_FLIGHT_UCX_SEND_IOV_MAP) + TryMapBuffer(worker_->context().get(), *buffer, memh_p); + memh_p++; +#endif + + const auto remainder = static_cast( + bit_util::RoundUpToMultipleOf8(buffer->size()) - buffer->size()); + if (remainder) { + iov->buffer = + const_cast(reinterpret_cast(padding_bytes_.data())); + iov->length = remainder; + ++iov; + } + } + + if (total_buffers == 0) { + // UCX cannot handle zero-byte payloads + pending_iov->iovs.resize(1); + pending_iov->iovs[0].buffer = + const_cast(reinterpret_cast(padding_bytes_.data())); + pending_iov->iovs[0].length = 1; + } + + send_data = pending_iov->iovs.data(); + send_size = pending_iov->iovs.size(); + } + + DCHECK(send_data) << "Payload cannot be nullptr"; + DCHECK_GT(send_size, 0) << "Payload cannot be empty"; + + RETURN_NOT_OK(pending_send->header.Set(FrameType::kPayloadBody, counter_++, + payload.ipc_message.body_length)); + pending_send->driver = this; + pending_send->completed = Future<>::Make(); + + request_param.user_data = pending_send.release(); + { + auto* pending_send = reinterpret_cast(request_param.user_data); + + void* request = ucp_am_send_nbx( + endpoint_, kUcpAmHandlerId, pending_send->header.data(), + pending_send->header.size(), send_data, send_size, &request_param); + if (!request) { + // Request completed immediately + delete pending_send; + return Status::OK(); + } else if (UCS_PTR_IS_ERR(request)) { + delete pending_send; + return FromUcsStatus("ucp_am_send_nbx", UCS_PTR_STATUS(request)); + } + return pending_send->completed; + } + } + + Status Close() { + if (!endpoint_) return Status::OK(); + + for (auto& item : frames_) { + item.second.MarkFinished(Status::Cancelled("UcpCallDriver is being closed")); + } + frames_.clear(); + + void* request = ucp_ep_close_nb(endpoint_, UCP_EP_CLOSE_MODE_FLUSH); + ucs_status_t status = UCS_OK; + std::string origin = "ucp_ep_close_nb"; + if (UCS_PTR_IS_ERR(request)) { + status = UCS_PTR_STATUS(request); + } else if (UCS_PTR_IS_PTR(request)) { + origin = "ucp_request_check_status"; + while ((status = ucp_request_check_status(request)) == UCS_INPROGRESS) { + MakeProgress(); + } + ucp_request_free(request); + } else { + DCHECK(!request); + } + + endpoint_ = nullptr; + if (status != UCS_OK && !IsIgnorableDisconnectError(status)) { + return FromUcsStatus(origin, status); + } + return Status::OK(); + } + + void MakeProgress() { ucp_worker_progress(worker_->get()); } + + void Push(std::shared_ptr frame) { + std::unique_lock guard(frame_mutex_); + if (ARROW_PREDICT_FALSE(!status_.ok())) return; + auto pair = frames_.insert({frame->counter, frame}); + if (!pair.second) { + // Not inserted, because ReadFrameAsync was called for this + // frame counter value and the client is already waiting on + // it. Complete the existing future. + pair.first->second.MarkFinished(std::move(frame)); + frames_.erase(pair.first); + } + // Otherwise, we inserted the frame, meaning the client was not + // currently waiting for that frame counter value + } + + void Push(Status status) { + std::unique_lock guard(frame_mutex_); + status_ = std::move(status); + for (auto& item : frames_) { + item.second.MarkFinished(status_); + } + frames_.clear(); + } + + ucs_status_t RecvActiveMessage(const void* header, size_t header_length, void* data, + const size_t data_length, + const ucp_am_recv_param_t* param) { + auto maybe_status = + RecvActiveMessageImpl(header, header_length, data, data_length, param); + if (!maybe_status.ok()) { + Push(maybe_status.status()); + return UCS_OK; + } + return maybe_status.MoveValueUnsafe(); + } + + const std::shared_ptr& memory_manager() const { return memory_manager_; } + void set_memory_manager(std::shared_ptr memory_manager) { + if (memory_manager) { + memory_manager_ = std::move(memory_manager); + } else { + memory_manager_ = CPUDevice::Instance()->default_memory_manager(); + } + } + void set_read_memory_pool(MemoryPool* pool) { + read_memory_pool_ = pool ? pool : default_memory_pool(); + } + void set_write_memory_pool(MemoryPool* pool) { + write_memory_pool_ = pool ? pool : default_memory_pool(); + } + const std::string& peer() const { return name_; } + + private: + class PendingAmSend { + public: + virtual ~PendingAmSend() = default; + UcpCallDriver::Impl* driver; + Future<> completed; + FrameHeader header; + }; + + class PendingContigSend : public PendingAmSend { + public: + std::unique_ptr ipc_message; + ucp_mem_h memh_p; + + virtual ~PendingContigSend() { + TryUnmapBuffer(driver->worker_->context().get(), memh_p); + } + }; + + class PendingIovSend : public PendingAmSend { + public: + FlightPayload payload; + std::vector iovs; +#if defined(ARROW_FLIGHT_UCX_SEND_IOV_MAP) + std::vector memh_ps; + + virtual ~PendingIovSend() { + for (ucp_mem_h memh_p : memh_ps) { + TryUnmapBuffer(driver->worker_->context().get(), memh_p); + } + } +#endif + }; + + struct PendingAmRecv { + UcpCallDriver::Impl* driver; + std::shared_ptr frame; + ucp_mem_h memh_p; + + PendingAmRecv(UcpCallDriver::Impl* driver_, std::shared_ptr frame_) + : driver(driver_), frame(std::move(frame_)) {} + + ~PendingAmRecv() { TryUnmapBuffer(driver->worker_->context().get(), memh_p); } + }; + + static void AmSendCallback(void* request, ucs_status_t status, void* user_data) { + auto* pending_send = reinterpret_cast(user_data); + if (status == UCS_OK) { + pending_send->completed.MarkFinished(); + } else { + pending_send->completed.MarkFinished(FromUcsStatus("ucp_am_send_nbx", status)); + } + // TODO(ARROW-16126): delete should occur on a background thread if there's + // mapped buffers, since unmapping can be nontrivial and we don't want to block + // the thread doing UCX work. (Borrow the Rust transfer-and-drop pattern.) + delete pending_send; + ucp_request_free(request); + } + + static void AmRecvCallback(void* request, ucs_status_t status, size_t length, + void* user_data) { + auto* pending_recv = reinterpret_cast(user_data); + ucp_request_free(request); + if (status != UCS_OK) { + pending_recv->driver->Push( + FromUcsStatus("ucp_am_recv_data_nbx (callback)", status)); + } else { + pending_recv->driver->Push(std::move(pending_recv->frame)); + } + delete pending_recv; + } + + arrow::Result RecvActiveMessageImpl(const void* header, + size_t header_length, void* data, + const size_t data_length, + const ucp_am_recv_param_t* param) { + DCHECK(param->recv_attr & UCP_AM_RECV_ATTR_FIELD_REPLY_EP); + + if (data_length > static_cast(std::numeric_limits::max())) { + return Status::Invalid("Cannot allocate buffer greater than 2 GiB, requested: ", + data_length); + } + + ARROW_ASSIGN_OR_RAISE(auto frame, Frame::ParseHeader(header, header_length)); + if (data_length < frame->size) { + return Status::IOError("Expected frame of ", frame->size, " bytes, but got only ", + data_length); + } + + if ((param->recv_attr & UCP_AM_RECV_ATTR_FLAG_DATA) && + (memory_manager_->is_cpu() || frame->type != FrameType::kPayloadBody)) { + // Zero-copy path. UCX-allocated buffer must be freed later. + + // XXX: this buffer can NOT be freed until AFTER we return from + // this handler. Otherwise, UCX won't have fully set up its + // internal data structures (allocated just before the buffer) + // and we'll crash when we free the buffer. Effectively: we can + // never use Then/AddCallback on a Future<> from ReadFrameAsync, + // because we might run the callback synchronously (which might + // free the buffer) when we call Push here. + frame->buffer = + arrow::internal::make_unique(worker_, data, data_length); + Push(std::move(frame)); + return UCS_INPROGRESS; + } + + if ((param->recv_attr & UCP_AM_RECV_ATTR_FLAG_DATA) || + (param->recv_attr & UCP_AM_RECV_ATTR_FLAG_RNDV)) { + // Rendezvous protocol (RNDV), or unpack to destination (DATA). + + // We want to map/pin/register the buffer for faster transfer + // where possible. (It gets unmapped in ~PendingAmRecv.) + // TODO(ARROW-16126): This takes non-trivial time, so return + // UCS_INPROGRESS, kick off the allocation in the background, + // and recv the data later (is it allowed to call + // ucp_am_recv_data_nbx asynchronously?). + if (frame->type == FrameType::kPayloadBody) { + ARROW_ASSIGN_OR_RAISE(frame->buffer, + memory_manager_->AllocateBuffer(data_length)); + } else { + ARROW_ASSIGN_OR_RAISE(frame->buffer, + AllocateBuffer(data_length, read_memory_pool_)); + } + + PendingAmRecv* pending_recv = new PendingAmRecv(this, std::move(frame)); + TryMapBuffer(worker_->context().get(), *pending_recv->frame->buffer, + &pending_recv->memh_p); + + ucp_request_param_t recv_param; + recv_param.op_attr_mask = UCP_OP_ATTR_FIELD_CALLBACK | + UCP_OP_ATTR_FIELD_MEMORY_TYPE | + UCP_OP_ATTR_FIELD_USER_DATA; + recv_param.cb.recv_am = AmRecvCallback; + recv_param.user_data = pending_recv; + recv_param.memory_type = InferMemoryType(*pending_recv->frame->buffer); + + void* dest = + reinterpret_cast(pending_recv->frame->buffer->mutable_address()); + void* request = + ucp_am_recv_data_nbx(worker_->get(), data, dest, data_length, &recv_param); + if (UCS_PTR_IS_ERR(request)) { + delete pending_recv; + return FromUcsStatus("ucp_am_recv_data_nbx", UCS_PTR_STATUS(request)); + } else if (!request) { + // Request completed instantly + Push(std::move(pending_recv->frame)); + delete pending_recv; + } + return UCS_OK; + } else { + // Data will be freed after callback returns - copy to buffer + if (memory_manager_->is_cpu() || frame->type != FrameType::kPayloadBody) { + ARROW_ASSIGN_OR_RAISE(frame->buffer, + AllocateBuffer(data_length, read_memory_pool_)); + std::memcpy(frame->buffer->mutable_data(), data, data_length); + } else { + ARROW_ASSIGN_OR_RAISE( + frame->buffer, + MemoryManager::CopyNonOwned(Buffer(reinterpret_cast(data), + static_cast(data_length)), + memory_manager_)); + } + Push(std::move(frame)); + return UCS_OK; + } + } + + Status CompleteRequestBlocking(const std::string& context, void* request) { + if (UCS_PTR_IS_ERR(request)) { + return FromUcsStatus(context, UCS_PTR_STATUS(request)); + } else if (UCS_PTR_IS_PTR(request)) { + while (true) { + auto status = ucp_request_check_status(request); + if (status == UCS_OK) { + break; + } else if (status != UCS_INPROGRESS) { + ucp_request_release(request); + return FromUcsStatus("ucp_request_check_status", status); + } + MakeProgress(); + } + ucp_request_free(request); + } else { + // Send was completed instantly + DCHECK(!request); + } + return Status::OK(); + } + + Status CheckClosed() { + if (!endpoint_) { + return Status::Invalid("UcpCallDriver is closed"); + } + return Status::OK(); + } + + const std::array padding_bytes_; +#if defined(ARROW_FLIGHT_UCX_SEND_IOV_MAP) + ucp_mem_h padding_memh_p_; +#endif + + std::shared_ptr worker_; + ucp_ep_h endpoint_; + MemoryPool* read_memory_pool_; + MemoryPool* write_memory_pool_; + std::shared_ptr memory_manager_; + + // Internal name for logging/tracing + std::string name_; + // Counter used to reorder messages + uint32_t counter_ = 0; + + std::mutex frame_mutex_; + Status status_; + std::unordered_map>> frames_; + uint32_t next_counter_ = 0; +}; + +UcpCallDriver::UcpCallDriver(std::shared_ptr worker, ucp_ep_h endpoint) + : impl_(new Impl(std::move(worker), endpoint)) {} +UcpCallDriver::UcpCallDriver(UcpCallDriver&&) = default; +UcpCallDriver& UcpCallDriver::operator=(UcpCallDriver&&) = default; +UcpCallDriver::~UcpCallDriver() = default; + +arrow::Result> UcpCallDriver::ReadNextFrame() { + return impl_->ReadNextFrame(); +} + +Future> UcpCallDriver::ReadFrameAsync() { + return impl_->ReadFrameAsync(); +} + +Status UcpCallDriver::ExpectFrameType(const Frame& frame, FrameType type) { + if (frame.type != type) { + return Status::IOError("Expected frame type ", static_cast(type), + ", but got frame type ", static_cast(frame.type)); + } + return Status::OK(); +} + +Status UcpCallDriver::StartCall(const std::string& method) { + std::vector> headers; + headers.emplace_back(kHeaderMethod, method); + ARROW_ASSIGN_OR_RAISE(auto frame, HeadersFrame::Make(headers)); + auto buffer = std::move(frame).GetBuffer(); + RETURN_NOT_OK(impl_->SendFrame(FrameType::kHeaders, buffer->data(), buffer->size())); + return Status::OK(); +} + +Future<> UcpCallDriver::SendFlightPayload(const FlightPayload& payload) { + return impl_->SendFlightPayload(payload); +} + +Status UcpCallDriver::SendFrame(FrameType frame_type, const uint8_t* data, + const int64_t size) { + return impl_->SendFrame(frame_type, data, size); +} + +Future<> UcpCallDriver::SendFrameAsync(FrameType frame_type, + std::unique_ptr buffer) { + return impl_->SendFrameAsync(frame_type, std::move(buffer)); +} + +Status UcpCallDriver::Close() { return impl_->Close(); } + +void UcpCallDriver::MakeProgress() { impl_->MakeProgress(); } + +ucs_status_t UcpCallDriver::RecvActiveMessage(const void* header, size_t header_length, + void* data, const size_t data_length, + const ucp_am_recv_param_t* param) { + return impl_->RecvActiveMessage(header, header_length, data, data_length, param); +} + +const std::shared_ptr& UcpCallDriver::memory_manager() const { + return impl_->memory_manager(); +} + +void UcpCallDriver::set_memory_manager(std::shared_ptr memory_manager) { + impl_->set_memory_manager(std::move(memory_manager)); +} +void UcpCallDriver::set_read_memory_pool(MemoryPool* pool) { + impl_->set_read_memory_pool(pool); +} +void UcpCallDriver::set_write_memory_pool(MemoryPool* pool) { + impl_->set_write_memory_pool(pool); +} +const std::string& UcpCallDriver::peer() const { return impl_->peer(); } + +} // namespace ucx +} // namespace transport +} // namespace flight +} // namespace arrow diff --git a/cpp/src/arrow/flight/transport/ucx/ucx_internal.h b/cpp/src/arrow/flight/transport/ucx/ucx_internal.h new file mode 100644 index 00000000000..bd176e23699 --- /dev/null +++ b/cpp/src/arrow/flight/transport/ucx/ucx_internal.h @@ -0,0 +1,354 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +// Common implementation of UCX communication primitives. + +#pragma once + +#include +#include +#include +#include + +#include + +#include "arrow/buffer.h" +#include "arrow/flight/server.h" +#include "arrow/flight/transport.h" +#include "arrow/flight/transport/ucx/util_internal.h" +#include "arrow/flight/visibility.h" +#include "arrow/type_fwd.h" +#include "arrow/util/future.h" +#include "arrow/util/logging.h" +#include "arrow/util/macros.h" +#include "arrow/util/string_view.h" + +namespace arrow { +namespace flight { +namespace transport { +namespace ucx { + +//------------------------------------------------------------ +// Protocol Constants + +static constexpr char kMethodDoExchange[] = "DoExchange"; +static constexpr char kMethodDoGet[] = "DoGet"; +static constexpr char kMethodDoPut[] = "DoPut"; +static constexpr char kMethodGetFlightInfo[] = "GetFlightInfo"; + +static constexpr char kHeaderStatusCode[] = "flight-status-code"; +static constexpr char kHeaderStatusMessage[] = "flight-status-message"; +static constexpr char kHeaderStatusDetail[] = "flight-status-detail"; +static constexpr char kHeaderStatusDetailCode[] = "flight-status-detail-code"; + +//------------------------------------------------------------ +// UCX Helpers + +/// \brief A wrapper around a ucp_context_h. +/// +/// Used so that multiple resources can share ownership of the +/// context. UCX has zero-copy optimizations where an application can +/// directly use a UCX buffer, but the lifetime of such buffers is +/// tied to the UCX context and worker, so ownership needs to be +/// preserved. +class UcpContext final { + public: + UcpContext() : ucp_context_(nullptr) {} + explicit UcpContext(ucp_context_h context) : ucp_context_(context) {} + ~UcpContext() { + if (ucp_context_) ucp_cleanup(ucp_context_); + ucp_context_ = nullptr; + } + ucp_context_h get() const { + DCHECK(ucp_context_); + return ucp_context_; + } + + private: + ucp_context_h ucp_context_; +}; + +/// \brief A wrapper around a ucp_worker_h. +class UcpWorker final { + public: + UcpWorker() : ucp_worker_(nullptr) {} + UcpWorker(std::shared_ptr context, ucp_worker_h worker) + : ucp_context_(std::move(context)), ucp_worker_(worker) {} + ~UcpWorker() { + if (ucp_worker_) ucp_worker_destroy(ucp_worker_); + ucp_worker_ = nullptr; + } + ucp_worker_h get() const { + DCHECK(ucp_worker_); + return ucp_worker_; + } + const UcpContext& context() const { return *ucp_context_; } + + private: + std::shared_ptr ucp_context_; + ucp_worker_h ucp_worker_; +}; + +//------------------------------------------------------------ +// Message Framing + +/// \brief The message type. +enum class FrameType : uint8_t { + /// Key-value headers. Sent at the beginning (client->server) and + /// end (server->client) of a call. Also, for client-streaming calls + /// (e.g. DoPut), the client should send a headers frame to signal + /// end-of-stream. + kHeaders = 0, + /// Binary blob, does not contain Arrow data. + kBuffer, + /// Binary blob. Contains IPC metadata, app metadata. + kPayloadHeader, + /// Binary blob. Contains IPC body. Body is sent separately since it + /// may use a different memory type. + kPayloadBody, + /// Ask server to disconnect (to avoid client/server waiting on each + /// other and getting stuck). + kDisconnect, + /// Keep at end. + kMaxFrameType = kDisconnect, +}; + +/// \brief The header of a message frame. Used when sending only. +/// +/// A frame is expected to be sent over UCP Active Messages and +/// consists of a header (of kFrameHeaderBytes bytes) and a body. +/// +/// The header is as follows: +/// +-------+---------------------------------+ +/// | Bytes | Function | +/// +=======+=================================+ +/// | 0 | Version tag (see kFrameVersion) | +/// | 1 | Frame type (see FrameType) | +/// | 2-3 | Unused, reserved | +/// | 4-7 | Frame counter (big-endian) | +/// | 8-11 | Body size (big-endian) | +/// +-------+---------------------------------+ +/// +/// The frame counter lets the receiver ensure messages are processed +/// in-order. (The message receive callback may use +/// ucp_am_recv_data_nbx which is asynchronous.) +/// +/// The body size reports the expected message size (UCX chokes on +/// zero-size payloads which we occasionally want to send, so the size +/// field in the header lets us know when a payload was meant to be +/// empty). +struct FrameHeader { + /// \brief The size of a frame header. + static constexpr size_t kFrameHeaderBytes = 12; + /// \brief The expected version tag in the header. + static constexpr uint8_t kFrameVersion = 0x01; + + FrameHeader() = default; + /// \brief Initialize the frame header. + Status Set(FrameType frame_type, uint32_t counter, int64_t body_size); + void* data() const { return header.data(); } + size_t size() const { return kFrameHeaderBytes; } + + // mutable since UCX expects void* not const void* + mutable std::array header = {0}; +}; + +/// \brief A single message received via UCX. Used when receiving only. +struct Frame { + /// \brief The message type. + FrameType type; + /// \brief The message length. + uint32_t size; + /// \brief An incrementing message counter (may wrap over). + uint32_t counter; + /// \brief The message contents. + std::unique_ptr buffer; + + Frame() = default; + Frame(FrameType type_, uint32_t size_, uint32_t counter_, + std::unique_ptr buffer_) + : type(type_), size(size_), counter(counter_), buffer(std::move(buffer_)) {} + + util::string_view view() const { + return util::string_view(reinterpret_cast(buffer->data()), size); + } + + /// \brief Parse a UCX active message header. This will not + /// initialize the buffer field. + static arrow::Result> ParseHeader(const void* header, + size_t header_length); +}; + +/// \brief The active message handler callback ID. +static constexpr uint32_t kUcpAmHandlerId = 0x1024; + +/// \brief A collection of key-value headers. +/// +/// This should be stored in a frame of type kHeaders. +/// +/// Format: +/// +-------+----------------------------------+ +/// | Bytes | Contents | +/// +=======+==================================+ +/// | 0-4 | # of headers (big-endian) | +/// | 4-8 | Header key length (big-endian) | +/// | 2-3 | Header value length (big-endian) | +/// | (...) | Header key | +/// | (...) | Header value | +/// | (...) | (repeat from row 2) | +/// +-------+----------------------------------+ +class HeadersFrame { + public: + /// \brief Get a header value (or an error if it was not found) + arrow::Result Get(const std::string& key); + /// \brief Extract the server-sent status. + Status GetStatus(Status* out); + /// \brief Parse the headers from the buffer. + static arrow::Result Parse(std::unique_ptr buffer); + /// \brief Create a new frame with the given headers. + static arrow::Result Make( + const std::vector>& headers); + /// \brief Create a new frame with the given headers and the given status. + static arrow::Result Make( + const Status& status, + const std::vector>& headers); + + /// \brief Take ownership of the underlying buffer. + std::unique_ptr GetBuffer() && { return std::move(buffer_); } + + private: + std::unique_ptr buffer_; + std::vector> headers_; +}; + +/// \brief A representation of a kPayloadHeader frame (i.e. all of the +/// metadata in a FlightPayload/FlightData). +/// +/// Data messages are sent in two parts: one containing all metadata +/// (the Flatbuffers header, FlightDescriptor, and app_metadata +/// fields) and one containing the actual data. This was done to avoid +/// having to concatenate these fields with the data itself (in the +/// cases where we are not using IOV). +/// +/// Format: +/// +--------+----------------------------------+ +/// | Bytes | Contents | +/// +========+==================================+ +/// | 0-4 | Descriptor length (big-endian) | +/// | 4..a | Descriptor bytes | +/// | a-a+4 | app_metadata length (big-endian) | +/// | a+4..b | app_metadata bytes | +/// | b-b+4 | ipc_metadata length (big-endian) | +/// | b+4..c | ipc_metadata bytes | +/// +--------+----------------------------------+ +/// +/// If a field is not present, its length is still there, but is set +/// to UINT32_MAX. +class PayloadHeaderFrame { + public: + explicit PayloadHeaderFrame(std::unique_ptr buffer) + : buffer_(std::move(buffer)) {} + /// \brief Unpack the internal buffer into a FlightData. + Status ToFlightData(internal::FlightData* data); + /// \brief Pack a payload into the internal buffer. + static arrow::Result Make(const FlightPayload& payload, + MemoryPool* memory_pool); + const uint8_t* data() const { return buffer_->data(); } + int64_t size() const { return buffer_->size(); } + + private: + std::unique_ptr buffer_; +}; + +/// \brief Manage the state of a UCX connection. +class UcpCallDriver { + public: + UcpCallDriver(std::shared_ptr worker, ucp_ep_h endpoint); + + UcpCallDriver(const UcpCallDriver&) = delete; + UcpCallDriver(UcpCallDriver&&); + void operator=(const UcpCallDriver&) = delete; + UcpCallDriver& operator=(UcpCallDriver&&); + + ~UcpCallDriver(); + + /// \brief Start a call by sending a headers frame. Client side only. + /// + /// \param[in] method The RPC method. + Status StartCall(const std::string& method); + + /// \brief Synchronously send a generic message with binary payload. + Status SendFrame(FrameType frame_type, const uint8_t* data, const int64_t size); + /// \brief Asynchronously send a generic message with binary payload. + /// + /// The UCP driver must be manually polled (call MakeProgress()). + Future<> SendFrameAsync(FrameType frame_type, std::unique_ptr buffer); + /// \brief Asynchronously send a data message. + /// + /// The UCP driver must be manually polled (call MakeProgress()). + Future<> SendFlightPayload(const FlightPayload& payload); + + /// \brief Synchronously read the next frame. + arrow::Result> ReadNextFrame(); + /// \brief Asynchronously read the next frame. + /// + /// The UCP driver must be manually polled (call MakeProgress()). + Future> ReadFrameAsync(); + + /// \brief Validate that the frame is of the given type. + Status ExpectFrameType(const Frame& frame, FrameType type); + + /// \brief Disconnect the other side of the connection. Note, this + /// can cause deadlock. + Status Close(); + + /// \brief Synchronously make progress (to adapt async to sync APIs) + void MakeProgress(); + + /// \brief Get the associated memory manager. + const std::shared_ptr& memory_manager() const; + /// \brief Set the associated memory manager. + void set_memory_manager(std::shared_ptr memory_manager); + /// \brief Set memory pool for scratch space used during reading. + void set_read_memory_pool(MemoryPool* memory_pool); + /// \brief Set memory pool for scratch space used during writing. + void set_write_memory_pool(MemoryPool* memory_pool); + /// \brief Get a debug string naming the peer. + const std::string& peer() const; + + /// \brief Process an incoming active message. This will unblock the + /// corresponding call to ReadFrameAsync/ReadNextFrame. + ucs_status_t RecvActiveMessage(const void* header, size_t header_length, void* data, + const size_t data_length, + const ucp_am_recv_param_t* param); + + private: + class Impl; + std::unique_ptr impl_; +}; + +ARROW_FLIGHT_EXPORT +std::unique_ptr MakeUcxClientImpl(); + +ARROW_FLIGHT_EXPORT +std::unique_ptr MakeUcxServerImpl( + FlightServerBase* base, std::shared_ptr memory_manager); + +} // namespace ucx +} // namespace transport +} // namespace flight +} // namespace arrow diff --git a/cpp/src/arrow/flight/transport/ucx/ucx_server.cc b/cpp/src/arrow/flight/transport/ucx/ucx_server.cc new file mode 100644 index 00000000000..74a9311d0c8 --- /dev/null +++ b/cpp/src/arrow/flight/transport/ucx/ucx_server.cc @@ -0,0 +1,628 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#include "arrow/flight/transport/ucx/ucx_internal.h" + +#include +#include +#include +#include +#include + +#include +#include + +#include "arrow/buffer.h" +#include "arrow/flight/server.h" +#include "arrow/flight/transport.h" +#include "arrow/flight/transport/ucx/util_internal.h" +#include "arrow/flight/transport_server.h" +#include "arrow/result.h" +#include "arrow/status.h" +#include "arrow/util/io_util.h" +#include "arrow/util/logging.h" +#include "arrow/util/make_unique.h" +#include "arrow/util/thread_pool.h" +#include "arrow/util/uri.h" + +namespace arrow { +namespace flight { +namespace transport { +namespace ucx { + +// Send an error to the client and return OK. +// Statuses returned up to the main server loop trigger a kReset instead. +#define SERVER_RETURN_NOT_OK(driver, status) \ + do { \ + ::arrow::Status s = (status); \ + if (!s.ok()) { \ + ARROW_ASSIGN_OR_RAISE(auto headers, HeadersFrame::Make(s, {})); \ + auto payload = std::move(headers).GetBuffer(); \ + RETURN_NOT_OK( \ + driver->SendFrame(FrameType::kHeaders, payload->data(), payload->size())); \ + return ::arrow::Status::OK(); \ + } \ + } while (false) + +#define FLIGHT_LOG(LEVEL) (ARROW_LOG(LEVEL) << "[server] ") +#define FLIGHT_LOG_PEER(LEVEL, PEER) \ + (ARROW_LOG(LEVEL) << "[server]" \ + << "[peer=" << (PEER) << "] ") + +namespace { +class UcxServerCallContext : public flight::ServerCallContext { + public: + const std::string& peer_identity() const override { return peer_; } + const std::string& peer() const override { return peer_; } + ServerMiddleware* GetMiddleware(const std::string& key) const override { + return nullptr; + } + bool is_cancelled() const override { return false; } + + private: + std::string peer_; +}; + +class UcxServerStream : public internal::ServerDataStream { + public: + explicit UcxServerStream(UcpCallDriver* driver) + : peer_(driver->peer()), driver_(driver), writes_done_(false) {} + + Status WritesDone() override { + writes_done_ = true; + return Status::OK(); + } + + protected: + std::string peer_; + UcpCallDriver* driver_; + bool writes_done_; +}; + +class GetServerStream : public UcxServerStream { + public: + using UcxServerStream::UcxServerStream; + + arrow::Result WriteData(const FlightPayload& payload) override { + if (writes_done_) return false; + Future<> pending_send = driver_->SendFlightPayload(payload); + while (!pending_send.is_finished()) { + driver_->MakeProgress(); + } + RETURN_NOT_OK(pending_send.status()); + return true; + } +}; + +class PutServerStream : public UcxServerStream { + public: + explicit PutServerStream(UcpCallDriver* driver) + : UcxServerStream(driver), finished_(false) {} + + bool ReadData(internal::FlightData* data) override { + if (finished_) return false; + + bool success = true; + auto status = ReadImpl(data).Value(&success); + + if (!status.ok() || !success) { + finished_ = true; + if (!status.ok()) { + FLIGHT_LOG_PEER(WARNING, peer_) << "I/O error in DoPut: " << status.ToString(); + return false; + } + } + return success; + } + + Status WritePutMetadata(const Buffer& payload) override { + if (finished_) return Status::OK(); + // Send synchronously (we don't control payload lifetime) + return driver_->SendFrame(FrameType::kBuffer, payload.data(), payload.size()); + } + + private: + ::arrow::Result ReadImpl(internal::FlightData* data) { + ARROW_ASSIGN_OR_RAISE(auto frame, driver_->ReadNextFrame()); + if (frame->type == FrameType::kHeaders) { + // Trailers, client is done writing + return false; + } + RETURN_NOT_OK(driver_->ExpectFrameType(*frame, FrameType::kPayloadHeader)); + PayloadHeaderFrame payload_header(std::move(frame->buffer)); + RETURN_NOT_OK(payload_header.ToFlightData(data)); + + if (data->metadata) { + ARROW_ASSIGN_OR_RAISE(auto message, ipc::Message::Open(data->metadata, nullptr)); + + if (ipc::Message::HasBody(message->type())) { + ARROW_ASSIGN_OR_RAISE(frame, driver_->ReadNextFrame()); + RETURN_NOT_OK(driver_->ExpectFrameType(*frame, FrameType::kPayloadBody)); + data->body = std::move(frame->buffer); + } + } + return true; + } + + bool finished_; +}; + +class ExchangeServerStream : public PutServerStream { + public: + using PutServerStream::PutServerStream; + + arrow::Result WriteData(const FlightPayload& payload) override { + if (writes_done_) return false; + Future<> pending_send = driver_->SendFlightPayload(payload); + while (!pending_send.is_finished()) { + driver_->MakeProgress(); + } + RETURN_NOT_OK(pending_send.status()); + return true; + } + Status WritePutMetadata(const Buffer& payload) override { + return Status::NotImplemented("Not supported on this stream"); + } +}; + +class UcxServerImpl : public arrow::flight::internal::ServerTransport { + public: + using arrow::flight::internal::ServerTransport::ServerTransport; + + virtual ~UcxServerImpl() { + if (listening_.load()) { + auto st = Shutdown(); + if (!st.ok()) { + ARROW_LOG(WARNING) << "Server did not shut down properly: " << st.ToString(); + } + } + } + + Status Init(const FlightServerOptions& options, + const arrow::internal::Uri& uri) override { + const auto max_threads = std::max(8, std::thread::hardware_concurrency()); + ARROW_ASSIGN_OR_RAISE(rpc_pool_, arrow::internal::ThreadPool::Make(max_threads)); + + struct sockaddr_storage listen_addr; + ARROW_ASSIGN_OR_RAISE(auto addrlen, UriToSockaddr(uri, &listen_addr)); + + // Init UCX + { + ucp_config_t* ucp_config; + ucp_params_t ucp_params; + ucs_status_t status; + + status = ucp_config_read(nullptr, nullptr, &ucp_config); + RETURN_NOT_OK(FromUcsStatus("ucp_config_read", status)); + + // If location is IPv6, must adjust UCX config + if (listen_addr.ss_family == AF_INET6) { + status = ucp_config_modify(ucp_config, "AF_PRIO", "inet6"); + RETURN_NOT_OK(FromUcsStatus("ucp_config_modify", status)); + } + + // Allow application to override UCP config + if (options.builder_hook) options.builder_hook(ucp_config); + + std::memset(&ucp_params, 0, sizeof(ucp_params)); + ucp_params.field_mask = + UCP_PARAM_FIELD_FEATURES | UCP_PARAM_FIELD_MT_WORKERS_SHARED; + ucp_params.features = UCP_FEATURE_AM | UCP_FEATURE_WAKEUP; + ucp_params.mt_workers_shared = UCS_THREAD_MODE_MULTI; + + ucp_context_h ucp_context; + status = ucp_init(&ucp_params, ucp_config, &ucp_context); + ucp_config_release(ucp_config); + RETURN_NOT_OK(FromUcsStatus("ucp_init", status)); + ucp_context_.reset(new UcpContext(ucp_context)); + } + + { + // Create one worker to listen for incoming connections. + ucp_worker_params_t worker_params; + ucs_status_t status; + + std::memset(&worker_params, 0, sizeof(worker_params)); + worker_params.field_mask = UCP_WORKER_PARAM_FIELD_THREAD_MODE; + worker_params.thread_mode = UCS_THREAD_MODE_MULTI; + ucp_worker_h worker; + status = ucp_worker_create(ucp_context_->get(), &worker_params, &worker); + RETURN_NOT_OK(FromUcsStatus("ucp_worker_create", status)); + worker_conn_.reset(new UcpWorker(ucp_context_, worker)); + } + + // Start listening for connections. + { + ucp_listener_params_t params; + ucs_status_t status; + + params.field_mask = + UCP_LISTENER_PARAM_FIELD_SOCK_ADDR | UCP_LISTENER_PARAM_FIELD_CONN_HANDLER; + params.sockaddr.addr = reinterpret_cast(&listen_addr); + params.sockaddr.addrlen = addrlen; + params.conn_handler.cb = HandleIncomingConnection; + params.conn_handler.arg = this; + + status = ucp_listener_create(worker_conn_->get(), ¶ms, &listener_); + RETURN_NOT_OK(FromUcsStatus("ucp_listener_create", status)); + + // Get the real address/port + ucp_listener_attr_t attr; + attr.field_mask = UCP_LISTENER_ATTR_FIELD_SOCKADDR; + status = ucp_listener_query(listener_, &attr); + RETURN_NOT_OK(FromUcsStatus("ucp_listener_query", status)); + + std::string raw_uri = "ucx://"; + if (uri.host().find(':') != std::string::npos) { + // IPv6 host + raw_uri += '['; + raw_uri += uri.host(); + raw_uri += ']'; + } else { + raw_uri += uri.host(); + } + raw_uri += ":"; + raw_uri += std::to_string( + ntohs(reinterpret_cast(&attr.sockaddr)->sin_port)); + std::string listen_str; + ARROW_UNUSED(SockaddrToString(attr.sockaddr).Value(&listen_str)); + FLIGHT_LOG(DEBUG) << "Listening on " << listen_str; + ARROW_ASSIGN_OR_RAISE(location_, Location::Parse(raw_uri)); + } + + { + listening_.store(true); + std::thread listener_thread(&UcxServerImpl::DriveConnections, this); + listener_thread_.swap(listener_thread); + } + + return Status::OK(); + } + + Status Shutdown() override { + if (!listening_.load()) return Status::OK(); + Status status; + + // Wait for current RPCs to finish + listening_.store(false); + // Unstick the listener thread from ucp_worker_wait + RETURN_NOT_OK( + FromUcsStatus("ucp_worker_signal", ucp_worker_signal(worker_conn_->get()))); + status &= Wait(); + + { + // Reject all pending connections + std::unique_lock guard(pending_connections_mutex_); + while (!pending_connections_.empty()) { + status &= + FromUcsStatus("ucp_listener_reject", + ucp_listener_reject(listener_, pending_connections_.front())); + pending_connections_.pop(); + } + ucp_listener_destroy(listener_); + worker_conn_.reset(); + } + + status &= rpc_pool_->Shutdown(); + rpc_pool_.reset(); + + ucp_context_.reset(); + return status; + } + + Status Shutdown(const std::chrono::system_clock::time_point& deadline) override { + // TODO(ARROW-16125): implement shutdown with deadline + return Shutdown(); + } + + Status Wait() override { + std::lock_guard guard(join_mutex_); + try { + listener_thread_.join(); + } catch (const std::system_error& e) { + if (e.code() != std::errc::invalid_argument) { + return Status::UnknownError("Could not Wait(): ", e.what()); + } + // Else, server wasn't running anyways + } + return Status::OK(); + } + + Location location() const override { return location_; } + + private: + struct ClientWorker { + std::shared_ptr worker; + std::unique_ptr driver; + }; + + Status SendStatus(UcpCallDriver* driver, const Status& status) { + ARROW_ASSIGN_OR_RAISE(auto headers, HeadersFrame::Make(status, {})); + auto payload = std::move(headers).GetBuffer(); + RETURN_NOT_OK( + driver->SendFrame(FrameType::kHeaders, payload->data(), payload->size())); + return Status::OK(); + } + + Status HandleGetFlightInfo(UcpCallDriver* driver) { + UcxServerCallContext context; + + ARROW_ASSIGN_OR_RAISE(auto frame, driver->ReadNextFrame()); + SERVER_RETURN_NOT_OK(driver, driver->ExpectFrameType(*frame, FrameType::kBuffer)); + FlightDescriptor descriptor; + SERVER_RETURN_NOT_OK(driver, + FlightDescriptor::Deserialize(util::string_view(*frame->buffer)) + .Value(&descriptor)); + + std::unique_ptr info; + std::string response; + SERVER_RETURN_NOT_OK(driver, base_->GetFlightInfo(context, descriptor, &info)); + SERVER_RETURN_NOT_OK(driver, info->SerializeToString().Value(&response)); + RETURN_NOT_OK(driver->SendFrame(FrameType::kBuffer, + reinterpret_cast(response.data()), + static_cast(response.size()))); + RETURN_NOT_OK(SendStatus(driver, Status::OK())); + return Status::OK(); + } + + Status HandleDoGet(UcpCallDriver* driver) { + UcxServerCallContext context; + + ARROW_ASSIGN_OR_RAISE(auto frame, driver->ReadNextFrame()); + SERVER_RETURN_NOT_OK(driver, driver->ExpectFrameType(*frame, FrameType::kBuffer)); + Ticket ticket; + SERVER_RETURN_NOT_OK(driver, Ticket::Deserialize(frame->view()).Value(&ticket)); + + GetServerStream stream(driver); + auto status = DoGet(context, std::move(ticket), &stream); + RETURN_NOT_OK(SendStatus(driver, status)); + return Status::OK(); + } + + Status HandleDoPut(UcpCallDriver* driver) { + UcxServerCallContext context; + + PutServerStream stream(driver); + auto status = DoPut(context, &stream); + RETURN_NOT_OK(SendStatus(driver, status)); + // Must drain any unread messages, or the next call will get confused + internal::FlightData ignored; + while (stream.ReadData(&ignored)) { + } + return Status::OK(); + } + + Status HandleDoExchange(UcpCallDriver* driver) { + UcxServerCallContext context; + + ExchangeServerStream stream(driver); + auto status = DoExchange(context, &stream); + RETURN_NOT_OK(SendStatus(driver, status)); + // Must drain any unread messages, or the next call will get confused + internal::FlightData ignored; + while (stream.ReadData(&ignored)) { + } + return Status::OK(); + } + + Status HandleOneCall(UcpCallDriver* driver, Frame* frame) { + SERVER_RETURN_NOT_OK(driver, driver->ExpectFrameType(*frame, FrameType::kHeaders)); + ARROW_ASSIGN_OR_RAISE(auto headers, HeadersFrame::Parse(std::move(frame->buffer))); + ARROW_ASSIGN_OR_RAISE(auto method, headers.Get(":method:")); + if (method == kMethodGetFlightInfo) { + return HandleGetFlightInfo(driver); + } else if (method == kMethodDoExchange) { + return HandleDoExchange(driver); + } else if (method == kMethodDoGet) { + return HandleDoGet(driver); + } else if (method == kMethodDoPut) { + return HandleDoPut(driver); + } + RETURN_NOT_OK(SendStatus(driver, Status::NotImplemented(method))); + return Status::OK(); + } + + void WorkerLoop(ucp_conn_request_h request) { + std::string peer = "unknown:" + std::to_string(counter_++); + { + ucp_conn_request_attr_t request_attr; + std::memset(&request_attr, 0, sizeof(request_attr)); + request_attr.field_mask = UCP_CONN_REQUEST_ATTR_FIELD_CLIENT_ADDR; + if (ucp_conn_request_query(request, &request_attr) == UCS_OK) { + ARROW_UNUSED(SockaddrToString(request_attr.client_address).Value(&peer)); + } + } + FLIGHT_LOG_PEER(DEBUG, peer) << "Received connection request"; + + auto maybe_worker = CreateWorker(); + if (!maybe_worker.ok()) { + FLIGHT_LOG_PEER(WARNING, peer) + << "Failed to create worker" << maybe_worker.status().ToString(); + auto status = ucp_listener_reject(listener_, request); + if (status != UCS_OK) { + FLIGHT_LOG_PEER(WARNING, peer) + << FromUcsStatus("ucp_listener_reject", status).ToString(); + } + return; + } + auto worker = maybe_worker.MoveValueUnsafe(); + + // Create an endpoint to the client, using the data worker + { + ucs_status_t status; + ucp_ep_params_t params; + std::memset(¶ms, 0, sizeof(params)); + params.field_mask = UCP_EP_PARAM_FIELD_CONN_REQUEST; + params.conn_request = request; + + ucp_ep_h client_endpoint; + + status = ucp_ep_create(worker->worker->get(), ¶ms, &client_endpoint); + if (status != UCS_OK) { + FLIGHT_LOG_PEER(WARNING, peer) + << "Failed to create endpoint: " + << FromUcsStatus("ucp_ep_create", status).ToString(); + return; + } + worker->driver.reset(new UcpCallDriver(worker->worker, client_endpoint)); + worker->driver->set_memory_manager(memory_manager_); + peer = worker->driver->peer(); + } + + while (listening_.load()) { + auto maybe_frame = worker->driver->ReadNextFrame(); + if (!maybe_frame.ok()) { + if (!maybe_frame.status().IsCancelled()) { + FLIGHT_LOG_PEER(WARNING, peer) + << "Failed to read next message: " << maybe_frame.status().ToString(); + } + break; + } + + auto status = HandleOneCall(worker->driver.get(), maybe_frame->get()); + if (!status.ok()) { + FLIGHT_LOG_PEER(WARNING, peer) << "Call failed: " << status.ToString(); + break; + } + } + + // Clean up + auto status = worker->driver->Close(); + if (!status.ok()) { + FLIGHT_LOG_PEER(WARNING, peer) << "Failed to close worker: " << status.ToString(); + } + worker->worker.reset(); + FLIGHT_LOG_PEER(DEBUG, peer) << "Disconnected"; + } + + void DriveConnections() { + while (listening_.load()) { + while (ucp_worker_progress(worker_conn_->get())) { + } + { + // Check for connect requests in queue + std::unique_lock guard(pending_connections_mutex_); + while (!pending_connections_.empty()) { + ucp_conn_request_h request = pending_connections_.front(); + pending_connections_.pop(); + + auto submitted = rpc_pool_->Submit([this, request]() { WorkerLoop(request); }); + if (!submitted.ok()) { + ARROW_LOG(WARNING) << "Failed to submit task to handle client " + << submitted.status().ToString(); + } + } + } + + // Check listening_ in case we're shutting down. It is possible + // that Shutdown() was called while we were in + // ucp_worker_progress above, in which case if we don't check + // listening_ here, we'll enter ucp_worker_wait and get stuck. + if (!listening_.load()) break; + auto status = ucp_worker_wait(worker_conn_->get()); + if (status != UCS_OK) { + FLIGHT_LOG(WARNING) << FromUcsStatus("ucp_worker_wait", status).ToString(); + } + } + } + + void EnqueueClient(ucp_conn_request_h connection_request) { + std::unique_lock guard(pending_connections_mutex_); + pending_connections_.push(connection_request); + guard.unlock(); + } + + arrow::Result> CreateWorker() { + auto worker = std::make_shared(); + + ucp_worker_params_t worker_params; + std::memset(&worker_params, 0, sizeof(worker_params)); + worker_params.field_mask = UCP_WORKER_PARAM_FIELD_THREAD_MODE; + worker_params.thread_mode = UCS_THREAD_MODE_SINGLE; + + ucp_worker_h ucp_worker; + auto status = ucp_worker_create(ucp_context_->get(), &worker_params, &ucp_worker); + RETURN_NOT_OK(FromUcsStatus("ucp_worker_create", status)); + worker->worker.reset(new UcpWorker(ucp_context_, ucp_worker)); + + // Set up Active Message (AM) handler + ucp_am_handler_param_t handler_params; + std::memset(&handler_params, 0, sizeof(handler_params)); + handler_params.field_mask = UCP_AM_HANDLER_PARAM_FIELD_ID | + UCP_AM_HANDLER_PARAM_FIELD_CB | + UCP_AM_HANDLER_PARAM_FIELD_ARG; + handler_params.id = kUcpAmHandlerId; + handler_params.cb = HandleIncomingActiveMessage; + handler_params.arg = worker.get(); + + status = ucp_worker_set_am_recv_handler(worker->worker->get(), &handler_params); + RETURN_NOT_OK(FromUcsStatus("ucp_worker_set_am_recv_handler", status)); + return worker; + } + + // Callback handler. A new client has connected to the server. + static void HandleIncomingConnection(ucp_conn_request_h connection_request, + void* data) { + UcxServerImpl* server = reinterpret_cast(data); + // TODO(ARROW-16124): enable shedding load above some threshold + // (which is a pitfall with gRPC/Java) + server->EnqueueClient(connection_request); + } + + static ucs_status_t HandleIncomingActiveMessage(void* self, const void* header, + size_t header_length, void* data, + size_t data_length, + const ucp_am_recv_param_t* param) { + ClientWorker* worker = reinterpret_cast(self); + DCHECK(worker->driver); + return worker->driver->RecvActiveMessage(header, header_length, data, data_length, + param); + } + + std::shared_ptr ucp_context_; + // Listen for and handle incoming connections + std::shared_ptr worker_conn_; + ucp_listener_h listener_; + Location location_; + + // Counter for identifying peers when UCX doesn't give us a way + std::atomic counter_; + + std::shared_ptr rpc_pool_; + std::atomic listening_; + std::thread listener_thread_; + // std::thread::join cannot be called concurrently + std::mutex join_mutex_; + + std::mutex pending_connections_mutex_; + std::queue pending_connections_; +}; +} // namespace + +std::unique_ptr MakeUcxServerImpl( + FlightServerBase* base, std::shared_ptr memory_manager) { + return arrow::internal::make_unique(base, memory_manager); +} + +#undef SERVER_RETURN_NOT_OK +#undef FLIGHT_LOG +#undef FLIGHT_LOG_PEER + +} // namespace ucx +} // namespace transport +} // namespace flight +} // namespace arrow diff --git a/cpp/src/arrow/flight/transport/ucx/util_internal.cc b/cpp/src/arrow/flight/transport/ucx/util_internal.cc new file mode 100644 index 00000000000..ca4df21a055 --- /dev/null +++ b/cpp/src/arrow/flight/transport/ucx/util_internal.cc @@ -0,0 +1,289 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#include "arrow/flight/transport/ucx/util_internal.h" + +#include +#include +#include + +#include +#include +#include + +#include "arrow/buffer.h" +#include "arrow/flight/types.h" +#include "arrow/util/base64.h" +#include "arrow/util/bit_util.h" +#include "arrow/util/io_util.h" +#include "arrow/util/logging.h" +#include "arrow/util/make_unique.h" +#include "arrow/util/uri.h" + +namespace arrow { +namespace flight { +namespace transport { +namespace ucx { + +constexpr char FlightUcxStatusDetail::kTypeId[]; +std::string FlightUcxStatusDetail::ToString() const { return ucs_status_string(status_); } +ucs_status_t FlightUcxStatusDetail::Unwrap(const Status& status) { + if (!status.detail() || status.detail()->type_id() != kTypeId) return UCS_OK; + return dynamic_cast(status.detail().get())->status_; +} + +arrow::Result UriToSockaddr(const arrow::internal::Uri& uri, + struct sockaddr_storage* addr) { + std::string host = uri.host(); + if (host.empty()) { + return Status::Invalid("Must provide a host"); + } else if (uri.port() < 0) { + return Status::Invalid("Must provide a port"); + } + + std::memset(addr, 0, sizeof(*addr)); + + struct addrinfo* info = nullptr; + int err = getaddrinfo(host.c_str(), /*service=*/nullptr, /*hints=*/nullptr, &info); + if (err != 0) { + if (err == EAI_SYSTEM) { + return arrow::internal::IOErrorFromErrno(errno, "[getaddrinfo] Failure resolving ", + host); + } else { + return Status::IOError("[getaddrinfo] Failure resolving ", host, ": ", + gai_strerror(err)); + } + } + + struct addrinfo* cur_info = info; + while (cur_info) { + if (cur_info->ai_family != AF_INET && cur_info->ai_family != AF_INET6) { + cur_info = cur_info->ai_next; + continue; + } + + std::memcpy(addr, cur_info->ai_addr, cur_info->ai_addrlen); + if (cur_info->ai_family == AF_INET) { + reinterpret_cast(addr)->sin_port = htons(uri.port()); + } else if (cur_info->ai_family == AF_INET6) { + reinterpret_cast(addr)->sin6_port = htons(uri.port()); + } + size_t addrlen = cur_info->ai_addrlen; + freeaddrinfo(info); + return addrlen; + } + + if (info) freeaddrinfo(info); + return Status::IOError("[getaddrinfo] Failure resolving ", host, + ": no results of a supported family returned"); +} + +arrow::Result SockaddrToString(const struct sockaddr_storage& address) { + std::string result = ""; + if (address.ss_family != AF_INET && address.ss_family != AF_INET6) { + return Status::NotImplemented("Unknown address family"); + } + + uint16_t port = 0; + if (address.ss_family == AF_INET) { + result.resize(INET_ADDRSTRLEN + 1); + const auto* in_addr = reinterpret_cast(&address); + if (!inet_ntop(address.ss_family, &in_addr->sin_addr, &result[0], INET_ADDRSTRLEN)) { + return arrow::internal::IOErrorFromErrno(errno, + "Could not convert address to string"); + } + port = ntohs(in_addr->sin_port); + } else { + result.resize(INET6_ADDRSTRLEN + 1); + const auto* in6_addr = reinterpret_cast(&address); + if (!inet_ntop(address.ss_family, &in6_addr->sin6_addr, &result[0], + INET6_ADDRSTRLEN)) { + return arrow::internal::IOErrorFromErrno(errno, + "Could not convert address to string"); + } + port = ntohs(in6_addr->sin6_port); + } + + const size_t pos = result.find('\0'); + DCHECK_NE(pos, std::string::npos); + result[pos] = ':'; + result.resize(pos + 1); + result += std::to_string(port); + return result; +} + +Status FromUcsStatus(const std::string& context, ucs_status_t ucs_status) { + switch (ucs_status) { + case UCS_OK: + return Status::OK(); + case UCS_INPROGRESS: + return Status::IOError(context, ": UCX error ", static_cast(ucs_status), + ": ", "UCS_INPROGRESS ", ucs_status_string(ucs_status)) + .WithDetail(std::make_shared(ucs_status)); + case UCS_ERR_NO_MESSAGE: + return Status::IOError(context, ": UCX error ", static_cast(ucs_status), + ": ", "UCS_ERR_NO_MESSAGE ", ucs_status_string(ucs_status)) + .WithDetail(std::make_shared(ucs_status)); + case UCS_ERR_NO_RESOURCE: + return Status::IOError(context, ": UCX error ", static_cast(ucs_status), + ": ", "UCS_ERR_NO_RESOURCE ", ucs_status_string(ucs_status)) + .WithDetail(std::make_shared(ucs_status)); + case UCS_ERR_IO_ERROR: + return Status::IOError(context, ": UCX error ", static_cast(ucs_status), + ": ", "UCS_ERR_IO_ERROR ", ucs_status_string(ucs_status)) + .WithDetail(std::make_shared(ucs_status)); + case UCS_ERR_NO_MEMORY: + return Status::OutOfMemory(context, ": UCX error ", + static_cast(ucs_status), ": ", + "UCS_ERR_NO_MEMORY ", ucs_status_string(ucs_status)) + .WithDetail(std::make_shared(ucs_status)); + case UCS_ERR_INVALID_PARAM: + return Status::Invalid(context, ": UCX error ", static_cast(ucs_status), + ": ", "UCS_ERR_INVALID_PARAM ", + ucs_status_string(ucs_status)) + .WithDetail(std::make_shared(ucs_status)); + case UCS_ERR_UNREACHABLE: + return Status::IOError(context, ": UCX error ", static_cast(ucs_status), + ": ", "UCS_ERR_UNREACHABLE ", ucs_status_string(ucs_status)) + .WithDetail(std::make_shared(ucs_status)); + case UCS_ERR_INVALID_ADDR: + return Status::Invalid(context, ": UCX error ", static_cast(ucs_status), + ": ", "UCS_ERR_INVALID_ADDR ", ucs_status_string(ucs_status)) + .WithDetail(std::make_shared(ucs_status)); + case UCS_ERR_NOT_IMPLEMENTED: + return Status::NotImplemented( + context, ": UCX error ", static_cast(ucs_status), ": ", + "UCS_ERR_NOT_IMPLEMENTED ", ucs_status_string(ucs_status)) + .WithDetail(std::make_shared(ucs_status)); + case UCS_ERR_MESSAGE_TRUNCATED: + return Status::IOError(context, ": UCX error ", static_cast(ucs_status), + ": ", "UCS_ERR_MESSAGE_TRUNCATED ", + ucs_status_string(ucs_status)) + .WithDetail(std::make_shared(ucs_status)); + case UCS_ERR_NO_PROGRESS: + return Status::IOError(context, ": UCX error ", static_cast(ucs_status), + ": ", "UCS_ERR_NO_PROGRESS ", ucs_status_string(ucs_status)) + .WithDetail(std::make_shared(ucs_status)); + case UCS_ERR_BUFFER_TOO_SMALL: + return Status::Invalid(context, ": UCX error ", static_cast(ucs_status), + ": ", "UCS_ERR_BUFFER_TOO_SMALL ", + ucs_status_string(ucs_status)) + .WithDetail(std::make_shared(ucs_status)); + case UCS_ERR_NO_ELEM: + return Status::IOError(context, ": UCX error ", static_cast(ucs_status), + ": ", "UCS_ERR_NO_ELEM ", ucs_status_string(ucs_status)) + .WithDetail(std::make_shared(ucs_status)); + case UCS_ERR_SOME_CONNECTS_FAILED: + return Status::IOError(context, ": UCX error ", static_cast(ucs_status), + ": ", "UCS_ERR_SOME_CONNECTS_FAILED ", + ucs_status_string(ucs_status)) + .WithDetail(std::make_shared(ucs_status)); + case UCS_ERR_NO_DEVICE: + return Status::IOError(context, ": UCX error ", static_cast(ucs_status), + ": ", "UCS_ERR_NO_DEVICE ", ucs_status_string(ucs_status)) + .WithDetail(std::make_shared(ucs_status)); + case UCS_ERR_BUSY: + return Status::IOError(context, ": UCX error ", static_cast(ucs_status), + ": ", "UCS_ERR_BUSY ", ucs_status_string(ucs_status)) + .WithDetail(std::make_shared(ucs_status)); + case UCS_ERR_CANCELED: + return Status::Cancelled(context, ": UCX error ", static_cast(ucs_status), + ": ", "UCS_ERR_CANCELED ", ucs_status_string(ucs_status)) + .WithDetail(std::make_shared(ucs_status)); + case UCS_ERR_SHMEM_SEGMENT: + return Status::IOError(context, ": UCX error ", static_cast(ucs_status), + ": ", "UCS_ERR_SHMEM_SEGMENT ", + ucs_status_string(ucs_status)) + .WithDetail(std::make_shared(ucs_status)); + case UCS_ERR_ALREADY_EXISTS: + return Status::AlreadyExists( + context, ": UCX error ", static_cast(ucs_status), ": ", + "UCS_ERR_ALREADY_EXISTS ", ucs_status_string(ucs_status)) + .WithDetail(std::make_shared(ucs_status)); + case UCS_ERR_OUT_OF_RANGE: + return Status::IOError(context, ": UCX error ", static_cast(ucs_status), + ": ", "UCS_ERR_OUT_OF_RANGE ", ucs_status_string(ucs_status)) + .WithDetail(std::make_shared(ucs_status)); + case UCS_ERR_TIMED_OUT: + return Status::Cancelled(context, ": UCX error ", static_cast(ucs_status), + ": ", "UCS_ERR_TIMED_OUT ", ucs_status_string(ucs_status)) + .WithDetail(std::make_shared(ucs_status)); + case UCS_ERR_EXCEEDS_LIMIT: + return Status::IOError(context, ": UCX error ", static_cast(ucs_status), + ": ", "UCS_ERR_EXCEEDS_LIMIT ", + ucs_status_string(ucs_status)) + .WithDetail(std::make_shared(ucs_status)); + case UCS_ERR_UNSUPPORTED: + return Status::NotImplemented(context, ": UCX error ", + static_cast(ucs_status), ": ", + "UCS_ERR_UNSUPPORTED ", ucs_status_string(ucs_status)) + .WithDetail(std::make_shared(ucs_status)); + case UCS_ERR_REJECTED: + return Status::IOError(context, ": UCX error ", static_cast(ucs_status), + ": ", "UCS_ERR_REJECTED ", ucs_status_string(ucs_status)) + .WithDetail(std::make_shared(ucs_status)); + case UCS_ERR_NOT_CONNECTED: + return Status::IOError(context, ": UCX error ", static_cast(ucs_status), + ": ", "UCS_ERR_NOT_CONNECTED ", + ucs_status_string(ucs_status)) + .WithDetail(std::make_shared(ucs_status)); + case UCS_ERR_CONNECTION_RESET: + return Status::IOError(context, ": UCX error ", static_cast(ucs_status), + ": ", "UCS_ERR_CONNECTION_RESET ", + ucs_status_string(ucs_status)) + .WithDetail(std::make_shared(ucs_status)); + case UCS_ERR_FIRST_LINK_FAILURE: + return Status::IOError(context, ": UCX error ", static_cast(ucs_status), + ": ", "UCS_ERR_FIRST_LINK_FAILURE ", + ucs_status_string(ucs_status)) + .WithDetail(std::make_shared(ucs_status)); + case UCS_ERR_LAST_LINK_FAILURE: + return Status::IOError(context, ": UCX error ", static_cast(ucs_status), + ": ", "UCS_ERR_LAST_LINK_FAILURE ", + ucs_status_string(ucs_status)) + .WithDetail(std::make_shared(ucs_status)); + case UCS_ERR_FIRST_ENDPOINT_FAILURE: + return Status::IOError(context, ": UCX error ", static_cast(ucs_status), + ": ", "UCS_ERR_FIRST_ENDPOINT_FAILURE ", + ucs_status_string(ucs_status)) + .WithDetail(std::make_shared(ucs_status)); + case UCS_ERR_LAST_ENDPOINT_FAILURE: + return Status::IOError(context, ": UCX error ", static_cast(ucs_status), + ": ", "UCS_ERR_LAST_ENDPOINT_FAILURE ", + ucs_status_string(ucs_status)) + .WithDetail(std::make_shared(ucs_status)); + case UCS_ERR_ENDPOINT_TIMEOUT: + return Status::IOError(context, ": UCX error ", static_cast(ucs_status), + ": ", "UCS_ERR_ENDPOINT_TIMEOUT ", + ucs_status_string(ucs_status)) + .WithDetail(std::make_shared(ucs_status)); + case UCS_ERR_LAST: + return Status::IOError(context, ": UCX error ", static_cast(ucs_status), + ": ", "UCS_ERR_LAST ", ucs_status_string(ucs_status)) + .WithDetail(std::make_shared(ucs_status)); + default: + return Status::UnknownError( + context, ": Unknown UCX error: ", static_cast(ucs_status), " ", + ucs_status_string(ucs_status)) + .WithDetail(std::make_shared(ucs_status)); + } +} + +} // namespace ucx +} // namespace transport +} // namespace flight +} // namespace arrow diff --git a/cpp/src/arrow/flight/transport/ucx/util_internal.h b/cpp/src/arrow/flight/transport/ucx/util_internal.h new file mode 100644 index 00000000000..84e84ba0711 --- /dev/null +++ b/cpp/src/arrow/flight/transport/ucx/util_internal.h @@ -0,0 +1,83 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#pragma once + +#include +#include +#include + +#include "arrow/flight/visibility.h" +#include "arrow/status.h" +#include "arrow/util/endian.h" +#include "arrow/util/ubsan.h" +#include "arrow/util/uri.h" + +namespace arrow { +namespace flight { +namespace transport { +namespace ucx { + +static inline void UInt32ToBytesBe(const uint32_t in, uint8_t* out) { + util::SafeStore(out, bit_util::ToBigEndian(in)); +} + +static inline uint32_t BytesToUInt32Be(const uint8_t* in) { + return bit_util::FromBigEndian(util::SafeLoadAs(in)); +} + +class ARROW_FLIGHT_EXPORT FlightUcxStatusDetail : public StatusDetail { + public: + explicit FlightUcxStatusDetail(ucs_status_t status) : status_(status) {} + static constexpr char const kTypeId[] = "flight::transport::ucx::FlightUcxStatusDetail"; + + const char* type_id() const override { return kTypeId; } + std::string ToString() const override; + static ucs_status_t Unwrap(const Status& status); + + private: + ucs_status_t status_; +}; + +/// \brief Convert a UCS status to an Arrow Status. +ARROW_FLIGHT_EXPORT +Status FromUcsStatus(const std::string& context, ucs_status_t ucs_status); + +/// \brief Check if a UCS error code can be ignored in the context of +/// a disconnect. +static inline bool IsIgnorableDisconnectError(ucs_status_t ucs_status) { + // Not connected, connection reset: we're already disconnected + // Timeout: most likely disconnected, but we can't tell from our end + return ucs_status == UCS_OK || ucs_status == UCS_ERR_ENDPOINT_TIMEOUT || + ucs_status == UCS_ERR_NOT_CONNECTED || ucs_status == UCS_ERR_CONNECTION_RESET; +} + +/// \brief Helper to convert a Uri to a struct sockaddr (used in +/// ucp_listener_params_t) +/// +/// \return The length of the sockaddr +ARROW_FLIGHT_EXPORT +arrow::Result UriToSockaddr(const arrow::internal::Uri& uri, + struct sockaddr_storage* addr); + +ARROW_FLIGHT_EXPORT +arrow::Result SockaddrToString(const struct sockaddr_storage& address); + +} // namespace ucx +} // namespace transport +} // namespace flight +} // namespace arrow diff --git a/cpp/src/arrow/flight/transport_server.cc b/cpp/src/arrow/flight/transport_server.cc index fa5bf827100..4944a79b8fb 100644 --- a/cpp/src/arrow/flight/transport_server.cc +++ b/cpp/src/arrow/flight/transport_server.cc @@ -54,8 +54,7 @@ class TransportIpcMessageReader : public ipc::MessageReader { stream_finished_ = true; return nullptr; } - if (data->body && - ARROW_PREDICT_FALSE(!data->body->device()->Equals(*memory_manager_->device()))) { + if (data->body) { ARROW_ASSIGN_OR_RAISE(data->body, Buffer::ViewOrCopy(data->body, memory_manager_)); } *app_metadata_ = std::move(data->app_metadata); @@ -111,7 +110,7 @@ class TransportMessageReader final : public FlightMessageReader { arrow::Result Next() override { FlightStreamChunk out; - internal::FlightData* data; + internal::FlightData* data = nullptr; peekable_reader_->Peek(&data); if (!data) { out.app_metadata = nullptr; diff --git a/cpp/src/arrow/util/config.h.cmake b/cpp/src/arrow/util/config.h.cmake index 7d7c83185ef..55bc2d01005 100644 --- a/cpp/src/arrow/util/config.h.cmake +++ b/cpp/src/arrow/util/config.h.cmake @@ -48,5 +48,6 @@ #cmakedefine ARROW_S3 #cmakedefine ARROW_USE_NATIVE_INT128 #cmakedefine ARROW_WITH_OPENTELEMETRY +#cmakedefine ARROW_WITH_UCX #cmakedefine GRPCPP_PP_INCLUDE diff --git a/docs/source/cpp/flight.rst b/docs/source/cpp/flight.rst index c1d2e43b9f4..75aea3c47c1 100644 --- a/docs/source/cpp/flight.rst +++ b/docs/source/cpp/flight.rst @@ -117,3 +117,38 @@ success/failure of the request. Any other return values are specified through out parameters. They also take an optional :class:`options ` parameter that allows specifying a timeout for the call. + +Alternative Transports +====================== + +The standard transport for Arrow Flight is gRPC_. The C++ +implementation also experimentally supports a transport based on +UCX_. To use it, use the protocol scheme ``ucx:`` when starting a +server or creating a client. + +UCX Transport +------------- + +Not all features of the gRPC transport are supported. See +:ref:`status-flight-rpc` for details. Also note these specific +caveats: + +- The server creates an independent UCP worker for each client. This + consumes more resources but provides better throughput. +- The client creates an independent UCP worker for each RPC + call. Again, this trades off resource consumption for + performance. This also means that unlike with gRPC, it is + essentially equivalent to make all calls with a single client or + with multiple clients. +- The UCX transport attempts to avoid copies where possible. In some + cases, it can directly reuse UCX-allocated buffers to back + :class:`arrow::Buffer` objects, however, this will also extend the + lifetime of associated UCX resources beyond the lifetime of the + Flight client or server object. +- Depending on the transport that UCX itself selects, you may find + that increasing ``UCX_MM_SEG_SIZE`` from the default (around 8KB) to + around 60KB improves performance (UCX will copy more data in a + single call). + +.. _gRPC: https://grpc.io/ +.. _UCX: https://openucx.org/ diff --git a/docs/source/status.rst b/docs/source/status.rst index 7c6157357a8..c30caed2f84 100644 --- a/docs/source/status.rst +++ b/docs/source/status.rst @@ -144,35 +144,81 @@ Notes: .. seealso:: The :ref:`format-ipc` specification. +.. _status-flight-rpc: Flight RPC ========== -+-----------------------------+-------+-------+-------+------------+-------+-------+-------+ -| Flight RPC Feature | C++ | Java | Go | JavaScript | C# | Rust | Julia | -| | | | | | | | | -+=============================+=======+=======+=======+============+=======+=======+=======+ -| gRPC transport | ✓ | ✓ | ✓ | | ✓ (1) | | | -+-----------------------------+-------+-------+-------+------------+-------+-------+-------+ -| gRPC + TLS transport | ✓ | ✓ | ✓ | | ✓ | | | -+-----------------------------+-------+-------+-------+------------+-------+-------+-------+ -| RPC error codes | ✓ | ✓ | ✓ | | ✓ | | | -+-----------------------------+-------+-------+-------+------------+-------+-------+-------+ -| Authentication handlers | ✓ | ✓ | ✓ | | ✓ (2) | | | -+-----------------------------+-------+-------+-------+------------+-------+-------+-------+ -| Custom client middleware | ✓ | ✓ | ✓ | | | | | -+-----------------------------+-------+-------+-------+------------+-------+-------+-------+ -| Custom server middleware | ✓ | ✓ | ✓ | | | | | -+-----------------------------+-------+-------+-------+------------+-------+-------+-------+ +.. note:: Flight RPC is still experimental. + ++--------------------------------------------+-------+-------+-------+------------+-------+-------+-------+ +| Flight RPC Transport | C++ | Java | Go | JavaScript | C# | Rust | Julia | ++============================================+=======+=======+=======+============+=======+=======+=======+ +| gRPC_ transport (grpc:, grpc+tcp:) | ✓ | ✓ | ✓ | | ✓ | | | ++--------------------------------------------+-------+-------+-------+------------+-------+-------+-------+ +| gRPC domain socket transport (grpc+unix:) | ✓ | ✓ | ✓ | | ✓ | | | ++--------------------------------------------+-------+-------+-------+------------+-------+-------+-------+ +| gRPC + TLS transport (grpc+tls:) | ✓ | ✓ | ✓ | | ✓ | | | ++--------------------------------------------+-------+-------+-------+------------+-------+-------+-------+ +| UCX_ transport (ucx:) | ✓ | | | | | | | ++--------------------------------------------+-------+-------+-------+------------+-------+-------+-------+ + +Supported features in the gRPC transport: + ++--------------------------------------------+-------+-------+-------+------------+-------+-------+-------+ +| Flight RPC Feature | C++ | Java | Go | JavaScript | C# | Rust | Julia | ++============================================+=======+=======+=======+============+=======+=======+=======+ +| All RPC methods | ✓ | ✓ | ✓ | | × (1) | ✓ | | ++--------------------------------------------+-------+-------+-------+------------+-------+-------+-------+ +| Authentication handlers | ✓ | ✓ | ✓ | | ✓ (2) | ✓ | | ++--------------------------------------------+-------+-------+-------+------------+-------+-------+-------+ +| Call timeouts | ✓ | ✓ | ✓ | | | ✓ | | ++--------------------------------------------+-------+-------+-------+------------+-------+-------+-------+ +| Call cancellation | ✓ | ✓ | ✓ | | | ✓ | | ++--------------------------------------------+-------+-------+-------+------------+-------+-------+-------+ +| Concurrent client calls (3) | ✓ | ✓ | ✓ | | ✓ | ✓ | | ++--------------------------------------------+-------+-------+-------+------------+-------+-------+-------+ +| Custom middleware | ✓ | ✓ | ✓ | | | ✓ | | ++--------------------------------------------+-------+-------+-------+------------+-------+-------+-------+ +| RPC error codes | ✓ | ✓ | ✓ | | ✓ | ✓ | | ++--------------------------------------------+-------+-------+-------+------------+-------+-------+-------+ + +Supported features in the UCX transport: + ++--------------------------------------------+-------+-------+-------+------------+-------+-------+-------+ +| Flight RPC Feature | C++ | Java | Go | JavaScript | C# | Rust | Julia | ++============================================+=======+=======+=======+============+=======+=======+=======+ +| All RPC methods | × (4) | | | | | | | ++--------------------------------------------+-------+-------+-------+------------+-------+-------+-------+ +| Authentication handlers | | | | | | | | ++--------------------------------------------+-------+-------+-------+------------+-------+-------+-------+ +| Call timeouts | | | | | | | | ++--------------------------------------------+-------+-------+-------+------------+-------+-------+-------+ +| Call cancellation | | | | | | | | ++--------------------------------------------+-------+-------+-------+------------+-------+-------+-------+ +| Concurrent client calls | ✓ (5) | | | | | | | ++--------------------------------------------+-------+-------+-------+------------+-------+-------+-------+ +| Custom middleware | | | | | | | | ++--------------------------------------------+-------+-------+-------+------------+-------+-------+-------+ +| RPC error codes | ✓ | | | | | | | ++--------------------------------------------+-------+-------+-------+------------+-------+-------+-------+ Notes: * \(1) No support for handshake or DoExchange. * \(2) Support using AspNetCore authentication handlers. +* \(3) Whether a single client can support multiple concurrent calls. +* \(4) Only support for DoExchange, DoGet, DoPut, and GetFlightInfo. +* \(5) Each concurrent call is a separate connection to the server + (unlike gRPC where concurrent calls are multiplexed over a single + connection). This will generally provide better throughput but + consumes more resources both on the server and the client. .. seealso:: The :ref:`flight-rpc` specification. +.. _gRPC: https://grpc.io/ +.. _UCX: https://openucx.org/ C Data Interface ================