From b40bb2a0fc8436e66c39d6956837c5edf2da77e4 Mon Sep 17 00:00:00 2001 From: David Li Date: Mon, 20 Dec 2021 14:56:41 -0500 Subject: [PATCH 01/16] ARROW-15706: [C++][FlightRPC] Implement Flight UCX transport --- cpp/cmake_modules/DefineOptions.cmake | 4 + cpp/src/arrow/CMakeLists.txt | 4 + cpp/src/arrow/flight/CMakeLists.txt | 10 + cpp/src/arrow/flight/flight_benchmark.cc | 35 +- cpp/src/arrow/flight/perf_server.cc | 34 +- .../arrow/flight/transport/ucx/CMakeLists.txt | 77 ++ .../ucx/flight_transport_ucx_test.cc | 399 ++++++ cpp/src/arrow/flight/transport/ucx/ucx.cc | 43 + cpp/src/arrow/flight/transport/ucx/ucx.h | 35 + .../arrow/flight/transport/ucx/ucx_client.cc | 714 ++++++++++ .../flight/transport/ucx/ucx_internal.cc | 1148 +++++++++++++++++ .../arrow/flight/transport/ucx/ucx_internal.h | 352 +++++ .../arrow/flight/transport/ucx/ucx_server.cc | 660 ++++++++++ .../flight/transport/ucx/util_internal.cc | 212 +++ .../flight/transport/ucx/util_internal.h | 58 + cpp/src/arrow/flight/transport_server.cc | 5 +- cpp/src/arrow/util/config.h.cmake | 1 + docs/source/cpp/flight.rst | 35 + docs/source/status.rst | 78 +- 19 files changed, 3883 insertions(+), 21 deletions(-) create mode 100644 cpp/src/arrow/flight/transport/ucx/CMakeLists.txt create mode 100644 cpp/src/arrow/flight/transport/ucx/flight_transport_ucx_test.cc create mode 100644 cpp/src/arrow/flight/transport/ucx/ucx.cc create mode 100644 cpp/src/arrow/flight/transport/ucx/ucx.h create mode 100644 cpp/src/arrow/flight/transport/ucx/ucx_client.cc create mode 100644 cpp/src/arrow/flight/transport/ucx/ucx_internal.cc create mode 100644 cpp/src/arrow/flight/transport/ucx/ucx_internal.h create mode 100644 cpp/src/arrow/flight/transport/ucx/ucx_server.cc create mode 100644 cpp/src/arrow/flight/transport/ucx/util_internal.cc create mode 100644 cpp/src/arrow/flight/transport/ucx/util_internal.h diff --git a/cpp/cmake_modules/DefineOptions.cmake b/cpp/cmake_modules/DefineOptions.cmake index 05fc14bbc72..ec1e0b6352a 100644 --- a/cpp/cmake_modules/DefineOptions.cmake +++ b/cpp/cmake_modules/DefineOptions.cmake @@ -391,6 +391,10 @@ if("${CMAKE_SOURCE_DIR}" STREQUAL "${CMAKE_CURRENT_SOURCE_DIR}") define_option(ARROW_WITH_ZLIB "Build with zlib compression" OFF) define_option(ARROW_WITH_ZSTD "Build with zstd compression" OFF) + define_option(ARROW_WITH_UCX + "Build with UCX transport for Arrow Flight;(only used if ARROW_FLIGHT is ON)" + OFF) + define_option(ARROW_WITH_UTF8PROC "Build with support for Unicode properties using the utf8proc library;(only used if ARROW_COMPUTE is ON or ARROW_GANDIVA is ON)" ON) diff --git a/cpp/src/arrow/CMakeLists.txt b/cpp/src/arrow/CMakeLists.txt index e9e826097b3..b6f1e2481fa 100644 --- a/cpp/src/arrow/CMakeLists.txt +++ b/cpp/src/arrow/CMakeLists.txt @@ -747,6 +747,10 @@ endif() if(ARROW_FLIGHT) add_subdirectory(flight) + + if(ARROW_WITH_UCX) + add_subdirectory(flight/transport/ucx) + endif() endif() if(ARROW_FLIGHT_SQL) diff --git a/cpp/src/arrow/flight/CMakeLists.txt b/cpp/src/arrow/flight/CMakeLists.txt index 7447e675e08..f9d135654b4 100644 --- a/cpp/src/arrow/flight/CMakeLists.txt +++ b/cpp/src/arrow/flight/CMakeLists.txt @@ -313,4 +313,14 @@ if(ARROW_BUILD_BENCHMARKS) add_dependencies(arrow-flight-benchmark arrow-flight-perf-server) add_dependencies(arrow_flight arrow-flight-benchmark) + + if(ARROW_WITH_UCX) + if(ARROW_FLIGHT_TEST_LINKAGE STREQUAL "static") + target_link_libraries(arrow-flight-benchmark arrow_flight_transport_ucx_static) + target_link_libraries(arrow-flight-perf-server arrow_flight_transport_ucx_static) + else() + target_link_libraries(arrow-flight-benchmark arrow_flight_transport_ucx_shared) + target_link_libraries(arrow-flight-perf-server arrow_flight_transport_ucx_shared) + endif() + endif() endif(ARROW_BUILD_BENCHMARKS) diff --git a/cpp/src/arrow/flight/flight_benchmark.cc b/cpp/src/arrow/flight/flight_benchmark.cc index 872c67c80b7..b5594e2d056 100644 --- a/cpp/src/arrow/flight/flight_benchmark.cc +++ b/cpp/src/arrow/flight/flight_benchmark.cc @@ -40,12 +40,20 @@ #include "arrow/flight/test_util.h" #ifdef ARROW_CUDA +#include #include "arrow/gpu/cuda_api.h" #endif +#ifdef ARROW_WITH_UCX +#include "arrow/flight/transport/ucx/ucx.h" +#endif DEFINE_bool(cuda, false, "Allocate results in CUDA memory"); DEFINE_string(transport, "grpc", - "The network transport to use. Supported: \"grpc\" (default)."); + "The network transport to use. Supported: \"grpc\" (default)" +#ifdef ARROW_WITH_UCX + ", \"ucx\"" +#endif // ARROW_WITH_UCX + "."); DEFINE_string(server_host, "", "An existing performance server to benchmark against (leave blank to spawn " "one automatically)"); @@ -497,6 +505,21 @@ int main(int argc, char** argv) { options.disable_server_verification = true; } } + } else if (FLAGS_transport == "ucx") { +#ifdef ARROW_WITH_UCX + arrow::flight::transport::ucx::InitializeFlightUcx(); + if (FLAGS_test_unix || !FLAGS_server_unix.empty()) { + std::cerr << "Transport does not support domain sockets: " << FLAGS_transport + << std::endl; + return EXIT_FAILURE; + } + ARROW_CHECK_OK(arrow::flight::Location::Parse( + "ucx://" + FLAGS_server_host + ":" + std::to_string(FLAGS_server_port), + &location)); +#else + std::cerr << "Not built with transport: " << FLAGS_transport << std::endl; + return EXIT_FAILURE; +#endif } else { std::cerr << "Unknown transport: " << FLAGS_transport << std::endl; return EXIT_FAILURE; @@ -514,6 +537,16 @@ int main(int argc, char** argv) { ABORT_NOT_OK(arrow::cuda::CudaDeviceManager::Instance().Value(&manager)); ABORT_NOT_OK(manager->GetDevice(0).Value(&device)); call_options.memory_manager = device->default_memory_manager(); + + // Needed to prevent UCX warning + // cuda_md.c:162 UCX ERROR cuMemGetAddressRange(0x7f2ab5dc0000) error: invalid + // device context + std::shared_ptr context; + ABORT_NOT_OK(device->GetContext().Value(&context)); + auto cuda_status = cuCtxPushCurrent(reinterpret_cast(context->handle())); + if (cuda_status != CUDA_SUCCESS) { + ARROW_LOG(WARNING) << "CUDA error " << cuda_status; + } #else std::cerr << "-cuda requires that Arrow is built with ARROW_CUDA" << std::endl; return 1; diff --git a/cpp/src/arrow/flight/perf_server.cc b/cpp/src/arrow/flight/perf_server.cc index cc42ffedd68..6bb30f62ddd 100644 --- a/cpp/src/arrow/flight/perf_server.cc +++ b/cpp/src/arrow/flight/perf_server.cc @@ -19,6 +19,7 @@ #include #include +#include #include #include #include @@ -43,10 +44,17 @@ #ifdef ARROW_CUDA #include "arrow/gpu/cuda_api.h" #endif +#ifdef ARROW_WITH_UCX +#include "arrow/flight/transport/ucx/ucx.h" +#endif DEFINE_bool(cuda, false, "Allocate results in CUDA memory"); DEFINE_string(transport, "grpc", - "The network transport to use. Supported: \"grpc\" (default)."); + "The network transport to use. Supported: \"grpc\" (default)" +#ifdef ARROW_WITH_UCX + ", \"ucx\"" +#endif // ARROW_WITH_UCX + "."); DEFINE_string(server_host, "localhost", "Host where the server is running on"); DEFINE_int32(port, 31337, "Server port to listen on"); DEFINE_string(server_unix, "", "Unix socket path where the server is running on"); @@ -274,6 +282,29 @@ int main(int argc, char** argv) { ARROW_CHECK_OK(arrow::flight::Location::ForGrpcUnix(FLAGS_server_unix) .Value(&connect_location)); } + } else if (FLAGS_transport == "ucx") { +#ifdef ARROW_WITH_UCX + arrow::flight::transport::ucx::InitializeFlightUcx(); + if (FLAGS_server_unix.empty()) { + if (!FLAGS_cert_file.empty() || !FLAGS_key_file.empty()) { + std::cerr << "Transport does not support TLS: " << FLAGS_transport << std::endl; + return EXIT_FAILURE; + } + ARROW_CHECK_OK(arrow::flight::Location::Parse( + "ucx://" + FLAGS_server_host + ":" + std::to_string(FLAGS_port), + &bind_location)); + ARROW_CHECK_OK(arrow::flight::Location::Parse( + "ucx://" + FLAGS_server_host + ":" + std::to_string(FLAGS_port), + &connect_location)); + } else { + std::cerr << "Transport does not support domain sockets: " << FLAGS_transport + << std::endl; + return EXIT_FAILURE; + } +#else + std::cerr << "Not built with transport: " << FLAGS_transport << std::endl; + return EXIT_FAILURE; +#endif } else { std::cerr << "Unknown transport: " << FLAGS_transport << std::endl; return EXIT_FAILURE; @@ -308,6 +339,7 @@ int main(int argc, char** argv) { // Exit with a clean error code (0) on SIGTERM ARROW_CHECK_OK(g_server->SetShutdownOnSignals({SIGTERM})); std::cout << "Server transport: " << FLAGS_transport << std::endl; + std::cout << "Server location: " << connect_location.ToString() << std::endl; if (FLAGS_server_unix.empty()) { std::cout << "Server host: " << FLAGS_server_host << std::endl; std::cout << "Server port: " << FLAGS_port << std::endl; diff --git a/cpp/src/arrow/flight/transport/ucx/CMakeLists.txt b/cpp/src/arrow/flight/transport/ucx/CMakeLists.txt new file mode 100644 index 00000000000..2c0e71de59f --- /dev/null +++ b/cpp/src/arrow/flight/transport/ucx/CMakeLists.txt @@ -0,0 +1,77 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +add_custom_target(arrow_flight_transport_ucx) +arrow_install_all_headers("arrow/flight/transport/ucx") + +find_package(PkgConfig REQUIRED) +pkg_check_modules(UCX REQUIRED ucx) + +set(ARROW_FLIGHT_TRANSPORT_UCX_SRCS + ucx_client.cc + ucx_server.cc + ucx.cc + ucx_internal.cc + util_internal.cc) +set(ARROW_FLIGHT_TRANSPORT_UCX_LINK_LIBS) + +include_directories(SYSTEM ${UCX_INCLUDE_DIRS}) +list(APPEND ARROW_FLIGHT_TRANSPORT_UCX_LINK_LIBS ${UCX_LIBRARIES}) + +add_arrow_lib(arrow_flight_transport_ucx + # CMAKE_PACKAGE_NAME + # ArrowFlightTransportUcx + # PKG_CONFIG_NAME + # arrow-flight-transport-ucx + SOURCES + ${ARROW_FLIGHT_TRANSPORT_UCX_SRCS} + PRECOMPILED_HEADERS + "$<$:arrow/flight/transport/ucx/pch.h>" + DEPENDENCIES + SHARED_LINK_FLAGS + ${ARROW_VERSION_SCRIPT_FLAGS} # Defined in cpp/arrow/CMakeLists.txt + SHARED_LINK_LIBS + arrow_shared + arrow_flight_shared + ${ARROW_FLIGHT_TRANSPORT_UCX_LINK_LIBS} + STATIC_LINK_LIBS + arrow_static + arrow_flight_static + ${ARROW_FLIGHT_TRANSPORT_UCX_LINK_LIBS}) + +if(ARROW_BUILD_TESTS) + if(ARROW_FLIGHT_TEST_LINKAGE STREQUAL "static") + set(ARROW_FLIGHT_UCX_TEST_LINK_LIBS + arrow_static + arrow_flight_static + arrow_flight_testing_static + arrow_flight_transport_ucx_static + ${ARROW_TEST_LINK_LIBS}) + else() + set(ARROW_FLIGHT_UCX_TEST_LINK_LIBS + arrow_shared + arrow_flight_shared + arrow_flight_testing_shared + arrow_flight_transport_ucx_shared + ${ARROW_TEST_LINK_LIBS}) + endif() + add_arrow_test(flight_transport_ucx_test + STATIC_LINK_LIBS + ${ARROW_FLIGHT_UCX_TEST_LINK_LIBS} + LABELS + "arrow_flight") +endif() diff --git a/cpp/src/arrow/flight/transport/ucx/flight_transport_ucx_test.cc b/cpp/src/arrow/flight/transport/ucx/flight_transport_ucx_test.cc new file mode 100644 index 00000000000..bb9e15fe1bf --- /dev/null +++ b/cpp/src/arrow/flight/transport/ucx/flight_transport_ucx_test.cc @@ -0,0 +1,399 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#include +#include + +#include "arrow/array/array_base.h" +#include "arrow/flight/test_definitions.h" +#include "arrow/flight/test_util.h" +#include "arrow/flight/transport/ucx/ucx.h" +#include "arrow/table.h" +#include "arrow/testing/gtest_util.h" +#include "arrow/util/config.h" + +#ifdef UCP_API_VERSION +#error "UCX headers should not be in public API" +#endif + +#include "arrow/flight/transport/ucx/ucx_internal.h" + +#ifdef ARROW_CUDA +#include "arrow/gpu/cuda_api.h" +#endif + +namespace arrow { +namespace flight { + +class UcxEnvironment : public ::testing::Environment { + public: + void SetUp() override { transport::ucx::InitializeFlightUcx(); } +}; + +testing::Environment* const kUcxEnvironment = + testing::AddGlobalTestEnvironment(new UcxEnvironment()); + +//------------------------------------------------------------ +// Common transport tests + +class UcxConnectivityTest : public ConnectivityTest { + protected: + std::string transport() const override { return "ucx"; } +}; +ARROW_FLIGHT_TEST_CONNECTIVITY(UcxConnectivityTest); + +class UcxDataTest : public DataTest { + protected: + std::string transport() const override { return "ucx"; } +}; +ARROW_FLIGHT_TEST_DATA(UcxDataTest); + +class UcxDoPutTest : public DoPutTest { + protected: + std::string transport() const override { return "ucx"; } +}; +ARROW_FLIGHT_TEST_DO_PUT(UcxDoPutTest); + +class UcxAppMetadataTest : public AppMetadataTest { + protected: + std::string transport() const override { return "ucx"; } +}; +ARROW_FLIGHT_TEST_APP_METADATA(UcxAppMetadataTest); + +class UcxIpcOptionsTest : public IpcOptionsTest { + protected: + std::string transport() const override { return "ucx"; } +}; +ARROW_FLIGHT_TEST_IPC_OPTIONS(UcxIpcOptionsTest); + +class UcxCudaDataTest : public CudaDataTest { + protected: + std::string transport() const override { return "ucx"; } +}; +ARROW_FLIGHT_TEST_CUDA_DATA(UcxCudaDataTest); + +//------------------------------------------------------------ +// UCX internals tests + +constexpr std::initializer_list kStatusCodes = { + StatusCode::OK, + StatusCode::OutOfMemory, + StatusCode::KeyError, + StatusCode::TypeError, + StatusCode::Invalid, + StatusCode::IOError, + StatusCode::CapacityError, + StatusCode::IndexError, + StatusCode::Cancelled, + StatusCode::UnknownError, + StatusCode::NotImplemented, + StatusCode::SerializationError, + StatusCode::RError, + StatusCode::CodeGenError, + StatusCode::ExpressionValidationError, + StatusCode::ExecutionError, + StatusCode::AlreadyExists, +}; + +constexpr std::initializer_list kFlightStatusCodes = { + FlightStatusCode::Internal, FlightStatusCode::TimedOut, + FlightStatusCode::Cancelled, FlightStatusCode::Unauthenticated, + FlightStatusCode::Unauthorized, FlightStatusCode::Unavailable, + FlightStatusCode::Failed, +}; + +class TestStatusDetail : public StatusDetail { + public: + const char* type_id() const override { return "test-status-detail"; } + std::string ToString() const override { return "Custom status detail"; } +}; + +namespace transport { +namespace ucx { + +static constexpr std::initializer_list kFrameTypes = { + FrameType::kHeaders, FrameType::kBuffer, FrameType::kPayloadHeader, + FrameType::kPayloadBody, FrameType::kDisconnect, +}; + +TEST(FrameHeader, Basics) { + for (const auto frame_type : kFrameTypes) { + FrameHeader header; + ASSERT_OK(header.Set(frame_type, /*counter=*/42, /*body_size=*/65535)); + if (frame_type == FrameType::kDisconnect) { + ASSERT_RAISES(Cancelled, Frame::ParseHeader(header.data(), header.size())); + } else { + ASSERT_OK_AND_ASSIGN(auto frame, Frame::ParseHeader(header.data(), header.size())); + ASSERT_EQ(frame->type, frame_type); + ASSERT_EQ(frame->counter, 42); + ASSERT_EQ(frame->size, 65535); + } + } +} + +TEST(FrameHeader, FrameType) { + for (const auto frame_type : kFrameTypes) { + ASSERT_LE(static_cast(frame_type), static_cast(FrameType::kMaxFrameType)); + } +} + +TEST(HeadersFrame, Parse) { + const char* data = + ("\x00\x00\x00\x02\x00\x00\x00\x05\x00\x00\x00\x03x-foobar" + "\x00\x00\x00\x05\x00\x00\x00\x01x-bin\x01"); + constexpr int64_t size = 34; + + { + std::unique_ptr buffer( + new Buffer(reinterpret_cast(data), size)); + ASSERT_OK_AND_ASSIGN(auto headers, HeadersFrame::Parse(std::move(buffer))); + ASSERT_OK_AND_ASSIGN(auto foo, headers.Get("x-foo")); + ASSERT_EQ(foo, "bar"); + ASSERT_OK_AND_ASSIGN(auto bin, headers.Get("x-bin")); + ASSERT_EQ(bin, "\x01"); + } + { + std::unique_ptr buffer(new Buffer(reinterpret_cast(data), 3)); + EXPECT_RAISES_WITH_MESSAGE_THAT(Invalid, + ::testing::HasSubstr("expected number of headers"), + HeadersFrame::Parse(std::move(buffer))); + } + { + std::unique_ptr buffer(new Buffer(reinterpret_cast(data), 7)); + EXPECT_RAISES_WITH_MESSAGE_THAT(Invalid, + ::testing::HasSubstr("expected length of key 1"), + HeadersFrame::Parse(std::move(buffer))); + } + { + std::unique_ptr buffer( + new Buffer(reinterpret_cast(data), 10)); + EXPECT_RAISES_WITH_MESSAGE_THAT(Invalid, + ::testing::HasSubstr("expected length of value 1"), + HeadersFrame::Parse(std::move(buffer))); + } + { + std::unique_ptr buffer( + new Buffer(reinterpret_cast(data), 12)); + EXPECT_RAISES_WITH_MESSAGE_THAT( + Invalid, + ::testing::HasSubstr("expected key 1 to have length 5, but only 0 bytes remain"), + HeadersFrame::Parse(std::move(buffer))); + } + { + std::unique_ptr buffer( + new Buffer(reinterpret_cast(data), 17)); + EXPECT_RAISES_WITH_MESSAGE_THAT( + Invalid, + ::testing::HasSubstr( + "expected value 1 to have length 3, but only 0 bytes remain"), + HeadersFrame::Parse(std::move(buffer))); + } +} + +TEST(HeadersFrame, RoundTripStatus) { + for (const auto code : kStatusCodes) { + { + Status expected = code == StatusCode::OK ? Status() : Status(code, "foo"); + ASSERT_OK_AND_ASSIGN(auto headers, HeadersFrame::Make(expected, {})); + Status status; + ASSERT_OK(headers.GetStatus(&status)); + ASSERT_EQ(status, expected); + } + + if (code == StatusCode::OK) continue; + + // Attach a generic status detail + { + auto detail = std::make_shared(); + Status original(code, "foo", detail); + Status expected(code, "foo", + std::make_shared(FlightStatusCode::Internal, + detail->ToString())); + ASSERT_OK_AND_ASSIGN(auto headers, HeadersFrame::Make(expected, {})); + Status status; + ASSERT_OK(headers.GetStatus(&status)); + ASSERT_EQ(status, expected); + } + + // Attach a Flight status detail + for (const auto flight_code : kFlightStatusCodes) { + Status expected(code, "foo", + std::make_shared(flight_code, "extra")); + ASSERT_OK_AND_ASSIGN(auto headers, HeadersFrame::Make(expected, {})); + Status status; + ASSERT_OK(headers.GetStatus(&status)); + ASSERT_EQ(status, expected); + } + } +} +} // namespace ucx +} // namespace transport + +//------------------------------------------------------------ +// Ad-hoc UCX-specific tests + +class SimpleTestServer : public FlightServerBase { + public: + Status GetFlightInfo(const ServerCallContext& context, const FlightDescriptor& request, + std::unique_ptr* info) override { + if (request.path.size() > 0 && request.path[0] == "error") { + return status_; + } + auto examples = ExampleFlightInfo(); + *info = std::unique_ptr(new FlightInfo(examples[0])); + return Status::OK(); + } + + Status DoGet(const ServerCallContext& context, const Ticket& request, + std::unique_ptr* data_stream) override { + RecordBatchVector batches; + RETURN_NOT_OK(ExampleIntBatches(&batches)); + auto batch_reader = std::make_shared(batches[0]->schema(), batches); + *data_stream = std::unique_ptr(new RecordBatchStream(batch_reader)); + return Status::OK(); + } + + void set_error_status(Status st) { status_ = std::move(st); } + + private: + Status status_; +}; + +class TestUcx : public ::testing::Test { + public: + void SetUp() { + ASSERT_OK_AND_ASSIGN(auto location, Location::ForScheme("ucx", "localhost", 0)); + ASSERT_OK(MakeServer( + location, &server_, &client_, + [](FlightServerOptions* options) { return Status::OK(); }, + [](FlightClientOptions* options) { return Status::OK(); })); + } + + void TearDown() { + ASSERT_OK(client_->Close()); + ASSERT_OK(server_->Shutdown()); + } + + protected: + std::unique_ptr client_; + std::unique_ptr server_; +}; + +TEST_F(TestUcx, GetFlightInfo) { + auto descriptor = FlightDescriptor::Path({"foo", "bar"}); + std::unique_ptr info; + ASSERT_OK(client_->GetFlightInfo(descriptor, &info)); + // Test that we can reuse the connection + ASSERT_OK(client_->GetFlightInfo(descriptor, &info)); +} + +TEST_F(TestUcx, SequentialClients) { + std::unique_ptr client2; + ASSERT_OK(FlightClient::Connect(server_->location(), FlightClientOptions::Defaults(), + &client2)); + + Ticket ticket{"a"}; + + std::unique_ptr stream1, stream2; + std::shared_ptr table1, table2; + + ASSERT_OK(client_->DoGet(ticket, &stream1)); + ASSERT_OK(stream1->ReadAll(&table1)); + + ASSERT_OK(client_->DoGet(ticket, &stream2)); + ASSERT_OK(stream2->ReadAll(&table2)); + + AssertTablesEqual(*table1, *table2); +} + +TEST_F(TestUcx, ConcurrentClients) { + std::unique_ptr client2; + ASSERT_OK(FlightClient::Connect(server_->location(), FlightClientOptions::Defaults(), + &client2)); + + Ticket ticket{"a"}; + + std::unique_ptr stream1, stream2; + std::shared_ptr
table1, table2; + + ASSERT_OK(client_->DoGet(ticket, &stream1)); + ASSERT_OK(client2->DoGet(ticket, &stream2)); + + ASSERT_OK(stream1->ReadAll(&table1)); + ASSERT_OK(stream2->ReadAll(&table2)); + + AssertTablesEqual(*table1, *table2); +} + +TEST_F(TestUcx, Errors) { + auto descriptor = FlightDescriptor::Path({"error", "bar"}); + std::unique_ptr info; + + auto* server = reinterpret_cast(server_.get()); + for (const auto code : kStatusCodes) { + if (code == StatusCode::OK) continue; + + Status expected(code, "Error message"); + server->set_error_status(expected); + Status actual = client_->GetFlightInfo(descriptor, &info); + ASSERT_EQ(actual, expected); + + // Attach a generic status detail + { + auto detail = std::make_shared(); + server->set_error_status(Status(code, "foo", detail)); + Status expected(code, "foo", + std::make_shared(FlightStatusCode::Internal, + detail->ToString())); + Status actual = client_->GetFlightInfo(descriptor, &info); + ASSERT_EQ(actual, expected); + } + + // Attach a Flight status detail + for (const auto flight_code : kFlightStatusCodes) { + Status expected(code, "Error message", + std::make_shared(flight_code, "extra")); + server->set_error_status(expected); + Status actual = client_->GetFlightInfo(descriptor, &info); + ASSERT_EQ(actual, expected); + } + } +} + +TEST(TestUcxIpV6, DISABLED_IpV6Port) { + // TODO(lidavidm): while we can listen on IPv6 fine, we can't + // actually connect to it (ucp_conn_request_h appears to point to a + // port where nobody is listening) + + // Also, disabled in CI as machines lack an IPv6 interface + ASSERT_OK_AND_ASSIGN(auto location, Location::ForScheme("ucx", "[::1]", 0)); + + std::unique_ptr server(new SimpleTestServer()); + FlightServerOptions server_options(location); + ASSERT_OK(server->Init(server_options)); + + std::unique_ptr client; + FlightClientOptions client_options = FlightClientOptions::Defaults(); + ASSERT_OK(FlightClient::Connect(server->location(), client_options, &client)); + + auto descriptor = FlightDescriptor::Path({"foo", "bar"}); + std::unique_ptr info; + ASSERT_OK(client->GetFlightInfo(descriptor, &info)); +} + +} // namespace flight +} // namespace arrow diff --git a/cpp/src/arrow/flight/transport/ucx/ucx.cc b/cpp/src/arrow/flight/transport/ucx/ucx.cc new file mode 100644 index 00000000000..0b61dbb93e9 --- /dev/null +++ b/cpp/src/arrow/flight/transport/ucx/ucx.cc @@ -0,0 +1,43 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#include "arrow/flight/transport/ucx/ucx.h" + +#include + +#include "arrow/flight/transport.h" +#include "arrow/flight/transport/ucx/ucx_internal.h" +#include "arrow/flight/transport_server.h" +#include "arrow/util/logging.h" + +namespace arrow { +namespace flight { +namespace transport { +namespace ucx { + +std::once_flag kInitializeOnce; +void InitializeFlightUcx() { + std::call_once(kInitializeOnce, []() { + auto* registry = flight::internal::GetDefaultTransportRegistry(); + DCHECK_OK(registry->RegisterClient("ucx", MakeUcxClientImpl)); + DCHECK_OK(registry->RegisterServer("ucx", MakeUcxServerImpl)); + }); +} +} // namespace ucx +} // namespace transport +} // namespace flight +} // namespace arrow diff --git a/cpp/src/arrow/flight/transport/ucx/ucx.h b/cpp/src/arrow/flight/transport/ucx/ucx.h new file mode 100644 index 00000000000..dda2c83035c --- /dev/null +++ b/cpp/src/arrow/flight/transport/ucx/ucx.h @@ -0,0 +1,35 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +// Experimental UCX-based transport for Flight. + +#pragma once + +#include "arrow/flight/visibility.h" + +namespace arrow { +namespace flight { +namespace transport { +namespace ucx { + +ARROW_FLIGHT_EXPORT +void InitializeFlightUcx(); + +} // namespace ucx +} // namespace transport +} // namespace flight +} // namespace arrow diff --git a/cpp/src/arrow/flight/transport/ucx/ucx_client.cc b/cpp/src/arrow/flight/transport/ucx/ucx_client.cc new file mode 100644 index 00000000000..04122942ba5 --- /dev/null +++ b/cpp/src/arrow/flight/transport/ucx/ucx_client.cc @@ -0,0 +1,714 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +/// The client-side implementation of a UCX-based transport for +/// Flight. +/// +/// Each UCX driver is used to support one call at a time. This gives +/// the greatest throughput for data plane methods, but is relatively +/// expensive in terms of other resources, both for the server and the +/// client. (UCX drivers have multiple threading modes: single-thread +/// access, serialized access, and multi-thread access. Testing found +/// that multi-thread access incurred high synchronization costs.) +/// Hence, for concurrent calls in a single client, we must maintain +/// multiple drivers, and so unlike gRPC, there is no real difference +/// between using one client concurrently and using multiple +/// independent clients. + +#include "arrow/flight/transport/ucx/ucx_internal.h" + +#include +#include +#include +#include + +#include +#include + +#include "arrow/buffer.h" +#include "arrow/flight/client.h" +#include "arrow/flight/transport.h" +#include "arrow/flight/transport/ucx/util_internal.h" +#include "arrow/result.h" +#include "arrow/status.h" +#include "arrow/util/logging.h" +#include "arrow/util/make_unique.h" +#include "arrow/util/uri.h" + +namespace arrow { +namespace flight { +namespace transport { +namespace ucx { + +class UcxClientImpl; + +namespace { + +Status MergeStatuses(Status server_status, Status transport_status) { + if (server_status.ok()) { + if (transport_status.ok()) return server_status; + return transport_status; + } else if (transport_status.ok()) { + return server_status; + } + return Status::FromDetailAndArgs(server_status.code(), server_status.detail(), + server_status.message(), + ". Transport context: ", transport_status.ToString()); +} + +/// \brief An individual connection to the server. +class ClientConnection { + public: + ClientConnection() = default; + ARROW_DISALLOW_COPY_AND_ASSIGN(ClientConnection); + ARROW_DEFAULT_MOVE_AND_ASSIGN(ClientConnection); + ~ClientConnection() { DCHECK(!driver_) << "Connection was not closed!"; } + + Status Init(std::shared_ptr ucp_context, const arrow::internal::Uri& uri) { + auto status = InitImpl(std::move(ucp_context), uri); + // Clean up after-the-fact if we fail to initialize + if (!status.ok()) { + if (driver_) { + status = MergeStatuses(std::move(status), driver_->Close()); + driver_.reset(); + remote_endpoint_ = nullptr; + } + if (ucp_worker_) ucp_worker_.reset(); + } + return status; + } + + Status InitImpl(std::shared_ptr ucp_context, + const arrow::internal::Uri& uri) { + { + ucs_status_t status; + ucp_worker_params_t worker_params; + std::memset(&worker_params, 0, sizeof(worker_params)); + worker_params.field_mask = UCP_WORKER_PARAM_FIELD_THREAD_MODE; + worker_params.thread_mode = UCS_THREAD_MODE_SERIALIZED; + + ucp_worker_h ucp_worker; + status = ucp_worker_create(ucp_context->get(), &worker_params, &ucp_worker); + RETURN_NOT_OK(FromUcsStatus("ucp_worker_create", status)); + ucp_worker_.reset(new UcpWorker(std::move(ucp_context), ucp_worker)); + } + { + // Create endpoint for remote worker + struct sockaddr_storage connect_addr; + ARROW_ASSIGN_OR_RAISE(auto addrlen, UriToSockaddr(uri, &connect_addr)); + + ucp_ep_params_t params; + params.field_mask = UCP_EP_PARAM_FIELD_FLAGS | UCP_EP_PARAM_FIELD_NAME | + UCP_EP_PARAM_FIELD_SOCK_ADDR; + params.flags = UCP_EP_PARAMS_FLAGS_CLIENT_SERVER; + params.name = "UcxClientImpl"; + params.sockaddr.addr = reinterpret_cast(&connect_addr); + params.sockaddr.addrlen = addrlen; + + auto status = ucp_ep_create(ucp_worker_->get(), ¶ms, &remote_endpoint_); + RETURN_NOT_OK(FromUcsStatus("ucp_ep_create", status)); + } + + driver_.reset(new UcpCallDriver(ucp_worker_, remote_endpoint_)); + + { + // Set up Active Message (AM) handler + ucp_am_handler_param_t handler_params; + handler_params.field_mask = UCP_AM_HANDLER_PARAM_FIELD_ID | + UCP_AM_HANDLER_PARAM_FIELD_CB | + UCP_AM_HANDLER_PARAM_FIELD_ARG; + handler_params.id = kUcpAmHandlerId; + handler_params.cb = HandleIncomingActiveMessage; + handler_params.arg = driver_.get(); + ucs_status_t status = + ucp_worker_set_am_recv_handler(ucp_worker_->get(), &handler_params); + RETURN_NOT_OK(FromUcsStatus("ucp_worker_set_am_recv_handler", status)); + } + + return Status::OK(); + } + + Status Close() { + if (!driver_) return Status::OK(); + + auto status = driver_->SendFrame(FrameType::kDisconnect, nullptr, 0); + status = MergeStatuses(std::move(status), driver_->Close()); + + driver_.reset(); + remote_endpoint_ = nullptr; + ucp_worker_.reset(); + return status; + } + + UcpCallDriver* driver() { + DCHECK(driver_); + return driver_.get(); + } + + private: + static ucs_status_t HandleIncomingActiveMessage(void* self, const void* header, + size_t header_length, void* data, + size_t data_length, + const ucp_am_recv_param_t* param) { + auto* driver = reinterpret_cast(self); + return driver->RecvActiveMessage(header, header_length, data, data_length, param); + } + + std::shared_ptr ucp_worker_; + ucp_ep_h remote_endpoint_; + std::unique_ptr driver_; +}; + +class UcxClientStream : public internal::ClientDataStream { + public: + UcxClientStream(UcxClientImpl* impl, ClientConnection conn) + : impl_(impl), + conn_(std::move(conn)), + driver_(conn_.driver()), + writes_done_(false), + finished_(false) {} + + protected: + Status DoFinish() override; + + UcxClientImpl* impl_; + ClientConnection conn_; + UcpCallDriver* driver_; + bool writes_done_; + bool finished_; + Status io_status_; + Status server_status_; +}; + +class GetClientStream : public UcxClientStream { + public: + GetClientStream(UcxClientImpl* impl, ClientConnection conn) + : UcxClientStream(impl, std::move(conn)) { + writes_done_ = true; + } + + bool ReadData(internal::FlightData* data) override { + if (finished_) return false; + + bool success = true; + io_status_ = ReadImpl(data).Value(&success); + + if (!io_status_.ok() || !success) { + finished_ = true; + } + return success; + } + + private: + ::arrow::Result ReadImpl(internal::FlightData* data) { + ARROW_ASSIGN_OR_RAISE(auto frame, driver_->ReadNextFrame()); + + if (frame->type == FrameType::kHeaders) { + // Trailers, stream is over + ARROW_ASSIGN_OR_RAISE(auto headers, HeadersFrame::Parse(std::move(frame->buffer))); + RETURN_NOT_OK(headers.GetStatus(&server_status_)); + return false; + } + + RETURN_NOT_OK(driver_->ExpectFrameType(*frame, FrameType::kPayloadHeader)); + PayloadHeaderFrame payload_header(std::move(frame->buffer)); + RETURN_NOT_OK(payload_header.ToFlightData(data)); + + // DoGet does not support metadata-only messages, so we can always + // assume we have an IPC payload + ARROW_ASSIGN_OR_RAISE(auto message, ipc::Message::Open(data->metadata, nullptr)); + + if (ipc::Message::HasBody(message->type())) { + ARROW_ASSIGN_OR_RAISE(frame, driver_->ReadNextFrame()); + RETURN_NOT_OK(driver_->ExpectFrameType(*frame, FrameType::kPayloadBody)); + data->body = std::move(frame->buffer); + } + return true; + } +}; + +class WriteClientStream : public UcxClientStream { + public: + WriteClientStream(UcxClientImpl* impl, ClientConnection conn) + : UcxClientStream(impl, std::move(conn)) { + std::thread t(&WriteClientStream::DriveWorker, this); + driver_thread_.swap(t); + } + arrow::Result WriteData(const FlightPayload& payload) override { + std::unique_lock guard(driver_mutex_); + if (finished_ || writes_done_) return Status::Invalid("Already done writing"); + outgoing_ = driver_->SendFlightPayload(payload); + working_cv_.notify_all(); + received_cv_.wait(guard, [this] { return outgoing_.is_finished(); }); + + auto status = outgoing_.status(); + outgoing_ = Future<>(); + RETURN_NOT_OK(status); + return true; + } + Status WritesDone() override { + std::unique_lock guard(driver_mutex_); + if (!writes_done_) { + ARROW_ASSIGN_OR_RAISE(auto headers, HeadersFrame::Make({})); + outgoing_ = + driver_->SendFrameAsync(FrameType::kHeaders, std::move(headers).GetBuffer()); + working_cv_.notify_all(); + received_cv_.wait(guard, [this] { return outgoing_.is_finished(); }); + + writes_done_ = true; + auto status = outgoing_.status(); + outgoing_ = Future<>(); + RETURN_NOT_OK(status); + } + return Status::OK(); + } + + protected: + void JoinThread() { + try { + driver_thread_.join(); + } catch (const std::system_error&) { + // Ignore + } + } + void DriveWorker() { + while (true) { + { + std::unique_lock guard(driver_mutex_); + working_cv_.wait(guard, [this] { + return finished_ || incoming_.is_valid() || outgoing_.is_valid(); + }); + if (finished_) return; + } + + while (true) { + std::unique_lock guard(driver_mutex_); + if (!incoming_.is_valid() && !outgoing_.is_valid()) break; + if (incoming_.is_valid() && incoming_.is_finished()) { + if (!incoming_.status().ok()) { + io_status_ = incoming_.status(); + finished_ = true; + } else { + HandleIncomingMessage(*incoming_.result()); + } + incoming_ = Future>(); + received_cv_.notify_all(); + break; + } + if (outgoing_.is_valid() && outgoing_.is_finished()) { + received_cv_.notify_all(); + break; + } + driver_->MakeProgress(); + if (finished_) return; + } + } + } + + virtual void HandleIncomingMessage(const std::shared_ptr& frame) {} + + std::mutex driver_mutex_; + std::thread driver_thread_; + std::condition_variable received_cv_; + std::condition_variable working_cv_; + Future> incoming_; + Future<> outgoing_; +}; + +class PutClientStream : public WriteClientStream { + public: + using WriteClientStream::WriteClientStream; + bool ReadPutMetadata(std::shared_ptr* out) override { + std::unique_lock guard(driver_mutex_); + if (finished_) { + *out = nullptr; + guard.unlock(); + JoinThread(); + return false; + } + next_metadata_ = nullptr; + incoming_ = driver_->ReadFrameAsync(); + working_cv_.notify_all(); + received_cv_.wait(guard, [this] { return next_metadata_ != nullptr || finished_; }); + + if (finished_) { + *out = nullptr; + guard.unlock(); + JoinThread(); + return false; + } + *out = std::move(next_metadata_); + return true; + } + + private: + void HandleIncomingMessage(const std::shared_ptr& frame) override { + // No lock here, since this is called from DriveWorker() which is + // holding the lock + if (frame->type == FrameType::kBuffer) { + next_metadata_ = std::move(frame->buffer); + } else if (frame->type == FrameType::kHeaders) { + // Trailers, stream is over + finished_ = true; + HeadersFrame headers; + io_status_ = HeadersFrame::Parse(std::move(frame->buffer)).Value(&headers); + if (!io_status_.ok()) { + finished_ = true; + return; + } + io_status_ = headers.GetStatus(&server_status_); + if (!io_status_.ok()) { + finished_ = true; + return; + } + } else { + finished_ = true; + io_status_ = + Status::IOError("Unexpected frame type ", static_cast(frame->type)); + } + } + std::shared_ptr next_metadata_; +}; + +class ExchangeClientStream : public WriteClientStream { + public: + ExchangeClientStream(UcxClientImpl* impl, ClientConnection conn) + : WriteClientStream(impl, std::move(conn)), read_state_(ReadState::kFinished) {} + + bool ReadData(internal::FlightData* data) override { + std::unique_lock guard(driver_mutex_); + if (finished_) { + guard.unlock(); + JoinThread(); + return false; + } + + // Drive the read loop here. (We can't recursively call + // ReadFrameAsync below since the internal mutex is not + // recursive.) + read_state_ = ReadState::kExpectHeader; + incoming_ = driver_->ReadFrameAsync(); + working_cv_.notify_all(); + received_cv_.wait(guard, [this] { return read_state_ != ReadState::kExpectHeader; }); + if (read_state_ != ReadState::kFinished) { + incoming_ = driver_->ReadFrameAsync(); + working_cv_.notify_all(); + received_cv_.wait(guard, [this] { return read_state_ == ReadState::kFinished; }); + } + + if (finished_) { + guard.unlock(); + JoinThread(); + return false; + } + *data = std::move(next_data_); + return true; + } + + private: + enum class ReadState { + kFinished, + kExpectHeader, + kExpectBody, + }; + + std::string DebugExpectingString() { + switch (read_state_) { + case ReadState::kFinished: + return "(not expecting a frame)"; + case ReadState::kExpectHeader: + return "payload header frame"; + case ReadState::kExpectBody: + return "payload body frame"; + } + return "(unknown or invalid state)"; + } + + void HandleIncomingMessage(const std::shared_ptr& frame) override { + // No lock here, since this is called from MakeProgress() + // which is called under the lock already + if (frame->type == FrameType::kPayloadHeader) { + if (read_state_ != ReadState::kExpectHeader) { + finished_ = true; + io_status_ = Status::IOError("Got unexpected payload header frame, expected: ", + DebugExpectingString()); + return; + } + + PayloadHeaderFrame payload_header(std::move(frame->buffer)); + io_status_ = payload_header.ToFlightData(&next_data_); + if (!io_status_.ok()) { + finished_ = true; + return; + } + + if (next_data_.metadata) { + std::unique_ptr message; + io_status_ = ipc::Message::Open(next_data_.metadata, nullptr).Value(&message); + if (!io_status_.ok()) { + finished_ = true; + return; + } + if (ipc::Message::HasBody(message->type())) { + read_state_ = ReadState::kExpectBody; + return; + } + } + read_state_ = ReadState::kFinished; + } else if (frame->type == FrameType::kPayloadBody) { + next_data_.body = std::move(frame->buffer); + read_state_ = ReadState::kFinished; + } else if (frame->type == FrameType::kHeaders) { + // Trailers, stream is over + finished_ = true; + read_state_ = ReadState::kFinished; + HeadersFrame headers; + io_status_ = HeadersFrame::Parse(std::move(frame->buffer)).Value(&headers); + if (!io_status_.ok()) { + finished_ = true; + return; + } + io_status_ = headers.GetStatus(&server_status_); + if (!io_status_.ok()) { + finished_ = true; + return; + } + } else { + finished_ = true; + io_status_ = + Status::IOError("Unexpected frame type ", static_cast(frame->type)); + read_state_ = ReadState::kFinished; + } + } + + internal::FlightData next_data_; + ReadState read_state_; +}; +} // namespace + +class ARROW_FLIGHT_EXPORT UcxClientImpl + : public arrow::flight::internal::ClientTransport { + public: + UcxClientImpl() {} + + virtual ~UcxClientImpl() { + if (!ucp_context_) return; + auto status = Close(); + if (!status.ok()) { + ARROW_LOG(WARNING) << "UcxClientImpl errored in Close() in destructor: " + << status.ToString(); + } + } + + Status Init(const FlightClientOptions& options, const Location& location, + const arrow::internal::Uri& uri) override { + RETURN_NOT_OK(uri_.Parse(uri.ToString())); + { + ucp_config_t* ucp_config; + ucp_params_t ucp_params; + ucs_status_t status; + + status = ucp_config_read(nullptr, nullptr, &ucp_config); + RETURN_NOT_OK(FromUcsStatus("ucp_config_read", status)); + + std::memset(&ucp_params, 0, sizeof(ucp_params)); + ucp_params.field_mask = UCP_PARAM_FIELD_FEATURES; + ucp_params.features = UCP_FEATURE_AM | UCP_FEATURE_WAKEUP; + + ucp_context_h ucp_context; + status = ucp_init(&ucp_params, ucp_config, &ucp_context); + ucp_config_release(ucp_config); + RETURN_NOT_OK(FromUcsStatus("ucp_init", status)); + ucp_context_.reset(new UcpContext(ucp_context)); + } + + RETURN_NOT_OK(MakeConnection()); + return Status::OK(); + } + + Status Close() override { + std::unique_lock connections_mutex_; + while (!connections_.empty()) { + RETURN_NOT_OK(connections_.front().Close()); + connections_.pop_front(); + } + return Status::OK(); + } + + Status GetFlightInfo(const FlightCallOptions& options, + const FlightDescriptor& descriptor, + std::unique_ptr* info) override { + ARROW_ASSIGN_OR_RAISE(auto connection, CheckoutConnection(options)); + UcpCallDriver* driver = connection.driver(); + + auto impl = [&]() { + RETURN_NOT_OK(driver->StartCall(kMethodGetFlightInfo)); + + ARROW_ASSIGN_OR_RAISE(std::string payload, descriptor.SerializeToString()); + + RETURN_NOT_OK(driver->SendFrame(FrameType::kBuffer, + reinterpret_cast(payload.data()), + static_cast(payload.size()))); + + ARROW_ASSIGN_OR_RAISE(auto incoming_message, driver->ReadNextFrame()); + if (incoming_message->type == FrameType::kBuffer) { + ARROW_ASSIGN_OR_RAISE( + *info, FlightInfo::Deserialize(util::string_view(*incoming_message->buffer))); + ARROW_ASSIGN_OR_RAISE(incoming_message, driver->ReadNextFrame()); + } + RETURN_NOT_OK(driver->ExpectFrameType(*incoming_message, FrameType::kHeaders)); + ARROW_ASSIGN_OR_RAISE(auto headers, + HeadersFrame::Parse(std::move(incoming_message->buffer))); + Status status; + RETURN_NOT_OK(headers.GetStatus(&status)); + return status; + }; + auto status = impl(); + return MergeStatuses(std::move(status), ReturnConnection(std::move(connection))); + } + + Status DoExchange(const FlightCallOptions& options, + std::unique_ptr* out) override { + ARROW_ASSIGN_OR_RAISE(auto connection, CheckoutConnection(options)); + UcpCallDriver* driver = connection.driver(); + + auto status = driver->StartCall(kMethodDoExchange); + if (ARROW_PREDICT_TRUE(status.ok())) { + *out = + arrow::internal::make_unique(this, std::move(connection)); + return Status::OK(); + } + return MergeStatuses(std::move(status), ReturnConnection(std::move(connection))); + } + + Status DoGet(const FlightCallOptions& options, const Ticket& ticket, + std::unique_ptr* stream) override { + ARROW_ASSIGN_OR_RAISE(auto connection, CheckoutConnection(options)); + UcpCallDriver* driver = connection.driver(); + + auto impl = [&]() { + RETURN_NOT_OK(driver->StartCall(kMethodDoGet)); + ARROW_ASSIGN_OR_RAISE(std::string payload, ticket.SerializeToString()); + RETURN_NOT_OK(driver->SendFrame(FrameType::kBuffer, + reinterpret_cast(payload.data()), + static_cast(payload.size()))); + *stream = + arrow::internal::make_unique(this, std::move(connection)); + return Status::OK(); + }; + + auto status = impl(); + if (ARROW_PREDICT_TRUE(status.ok())) return status; + return MergeStatuses(std::move(status), ReturnConnection(std::move(connection))); + } + + Status DoPut(const FlightCallOptions& options, + std::unique_ptr* out) override { + ARROW_ASSIGN_OR_RAISE(auto connection, CheckoutConnection(options)); + UcpCallDriver* driver = connection.driver(); + + auto status = driver->StartCall(kMethodDoPut); + if (ARROW_PREDICT_TRUE(status.ok())) { + *out = arrow::internal::make_unique(this, std::move(connection)); + return Status::OK(); + } + return MergeStatuses(std::move(status), ReturnConnection(std::move(connection))); + } + + Status DoAction(const FlightCallOptions& options, const Action& action, + std::unique_ptr* results) override { + // XXX: fake this for now to get the perf test to work + return Status::OK(); + } + + Status MakeConnection() { + ClientConnection conn; + RETURN_NOT_OK(conn.Init(ucp_context_, uri_)); + connections_.push_back(std::move(conn)); + return Status::OK(); + } + + arrow::Result CheckoutConnection(const FlightCallOptions& options) { + std::unique_lock connections_mutex_; + if (connections_.empty()) RETURN_NOT_OK(MakeConnection()); + ClientConnection conn = std::move(connections_.front()); + conn.driver()->set_memory_manager(options.memory_manager); + conn.driver()->set_read_memory_pool(options.read_options.memory_pool); + conn.driver()->set_write_memory_pool(options.write_options.memory_pool); + connections_.pop_front(); + return conn; + } + + Status ReturnConnection(ClientConnection conn) { + std::unique_lock connections_mutex_; + // TODO(lidavidm): for future improvement: reclaim clients + // asynchronously in the background (try to avoid issues like + // constantly opening/closing clients because the application is + // just barely over the limit of open connections) + if (connections_.size() >= kMaxOpenConnections) { + RETURN_NOT_OK(conn.Close()); + return Status::OK(); + } + connections_.push_back(std::move(conn)); + return Status::OK(); + } + + private: + static constexpr size_t kMaxOpenConnections = 3; + + arrow::internal::Uri uri_; + std::shared_ptr ucp_context_; + std::mutex connections_mutex_; + std::deque connections_; +}; + +Status UcxClientStream::DoFinish() { + RETURN_NOT_OK(WritesDone()); + if (!finished_) { + internal::FlightData message; + std::shared_ptr metadata; + while (ReadData(&message)) { + } + while (ReadPutMetadata(&metadata)) { + } + finished_ = true; + } + if (impl_) { + auto status = impl_->ReturnConnection(std::move(conn_)); + impl_ = nullptr; + driver_ = nullptr; + if (!status.ok()) { + if (io_status_.ok()) { + io_status_ = std::move(status); + } else { + io_status_ = Status::FromDetailAndArgs( + io_status_.code(), io_status_.detail(), io_status_.message(), + ". Transport context: ", status.ToString()); + } + } + } + return MergeStatuses(server_status_, io_status_); +} + +std::unique_ptr MakeUcxClientImpl() { + return arrow::internal::make_unique(); +} + +} // namespace ucx +} // namespace transport +} // namespace flight +} // namespace arrow diff --git a/cpp/src/arrow/flight/transport/ucx/ucx_internal.cc b/cpp/src/arrow/flight/transport/ucx/ucx_internal.cc new file mode 100644 index 00000000000..5baabc5b805 --- /dev/null +++ b/cpp/src/arrow/flight/transport/ucx/ucx_internal.cc @@ -0,0 +1,1148 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#include "arrow/flight/transport/ucx/ucx_internal.h" + +#include +#include +#include +#include + +#include "arrow/buffer.h" +#include "arrow/flight/transport/ucx/util_internal.h" +#include "arrow/flight/types.h" +#include "arrow/util/base64.h" +#include "arrow/util/bit_util.h" +#include "arrow/util/logging.h" +#include "arrow/util/make_unique.h" +#include "arrow/util/uri.h" + +namespace arrow { +namespace flight { +namespace transport { +namespace ucx { + +// Defines to test different implementation strategies +// Enable the CONTIG path for CPU-only data +// #define ARROW_FLIGHT_UCX_SEND_CONTIG +// Enable ucp_mem_map in IOV path +// #define ARROW_FLIGHT_UCX_SEND_IOV_MAP + +constexpr char kHeaderMethod[] = ":method:"; + +namespace { +Status SizeToUInt32BytesBe(const int64_t in, uint8_t* out) { + if (ARROW_PREDICT_FALSE(in < 0)) { + return Status::Invalid("Length cannot be negative"); + } else if (ARROW_PREDICT_FALSE( + in > static_cast(std::numeric_limits::max()))) { + return Status::Invalid("Length cannot exceed uint32_t"); + } + UInt32ToBytesBe(static_cast(in), out); + return Status::OK(); +} +ucs_memory_type InferMemoryType(const Buffer& buffer) { + if (!buffer.is_cpu()) { + return UCS_MEMORY_TYPE_CUDA; + } + return UCS_MEMORY_TYPE_UNKNOWN; +} +void TryMapBuffer(ucp_context_h context, const void* buffer, const size_t size, + ucs_memory_type memory_type, ucp_mem_h* memh_p) { + ucp_mem_map_params_t map_param; + std::memset(&map_param, 0, sizeof(map_param)); + map_param.field_mask = UCP_MEM_MAP_PARAM_FIELD_ADDRESS | + UCP_MEM_MAP_PARAM_FIELD_LENGTH | + UCP_MEM_MAP_PARAM_FIELD_MEMORY_TYPE; + map_param.address = const_cast(buffer); + map_param.length = size; + map_param.memory_type = memory_type; + auto ucs_status = ucp_mem_map(context, &map_param, memh_p); + if (ucs_status != UCS_OK) { + *memh_p = nullptr; + ARROW_LOG(WARNING) << "Could not map memory: " + << FromUcsStatus("ucp_mem_map", ucs_status); + } +} +void TryMapBuffer(ucp_context_h context, const Buffer& buffer, ucp_mem_h* memh_p) { + TryMapBuffer(context, reinterpret_cast(buffer.address()), + static_cast(buffer.size()), InferMemoryType(buffer), memh_p); +} +void TryUnmapBuffer(ucp_context_h context, ucp_mem_h memh_p) { + if (memh_p) { + auto ucs_status = ucp_mem_unmap(context, memh_p); + if (ucs_status != UCS_OK) { + ARROW_LOG(WARNING) << "Could not unmap memory: " + << FromUcsStatus("ucp_mem_unmap", ucs_status); + } + } +} + +/// \brief Wrapper around a UCX zero copy buffer (a host memory DATA +/// buffer). +/// +/// Owns a reference to the associated worker to avoid undefined +/// behavior. +class UcxDataBuffer : public Buffer { + public: + explicit UcxDataBuffer(std::shared_ptr worker, void* data, size_t size) + : Buffer(const_cast(reinterpret_cast(data)), + static_cast(size)), + worker_(std::move(worker)) {} + + ~UcxDataBuffer() { + ucp_am_data_release(worker_->get(), + const_cast(reinterpret_cast(data()))); + } + + private: + std::shared_ptr worker_; +}; +}; // namespace + +constexpr size_t FrameHeader::kFrameHeaderBytes; +constexpr uint8_t FrameHeader::kFrameVersion; + +Status FrameHeader::Set(FrameType frame_type, uint32_t counter, int64_t body_size) { + header[0] = kFrameVersion; + header[1] = static_cast(frame_type); + UInt32ToBytesBe(counter, header.data() + 4); + RETURN_NOT_OK(SizeToUInt32BytesBe(body_size, header.data() + 8)); + return Status::OK(); +} + +arrow::Result> Frame::ParseHeader(const void* header, + size_t header_length) { + if (header_length < FrameHeader::kFrameHeaderBytes) { + return Status::IOError("Header is too short, must be at least ", + FrameHeader::kFrameHeaderBytes, " bytes, got ", header_length); + } + + const uint8_t* frame_header = reinterpret_cast(header); + if (frame_header[0] != FrameHeader::kFrameVersion) { + return Status::IOError("Expected frame version ", + static_cast(FrameHeader::kFrameVersion), " but got ", + static_cast(frame_header[0])); + } else if (frame_header[1] > static_cast(FrameType::kMaxFrameType)) { + return Status::IOError("Unknown frame type ", static_cast(frame_header[1])); + } + + const FrameType frame_type = static_cast(frame_header[1]); + const uint32_t frame_counter = BytesToUInt32Be(frame_header + 4); + const uint32_t frame_size = BytesToUInt32Be(frame_header + 8); + + if (frame_type == FrameType::kDisconnect) { + return Status::Cancelled("Client initiated disconnect"); + } + + return std::make_shared(frame_type, frame_size, frame_counter, nullptr); +} + +arrow::Result HeadersFrame::Parse(std::unique_ptr buffer) { + HeadersFrame result; + const uint8_t* payload = buffer->data(); + const uint8_t* end = payload + buffer->size(); + if (ARROW_PREDICT_FALSE((end - payload) < 4)) { + return Status::Invalid("Buffer underflow, expected number of headers"); + } + const uint32_t num_headers = BytesToUInt32Be(payload); + payload += 4; + for (uint32_t i = 0; i < num_headers; i++) { + if (ARROW_PREDICT_FALSE((end - payload) < 4)) { + return Status::Invalid("Buffer underflow, expected length of key ", i + 1); + } + const uint32_t key_length = BytesToUInt32Be(payload); + payload += 4; + + if (ARROW_PREDICT_FALSE((end - payload) < 4)) { + return Status::Invalid("Buffer underflow, expected length of value ", i + 1); + } + const uint32_t value_length = BytesToUInt32Be(payload); + payload += 4; + + if (ARROW_PREDICT_FALSE((end - payload) < key_length)) { + return Status::Invalid("Buffer underflow, expected key ", i + 1, " to have length ", + key_length, ", but only ", (end - payload), " bytes remain"); + } + const util::string_view key(reinterpret_cast(payload), key_length); + payload += key_length; + + if (ARROW_PREDICT_FALSE((end - payload) < value_length)) { + return Status::Invalid("Buffer underflow, expected value ", i + 1, + " to have length ", value_length, ", but only ", + (end - payload), " bytes remain"); + } + const util::string_view value(reinterpret_cast(payload), value_length); + payload += value_length; + result.headers_.emplace_back(key, value); + } + + result.buffer_ = std::move(buffer); + return result; +} +arrow::Result HeadersFrame::Make( + const std::vector>& headers) { + int32_t total_length = 4 /* # of headers */; + for (const auto& header : headers) { + total_length += 4 /* key length */ + 4 /* value length */ + + header.first.size() /* key */ + header.second.size(); + } + + ARROW_ASSIGN_OR_RAISE(auto buffer, AllocateBuffer(total_length)); + uint8_t* payload = buffer->mutable_data(); + + RETURN_NOT_OK(SizeToUInt32BytesBe(headers.size(), payload)); + payload += 4; + for (const auto& header : headers) { + RETURN_NOT_OK(SizeToUInt32BytesBe(header.first.size(), payload)); + payload += 4; + RETURN_NOT_OK(SizeToUInt32BytesBe(header.second.size(), payload)); + payload += 4; + std::memcpy(payload, header.first.data(), header.first.size()); + payload += header.first.size(); + std::memcpy(payload, header.second.data(), header.second.size()); + payload += header.second.size(); + } + return Parse(std::move(buffer)); +} +arrow::Result HeadersFrame::Make( + const Status& status, + const std::vector>& headers) { + auto all_headers = headers; + all_headers.emplace_back(kHeaderStatusCode, + std::to_string(static_cast(status.code()))); + all_headers.emplace_back(kHeaderStatusMessage, status.message()); + if (status.detail()) { + auto fsd = FlightStatusDetail::UnwrapStatus(status); + if (fsd) { + all_headers.emplace_back(kHeaderStatusDetailCode, + std::to_string(static_cast(fsd->code()))); + all_headers.emplace_back(kHeaderStatusDetail, fsd->extra_info()); + } else { + all_headers.emplace_back(kHeaderStatusDetail, status.detail()->ToString()); + } + } + return Make(all_headers); +} + +arrow::Result HeadersFrame::Get(const std::string& key) { + for (const auto& pair : headers_) { + if (pair.first == key) return pair.second; + } + return Status::KeyError(key); +} + +Status HeadersFrame::GetStatus(Status* out) { + util::string_view code_str, message_str; + auto status = Get(kHeaderStatusCode).Value(&code_str); + if (!status.ok()) { + return Status::KeyError("Server did not send status code header ", kHeaderStatusCode); + } + + StatusCode status_code = StatusCode::OK; + auto code = std::strtol(code_str.data(), nullptr, /*base=*/10); + switch (code) { + case 0: + status_code = StatusCode::OK; + break; + case 1: + status_code = StatusCode::OutOfMemory; + break; + case 2: + status_code = StatusCode::KeyError; + break; + case 3: + status_code = StatusCode::TypeError; + break; + case 4: + status_code = StatusCode::Invalid; + break; + case 5: + status_code = StatusCode::IOError; + break; + case 6: + status_code = StatusCode::CapacityError; + break; + case 7: + status_code = StatusCode::IndexError; + break; + case 8: + status_code = StatusCode::Cancelled; + break; + case 9: + status_code = StatusCode::UnknownError; + break; + case 10: + status_code = StatusCode::NotImplemented; + break; + case 11: + status_code = StatusCode::SerializationError; + break; + case 13: + status_code = StatusCode::RError; + break; + case 40: + status_code = StatusCode::CodeGenError; + break; + case 41: + status_code = StatusCode::ExpressionValidationError; + break; + case 42: + status_code = StatusCode::ExecutionError; + break; + case 45: + status_code = StatusCode::AlreadyExists; + break; + default: + status_code = StatusCode::UnknownError; + break; + } + if (status_code == StatusCode::OK) { + *out = Status::OK(); + return Status::OK(); + } + + status = Get(kHeaderStatusMessage).Value(&message_str); + if (!status.ok()) { + *out = Status(status_code, "Server did not send status message header", nullptr); + return Status::OK(); + } + + util::string_view detail_code_str, detail_str; + FlightStatusCode detail_code = FlightStatusCode::Internal; + + if (Get(kHeaderStatusDetailCode).Value(&detail_code_str).ok()) { + auto detail_code_int = std::strtol(detail_code_str.data(), nullptr, /*base=*/10); + switch (detail_code_int) { + case 1: + detail_code = FlightStatusCode::TimedOut; + break; + case 2: + detail_code = FlightStatusCode::Cancelled; + break; + case 3: + detail_code = FlightStatusCode::Unauthenticated; + break; + case 4: + detail_code = FlightStatusCode::Unauthorized; + break; + case 5: + detail_code = FlightStatusCode::Unavailable; + break; + case 6: + detail_code = FlightStatusCode::Failed; + break; + case 0: + default: + detail_code = FlightStatusCode::Internal; + break; + } + } + ARROW_UNUSED(Get(kHeaderStatusDetail).Value(&detail_str)); + + std::shared_ptr detail = nullptr; + if (!detail_str.empty()) { + detail = std::make_shared(detail_code, std::string(detail_str)); + } + *out = Status(status_code, std::string(message_str), std::move(detail)); + return Status::OK(); +} + +namespace { +static constexpr uint32_t kMissingFieldSentinel = std::numeric_limits::max(); +static constexpr uint32_t kInt32Max = + static_cast(std::numeric_limits::max()); +arrow::Result PayloadHeaderFieldSize(const std::string& field, + const std::shared_ptr& data, + uint32_t* total_size) { + if (!data) return kMissingFieldSentinel; + if (data->size() > kInt32Max) { + return Status::Invalid(field, " must be less than 2 GiB, was: ", data->size()); + } + *total_size += static_cast(data->size()); + // Check for underflow + if (*total_size < 0) return Status::Invalid("Payload header must fit in a uint32_t"); + return static_cast(data->size()); +} +uint8_t* PackField(uint32_t size, const std::shared_ptr& data, uint8_t* out) { + UInt32ToBytesBe(size, out); + if (size != kMissingFieldSentinel) { + std::memcpy(out + 4, data->data(), size); + return out + 4 + size; + } else { + return out + 4; + } +} +} // namespace + +arrow::Result PayloadHeaderFrame::Make(const FlightPayload& payload, + MemoryPool* memory_pool) { + // Assemble all non-data fields here. Presumably this is much less + // than data size so we will pay the copy. + + // Structure per field: [4 byte length][data]. If a field is not + // present, UINT32_MAX is used as the sentinel (since 0-sized fields + // are acceptable) + uint32_t header_size = 12; + ARROW_ASSIGN_OR_RAISE( + const uint32_t descriptor_size, + PayloadHeaderFieldSize("descriptor", payload.descriptor, &header_size)); + ARROW_ASSIGN_OR_RAISE( + const uint32_t app_metadata_size, + PayloadHeaderFieldSize("app_metadata", payload.app_metadata, &header_size)); + ARROW_ASSIGN_OR_RAISE( + const uint32_t ipc_metadata_size, + PayloadHeaderFieldSize("ipc_message.metadata", payload.ipc_message.metadata, + &header_size)); + + ARROW_ASSIGN_OR_RAISE(auto header_buffer, AllocateBuffer(header_size, memory_pool)); + uint8_t* payload_header = header_buffer->mutable_data(); + + payload_header = PackField(descriptor_size, payload.descriptor, payload_header); + payload_header = PackField(app_metadata_size, payload.app_metadata, payload_header); + payload_header = + PackField(ipc_metadata_size, payload.ipc_message.metadata, payload_header); + + return PayloadHeaderFrame(std::move(header_buffer)); +} +Status PayloadHeaderFrame::ToFlightData(internal::FlightData* data) { + std::shared_ptr buffer = std::move(buffer_); + + // Unpack the descriptor + uint32_t offset = 0; + uint32_t size = BytesToUInt32Be(buffer->data()); + offset += 4; + if (size != kMissingFieldSentinel) { + if (static_cast(offset + size) > buffer->size()) { + return Status::Invalid("Buffer is too small: expected ", offset + size, + " bytes but have ", buffer->size()); + } + util::string_view desc(reinterpret_cast(buffer->data() + offset), size); + data->descriptor.reset(new FlightDescriptor()); + ARROW_ASSIGN_OR_RAISE(*data->descriptor, FlightDescriptor::Deserialize(desc)); + offset += size; + } else { + data->descriptor = nullptr; + } + + // Unpack app_metadata + size = BytesToUInt32Be(buffer->data() + offset); + offset += 4; + // While we properly handle zero-size vs nullptr metadata here, gRPC + // doesn't (Protobuf doesn't differentiate between the two) + if (size != kMissingFieldSentinel) { + if (static_cast(offset + size) > buffer->size()) { + return Status::Invalid("Buffer is too small: expected ", offset + size, + " bytes but have ", buffer->size()); + } + data->app_metadata = SliceBuffer(buffer, offset, size); + offset += size; + } else { + data->app_metadata = nullptr; + } + + // Unpack the IPC header + size = BytesToUInt32Be(buffer->data() + offset); + offset += 4; + if (size != kMissingFieldSentinel) { + if (static_cast(offset + size) > buffer->size()) { + return Status::Invalid("Buffer is too small: expected ", offset + size, + " bytes but have ", buffer->size()); + } + data->metadata = SliceBuffer(std::move(buffer), offset, size); + } else { + data->metadata = nullptr; + } + data->body = nullptr; + return Status::OK(); +} + +// pImpl the driver since async methods require a stable address +class UcpCallDriver::Impl { + public: +#if defined(ARROW_FLIGHT_UCX_SEND_CONTIG) + constexpr static bool kEnableContigSend = true; +#else + constexpr static bool kEnableContigSend = false; +#endif + + Impl(std::shared_ptr worker, ucp_ep_h endpoint) + : padding_bytes_({0, 0, 0, 0, 0, 0, 0, 0}), + worker_(std::move(worker)), + endpoint_(endpoint), + read_memory_pool_(default_memory_pool()), + write_memory_pool_(default_memory_pool()), + memory_manager_(CPUDevice::Instance()->default_memory_manager()), + counter_(0) { +#if defined(ARROW_FLIGHT_UCX_SEND_IOV_MAP) + TryMapBuffer(worker_->context().get(), padding_bytes_.data(), padding_bytes_.size(), + UCS_MEMORY_TYPE_HOST, &padding_memh_p_); +#endif + + ucp_ep_attr_t attrs; + std::memset(&attrs, 0, sizeof(attrs)); + attrs.field_mask = UCP_EP_ATTR_FIELD_NAME; + if (ucp_ep_query(endpoint_, &attrs) == UCS_OK) { + name_ = attrs.name; + } else { + name_ = "(unknown remote)"; + } + } + + ~Impl() { +#if defined(ARROW_FLIGHT_UCX_SEND_IOV_MAP) + TryUnmapBuffer(worker_->context().get(), padding_memh_p_); +#endif + } + + arrow::Result> ReadNextFrame() { + auto fut = ReadFrameAsync(); + while (!fut.is_finished()) MakeProgress(); + RETURN_NOT_OK(fut.status()); + return fut.MoveResult(); + } + + Future> ReadFrameAsync() { + RETURN_NOT_OK(CheckClosed()); + + std::unique_lock guard(frame_mutex_); + if (ARROW_PREDICT_FALSE(!status_.ok())) return status_; + + const uint32_t counter_value = next_counter_++; + auto it = frames_.find(counter_value); + if (it != frames_.end()) { + Future> fut = it->second; + frames_.erase(it); + return fut; + } + auto pair = frames_.insert({counter_value, Future>::Make()}); + DCHECK(pair.second); + return pair.first->second; + } + + Status SendFrame(FrameType frame_type, const uint8_t* data, const int64_t size) { + static uint8_t kZeroes[1] = {0}; + + RETURN_NOT_OK(CheckClosed()); + + void* request = nullptr; + ucp_request_param_t request_param; + request_param.op_attr_mask = UCP_OP_ATTR_FIELD_FLAGS; + request_param.flags = UCP_AM_SEND_FLAG_REPLY; + + // Send frame header + FrameHeader header; + RETURN_NOT_OK(header.Set(frame_type, counter_++, size)); + if (size == 0) { + // UCX appears to crash on zero-byte payloads + request = ucp_am_send_nbx(endpoint_, kUcpAmHandlerId, header.data(), header.size(), + kZeroes, + /*size=*/1, &request_param); + } else { + request = ucp_am_send_nbx(endpoint_, kUcpAmHandlerId, header.data(), header.size(), + data, size, &request_param); + } + RETURN_NOT_OK(CompleteRequestBlocking("ucp_am_send_nbx", request)); + + return Status::OK(); + } + + Future<> SendFrameAsync(FrameType frame_type, std::unique_ptr buffer) { + RETURN_NOT_OK(CheckClosed()); + + ucp_request_param_t request_param; + std::memset(&request_param, 0, sizeof(request_param)); + request_param.op_attr_mask = UCP_OP_ATTR_FIELD_CALLBACK | UCP_OP_ATTR_FIELD_DATATYPE | + UCP_OP_ATTR_FIELD_FLAGS | UCP_OP_ATTR_FIELD_USER_DATA; + request_param.cb.send = AmSendCallback; + request_param.datatype = ucp_dt_make_contig(1); + request_param.flags = UCP_AM_SEND_FLAG_REPLY; + + const int64_t size = buffer->size(); + if (size == 0) { + // UCX appears to crash on zero-byte payloads + ARROW_ASSIGN_OR_RAISE(auto buffer, AllocateBuffer(1, write_memory_pool_)); + } + + std::unique_ptr pending_send(new PendingContigSend()); + RETURN_NOT_OK(pending_send->header.Set(frame_type, counter_++, size)); + pending_send->ipc_message = std::move(buffer); + pending_send->driver = this; + pending_send->completed = Future<>::Make(); + pending_send->memh_p = nullptr; + + request_param.user_data = pending_send.release(); + { + auto* pending_send = reinterpret_cast(request_param.user_data); + + void* request = ucp_am_send_nbx( + endpoint_, kUcpAmHandlerId, pending_send->header.data(), + pending_send->header.size(), + reinterpret_cast(pending_send->ipc_message->mutable_data()), + static_cast(pending_send->ipc_message->size()), &request_param); + if (!request) { + // Request completed immediately + delete pending_send; + return Status::OK(); + } else if (UCS_PTR_IS_ERR(request)) { + delete pending_send; + return FromUcsStatus("ucp_am_send_nbx", UCS_PTR_STATUS(request)); + } + return pending_send->completed; + } + } + + Future<> SendFlightPayload(const FlightPayload& payload) { + static const int64_t kMaxBatchSize = std::numeric_limits::max(); + RETURN_NOT_OK(CheckClosed()); + + if (payload.ipc_message.body_length > kMaxBatchSize) { + return Status::Invalid("Cannot send record batches exceeding 2GiB yet"); + } + + { + ARROW_ASSIGN_OR_RAISE(auto frame, + PayloadHeaderFrame::Make(payload, write_memory_pool_)); + RETURN_NOT_OK(SendFrame(FrameType::kPayloadHeader, frame.data(), frame.size())); + } + + if (!ipc::Message::HasBody(payload.ipc_message.type)) { + return Status::OK(); + } + + // While IOV (scatter-gather) might seem like it avoids a memcpy, + // profiling shows that at least for the TCP/SHM/RDMA transports, + // UCX just does a memcpy internally. Furthermore, on the receiver + // side, a sender-side IOV send prevents optimizations based on + // mapped buffers (UCX will memcpy to the destination buffer + // regardless of whether it's mapped or not). + + // If all buffers are on the CPU, concatenate them ourselves and + // do a regular send to avoid this. Else, use IOV and let UCX + // figure out what to do. + + // Weirdness: UCX prefers TCP over shared memory for CONTIG? We + // can avoid this by setting UCX_RNDV_THRESH=inf, this will make + // UCX prefer shared memory again. However, we still want to avoid + // the CONTIG path when shared memory is available, because the + // total amount of time spent in memcpy is greater than using IOV + // and letting UCX handle it. + + // Consider: if we can figure out how to make IOV always as fast + // as CONTIG, we can just send the metadata fields as part of the + // IOV payload and avoid having to send two distinct messages. + + bool all_cpu = true; + int32_t total_buffers = 0; + for (const auto& buffer : payload.ipc_message.body_buffers) { + if (!buffer || buffer->size() == 0) continue; + all_cpu = all_cpu && buffer->is_cpu(); + total_buffers++; + + // Arrow IPC requires that we align buffers to 8 byte boundary + const auto remainder = static_cast( + bit_util::RoundUpToMultipleOf8(buffer->size()) - buffer->size()); + if (remainder) total_buffers++; + } + + ucp_request_param_t request_param; + std::memset(&request_param, 0, sizeof(request_param)); + request_param.op_attr_mask = UCP_OP_ATTR_FIELD_CALLBACK | UCP_OP_ATTR_FIELD_DATATYPE | + UCP_OP_ATTR_FIELD_FLAGS | UCP_OP_ATTR_FIELD_USER_DATA; + request_param.cb.send = AmSendCallback; + request_param.flags = UCP_AM_SEND_FLAG_REPLY; + + std::unique_ptr pending_send; + void* send_data = nullptr; + size_t send_size = 0; + + if (!all_cpu) { + request_param.op_attr_mask = + request_param.op_attr_mask | UCP_OP_ATTR_FIELD_MEMORY_TYPE; + // XXX: UCX doesn't appear to autodetect this correctly if we + // use UNKNOWN + request_param.memory_type = UCS_MEMORY_TYPE_CUDA; + } + + if (kEnableContigSend && all_cpu) { + // CONTIG - concatenate buffers into one before sending + + // TODO(lidavidm): this needs to be pipelined since it can be expensive. + // Preliminary profiling shows ~5% overhead just from mapping the buffer + // alone (on Infiniband; it seems to be trivial for shared memory) + request_param.datatype = ucp_dt_make_contig(1); + pending_send = arrow::internal::make_unique(); + auto* pending_contig = reinterpret_cast(pending_send.get()); + + ARROW_ASSIGN_OR_RAISE( + pending_contig->ipc_message, + AllocateBuffer(payload.ipc_message.body_length, write_memory_pool_)); + TryMapBuffer(worker_->context().get(), *pending_contig->ipc_message, + &pending_contig->memh_p); + + uint8_t* ipc_message = pending_contig->ipc_message->mutable_data(); + for (const auto& buffer : payload.ipc_message.body_buffers) { + if (!buffer || buffer->size() == 0) continue; + + std::memcpy(ipc_message, buffer->data(), buffer->size()); + ipc_message += buffer->size(); + + const auto remainder = static_cast( + bit_util::RoundUpToMultipleOf8(buffer->size()) - buffer->size()); + if (remainder) { + std::memset(ipc_message, 0, remainder); + ipc_message += remainder; + } + } + + send_data = reinterpret_cast(pending_contig->ipc_message->mutable_data()); + send_size = static_cast(pending_contig->ipc_message->size()); + } else { + // IOV - let UCX use scatter-gather path + request_param.datatype = UCP_DATATYPE_IOV; + pending_send = arrow::internal::make_unique(); + auto* pending_iov = reinterpret_cast(pending_send.get()); + + pending_iov->payload = payload; + pending_iov->iovs.resize(total_buffers); + ucp_dt_iov_t* iov = pending_iov->iovs.data(); +#if defined(ARROW_FLIGHT_UCX_SEND_IOV_MAP) + // XXX: this seems to have no benefits in tests so far + pending_iov->memh_ps.resize(total_buffers); + ucp_mem_h* memh_p = pending_iov->memh_ps.data(); +#endif + for (const auto& buffer : payload.ipc_message.body_buffers) { + if (!buffer || buffer->size() == 0) continue; + + iov->buffer = const_cast(reinterpret_cast(buffer->address())); + iov->length = buffer->size(); + ++iov; + +#if defined(ARROW_FLIGHT_UCX_SEND_IOV_MAP) + TryMapBuffer(worker_->context().get(), *buffer, memh_p); + memh_p++; +#endif + + const auto remainder = static_cast( + bit_util::RoundUpToMultipleOf8(buffer->size()) - buffer->size()); + if (remainder) { + iov->buffer = + const_cast(reinterpret_cast(padding_bytes_.data())); + iov->length = remainder; + ++iov; + } + } + + send_data = pending_iov->iovs.data(); + send_size = pending_iov->iovs.size(); + } + + RETURN_NOT_OK(pending_send->header.Set(FrameType::kPayloadBody, counter_++, + payload.ipc_message.body_length)); + pending_send->driver = this; + pending_send->completed = Future<>::Make(); + + request_param.user_data = pending_send.release(); + { + auto* pending_send = reinterpret_cast(request_param.user_data); + + void* request = ucp_am_send_nbx( + endpoint_, kUcpAmHandlerId, pending_send->header.data(), + pending_send->header.size(), send_data, send_size, &request_param); + if (!request) { + // Request completed immediately + delete pending_send; + return Status::OK(); + } else if (UCS_PTR_IS_ERR(request)) { + delete pending_send; + return FromUcsStatus("ucp_am_send_nbx", UCS_PTR_STATUS(request)); + } + return pending_send->completed; + } + } + + Status Close() { + if (!endpoint_) return Status::OK(); + + for (auto& item : frames_) { + item.second.MarkFinished(Status::Cancelled("UcpCallDriver is being closed")); + } + frames_.clear(); + + void* request = ucp_ep_close_nb(endpoint_, UCP_EP_CLOSE_MODE_FLUSH); + ucs_status_t status = UCS_OK; + std::string origin = "ucp_ep_close_nb"; + if (UCS_PTR_IS_ERR(request)) { + status = UCS_PTR_STATUS(request); + } else if (UCS_PTR_IS_PTR(request)) { + origin = "ucp_request_check_status"; + while ((status = ucp_request_check_status(request)) == UCS_INPROGRESS) { + MakeProgress(); + } + ucp_request_free(request); + } else { + DCHECK(!request); + } + + endpoint_ = nullptr; + if (status != UCS_OK && status != UCS_ERR_ENDPOINT_TIMEOUT && + status != UCS_ERR_NOT_CONNECTED) { + // Ignore timeout, not connected + return FromUcsStatus(origin, status); + } + return Status::OK(); + } + + void MakeProgress() { ucp_worker_progress(worker_->get()); } + + void Push(std::shared_ptr frame) { + std::unique_lock guard(frame_mutex_); + if (ARROW_PREDICT_FALSE(!status_.ok())) return; + auto pair = frames_.insert({frame->counter, frame}); + if (!pair.second) { + pair.first->second.MarkFinished(std::move(frame)); + frames_.erase(pair.first); + } + } + + void Push(Status status) { + std::unique_lock guard(frame_mutex_); + status_ = std::move(status); + for (auto& item : frames_) { + item.second.MarkFinished(status_); + } + frames_.clear(); + } + + ucs_status_t RecvActiveMessage(const void* header, size_t header_length, void* data, + const size_t data_length, + const ucp_am_recv_param_t* param) { + auto maybe_status = + RecvActiveMessageImpl(header, header_length, data, data_length, param); + if (!maybe_status.ok()) { + Push(maybe_status.status()); + return UCS_OK; + } + return maybe_status.MoveValueUnsafe(); + } + + const std::shared_ptr& memory_manager() const { return memory_manager_; } + void set_memory_manager(std::shared_ptr memory_manager) { + if (memory_manager) { + memory_manager_ = std::move(memory_manager); + } else { + memory_manager_ = CPUDevice::Instance()->default_memory_manager(); + } + } + void set_read_memory_pool(MemoryPool* pool) { + read_memory_pool_ = pool ? pool : default_memory_pool(); + } + void set_write_memory_pool(MemoryPool* pool) { + write_memory_pool_ = pool ? pool : default_memory_pool(); + } + + private: + class PendingAmSend { + public: + virtual ~PendingAmSend() = default; + UcpCallDriver::Impl* driver; + Future<> completed; + FrameHeader header; + }; + + class PendingContigSend : public PendingAmSend { + public: + std::unique_ptr ipc_message; + ucp_mem_h memh_p; + + virtual ~PendingContigSend() { + TryUnmapBuffer(driver->worker_->context().get(), memh_p); + } + }; + + class PendingIovSend : public PendingAmSend { + public: + FlightPayload payload; + std::vector iovs; +#if defined(ARROW_FLIGHT_UCX_SEND_IOV_MAP) + std::vector memh_ps; + + virtual ~PendingIovSend() { + for (ucp_mem_h memh_p : memh_ps) { + TryUnmapBuffer(driver->worker_->context().get(), memh_p); + } + } +#endif + }; + + struct PendingAmRecv { + UcpCallDriver::Impl* driver; + std::shared_ptr frame; + ucp_mem_h memh_p; + + PendingAmRecv(UcpCallDriver::Impl* driver_, std::shared_ptr frame_) + : driver(driver_), frame(std::move(frame_)) {} + + ~PendingAmRecv() { TryUnmapBuffer(driver->worker_->context().get(), memh_p); } + }; + + static void AmSendCallback(void* request, ucs_status_t status, void* user_data) { + auto* pending_send = reinterpret_cast(user_data); + if (status == UCS_OK) { + pending_send->completed.MarkFinished(); + } else { + pending_send->completed.MarkFinished(FromUcsStatus("ucp_am_send_nbx", status)); + } + // TODO(lidavidm): delete should occur on a background thread if there's mapped + // buffers, since unmapping can be nontrivial and we don't want to block + // the thread doing UCX work. (Borrow the Rust transfer-and-drop pattern.) + delete pending_send; + ucp_request_free(request); + } + + static void AmRecvCallback(void* request, ucs_status_t status, size_t length, + void* user_data) { + auto* pending_recv = reinterpret_cast(user_data); + ucp_request_free(request); + if (status != UCS_OK) { + pending_recv->driver->Push( + FromUcsStatus("ucp_am_recv_data_nbx (callback)", status)); + } else { + pending_recv->driver->Push(std::move(pending_recv->frame)); + } + delete pending_recv; + } + + arrow::Result RecvActiveMessageImpl(const void* header, + size_t header_length, void* data, + const size_t data_length, + const ucp_am_recv_param_t* param) { + DCHECK(param->recv_attr & UCP_AM_RECV_ATTR_FIELD_REPLY_EP); + + if (data_length > static_cast(std::numeric_limits::max())) { + return Status::Invalid( + "Cannot allocate buffer greater than int64_t max, requested: ", data_length); + } + + ARROW_ASSIGN_OR_RAISE(auto frame, Frame::ParseHeader(header, header_length)); + if (data_length < frame->size) { + return Status::IOError("Expected frame of ", frame->size, " bytes, but got only ", + data_length); + } + + if ((param->recv_attr & UCP_AM_RECV_ATTR_FLAG_DATA) && + (frame->type != FrameType::kPayloadBody || memory_manager_->is_cpu())) { + // Zero-copy path. UCX-allocated buffer must be freed later. + + // XXX: this buffer can NOT be freed until AFTER we return from + // this handler. Otherwise, UCX won't have fully set up its + // internal data structures (allocated just before the buffer) + // and we'll crash when we free the buffer. Effectively: we can + // never use Then/AddCallback on a Future<> from ReadFrameAsync, + // because we might run the callback synchronously (which might + // free the buffer) when we call Push here. + frame->buffer = + arrow::internal::make_unique(worker_, data, data_length); + Push(std::move(frame)); + return UCS_INPROGRESS; + } + + if ((param->recv_attr & UCP_AM_RECV_ATTR_FLAG_DATA) || + (param->recv_attr & UCP_AM_RECV_ATTR_FLAG_RNDV)) { + // Rendezvous protocol (RNDV), or unpack to destination (DATA). + + // We want to map/pin/register the buffer for faster transfer + // where possible. (It gets unmapped in ~PendingAmRecv.) + // TODO(lidavidm): This takes non-trivial time, so return + // UCS_INPROGRESS, kick off the allocation in the background, + // and recv the data later (is it allowed to call + // ucp_am_recv_data_nbx asynchronously?). + if (frame->type == FrameType::kPayloadBody) { + ARROW_ASSIGN_OR_RAISE(frame->buffer, + memory_manager_->AllocateBuffer(data_length)); + } else { + ARROW_ASSIGN_OR_RAISE(frame->buffer, + AllocateBuffer(data_length, read_memory_pool_)); + } + + PendingAmRecv* pending_recv = new PendingAmRecv(this, std::move(frame)); + TryMapBuffer(worker_->context().get(), *pending_recv->frame->buffer, + &pending_recv->memh_p); + + ucp_request_param_t recv_param; + recv_param.op_attr_mask = UCP_OP_ATTR_FIELD_CALLBACK | + UCP_OP_ATTR_FIELD_MEMORY_TYPE | + UCP_OP_ATTR_FIELD_USER_DATA; + recv_param.cb.recv_am = AmRecvCallback; + recv_param.user_data = pending_recv; + recv_param.memory_type = InferMemoryType(*pending_recv->frame->buffer); + + void* dest = + reinterpret_cast(pending_recv->frame->buffer->mutable_address()); + void* request = + ucp_am_recv_data_nbx(worker_->get(), data, dest, data_length, &recv_param); + if (UCS_PTR_IS_ERR(request)) { + delete pending_recv; + return FromUcsStatus("ucp_am_recv_data_nbx", UCS_PTR_STATUS(request)); + } else if (!request) { + // Request completed instantly + Push(std::move(pending_recv->frame)); + delete pending_recv; + } + return UCS_OK; + } else { + // Data will be freed after callback returns - copy to buffer + if (frame->type != FrameType::kPayloadBody || memory_manager_->is_cpu()) { + ARROW_ASSIGN_OR_RAISE(frame->buffer, + AllocateBuffer(data_length, read_memory_pool_)); + std::memcpy(frame->buffer->mutable_data(), data, data_length); + } else { + ARROW_ASSIGN_OR_RAISE( + frame->buffer, + MemoryManager::CopyNonOwned(Buffer(reinterpret_cast(data), + static_cast(data_length)), + memory_manager_)); + } + Push(std::move(frame)); + return UCS_OK; + } + } + + Status CompleteRequestBlocking(const std::string& context, void* request) { + if (UCS_PTR_IS_ERR(request)) { + return FromUcsStatus(context, UCS_PTR_STATUS(request)); + } else if (UCS_PTR_IS_PTR(request)) { + while (true) { + auto status = ucp_request_check_status(request); + if (status == UCS_OK) { + break; + } else if (status != UCS_INPROGRESS) { + ucp_request_release(request); + return FromUcsStatus("ucp_request_check_status", status); + } + MakeProgress(); + } + ucp_request_free(request); + } else { + // Send was completed instantly + DCHECK(!request); + } + return Status::OK(); + } + + Status CheckClosed() { + if (!endpoint_) { + return Status::Invalid("UcpCallDriver is closed"); + } + return Status::OK(); + } + + const std::array padding_bytes_; +#if defined(ARROW_FLIGHT_UCX_SEND_IOV_MAP) + ucp_mem_h padding_memh_p_; +#endif + + std::shared_ptr worker_; + ucp_ep_h endpoint_; + MemoryPool* read_memory_pool_; + MemoryPool* write_memory_pool_; + std::shared_ptr memory_manager_; + + // Internal name for logging/tracing + std::string name_; + // Counter used to reorder messages + uint32_t counter_ = 0; + + std::mutex frame_mutex_; + Status status_; + std::unordered_map>> frames_; + uint32_t next_counter_ = 0; +}; + +UcpCallDriver::UcpCallDriver(std::shared_ptr worker, ucp_ep_h endpoint) + : impl_(new Impl(std::move(worker), endpoint)) {} +UcpCallDriver::UcpCallDriver(UcpCallDriver&&) = default; +UcpCallDriver& UcpCallDriver::operator=(UcpCallDriver&&) = default; +UcpCallDriver::~UcpCallDriver() = default; + +arrow::Result> UcpCallDriver::ReadNextFrame() { + return impl_->ReadNextFrame(); +} + +Future> UcpCallDriver::ReadFrameAsync() { + return impl_->ReadFrameAsync(); +} + +Status UcpCallDriver::ExpectFrameType(const Frame& frame, FrameType type) { + if (frame.type != type) { + return Status::IOError("Expected frame type ", static_cast(type), + ", but got frame type ", static_cast(frame.type)); + } + return Status::OK(); +} + +Status UcpCallDriver::StartCall(const std::string& method) { + std::vector> headers; + headers.emplace_back(kHeaderMethod, method); + ARROW_ASSIGN_OR_RAISE(auto frame, HeadersFrame::Make(headers)); + auto buffer = std::move(frame).GetBuffer(); + RETURN_NOT_OK(impl_->SendFrame(FrameType::kHeaders, buffer->data(), buffer->size())); + return Status::OK(); +} + +Future<> UcpCallDriver::SendFlightPayload(const FlightPayload& payload) { + return impl_->SendFlightPayload(payload); +} + +Status UcpCallDriver::SendFrame(FrameType frame_type, const uint8_t* data, + const int64_t size) { + return impl_->SendFrame(frame_type, data, size); +} + +Future<> UcpCallDriver::SendFrameAsync(FrameType frame_type, + std::unique_ptr buffer) { + return impl_->SendFrameAsync(frame_type, std::move(buffer)); +} + +Status UcpCallDriver::Close() { return impl_->Close(); } + +void UcpCallDriver::MakeProgress() { impl_->MakeProgress(); } + +ucs_status_t UcpCallDriver::RecvActiveMessage(const void* header, size_t header_length, + void* data, const size_t data_length, + const ucp_am_recv_param_t* param) { + return impl_->RecvActiveMessage(header, header_length, data, data_length, param); +} + +const std::shared_ptr& UcpCallDriver::memory_manager() const { + return impl_->memory_manager(); +} + +void UcpCallDriver::set_memory_manager(std::shared_ptr memory_manager) { + impl_->set_memory_manager(std::move(memory_manager)); +} +void UcpCallDriver::set_read_memory_pool(MemoryPool* pool) { + impl_->set_read_memory_pool(pool); +} +void UcpCallDriver::set_write_memory_pool(MemoryPool* pool) { + impl_->set_write_memory_pool(pool); +} + +} // namespace ucx +} // namespace transport +} // namespace flight +} // namespace arrow diff --git a/cpp/src/arrow/flight/transport/ucx/ucx_internal.h b/cpp/src/arrow/flight/transport/ucx/ucx_internal.h new file mode 100644 index 00000000000..389ab2ea2a8 --- /dev/null +++ b/cpp/src/arrow/flight/transport/ucx/ucx_internal.h @@ -0,0 +1,352 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +// Common implementation of UCX communication primitives. + +#pragma once + +#include +#include +#include +#include + +#include + +#include "arrow/buffer.h" +#include "arrow/flight/server.h" +#include "arrow/flight/transport.h" +#include "arrow/flight/transport/ucx/util_internal.h" +#include "arrow/flight/visibility.h" +#include "arrow/type_fwd.h" +#include "arrow/util/future.h" +#include "arrow/util/logging.h" +#include "arrow/util/macros.h" +#include "arrow/util/string_view.h" + +namespace arrow { +namespace flight { +namespace transport { +namespace ucx { + +//------------------------------------------------------------ +// Protocol Constants + +static constexpr char kMethodDoExchange[] = "DoExchange"; +static constexpr char kMethodDoGet[] = "DoGet"; +static constexpr char kMethodDoPut[] = "DoPut"; +static constexpr char kMethodGetFlightInfo[] = "GetFlightInfo"; + +static constexpr char kHeaderStatusCode[] = "flight-status-code"; +static constexpr char kHeaderStatusMessage[] = "flight-status-message"; +static constexpr char kHeaderStatusDetail[] = "flight-status-detail"; +static constexpr char kHeaderStatusDetailCode[] = "flight-status-detail-code"; + +//------------------------------------------------------------ +// UCX Helpers + +/// \brief A wrapper around a ucp_context_h. +/// +/// Used so that multiple resources can share ownership of the +/// context. UCX has zero-copy optimizations where an application can +/// directly use a UCX buffer, but the lifetime of such buffers is +/// tied to the UCX context and worker, so ownership needs to be +/// preserved. +class UcpContext final { + public: + UcpContext() : ucp_context_(nullptr) {} + explicit UcpContext(ucp_context_h context) : ucp_context_(context) {} + ~UcpContext() { + if (ucp_context_) ucp_cleanup(ucp_context_); + ucp_context_ = nullptr; + } + ucp_context_h get() const { + DCHECK(ucp_context_); + return ucp_context_; + } + + private: + ucp_context_h ucp_context_; +}; + +/// \brief A wrapper around a ucp_worker_h. +class UcpWorker final { + public: + UcpWorker() : ucp_worker_(nullptr) {} + UcpWorker(std::shared_ptr context, ucp_worker_h worker) + : ucp_context_(std::move(context)), ucp_worker_(worker) {} + ~UcpWorker() { + if (ucp_worker_) ucp_worker_destroy(ucp_worker_); + ucp_worker_ = nullptr; + } + ucp_worker_h get() const { + DCHECK(ucp_worker_); + return ucp_worker_; + } + const UcpContext& context() const { return *ucp_context_; } + + private: + std::shared_ptr ucp_context_; + ucp_worker_h ucp_worker_; +}; + +//------------------------------------------------------------ +// Message Framing + +/// \brief The message type. +enum class FrameType : uint8_t { + /// Key-value headers. Sent at the beginning (client->server) and + /// end (server->client) of a call. Also, for client-streaming calls + /// (e.g. DoPut), the client should send a headers frame to signal + /// end-of-stream. + kHeaders = 0, + /// Binary blob, does not contain Arrow data. + kBuffer, + /// Binary blob. Contains IPC metadata, app metadata. + kPayloadHeader, + /// Binary blob. Contains IPC body. Body is sent separately since it + /// may use a different memory type. + kPayloadBody, + /// Ask server to disconnect (to avoid client/server waiting on each + /// other and getting stuck). + kDisconnect, + /// Keep at end. + kMaxFrameType = kDisconnect, +}; + +/// \brief The header of a message frame. Used when sending only. +/// +/// A frame is expected to be sent over UCP Active Messages and +/// consists of a header (of kFrameHeaderBytes bytes) and a body. +/// +/// The header is as follows: +/// +-------+---------------------------------+ +/// | Bytes | Function | +/// +=======+=================================+ +/// | 0 | Version tag (see kFrameVersion) | +/// | 1 | Frame type (see FrameType) | +/// | 2-3 | Unused, reserved | +/// | 4-7 | Frame counter (big-endian) | +/// | 8-11 | Body size (big-endian) | +/// +-------+---------------------------------+ +/// +/// The frame counter lets the receiver ensure messages are processed +/// in-order. (The message receive callback may use +/// ucp_am_recv_data_nbx which is asynchronous.) +/// +/// The body size reports the expected message size (UCX chokes on +/// zero-size payloads which we occasionally want to send, so the size +/// field in the header lets us know when a payload was meant to be +/// empty). +struct FrameHeader { + /// \brief The size of a frame header. + static constexpr size_t kFrameHeaderBytes = 12; + /// \brief The expected version tag in the header. + static constexpr uint8_t kFrameVersion = 0x01; + + FrameHeader() = default; + /// \brief Initialize the frame header. + Status Set(FrameType frame_type, uint32_t counter, int64_t body_size); + void* data() const { return header.data(); } + size_t size() const { return kFrameHeaderBytes; } + + // mutable since UCX expects void* not const void* + mutable std::array header = {0}; +}; + +/// \brief A single message received via UCX. Used when receiving only. +struct Frame { + /// \brief The message type. + FrameType type; + /// \brief The message length. + uint32_t size; + /// \brief An incrementing message counter (may wrap over). + uint32_t counter; + /// \brief The message contents. + std::unique_ptr buffer; + + Frame() = default; + Frame(FrameType type_, uint32_t size_, uint32_t counter_, + std::unique_ptr buffer_) + : type(type_), size(size_), counter(counter_), buffer(std::move(buffer_)) {} + + util::string_view view() const { + return util::string_view(reinterpret_cast(buffer->data()), size); + } + + /// \brief Parse a UCX active message header. This will not + /// initialize the buffer field. + static arrow::Result> ParseHeader(const void* header, + size_t header_length); +}; + +/// \brief The active message handler callback ID. +static constexpr uint32_t kUcpAmHandlerId = 0x1024; + +/// \brief A collection of key-value headers. +/// +/// This should be stored in a frame of type kHeaders. +/// +/// Format: +/// +-------+----------------------------------+ +/// | Bytes | Contents | +/// +=======+==================================+ +/// | 0-4 | # of headers (big-endian) | +/// | 4-8 | Header key length (big-endian) | +/// | 2-3 | Header value length (big-endian) | +/// | (...) | Header key | +/// | (...) | Header value | +/// | (...) | (repeat from row 2) | +/// +-------+----------------------------------+ +class HeadersFrame { + public: + /// \brief Get a header value (or an error if it was not found) + arrow::Result Get(const std::string& key); + /// \brief Extract the server-sent status. + Status GetStatus(Status* out); + /// \brief Parse the headers from the buffer. + static arrow::Result Parse(std::unique_ptr buffer); + /// \brief Create a new frame with the given headers. + static arrow::Result Make( + const std::vector>& headers); + /// \brief Create a new frame with the given headers and the given status. + static arrow::Result Make( + const Status& status, + const std::vector>& headers); + + /// \brief Take ownership of the underlying buffer. + std::unique_ptr GetBuffer() && { return std::move(buffer_); } + + private: + std::unique_ptr buffer_; + std::vector> headers_; +}; + +/// \brief A representation of a kPayloadHeader frame (i.e. all of the +/// metadata in a FlightPayload/FlightData). +/// +/// Data messages are sent in two parts: one containing all metadata +/// (the Flatbuffers header, FlightDescriptor, and app_metadata +/// fields) and one containing the actual data. This was done to avoid +/// having to concatenate these fields with the data itself (in the +/// cases where we are not using IOV). +/// +/// Format: +/// +--------+----------------------------------+ +/// | Bytes | Contents | +/// +========+==================================+ +/// | 0-4 | Descriptor length (big-endian) | +/// | 4..a | Descriptor bytes | +/// | a-a+4 | app_metadata length (big-endian) | +/// | a+4..b | app_metadata bytes | +/// | b-b+4 | ipc_metadata length (big-endian) | +/// | b+4..c | ipc_metadata bytes | +/// +--------+----------------------------------+ +/// +/// If a field is not present, its length is still there, but is set +/// to UINT32_MAX. +class PayloadHeaderFrame { + public: + explicit PayloadHeaderFrame(std::unique_ptr buffer) + : buffer_(std::move(buffer)) {} + /// \brief Unpack the internal buffer into a FlightData. + Status ToFlightData(internal::FlightData* data); + /// \brief Pack a payload into the internal buffer. + static arrow::Result Make(const FlightPayload& payload, + MemoryPool* memory_pool); + const uint8_t* data() const { return buffer_->data(); } + int64_t size() const { return buffer_->size(); } + + private: + std::unique_ptr buffer_; +}; + +/// \brief Manage the state of a UCX connection. +class UcpCallDriver { + public: + UcpCallDriver(std::shared_ptr worker, ucp_ep_h endpoint); + + UcpCallDriver(const UcpCallDriver&) = delete; + UcpCallDriver(UcpCallDriver&&); + void operator=(const UcpCallDriver&) = delete; + UcpCallDriver& operator=(UcpCallDriver&&); + + ~UcpCallDriver(); + + /// \brief Start a call by sending a headers frame. Client side only. + /// + /// \param[in] method The RPC method. + Status StartCall(const std::string& method); + + /// \brief Synchronously send a generic message with binary payload. + Status SendFrame(FrameType frame_type, const uint8_t* data, const int64_t size); + /// \brief Asynchronously send a generic message with binary payload. + /// + /// The UCP driver must be manually polled (call MakeProgress()). + Future<> SendFrameAsync(FrameType frame_type, std::unique_ptr buffer); + /// \brief Asynchronously send a data message. + /// + /// The UCP driver must be manually polled (call MakeProgress()). + Future<> SendFlightPayload(const FlightPayload& payload); + + /// \brief Synchronously read the next frame. + arrow::Result> ReadNextFrame(); + /// \brief Asynchronously read the next frame. + /// + /// The UCP driver must be manually polled (call MakeProgress()). + Future> ReadFrameAsync(); + + /// \brief Validate that the frame is of the given type. + Status ExpectFrameType(const Frame& frame, FrameType type); + + /// \brief Disconnect the other side of the connection. Note, this + /// can cause deadlock. + Status Close(); + + /// \brief Synchronously make progress (to adapt async to sync APIs) + void MakeProgress(); + + /// \brief Get the associated memory manager. + const std::shared_ptr& memory_manager() const; + /// \brief Set the associated memory manager. + void set_memory_manager(std::shared_ptr memory_manager); + /// \brief Set memory pool for scratch space used during reading. + void set_read_memory_pool(MemoryPool* memory_pool); + /// \brief Set memory pool for scratch space used during writing. + void set_write_memory_pool(MemoryPool* memory_pool); + + /// \brief Process an incoming active message. This will unblock the + /// corresponding call to ReadFrameAsync/ReadNextFrame. + ucs_status_t RecvActiveMessage(const void* header, size_t header_length, void* data, + const size_t data_length, + const ucp_am_recv_param_t* param); + + private: + class Impl; + std::unique_ptr impl_; +}; + +ARROW_FLIGHT_EXPORT +std::unique_ptr MakeUcxClientImpl(); + +ARROW_FLIGHT_EXPORT +std::unique_ptr MakeUcxServerImpl( + FlightServerBase* base, std::shared_ptr memory_manager); + +} // namespace ucx +} // namespace transport +} // namespace flight +} // namespace arrow diff --git a/cpp/src/arrow/flight/transport/ucx/ucx_server.cc b/cpp/src/arrow/flight/transport/ucx/ucx_server.cc new file mode 100644 index 00000000000..b9612a19d0a --- /dev/null +++ b/cpp/src/arrow/flight/transport/ucx/ucx_server.cc @@ -0,0 +1,660 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#include "arrow/flight/transport/ucx/ucx_internal.h" + +#include +#include +#include +#include +#include + +#include +#include + +#include "arrow/buffer.h" +#include "arrow/flight/server.h" +#include "arrow/flight/transport.h" +#include "arrow/flight/transport/ucx/util_internal.h" +#include "arrow/flight/transport_server.h" +#include "arrow/result.h" +#include "arrow/status.h" +#include "arrow/util/io_util.h" +#include "arrow/util/logging.h" +#include "arrow/util/make_unique.h" +#include "arrow/util/thread_pool.h" +#include "arrow/util/uri.h" + +namespace arrow { +namespace flight { +namespace transport { +namespace ucx { + +// Send an error to the client and return OK. +// Statuses returned up to the main server loop trigger a kReset instead. +#define SERVER_RETURN_NOT_OK(driver, status) \ + do { \ + ::arrow::Status s = (status); \ + if (!s.ok()) { \ + ARROW_ASSIGN_OR_RAISE(auto headers, HeadersFrame::Make(s, {})); \ + auto payload = std::move(headers).GetBuffer(); \ + RETURN_NOT_OK( \ + driver->SendFrame(FrameType::kHeaders, payload->data(), payload->size())); \ + return ::arrow::Status::OK(); \ + } \ + } while (false) + +#define FLIGHT_LOG(LEVEL) (ARROW_LOG(LEVEL) << "[server] ") +#define FLIGHT_LOG_PEER(LEVEL, PEER) \ + (ARROW_LOG(LEVEL) << "[server]" \ + << "[peer=" << (PEER) << "] ") + +namespace { +class UcxServerCallContext : public flight::ServerCallContext { + public: + const std::string& peer_identity() const override { return peer_; } + const std::string& peer() const override { return peer_; } + ServerMiddleware* GetMiddleware(const std::string& key) const override { + return nullptr; + } + bool is_cancelled() const override { return false; } + + private: + std::string peer_; +}; + +class UcxServerStream : public internal::ServerDataStream { + public: + // TODO(lidavidm): backpressure threshold should be dynamic (ideally + // auto-adjusted, or at least configurable) + constexpr static size_t kBackpressureThreshold = 8; + + UcxServerStream(std::string peer, UcpCallDriver* driver) + : peer_(std::move(peer)), driver_(driver), writes_done_(false) {} + + Status WritesDone() override { + RETURN_NOT_OK(CheckBackpressure(0)); + writes_done_ = true; + return Status::OK(); + } + + protected: + Status CheckBackpressure(size_t limit = kBackpressureThreshold - 1) { + while (requests_.size() > limit) { + auto& next = requests_.front(); + while (!next.is_finished()) { + driver_->MakeProgress(); + } + RETURN_NOT_OK(next.status()); + requests_.pop(); + } + return Status::OK(); + } + + std::string peer_; + UcpCallDriver* driver_; + bool writes_done_; + std::queue> requests_; +}; + +class GetServerStream : public UcxServerStream { + public: + using UcxServerStream::UcxServerStream; + + arrow::Result WriteData(const FlightPayload& payload) override { + if (writes_done_) return false; + RETURN_NOT_OK(CheckBackpressure()); + Future<> pending_send = driver_->SendFlightPayload(payload); + if (!pending_send.is_finished()) { + requests_.push(std::move(pending_send)); + } else { + // Request completed instantly + RETURN_NOT_OK(pending_send.status()); + } + return true; + } +}; + +class PutServerStream : public UcxServerStream { + public: + PutServerStream(std::string peer, UcpCallDriver* driver) + : UcxServerStream(std::move(peer), driver), finished_(false) {} + + bool ReadData(internal::FlightData* data) override { + if (finished_) return false; + + bool success = true; + auto status = ReadImpl(data).Value(&success); + + if (!status.ok() || !success) { + finished_ = true; + if (!status.ok()) { + FLIGHT_LOG_PEER(WARNING, peer_) << "I/O error in DoPut: " << status.ToString(); + return false; + } + } + return success; + } + + Status WritePutMetadata(const Buffer& payload) override { + if (finished_) return Status::OK(); + // Send synchronously (we don't control payload lifetime) + RETURN_NOT_OK(driver_->SendFrame(FrameType::kBuffer, payload.data(), payload.size())); + return Status::OK(); + } + + private: + ::arrow::Result ReadImpl(internal::FlightData* data) { + ARROW_ASSIGN_OR_RAISE(auto frame, driver_->ReadNextFrame()); + if (frame->type == FrameType::kHeaders) { + // Trailers, client is done writing + return false; + } + RETURN_NOT_OK(driver_->ExpectFrameType(*frame, FrameType::kPayloadHeader)); + PayloadHeaderFrame payload_header(std::move(frame->buffer)); + RETURN_NOT_OK(payload_header.ToFlightData(data)); + + if (data->metadata) { + ARROW_ASSIGN_OR_RAISE(auto message, ipc::Message::Open(data->metadata, nullptr)); + + if (ipc::Message::HasBody(message->type())) { + ARROW_ASSIGN_OR_RAISE(frame, driver_->ReadNextFrame()); + RETURN_NOT_OK(driver_->ExpectFrameType(*frame, FrameType::kPayloadBody)); + data->body = std::move(frame->buffer); + } + } + return true; + } + + bool finished_; +}; + +class ExchangeServerStream : public PutServerStream { + public: + using PutServerStream::PutServerStream; + + arrow::Result WriteData(const FlightPayload& payload) override { + if (writes_done_) return false; + // Don't use backpressure - the application may expect synchronous + // behavior (write a message, read the client response) + Future<> pending_send = driver_->SendFlightPayload(payload); + while (!pending_send.is_finished()) { + driver_->MakeProgress(); + } + RETURN_NOT_OK(pending_send.status()); + return true; + } + Status WritePutMetadata(const Buffer& payload) override { + return Status::NotImplemented("Not supported on this stream"); + } +}; + +arrow::Result SockaddrToString(const struct sockaddr_storage& address) { + std::string result; + if (address.ss_family != AF_INET && address.ss_family != AF_INET6) { + return Status::NotImplemented("Unknown address family"); + } + + uint16_t port = 0; + if (address.ss_family == AF_INET) { + result.resize(INET_ADDRSTRLEN + 1); + port = ntohs(reinterpret_cast(&address)->sin_port); + result[INET_ADDRSTRLEN] = ':'; + result += std::to_string(port); + } else { + result.resize(INET6_ADDRSTRLEN + 1); + port = ntohs(reinterpret_cast(&address)->sin6_port); + result[INET_ADDRSTRLEN] = ':'; + result += std::to_string(port); + } + if (!inet_ntop(address.ss_family, &address, &result[0], result.size())) { + return arrow::internal::IOErrorFromErrno(errno, + "Could not convert address to string"); + } + + return result; +} +} // namespace + +class ARROW_FLIGHT_EXPORT UcxServerImpl + : public arrow::flight::internal::ServerTransport { + public: + using arrow::flight::internal::ServerTransport::ServerTransport; + + virtual ~UcxServerImpl() { + if (listening_.load()) { + auto st = Shutdown(); + if (!st.ok()) { + ARROW_LOG(WARNING) << "Server did not shut down properly: " << st.ToString(); + } + } + } + + Status Init(const FlightServerOptions& options, const arrow::internal::Uri& uri) { + ARROW_ASSIGN_OR_RAISE(rpc_pool_, arrow::internal::ThreadPool::Make(8)); + + // Init UCX + { + ucp_config_t* ucp_config; + ucp_params_t ucp_params; + ucs_status_t status; + + status = ucp_config_read(nullptr, nullptr, &ucp_config); + RETURN_NOT_OK(FromUcsStatus("ucp_config_read", status)); + + // Allow application to override UCP config + if (options.builder_hook) options.builder_hook(ucp_config); + + std::memset(&ucp_params, 0, sizeof(ucp_params)); + ucp_params.field_mask = + UCP_PARAM_FIELD_FEATURES | UCP_PARAM_FIELD_MT_WORKERS_SHARED; + ucp_params.features = UCP_FEATURE_AM | UCP_FEATURE_WAKEUP; + ucp_params.mt_workers_shared = UCS_THREAD_MODE_MULTI; + + ucp_context_h ucp_context; + status = ucp_init(&ucp_params, ucp_config, &ucp_context); + ucp_config_release(ucp_config); + RETURN_NOT_OK(FromUcsStatus("ucp_init", status)); + ucp_context_.reset(new UcpContext(ucp_context)); + } + + { + // Create one worker to listen for incoming connections. + ucp_worker_params_t worker_params; + ucs_status_t status; + + std::memset(&worker_params, 0, sizeof(worker_params)); + worker_params.field_mask = UCP_WORKER_PARAM_FIELD_THREAD_MODE; + worker_params.thread_mode = UCS_THREAD_MODE_MULTI; + ucp_worker_h worker; + status = ucp_worker_create(ucp_context_->get(), &worker_params, &worker); + RETURN_NOT_OK(FromUcsStatus("ucp_worker_create", status)); + worker_conn_.reset(new UcpWorker(ucp_context_, worker)); + } + + // Start listening for connections. + { + ucp_listener_params_t params; + ucs_status_t status; + + struct sockaddr_storage listen_addr; + ARROW_ASSIGN_OR_RAISE(auto addrlen, UriToSockaddr(uri, &listen_addr)); + + params.field_mask = + UCP_LISTENER_PARAM_FIELD_SOCK_ADDR | UCP_LISTENER_PARAM_FIELD_CONN_HANDLER; + params.sockaddr.addr = reinterpret_cast(&listen_addr); + params.sockaddr.addrlen = addrlen; + params.conn_handler.cb = HandleIncomingConnection; + params.conn_handler.arg = this; + + status = ucp_listener_create(worker_conn_->get(), ¶ms, &listener_); + RETURN_NOT_OK(FromUcsStatus("ucp_listener_create", status)); + + // Get the real address/port + ucp_listener_attr_t attr; + attr.field_mask = UCP_LISTENER_ATTR_FIELD_SOCKADDR; + status = ucp_listener_query(listener_, &attr); + RETURN_NOT_OK(FromUcsStatus("ucp_listener_query", status)); + + std::string raw_uri = "ucx://"; + if (uri.host().find(':') != std::string::npos) { + // IPv6 host + raw_uri += '['; + raw_uri += uri.host(); + raw_uri += ']'; + } else { + raw_uri += uri.host(); + } + raw_uri += ":"; + raw_uri += std::to_string( + ntohs(reinterpret_cast(&attr.sockaddr)->sin_port)); + RETURN_NOT_OK(Location::Parse(raw_uri, &location_)); + } + + { + listening_.store(true); + std::thread listener_thread(&UcxServerImpl::DriveConnections, this); + listener_thread_.swap(listener_thread); + } + + return Status::OK(); + } + + Status Shutdown() override { + if (!listening_.load()) return Status::OK(); + Status status; + + // Wait for current RPCs to finish + listening_.store(false); + RETURN_NOT_OK( + FromUcsStatus("ucp_worker_signal", ucp_worker_signal(worker_conn_->get()))); + status &= Wait(); + + { + // Reject all pending connections + std::unique_lock guard(pending_connections_mutex_); + while (!pending_connections_.empty()) { + status &= + FromUcsStatus("ucp_listener_reject", + ucp_listener_reject(listener_, pending_connections_.front())); + pending_connections_.pop(); + } + ucp_listener_destroy(listener_); + worker_conn_.reset(); + } + + status &= rpc_pool_->Shutdown(); + rpc_pool_.reset(); + + ucp_context_.reset(); + return status; + } + + Status Shutdown(const std::chrono::system_clock::time_point& deadline) override { + // TODO(lidavidm): implement shutdown with deadline + return Shutdown(); + } + + Status Wait() override { + std::unique_lock guard(join_mutex_); + try { + listener_thread_.join(); + } catch (const std::system_error& e) { + if (e.code() != std::errc::invalid_argument) { + return Status::UnknownError("Could not Wait(): ", e.what()); + } + // Else, server wasn't running anyways + } + return Status::OK(); + } + + Location location() const override { return location_; } + + private: + struct ClientWorker { + std::shared_ptr worker; + std::unique_ptr driver; + }; + + Status SendStatus(UcpCallDriver* driver, const Status& status) { + ARROW_ASSIGN_OR_RAISE(auto headers, HeadersFrame::Make(status, {})); + auto payload = std::move(headers).GetBuffer(); + RETURN_NOT_OK( + driver->SendFrame(FrameType::kHeaders, payload->data(), payload->size())); + return Status::OK(); + } + + Status HandleGetFlightInfo(const std::string& peer, UcpCallDriver* driver) { + UcxServerCallContext context; + + ARROW_ASSIGN_OR_RAISE(auto frame, driver->ReadNextFrame()); + SERVER_RETURN_NOT_OK(driver, driver->ExpectFrameType(*frame, FrameType::kBuffer)); + FlightDescriptor descriptor; + SERVER_RETURN_NOT_OK(driver, + FlightDescriptor::Deserialize(util::string_view(*frame->buffer)) + .Value(&descriptor)); + + std::unique_ptr info; + std::string response; + SERVER_RETURN_NOT_OK(driver, base_->GetFlightInfo(context, descriptor, &info)); + SERVER_RETURN_NOT_OK(driver, info->SerializeToString().Value(&response)); + RETURN_NOT_OK(driver->SendFrame(FrameType::kBuffer, + reinterpret_cast(response.data()), + static_cast(response.size()))); + RETURN_NOT_OK(SendStatus(driver, Status::OK())); + return Status::OK(); + } + + Status HandleDoGet(const std::string& peer, UcpCallDriver* driver) { + UcxServerCallContext context; + + ARROW_ASSIGN_OR_RAISE(auto frame, driver->ReadNextFrame()); + SERVER_RETURN_NOT_OK(driver, driver->ExpectFrameType(*frame, FrameType::kBuffer)); + Ticket ticket; + SERVER_RETURN_NOT_OK(driver, Ticket::Deserialize(frame->view()).Value(&ticket)); + + GetServerStream stream(peer, driver); + auto status = DoGet(context, std::move(ticket), &stream); + RETURN_NOT_OK(SendStatus(driver, status)); + return Status::OK(); + } + + Status HandleDoPut(const std::string& peer, UcpCallDriver* driver) { + UcxServerCallContext context; + + PutServerStream stream(peer, driver); + auto status = DoPut(context, &stream); + RETURN_NOT_OK(SendStatus(driver, status)); + // Must drain any unread messages, or the next call will get confused + internal::FlightData ignored; + while (stream.ReadData(&ignored)) { + } + return Status::OK(); + } + + Status HandleDoExchange(const std::string& peer, UcpCallDriver* driver) { + UcxServerCallContext context; + + ExchangeServerStream stream(peer, driver); + auto status = DoExchange(context, &stream); + RETURN_NOT_OK(SendStatus(driver, status)); + // Must drain any unread messages, or the next call will get confused + internal::FlightData ignored; + while (stream.ReadData(&ignored)) { + } + return Status::OK(); + } + + Status HandleOneCall(const std::string& peer, UcpCallDriver* driver, Frame* frame) { + SERVER_RETURN_NOT_OK(driver, driver->ExpectFrameType(*frame, FrameType::kHeaders)); + ARROW_ASSIGN_OR_RAISE(auto headers, HeadersFrame::Parse(std::move(frame->buffer))); + ARROW_ASSIGN_OR_RAISE(auto method, headers.Get(":method:")); + if (method == kMethodGetFlightInfo) { + return HandleGetFlightInfo(peer, driver); + } else if (method == kMethodDoExchange) { + return HandleDoExchange(peer, driver); + } else if (method == kMethodDoGet) { + return HandleDoGet(peer, driver); + } else if (method == kMethodDoPut) { + return HandleDoPut(peer, driver); + } + RETURN_NOT_OK(SendStatus(driver, Status::NotImplemented(method))); + return Status::OK(); + } + + void WorkerLoop(ucp_conn_request_h request) { + std::string peer = "unknown:" + std::to_string(counter_++); + ucp_conn_request_attr_t request_attr; + std::memset(&request_attr, 0, sizeof(request_attr)); + request_attr.field_mask = UCP_CONN_REQUEST_ATTR_FIELD_CLIENT_ADDR; + if (ucp_conn_request_query(request, &request_attr) == UCS_OK) { + ARROW_UNUSED(SockaddrToString(request_attr.client_address).Value(&peer)); + } + FLIGHT_LOG_PEER(DEBUG, peer) << "Received connection request"; + + auto maybe_worker = CreateWorker(); + if (!maybe_worker.ok()) { + FLIGHT_LOG_PEER(WARNING, peer) + << "Failed to create worker" << maybe_worker.status().ToString(); + auto status = ucp_listener_reject(listener_, request); + if (status != UCS_OK) { + FLIGHT_LOG_PEER(WARNING, peer) + << FromUcsStatus("ucp_listener_reject", status).ToString(); + } + return; + } + auto worker = maybe_worker.MoveValueUnsafe(); + + // Create an endpoint to the client, using the data worker + { + ucs_status_t status; + ucp_ep_params_t params; + std::memset(¶ms, 0, sizeof(params)); + params.field_mask = UCP_EP_PARAM_FIELD_CONN_REQUEST; + params.conn_request = request; + + ucp_ep_h client_endpoint; + + status = ucp_ep_create(worker->worker->get(), ¶ms, &client_endpoint); + if (status != UCS_OK) { + FLIGHT_LOG_PEER(WARNING, peer) + << "Failed to create endpoint: " + << FromUcsStatus("ucp_ep_create", status).ToString(); + return; + } + worker->driver.reset(new UcpCallDriver(worker->worker, client_endpoint)); + worker->driver->set_memory_manager(memory_manager_); + } + + while (listening_.load()) { + auto maybe_frame = worker->driver->ReadNextFrame(); + if (!maybe_frame.ok()) { + if (!maybe_frame.status().IsCancelled()) { + FLIGHT_LOG_PEER(WARNING, peer) + << "Failed to read next message: " << maybe_frame.status().ToString(); + } + break; + } + + auto status = HandleOneCall(peer, worker->driver.get(), maybe_frame->get()); + if (!status.ok()) { + FLIGHT_LOG_PEER(WARNING, peer) << "Call failed: " << status.ToString(); + break; + } + } + + // Clean up + auto status = worker->driver->Close(); + if (!status.ok()) { + FLIGHT_LOG_PEER(WARNING, peer) << "Failed to close worker: " << status.ToString(); + } + worker->worker.reset(); + FLIGHT_LOG_PEER(DEBUG, peer) << "Disconnected"; + } + + void DriveConnections() { + while (listening_.load()) { + { + // Check for connect requests in queue + std::unique_lock guard(pending_connections_mutex_); + while (!pending_connections_.empty()) { + ucp_conn_request_h request = pending_connections_.front(); + pending_connections_.pop(); + + auto submitted = rpc_pool_->Submit([this, request]() { WorkerLoop(request); }); + if (!submitted.ok()) { + ARROW_LOG(WARNING) << "Failed to submit task to handle client " + << submitted.status().ToString(); + } + } + } + + while (ucp_worker_progress(worker_conn_->get())) { + } + if (!listening_.load()) break; + auto status = ucp_worker_wait(worker_conn_->get()); + if (status != UCS_OK) { + FLIGHT_LOG(WARNING) << FromUcsStatus("ucp_worker_wait", status).ToString(); + } + } + } + + void EnqueueClient(ucp_conn_request_h connection_request) { + std::unique_lock guard(pending_connections_mutex_); + pending_connections_.push(connection_request); + guard.unlock(); + } + + arrow::Result> CreateWorker() { + auto worker = std::make_shared(); + + ucp_worker_params_t worker_params; + std::memset(&worker_params, 0, sizeof(worker_params)); + worker_params.field_mask = UCP_WORKER_PARAM_FIELD_THREAD_MODE; + worker_params.thread_mode = UCS_THREAD_MODE_SINGLE; + + ucp_worker_h ucp_worker; + auto status = ucp_worker_create(ucp_context_->get(), &worker_params, &ucp_worker); + RETURN_NOT_OK(FromUcsStatus("ucp_worker_create", status)); + worker->worker.reset(new UcpWorker(ucp_context_, ucp_worker)); + + // Set up Active Message (AM) handler + ucp_am_handler_param_t handler_params; + std::memset(&handler_params, 0, sizeof(handler_params)); + handler_params.field_mask = UCP_AM_HANDLER_PARAM_FIELD_ID | + UCP_AM_HANDLER_PARAM_FIELD_CB | + UCP_AM_HANDLER_PARAM_FIELD_ARG; + handler_params.id = kUcpAmHandlerId; + handler_params.cb = HandleIncomingActiveMessage; + handler_params.arg = worker.get(); + + status = ucp_worker_set_am_recv_handler(worker->worker->get(), &handler_params); + RETURN_NOT_OK(FromUcsStatus("ucp_worker_set_am_recv_handler", status)); + return worker; + } + + /// Callback handler. A new client has connected to the server. + static void HandleIncomingConnection(ucp_conn_request_h connection_request, + void* data) { + UcxServerImpl* server = reinterpret_cast(data); + // TODO(lidavidm): enable shedding load above some threshold + // (which is a pitfall with gRPC/Java) + server->EnqueueClient(connection_request); + } + + static ucs_status_t HandleIncomingActiveMessage(void* self, const void* header, + size_t header_length, void* data, + size_t data_length, + const ucp_am_recv_param_t* param) { + ClientWorker* worker = reinterpret_cast(self); + DCHECK(worker->driver); + return worker->driver->RecvActiveMessage(header, header_length, data, data_length, + param); + } + + std::shared_ptr ucp_context_; + // Listen for and handle incoming connections + std::shared_ptr worker_conn_; + ucp_listener_h listener_; + Location location_; + + // Counter for identifying peers when UCX doesn't give us a way + std::atomic counter_; + + std::shared_ptr rpc_pool_; + std::atomic listening_; + std::thread listener_thread_; + // std::thread::join cannot be called concurrently + std::mutex join_mutex_; + + std::mutex pending_connections_mutex_; + std::queue pending_connections_; +}; + +std::unique_ptr MakeUcxServerImpl( + FlightServerBase* base, std::shared_ptr memory_manager) { + return arrow::internal::make_unique(base, memory_manager); +} + +#undef SERVER_RETURN_NOT_OK +#undef FLIGHT_LOG +#undef FLIGHT_LOG_PEER + +} // namespace ucx +} // namespace transport +} // namespace flight +} // namespace arrow diff --git a/cpp/src/arrow/flight/transport/ucx/util_internal.cc b/cpp/src/arrow/flight/transport/ucx/util_internal.cc new file mode 100644 index 00000000000..47198ceceb0 --- /dev/null +++ b/cpp/src/arrow/flight/transport/ucx/util_internal.cc @@ -0,0 +1,212 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#include "arrow/flight/transport/ucx/util_internal.h" + +#include +#include +#include + +#include +#include +#include + +#include "arrow/buffer.h" +#include "arrow/flight/types.h" +#include "arrow/util/base64.h" +#include "arrow/util/bit_util.h" +#include "arrow/util/io_util.h" +#include "arrow/util/logging.h" +#include "arrow/util/make_unique.h" +#include "arrow/util/uri.h" + +namespace arrow { +namespace flight { +namespace transport { +namespace ucx { + +arrow::Result UriToSockaddr(const arrow::internal::Uri& uri, + struct sockaddr_storage* addr) { + std::string host = uri.host(); + if (host.empty()) { + return Status::Invalid("Must provide a host"); + } else if (uri.port() < 0) { + return Status::Invalid("Must provide a port"); + } + + std::memset(addr, 0, sizeof(*addr)); + + struct addrinfo* info = nullptr; + int err = getaddrinfo(host.c_str(), /*service=*/nullptr, /*hints=*/nullptr, &info); + if (err != 0) { + if (err == EAI_SYSTEM) { + return arrow::internal::IOErrorFromErrno(errno, "[getaddrinfo] Failure resolving ", + host); + } else { + return Status::IOError("[getaddrinfo] Failure resolving ", host, ": ", + gai_strerror(err)); + } + } + + if (!info) { + return Status::IOError("[getaddrinfo] Failure resolving ", host, + ": no results returned"); + } + + std::memcpy(addr, info->ai_addr, info->ai_addrlen); + const size_t addrlen = info->ai_addrlen; + if (info->ai_family == AF_INET) { + reinterpret_cast(addr)->sin_port = htons(uri.port()); + } else if (info->ai_family == AF_INET6) { + reinterpret_cast(addr)->sin6_port = htons(uri.port()); + } else { + freeaddrinfo(info); + return Status::Invalid("Unknown address family: ", info->ai_family); + } + return addrlen; +} + +Status FromUcsStatus(const std::string& context, ucs_status_t ucs_status) { + switch (ucs_status) { + case UCS_OK: + return Status::OK(); + case UCS_INPROGRESS: + return Status::IOError(context, ": UCX error ", static_cast(ucs_status), + ": ", "UCS_INPROGRESS ", ucs_status_string(ucs_status)); + case UCS_ERR_NO_MESSAGE: + return Status::IOError(context, ": UCX error ", static_cast(ucs_status), + ": ", "UCS_ERR_NO_MESSAGE ", ucs_status_string(ucs_status)); + case UCS_ERR_NO_RESOURCE: + return Status::IOError(context, ": UCX error ", static_cast(ucs_status), + ": ", "UCS_ERR_NO_RESOURCE ", ucs_status_string(ucs_status)); + case UCS_ERR_IO_ERROR: + return Status::IOError(context, ": UCX error ", static_cast(ucs_status), + ": ", "UCS_ERR_IO_ERROR ", ucs_status_string(ucs_status)); + case UCS_ERR_NO_MEMORY: + return Status::OutOfMemory(context, ": UCX error ", + static_cast(ucs_status), ": ", + "UCS_ERR_NO_MEMORY ", ucs_status_string(ucs_status)); + case UCS_ERR_INVALID_PARAM: + return Status::Invalid(context, ": UCX error ", static_cast(ucs_status), + ": ", "UCS_ERR_INVALID_PARAM ", + ucs_status_string(ucs_status)); + case UCS_ERR_UNREACHABLE: + return Status::IOError(context, ": UCX error ", static_cast(ucs_status), + ": ", "UCS_ERR_UNREACHABLE ", ucs_status_string(ucs_status)); + case UCS_ERR_INVALID_ADDR: + return Status::Invalid(context, ": UCX error ", static_cast(ucs_status), + ": ", "UCS_ERR_INVALID_ADDR ", + ucs_status_string(ucs_status)); + case UCS_ERR_NOT_IMPLEMENTED: + return Status::NotImplemented( + context, ": UCX error ", static_cast(ucs_status), ": ", + "UCS_ERR_NOT_IMPLEMENTED ", ucs_status_string(ucs_status)); + case UCS_ERR_MESSAGE_TRUNCATED: + return Status::IOError(context, ": UCX error ", static_cast(ucs_status), + ": ", "UCS_ERR_MESSAGE_TRUNCATED ", + ucs_status_string(ucs_status)); + case UCS_ERR_NO_PROGRESS: + return Status::IOError(context, ": UCX error ", static_cast(ucs_status), + ": ", "UCS_ERR_NO_PROGRESS ", ucs_status_string(ucs_status)); + case UCS_ERR_BUFFER_TOO_SMALL: + return Status::Invalid(context, ": UCX error ", static_cast(ucs_status), + ": ", "UCS_ERR_BUFFER_TOO_SMALL ", + ucs_status_string(ucs_status)); + case UCS_ERR_NO_ELEM: + return Status::IOError(context, ": UCX error ", static_cast(ucs_status), + ": ", "UCS_ERR_NO_ELEM ", ucs_status_string(ucs_status)); + case UCS_ERR_SOME_CONNECTS_FAILED: + return Status::IOError(context, ": UCX error ", static_cast(ucs_status), + ": ", "UCS_ERR_SOME_CONNECTS_FAILED ", + ucs_status_string(ucs_status)); + case UCS_ERR_NO_DEVICE: + return Status::IOError(context, ": UCX error ", static_cast(ucs_status), + ": ", "UCS_ERR_NO_DEVICE ", ucs_status_string(ucs_status)); + case UCS_ERR_BUSY: + return Status::IOError(context, ": UCX error ", static_cast(ucs_status), + ": ", "UCS_ERR_BUSY ", ucs_status_string(ucs_status)); + case UCS_ERR_CANCELED: + return Status::Cancelled(context, ": UCX error ", static_cast(ucs_status), + ": ", "UCS_ERR_CANCELED ", ucs_status_string(ucs_status)); + case UCS_ERR_SHMEM_SEGMENT: + return Status::IOError(context, ": UCX error ", static_cast(ucs_status), + ": ", "UCS_ERR_SHMEM_SEGMENT ", + ucs_status_string(ucs_status)); + case UCS_ERR_ALREADY_EXISTS: + return Status::AlreadyExists( + context, ": UCX error ", static_cast(ucs_status), ": ", + "UCS_ERR_ALREADY_EXISTS ", ucs_status_string(ucs_status)); + case UCS_ERR_OUT_OF_RANGE: + return Status::IOError(context, ": UCX error ", static_cast(ucs_status), + ": ", "UCS_ERR_OUT_OF_RANGE ", + ucs_status_string(ucs_status)); + case UCS_ERR_TIMED_OUT: + return Status::Cancelled(context, ": UCX error ", static_cast(ucs_status), + ": ", "UCS_ERR_TIMED_OUT ", ucs_status_string(ucs_status)); + case UCS_ERR_EXCEEDS_LIMIT: + return Status::IOError(context, ": UCX error ", static_cast(ucs_status), + ": ", "UCS_ERR_EXCEEDS_LIMIT ", + ucs_status_string(ucs_status)); + case UCS_ERR_UNSUPPORTED: + return Status::NotImplemented( + context, ": UCX error ", static_cast(ucs_status), ": ", + "UCS_ERR_UNSUPPORTED ", ucs_status_string(ucs_status)); + case UCS_ERR_REJECTED: + return Status::IOError(context, ": UCX error ", static_cast(ucs_status), + ": ", "UCS_ERR_REJECTED ", ucs_status_string(ucs_status)); + case UCS_ERR_NOT_CONNECTED: + return Status::IOError(context, ": UCX error ", static_cast(ucs_status), + ": ", "UCS_ERR_NOT_CONNECTED ", + ucs_status_string(ucs_status)); + case UCS_ERR_CONNECTION_RESET: + return Status::IOError(context, ": UCX error ", static_cast(ucs_status), + ": ", "UCS_ERR_CONNECTION_RESET ", + ucs_status_string(ucs_status)); + case UCS_ERR_FIRST_LINK_FAILURE: + return Status::IOError(context, ": UCX error ", static_cast(ucs_status), + ": ", "UCS_ERR_FIRST_LINK_FAILURE ", + ucs_status_string(ucs_status)); + case UCS_ERR_LAST_LINK_FAILURE: + return Status::IOError(context, ": UCX error ", static_cast(ucs_status), + ": ", "UCS_ERR_LAST_LINK_FAILURE ", + ucs_status_string(ucs_status)); + case UCS_ERR_FIRST_ENDPOINT_FAILURE: + return Status::IOError(context, ": UCX error ", static_cast(ucs_status), + ": ", "UCS_ERR_FIRST_ENDPOINT_FAILURE ", + ucs_status_string(ucs_status)); + case UCS_ERR_LAST_ENDPOINT_FAILURE: + return Status::IOError(context, ": UCX error ", static_cast(ucs_status), + ": ", "UCS_ERR_LAST_ENDPOINT_FAILURE ", + ucs_status_string(ucs_status)); + case UCS_ERR_ENDPOINT_TIMEOUT: + return Status::IOError(context, ": UCX error ", static_cast(ucs_status), + ": ", "UCS_ERR_ENDPOINT_TIMEOUT ", + ucs_status_string(ucs_status)); + case UCS_ERR_LAST: + return Status::IOError(context, ": UCX error ", static_cast(ucs_status), + ": ", "UCS_ERR_LAST ", ucs_status_string(ucs_status)); + default: + return Status::UnknownError( + context, ": Unknown UCX error: ", static_cast(ucs_status), " ", + ucs_status_string(ucs_status)); + } +} + +} // namespace ucx +} // namespace transport +} // namespace flight +} // namespace arrow diff --git a/cpp/src/arrow/flight/transport/ucx/util_internal.h b/cpp/src/arrow/flight/transport/ucx/util_internal.h new file mode 100644 index 00000000000..f05889a076c --- /dev/null +++ b/cpp/src/arrow/flight/transport/ucx/util_internal.h @@ -0,0 +1,58 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#pragma once + +#include +#include +#include + +#include "arrow/flight/visibility.h" +#include "arrow/util/uri.h" + +namespace arrow { +namespace flight { +namespace transport { +namespace ucx { + +static inline void UInt32ToBytesBe(const uint32_t in, uint8_t* out) { + out[0] = static_cast((in >> 24) & 0xFF); + out[1] = static_cast((in >> 16) & 0xFF); + out[2] = static_cast((in >> 8) & 0xFF); + out[3] = static_cast(in & 0xFF); +} + +static inline uint32_t BytesToUInt32Be(const uint8_t* in) { + return static_cast(in[3]) | (static_cast(in[2]) << 8) | + (static_cast(in[1]) << 16) | (static_cast(in[0]) << 24); +} + +ARROW_FLIGHT_EXPORT +Status FromUcsStatus(const std::string& context, ucs_status_t ucs_status); + +/// \brief Helper to convert a Uri to a struct sockaddr (used in +/// ucp_listener_params_t) +/// +/// \return The length of the sockaddr +ARROW_FLIGHT_EXPORT +arrow::Result UriToSockaddr(const arrow::internal::Uri& uri, + struct sockaddr_storage* addr); + +} // namespace ucx +} // namespace transport +} // namespace flight +} // namespace arrow diff --git a/cpp/src/arrow/flight/transport_server.cc b/cpp/src/arrow/flight/transport_server.cc index fa5bf827100..4944a79b8fb 100644 --- a/cpp/src/arrow/flight/transport_server.cc +++ b/cpp/src/arrow/flight/transport_server.cc @@ -54,8 +54,7 @@ class TransportIpcMessageReader : public ipc::MessageReader { stream_finished_ = true; return nullptr; } - if (data->body && - ARROW_PREDICT_FALSE(!data->body->device()->Equals(*memory_manager_->device()))) { + if (data->body) { ARROW_ASSIGN_OR_RAISE(data->body, Buffer::ViewOrCopy(data->body, memory_manager_)); } *app_metadata_ = std::move(data->app_metadata); @@ -111,7 +110,7 @@ class TransportMessageReader final : public FlightMessageReader { arrow::Result Next() override { FlightStreamChunk out; - internal::FlightData* data; + internal::FlightData* data = nullptr; peekable_reader_->Peek(&data); if (!data) { out.app_metadata = nullptr; diff --git a/cpp/src/arrow/util/config.h.cmake b/cpp/src/arrow/util/config.h.cmake index 7d7c83185ef..55bc2d01005 100644 --- a/cpp/src/arrow/util/config.h.cmake +++ b/cpp/src/arrow/util/config.h.cmake @@ -48,5 +48,6 @@ #cmakedefine ARROW_S3 #cmakedefine ARROW_USE_NATIVE_INT128 #cmakedefine ARROW_WITH_OPENTELEMETRY +#cmakedefine ARROW_WITH_UCX #cmakedefine GRPCPP_PP_INCLUDE diff --git a/docs/source/cpp/flight.rst b/docs/source/cpp/flight.rst index c1d2e43b9f4..75aea3c47c1 100644 --- a/docs/source/cpp/flight.rst +++ b/docs/source/cpp/flight.rst @@ -117,3 +117,38 @@ success/failure of the request. Any other return values are specified through out parameters. They also take an optional :class:`options ` parameter that allows specifying a timeout for the call. + +Alternative Transports +====================== + +The standard transport for Arrow Flight is gRPC_. The C++ +implementation also experimentally supports a transport based on +UCX_. To use it, use the protocol scheme ``ucx:`` when starting a +server or creating a client. + +UCX Transport +------------- + +Not all features of the gRPC transport are supported. See +:ref:`status-flight-rpc` for details. Also note these specific +caveats: + +- The server creates an independent UCP worker for each client. This + consumes more resources but provides better throughput. +- The client creates an independent UCP worker for each RPC + call. Again, this trades off resource consumption for + performance. This also means that unlike with gRPC, it is + essentially equivalent to make all calls with a single client or + with multiple clients. +- The UCX transport attempts to avoid copies where possible. In some + cases, it can directly reuse UCX-allocated buffers to back + :class:`arrow::Buffer` objects, however, this will also extend the + lifetime of associated UCX resources beyond the lifetime of the + Flight client or server object. +- Depending on the transport that UCX itself selects, you may find + that increasing ``UCX_MM_SEG_SIZE`` from the default (around 8KB) to + around 60KB improves performance (UCX will copy more data in a + single call). + +.. _gRPC: https://grpc.io/ +.. _UCX: https://openucx.org/ diff --git a/docs/source/status.rst b/docs/source/status.rst index 7c6157357a8..c30caed2f84 100644 --- a/docs/source/status.rst +++ b/docs/source/status.rst @@ -144,35 +144,81 @@ Notes: .. seealso:: The :ref:`format-ipc` specification. +.. _status-flight-rpc: Flight RPC ========== -+-----------------------------+-------+-------+-------+------------+-------+-------+-------+ -| Flight RPC Feature | C++ | Java | Go | JavaScript | C# | Rust | Julia | -| | | | | | | | | -+=============================+=======+=======+=======+============+=======+=======+=======+ -| gRPC transport | ✓ | ✓ | ✓ | | ✓ (1) | | | -+-----------------------------+-------+-------+-------+------------+-------+-------+-------+ -| gRPC + TLS transport | ✓ | ✓ | ✓ | | ✓ | | | -+-----------------------------+-------+-------+-------+------------+-------+-------+-------+ -| RPC error codes | ✓ | ✓ | ✓ | | ✓ | | | -+-----------------------------+-------+-------+-------+------------+-------+-------+-------+ -| Authentication handlers | ✓ | ✓ | ✓ | | ✓ (2) | | | -+-----------------------------+-------+-------+-------+------------+-------+-------+-------+ -| Custom client middleware | ✓ | ✓ | ✓ | | | | | -+-----------------------------+-------+-------+-------+------------+-------+-------+-------+ -| Custom server middleware | ✓ | ✓ | ✓ | | | | | -+-----------------------------+-------+-------+-------+------------+-------+-------+-------+ +.. note:: Flight RPC is still experimental. + ++--------------------------------------------+-------+-------+-------+------------+-------+-------+-------+ +| Flight RPC Transport | C++ | Java | Go | JavaScript | C# | Rust | Julia | ++============================================+=======+=======+=======+============+=======+=======+=======+ +| gRPC_ transport (grpc:, grpc+tcp:) | ✓ | ✓ | ✓ | | ✓ | | | ++--------------------------------------------+-------+-------+-------+------------+-------+-------+-------+ +| gRPC domain socket transport (grpc+unix:) | ✓ | ✓ | ✓ | | ✓ | | | ++--------------------------------------------+-------+-------+-------+------------+-------+-------+-------+ +| gRPC + TLS transport (grpc+tls:) | ✓ | ✓ | ✓ | | ✓ | | | ++--------------------------------------------+-------+-------+-------+------------+-------+-------+-------+ +| UCX_ transport (ucx:) | ✓ | | | | | | | ++--------------------------------------------+-------+-------+-------+------------+-------+-------+-------+ + +Supported features in the gRPC transport: + ++--------------------------------------------+-------+-------+-------+------------+-------+-------+-------+ +| Flight RPC Feature | C++ | Java | Go | JavaScript | C# | Rust | Julia | ++============================================+=======+=======+=======+============+=======+=======+=======+ +| All RPC methods | ✓ | ✓ | ✓ | | × (1) | ✓ | | ++--------------------------------------------+-------+-------+-------+------------+-------+-------+-------+ +| Authentication handlers | ✓ | ✓ | ✓ | | ✓ (2) | ✓ | | ++--------------------------------------------+-------+-------+-------+------------+-------+-------+-------+ +| Call timeouts | ✓ | ✓ | ✓ | | | ✓ | | ++--------------------------------------------+-------+-------+-------+------------+-------+-------+-------+ +| Call cancellation | ✓ | ✓ | ✓ | | | ✓ | | ++--------------------------------------------+-------+-------+-------+------------+-------+-------+-------+ +| Concurrent client calls (3) | ✓ | ✓ | ✓ | | ✓ | ✓ | | ++--------------------------------------------+-------+-------+-------+------------+-------+-------+-------+ +| Custom middleware | ✓ | ✓ | ✓ | | | ✓ | | ++--------------------------------------------+-------+-------+-------+------------+-------+-------+-------+ +| RPC error codes | ✓ | ✓ | ✓ | | ✓ | ✓ | | ++--------------------------------------------+-------+-------+-------+------------+-------+-------+-------+ + +Supported features in the UCX transport: + ++--------------------------------------------+-------+-------+-------+------------+-------+-------+-------+ +| Flight RPC Feature | C++ | Java | Go | JavaScript | C# | Rust | Julia | ++============================================+=======+=======+=======+============+=======+=======+=======+ +| All RPC methods | × (4) | | | | | | | ++--------------------------------------------+-------+-------+-------+------------+-------+-------+-------+ +| Authentication handlers | | | | | | | | ++--------------------------------------------+-------+-------+-------+------------+-------+-------+-------+ +| Call timeouts | | | | | | | | ++--------------------------------------------+-------+-------+-------+------------+-------+-------+-------+ +| Call cancellation | | | | | | | | ++--------------------------------------------+-------+-------+-------+------------+-------+-------+-------+ +| Concurrent client calls | ✓ (5) | | | | | | | ++--------------------------------------------+-------+-------+-------+------------+-------+-------+-------+ +| Custom middleware | | | | | | | | ++--------------------------------------------+-------+-------+-------+------------+-------+-------+-------+ +| RPC error codes | ✓ | | | | | | | ++--------------------------------------------+-------+-------+-------+------------+-------+-------+-------+ Notes: * \(1) No support for handshake or DoExchange. * \(2) Support using AspNetCore authentication handlers. +* \(3) Whether a single client can support multiple concurrent calls. +* \(4) Only support for DoExchange, DoGet, DoPut, and GetFlightInfo. +* \(5) Each concurrent call is a separate connection to the server + (unlike gRPC where concurrent calls are multiplexed over a single + connection). This will generally provide better throughput but + consumes more resources both on the server and the client. .. seealso:: The :ref:`flight-rpc` specification. +.. _gRPC: https://grpc.io/ +.. _UCX: https://openucx.org/ C Data Interface ================ From 2ae0f294441b27038714b80a9f11f17680fd2039 Mon Sep 17 00:00:00 2001 From: David Li Date: Mon, 21 Mar 2022 12:47:24 -0400 Subject: [PATCH 02/16] ARROW-15706: [C++][FlightRPC] Fix hanging/race condition --- .../flight/transport/ucx/ucx_internal.cc | 13 +++- .../arrow/flight/transport/ucx/ucx_internal.h | 2 + .../arrow/flight/transport/ucx/ucx_server.cc | 77 +++++++------------ .../flight/transport/ucx/util_internal.cc | 31 ++++++++ .../flight/transport/ucx/util_internal.h | 3 + 5 files changed, 72 insertions(+), 54 deletions(-) diff --git a/cpp/src/arrow/flight/transport/ucx/ucx_internal.cc b/cpp/src/arrow/flight/transport/ucx/ucx_internal.cc index 5baabc5b805..09a237590ab 100644 --- a/cpp/src/arrow/flight/transport/ucx/ucx_internal.cc +++ b/cpp/src/arrow/flight/transport/ucx/ucx_internal.cc @@ -487,6 +487,7 @@ class UcpCallDriver::Impl { read_memory_pool_(default_memory_pool()), write_memory_pool_(default_memory_pool()), memory_manager_(CPUDevice::Instance()->default_memory_manager()), + name_("(unknown remote)"), counter_(0) { #if defined(ARROW_FLIGHT_UCX_SEND_IOV_MAP) TryMapBuffer(worker_->context().get(), padding_bytes_.data(), padding_bytes_.size(), @@ -495,11 +496,13 @@ class UcpCallDriver::Impl { ucp_ep_attr_t attrs; std::memset(&attrs, 0, sizeof(attrs)); - attrs.field_mask = UCP_EP_ATTR_FIELD_NAME; + attrs.field_mask = + UCP_EP_ATTR_FIELD_LOCAL_SOCKADDR | UCP_EP_ATTR_FIELD_REMOTE_SOCKADDR; if (ucp_ep_query(endpoint_, &attrs) == UCS_OK) { - name_ = attrs.name; - } else { - name_ = "(unknown remote)"; + std::string local_addr, remote_addr; + ARROW_UNUSED(SockaddrToString(attrs.local_sockaddr).Value(&local_addr)); + ARROW_UNUSED(SockaddrToString(attrs.remote_sockaddr).Value(&remote_addr)); + name_ = "local:" + local_addr + ";remote:" + remote_addr; } } @@ -854,6 +857,7 @@ class UcpCallDriver::Impl { void set_write_memory_pool(MemoryPool* pool) { write_memory_pool_ = pool ? pool : default_memory_pool(); } + const std::string& peer() const { return name_; } private: class PendingAmSend { @@ -1141,6 +1145,7 @@ void UcpCallDriver::set_read_memory_pool(MemoryPool* pool) { void UcpCallDriver::set_write_memory_pool(MemoryPool* pool) { impl_->set_write_memory_pool(pool); } +const std::string& UcpCallDriver::peer() const { return impl_->peer(); } } // namespace ucx } // namespace transport diff --git a/cpp/src/arrow/flight/transport/ucx/ucx_internal.h b/cpp/src/arrow/flight/transport/ucx/ucx_internal.h index 389ab2ea2a8..bd176e23699 100644 --- a/cpp/src/arrow/flight/transport/ucx/ucx_internal.h +++ b/cpp/src/arrow/flight/transport/ucx/ucx_internal.h @@ -327,6 +327,8 @@ class UcpCallDriver { void set_read_memory_pool(MemoryPool* memory_pool); /// \brief Set memory pool for scratch space used during writing. void set_write_memory_pool(MemoryPool* memory_pool); + /// \brief Get a debug string naming the peer. + const std::string& peer() const; /// \brief Process an incoming active message. This will unblock the /// corresponding call to ReadFrameAsync/ReadNextFrame. diff --git a/cpp/src/arrow/flight/transport/ucx/ucx_server.cc b/cpp/src/arrow/flight/transport/ucx/ucx_server.cc index b9612a19d0a..464ccc6a5c9 100644 --- a/cpp/src/arrow/flight/transport/ucx/ucx_server.cc +++ b/cpp/src/arrow/flight/transport/ucx/ucx_server.cc @@ -83,8 +83,8 @@ class UcxServerStream : public internal::ServerDataStream { // auto-adjusted, or at least configurable) constexpr static size_t kBackpressureThreshold = 8; - UcxServerStream(std::string peer, UcpCallDriver* driver) - : peer_(std::move(peer)), driver_(driver), writes_done_(false) {} + explicit UcxServerStream(UcpCallDriver* driver) + : peer_(driver->peer()), driver_(driver), writes_done_(false) {} Status WritesDone() override { RETURN_NOT_OK(CheckBackpressure(0)); @@ -131,8 +131,8 @@ class GetServerStream : public UcxServerStream { class PutServerStream : public UcxServerStream { public: - PutServerStream(std::string peer, UcpCallDriver* driver) - : UcxServerStream(std::move(peer), driver), finished_(false) {} + explicit PutServerStream(UcpCallDriver* driver) + : UcxServerStream(driver), finished_(false) {} bool ReadData(internal::FlightData* data) override { if (finished_) return false; @@ -202,32 +202,6 @@ class ExchangeServerStream : public PutServerStream { return Status::NotImplemented("Not supported on this stream"); } }; - -arrow::Result SockaddrToString(const struct sockaddr_storage& address) { - std::string result; - if (address.ss_family != AF_INET && address.ss_family != AF_INET6) { - return Status::NotImplemented("Unknown address family"); - } - - uint16_t port = 0; - if (address.ss_family == AF_INET) { - result.resize(INET_ADDRSTRLEN + 1); - port = ntohs(reinterpret_cast(&address)->sin_port); - result[INET_ADDRSTRLEN] = ':'; - result += std::to_string(port); - } else { - result.resize(INET6_ADDRSTRLEN + 1); - port = ntohs(reinterpret_cast(&address)->sin6_port); - result[INET_ADDRSTRLEN] = ':'; - result += std::to_string(port); - } - if (!inet_ntop(address.ss_family, &address, &result[0], result.size())) { - return arrow::internal::IOErrorFromErrno(errno, - "Could not convert address to string"); - } - - return result; -} } // namespace class ARROW_FLIGHT_EXPORT UcxServerImpl @@ -398,7 +372,7 @@ class ARROW_FLIGHT_EXPORT UcxServerImpl return Status::OK(); } - Status HandleGetFlightInfo(const std::string& peer, UcpCallDriver* driver) { + Status HandleGetFlightInfo(UcpCallDriver* driver) { UcxServerCallContext context; ARROW_ASSIGN_OR_RAISE(auto frame, driver->ReadNextFrame()); @@ -419,7 +393,7 @@ class ARROW_FLIGHT_EXPORT UcxServerImpl return Status::OK(); } - Status HandleDoGet(const std::string& peer, UcpCallDriver* driver) { + Status HandleDoGet(UcpCallDriver* driver) { UcxServerCallContext context; ARROW_ASSIGN_OR_RAISE(auto frame, driver->ReadNextFrame()); @@ -427,16 +401,16 @@ class ARROW_FLIGHT_EXPORT UcxServerImpl Ticket ticket; SERVER_RETURN_NOT_OK(driver, Ticket::Deserialize(frame->view()).Value(&ticket)); - GetServerStream stream(peer, driver); + GetServerStream stream(driver); auto status = DoGet(context, std::move(ticket), &stream); RETURN_NOT_OK(SendStatus(driver, status)); return Status::OK(); } - Status HandleDoPut(const std::string& peer, UcpCallDriver* driver) { + Status HandleDoPut(UcpCallDriver* driver) { UcxServerCallContext context; - PutServerStream stream(peer, driver); + PutServerStream stream(driver); auto status = DoPut(context, &stream); RETURN_NOT_OK(SendStatus(driver, status)); // Must drain any unread messages, or the next call will get confused @@ -446,10 +420,10 @@ class ARROW_FLIGHT_EXPORT UcxServerImpl return Status::OK(); } - Status HandleDoExchange(const std::string& peer, UcpCallDriver* driver) { + Status HandleDoExchange(UcpCallDriver* driver) { UcxServerCallContext context; - ExchangeServerStream stream(peer, driver); + ExchangeServerStream stream(driver); auto status = DoExchange(context, &stream); RETURN_NOT_OK(SendStatus(driver, status)); // Must drain any unread messages, or the next call will get confused @@ -459,18 +433,18 @@ class ARROW_FLIGHT_EXPORT UcxServerImpl return Status::OK(); } - Status HandleOneCall(const std::string& peer, UcpCallDriver* driver, Frame* frame) { + Status HandleOneCall(UcpCallDriver* driver, Frame* frame) { SERVER_RETURN_NOT_OK(driver, driver->ExpectFrameType(*frame, FrameType::kHeaders)); ARROW_ASSIGN_OR_RAISE(auto headers, HeadersFrame::Parse(std::move(frame->buffer))); ARROW_ASSIGN_OR_RAISE(auto method, headers.Get(":method:")); if (method == kMethodGetFlightInfo) { - return HandleGetFlightInfo(peer, driver); + return HandleGetFlightInfo(driver); } else if (method == kMethodDoExchange) { - return HandleDoExchange(peer, driver); + return HandleDoExchange(driver); } else if (method == kMethodDoGet) { - return HandleDoGet(peer, driver); + return HandleDoGet(driver); } else if (method == kMethodDoPut) { - return HandleDoPut(peer, driver); + return HandleDoPut(driver); } RETURN_NOT_OK(SendStatus(driver, Status::NotImplemented(method))); return Status::OK(); @@ -478,11 +452,13 @@ class ARROW_FLIGHT_EXPORT UcxServerImpl void WorkerLoop(ucp_conn_request_h request) { std::string peer = "unknown:" + std::to_string(counter_++); - ucp_conn_request_attr_t request_attr; - std::memset(&request_attr, 0, sizeof(request_attr)); - request_attr.field_mask = UCP_CONN_REQUEST_ATTR_FIELD_CLIENT_ADDR; - if (ucp_conn_request_query(request, &request_attr) == UCS_OK) { - ARROW_UNUSED(SockaddrToString(request_attr.client_address).Value(&peer)); + { + ucp_conn_request_attr_t request_attr; + std::memset(&request_attr, 0, sizeof(request_attr)); + request_attr.field_mask = UCP_CONN_REQUEST_ATTR_FIELD_CLIENT_ADDR; + if (ucp_conn_request_query(request, &request_attr) == UCS_OK) { + ARROW_UNUSED(SockaddrToString(request_attr.client_address).Value(&peer)); + } } FLIGHT_LOG_PEER(DEBUG, peer) << "Received connection request"; @@ -518,6 +494,7 @@ class ARROW_FLIGHT_EXPORT UcxServerImpl } worker->driver.reset(new UcpCallDriver(worker->worker, client_endpoint)); worker->driver->set_memory_manager(memory_manager_); + peer = worker->driver->peer(); } while (listening_.load()) { @@ -530,7 +507,7 @@ class ARROW_FLIGHT_EXPORT UcxServerImpl break; } - auto status = HandleOneCall(peer, worker->driver.get(), maybe_frame->get()); + auto status = HandleOneCall(worker->driver.get(), maybe_frame->get()); if (!status.ok()) { FLIGHT_LOG_PEER(WARNING, peer) << "Call failed: " << status.ToString(); break; @@ -548,6 +525,8 @@ class ARROW_FLIGHT_EXPORT UcxServerImpl void DriveConnections() { while (listening_.load()) { + while (ucp_worker_progress(worker_conn_->get())) { + } { // Check for connect requests in queue std::unique_lock guard(pending_connections_mutex_); @@ -563,8 +542,6 @@ class ARROW_FLIGHT_EXPORT UcxServerImpl } } - while (ucp_worker_progress(worker_conn_->get())) { - } if (!listening_.load()) break; auto status = ucp_worker_wait(worker_conn_->get()); if (status != UCS_OK) { diff --git a/cpp/src/arrow/flight/transport/ucx/util_internal.cc b/cpp/src/arrow/flight/transport/ucx/util_internal.cc index 47198ceceb0..1966f39bde1 100644 --- a/cpp/src/arrow/flight/transport/ucx/util_internal.cc +++ b/cpp/src/arrow/flight/transport/ucx/util_internal.cc @@ -80,6 +80,37 @@ arrow::Result UriToSockaddr(const arrow::internal::Uri& uri, return addrlen; } +arrow::Result SockaddrToString(const struct sockaddr_storage& address) { + std::string result = ""; + if (address.ss_family != AF_INET && address.ss_family != AF_INET6) { + return Status::NotImplemented("Unknown address family"); + } + + uint16_t port = 0; + if (address.ss_family == AF_INET) { + result.resize(INET_ADDRSTRLEN + 1); + if (!inet_ntop(address.ss_family, &address, &result[0], INET_ADDRSTRLEN)) { + return arrow::internal::IOErrorFromErrno(errno, + "Could not convert address to string"); + } + port = ntohs(reinterpret_cast(&address)->sin_port); + } else { + result.resize(INET6_ADDRSTRLEN + 1); + if (!inet_ntop(address.ss_family, &address, &result[0], INET6_ADDRSTRLEN)) { + return arrow::internal::IOErrorFromErrno(errno, + "Could not convert address to string"); + } + port = ntohs(reinterpret_cast(&address)->sin6_port); + } + + const size_t pos = result.find('\0'); + DCHECK_NE(pos, std::string::npos); + result[pos] = ':'; + result.resize(pos + 1); + result += std::to_string(port); + return result; +} + Status FromUcsStatus(const std::string& context, ucs_status_t ucs_status) { switch (ucs_status) { case UCS_OK: diff --git a/cpp/src/arrow/flight/transport/ucx/util_internal.h b/cpp/src/arrow/flight/transport/ucx/util_internal.h index f05889a076c..f88b1076e89 100644 --- a/cpp/src/arrow/flight/transport/ucx/util_internal.h +++ b/cpp/src/arrow/flight/transport/ucx/util_internal.h @@ -52,6 +52,9 @@ ARROW_FLIGHT_EXPORT arrow::Result UriToSockaddr(const arrow::internal::Uri& uri, struct sockaddr_storage* addr); +ARROW_FLIGHT_EXPORT +arrow::Result SockaddrToString(const struct sockaddr_storage& address); + } // namespace ucx } // namespace transport } // namespace flight From 3b78fd12f0970de1430ca9eb6718192403a97556 Mon Sep 17 00:00:00 2001 From: David Li Date: Mon, 28 Mar 2022 10:20:08 -0400 Subject: [PATCH 03/16] ARROW-15706: [C++][FlightRPC] Tweak how we use pkg-config --- cpp/src/arrow/flight/transport/ucx/CMakeLists.txt | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/cpp/src/arrow/flight/transport/ucx/CMakeLists.txt b/cpp/src/arrow/flight/transport/ucx/CMakeLists.txt index 2c0e71de59f..6e315b68d6c 100644 --- a/cpp/src/arrow/flight/transport/ucx/CMakeLists.txt +++ b/cpp/src/arrow/flight/transport/ucx/CMakeLists.txt @@ -19,7 +19,7 @@ add_custom_target(arrow_flight_transport_ucx) arrow_install_all_headers("arrow/flight/transport/ucx") find_package(PkgConfig REQUIRED) -pkg_check_modules(UCX REQUIRED ucx) +pkg_check_modules(UCX REQUIRED IMPORTED_TARGET ucx) set(ARROW_FLIGHT_TRANSPORT_UCX_SRCS ucx_client.cc @@ -30,7 +30,7 @@ set(ARROW_FLIGHT_TRANSPORT_UCX_SRCS set(ARROW_FLIGHT_TRANSPORT_UCX_LINK_LIBS) include_directories(SYSTEM ${UCX_INCLUDE_DIRS}) -list(APPEND ARROW_FLIGHT_TRANSPORT_UCX_LINK_LIBS ${UCX_LIBRARIES}) +list(APPEND ARROW_FLIGHT_TRANSPORT_UCX_LINK_LIBS PkgConfig::UCX) add_arrow_lib(arrow_flight_transport_ucx # CMAKE_PACKAGE_NAME From 0183978838df26f9413500e2ef961414c3497b9b Mon Sep 17 00:00:00 2001 From: David Li Date: Mon, 28 Mar 2022 10:58:12 -0400 Subject: [PATCH 04/16] ARROW-15706: [C++][FlightRPC] Fix SockaddrToString --- cpp/src/arrow/flight/transport/ucx/util_internal.cc | 11 +++++++---- 1 file changed, 7 insertions(+), 4 deletions(-) diff --git a/cpp/src/arrow/flight/transport/ucx/util_internal.cc b/cpp/src/arrow/flight/transport/ucx/util_internal.cc index 1966f39bde1..d04e797633c 100644 --- a/cpp/src/arrow/flight/transport/ucx/util_internal.cc +++ b/cpp/src/arrow/flight/transport/ucx/util_internal.cc @@ -89,18 +89,21 @@ arrow::Result SockaddrToString(const struct sockaddr_storage& addre uint16_t port = 0; if (address.ss_family == AF_INET) { result.resize(INET_ADDRSTRLEN + 1); - if (!inet_ntop(address.ss_family, &address, &result[0], INET_ADDRSTRLEN)) { + const auto* in_addr = reinterpret_cast(&address); + if (!inet_ntop(address.ss_family, &in_addr->sin_addr, &result[0], INET_ADDRSTRLEN)) { return arrow::internal::IOErrorFromErrno(errno, "Could not convert address to string"); } - port = ntohs(reinterpret_cast(&address)->sin_port); + port = ntohs(in_addr->sin_port); } else { result.resize(INET6_ADDRSTRLEN + 1); - if (!inet_ntop(address.ss_family, &address, &result[0], INET6_ADDRSTRLEN)) { + const auto* in6_addr = reinterpret_cast(&address); + if (!inet_ntop(address.ss_family, &in6_addr->sin6_addr, &result[0], + INET6_ADDRSTRLEN)) { return arrow::internal::IOErrorFromErrno(errno, "Could not convert address to string"); } - port = ntohs(reinterpret_cast(&address)->sin6_port); + port = ntohs(in6_addr->sin6_port); } const size_t pos = result.find('\0'); From c97addaf3e0547c109f2611b9bd84bba82451373 Mon Sep 17 00:00:00 2001 From: David Li Date: Mon, 28 Mar 2022 10:58:22 -0400 Subject: [PATCH 05/16] ARROW-15706: [C++][FlightRPC] Add more debug logging --- cpp/src/arrow/flight/transport/ucx/ucx_client.cc | 3 +++ cpp/src/arrow/flight/transport/ucx/ucx_server.cc | 1 + 2 files changed, 4 insertions(+) diff --git a/cpp/src/arrow/flight/transport/ucx/ucx_client.cc b/cpp/src/arrow/flight/transport/ucx/ucx_client.cc index 04122942ba5..d6b38973c43 100644 --- a/cpp/src/arrow/flight/transport/ucx/ucx_client.cc +++ b/cpp/src/arrow/flight/transport/ucx/ucx_client.cc @@ -110,6 +110,9 @@ class ClientConnection { // Create endpoint for remote worker struct sockaddr_storage connect_addr; ARROW_ASSIGN_OR_RAISE(auto addrlen, UriToSockaddr(uri, &connect_addr)); + std::string peer; + ARROW_UNUSED(SockaddrToString(connect_addr).Value(&peer)); + ARROW_LOG(DEBUG) << "Connecting to " << peer; ucp_ep_params_t params; params.field_mask = UCP_EP_PARAM_FIELD_FLAGS | UCP_EP_PARAM_FIELD_NAME | diff --git a/cpp/src/arrow/flight/transport/ucx/ucx_server.cc b/cpp/src/arrow/flight/transport/ucx/ucx_server.cc index 464ccc6a5c9..8a053c89ff2 100644 --- a/cpp/src/arrow/flight/transport/ucx/ucx_server.cc +++ b/cpp/src/arrow/flight/transport/ucx/ucx_server.cc @@ -296,6 +296,7 @@ class ARROW_FLIGHT_EXPORT UcxServerImpl raw_uri += ":"; raw_uri += std::to_string( ntohs(reinterpret_cast(&attr.sockaddr)->sin_port)); + FLIGHT_LOG(DEBUG) << "Listening on " << raw_uri; RETURN_NOT_OK(Location::Parse(raw_uri, &location_)); } From f9a41e8f08337b1e4bda7c2795b9de0fb9e6bd13 Mon Sep 17 00:00:00 2001 From: David Li Date: Mon, 28 Mar 2022 11:05:21 -0400 Subject: [PATCH 06/16] ARROW-15706: [C++][FlightRPC] Try harder to gt a valid family in UriToSockaddr --- .../flight/transport/ucx/util_internal.cc | 31 +++++++++++-------- 1 file changed, 18 insertions(+), 13 deletions(-) diff --git a/cpp/src/arrow/flight/transport/ucx/util_internal.cc b/cpp/src/arrow/flight/transport/ucx/util_internal.cc index d04e797633c..d1813988d34 100644 --- a/cpp/src/arrow/flight/transport/ucx/util_internal.cc +++ b/cpp/src/arrow/flight/transport/ucx/util_internal.cc @@ -62,22 +62,27 @@ arrow::Result UriToSockaddr(const arrow::internal::Uri& uri, } } - if (!info) { - return Status::IOError("[getaddrinfo] Failure resolving ", host, - ": no results returned"); - } + struct addrinfo* cur_info = info; + while (cur_info) { + if (cur_info->ai_family != AF_INET && cur_info->ai_family != AF_INET6) { + cur_info = cur_info->ai_next; + continue; + } - std::memcpy(addr, info->ai_addr, info->ai_addrlen); - const size_t addrlen = info->ai_addrlen; - if (info->ai_family == AF_INET) { - reinterpret_cast(addr)->sin_port = htons(uri.port()); - } else if (info->ai_family == AF_INET6) { - reinterpret_cast(addr)->sin6_port = htons(uri.port()); - } else { + std::memcpy(addr, info->ai_addr, info->ai_addrlen); + if (cur_info->ai_family == AF_INET) { + reinterpret_cast(addr)->sin_port = htons(uri.port()); + } else if (cur_info->ai_family == AF_INET6) { + reinterpret_cast(addr)->sin6_port = htons(uri.port()); + } + size_t addrlen = info->ai_addrlen; freeaddrinfo(info); - return Status::Invalid("Unknown address family: ", info->ai_family); + return addrlen; } - return addrlen; + + if (info) freeaddrinfo(info); + return Status::IOError("[getaddrinfo] Failure resolving ", host, + ": no results of a supported family returned"); } arrow::Result SockaddrToString(const struct sockaddr_storage& address) { From 9c6b8470e1702b57c1a89dd34d65736d3b89fa9d Mon Sep 17 00:00:00 2001 From: David Li Date: Tue, 29 Mar 2022 08:53:33 -0400 Subject: [PATCH 07/16] ARROW-15706: [C++][FlightRPC] Don't try to send nullptr IOV arrays --- .../flight/transport/ucx/ucx_internal.cc | 25 ++++++++++++++----- 1 file changed, 19 insertions(+), 6 deletions(-) diff --git a/cpp/src/arrow/flight/transport/ucx/ucx_internal.cc b/cpp/src/arrow/flight/transport/ucx/ucx_internal.cc index 09a237590ab..3ed33513d01 100644 --- a/cpp/src/arrow/flight/transport/ucx/ucx_internal.cc +++ b/cpp/src/arrow/flight/transport/ucx/ucx_internal.cc @@ -538,8 +538,6 @@ class UcpCallDriver::Impl { } Status SendFrame(FrameType frame_type, const uint8_t* data, const int64_t size) { - static uint8_t kZeroes[1] = {0}; - RETURN_NOT_OK(CheckClosed()); void* request = nullptr; @@ -553,7 +551,7 @@ class UcpCallDriver::Impl { if (size == 0) { // UCX appears to crash on zero-byte payloads request = ucp_am_send_nbx(endpoint_, kUcpAmHandlerId, header.data(), header.size(), - kZeroes, + padding_bytes_.data(), /*size=*/1, &request_param); } else { request = ucp_am_send_nbx(endpoint_, kUcpAmHandlerId, header.data(), header.size(), @@ -691,13 +689,17 @@ class UcpCallDriver::Impl { pending_send = arrow::internal::make_unique(); auto* pending_contig = reinterpret_cast(pending_send.get()); - ARROW_ASSIGN_OR_RAISE( - pending_contig->ipc_message, - AllocateBuffer(payload.ipc_message.body_length, write_memory_pool_)); + const int64_t body_length = std::max(payload.ipc_message.body_length, 1); + ARROW_ASSIGN_OR_RAISE(pending_contig->ipc_message, + AllocateBuffer(body_length, write_memory_pool_)); TryMapBuffer(worker_->context().get(), *pending_contig->ipc_message, &pending_contig->memh_p); uint8_t* ipc_message = pending_contig->ipc_message->mutable_data(); + if (payload.ipc_message.body_length == 0) { + std::memset(ipc_message, '\0', 1); + } + for (const auto& buffer : payload.ipc_message.body_buffers) { if (!buffer || buffer->size() == 0) continue; @@ -750,10 +752,21 @@ class UcpCallDriver::Impl { } } + if (total_buffers == 0) { + // UCX cannot handle zero-byte payloads + pending_iov->iovs.resize(1); + pending_iov->iovs[0].buffer = + const_cast(reinterpret_cast(padding_bytes_.data())); + pending_iov->iovs[0].length = 1; + } + send_data = pending_iov->iovs.data(); send_size = pending_iov->iovs.size(); } + DCHECK(send_data) << "Payload cannot be nullptr"; + DCHECK_GT(send_size, 0) << "Payload cannot be empty"; + RETURN_NOT_OK(pending_send->header.Set(FrameType::kPayloadBody, counter_++, payload.ipc_message.body_length)); pending_send->driver = this; From 3483ee944420fef494d330acf5f3f05d98024d1e Mon Sep 17 00:00:00 2001 From: David Li Date: Tue, 29 Mar 2022 11:19:15 -0400 Subject: [PATCH 08/16] ARROW-15706: [C++][FlightRPC] Be more resilient to localhost resolving to an IPv6 address --- cpp/src/arrow/flight/test_definitions.cc | 26 +++++++++---------- cpp/src/arrow/flight/test_util.h | 2 +- .../ucx/flight_transport_ucx_test.cc | 2 +- .../arrow/flight/transport/ucx/ucx_client.cc | 13 ++++++++++ .../arrow/flight/transport/ucx/ucx_server.cc | 17 +++++++++--- 5 files changed, 41 insertions(+), 19 deletions(-) diff --git a/cpp/src/arrow/flight/test_definitions.cc b/cpp/src/arrow/flight/test_definitions.cc index 2cfac641446..1ec06a1f004 100644 --- a/cpp/src/arrow/flight/test_definitions.cc +++ b/cpp/src/arrow/flight/test_definitions.cc @@ -45,7 +45,7 @@ using arrow::internal::checked_cast; void ConnectivityTest::TestGetPort() { std::unique_ptr server = ExampleTestServer(); - ASSERT_OK_AND_ASSIGN(auto location, Location::ForScheme(transport(), "localhost", 0)); + ASSERT_OK_AND_ASSIGN(auto location, Location::ForScheme(transport(), "127.0.0.1", 0)); FlightServerOptions options(location); ASSERT_OK(server->Init(options)); ASSERT_GT(server->port(), 0); @@ -53,7 +53,7 @@ void ConnectivityTest::TestGetPort() { void ConnectivityTest::TestBuilderHook() { std::unique_ptr server = ExampleTestServer(); - ASSERT_OK_AND_ASSIGN(auto location, Location::ForScheme(transport(), "localhost", 0)); + ASSERT_OK_AND_ASSIGN(auto location, Location::ForScheme(transport(), "127.0.0.1", 0)); FlightServerOptions options(location); bool builder_hook_run = false; options.builder_hook = [&builder_hook_run](void* builder) { @@ -68,7 +68,7 @@ void ConnectivityTest::TestBuilderHook() { void ConnectivityTest::TestShutdown() { // Regression test for ARROW-15181 constexpr int kIterations = 10; - ASSERT_OK_AND_ASSIGN(auto location, Location::ForScheme(transport(), "localhost", 0)); + ASSERT_OK_AND_ASSIGN(auto location, Location::ForScheme(transport(), "127.0.0.1", 0)); for (int i = 0; i < kIterations; i++) { std::unique_ptr server = ExampleTestServer(); @@ -84,7 +84,7 @@ void ConnectivityTest::TestShutdown() { void ConnectivityTest::TestShutdownWithDeadline() { std::unique_ptr server = ExampleTestServer(); - ASSERT_OK_AND_ASSIGN(auto location, Location::ForScheme(transport(), "localhost", 0)); + ASSERT_OK_AND_ASSIGN(auto location, Location::ForScheme(transport(), "127.0.0.1", 0)); FlightServerOptions options(location); ASSERT_OK(server->Init(options)); ASSERT_GT(server->port(), 0); @@ -96,13 +96,13 @@ void ConnectivityTest::TestShutdownWithDeadline() { } void ConnectivityTest::TestBrokenConnection() { std::unique_ptr server = ExampleTestServer(); - ASSERT_OK_AND_ASSIGN(auto location, Location::ForScheme(transport(), "localhost", 0)); + ASSERT_OK_AND_ASSIGN(auto location, Location::ForScheme(transport(), "127.0.0.1", 0)); FlightServerOptions options(location); ASSERT_OK(server->Init(options)); std::unique_ptr client; ASSERT_OK_AND_ASSIGN(location, - Location::ForScheme(transport(), "localhost", server->port())); + Location::ForScheme(transport(), "127.0.0.1", server->port())); ASSERT_OK_AND_ASSIGN(client, FlightClient::Connect(location)); ASSERT_OK(server->Shutdown()); @@ -117,7 +117,7 @@ void ConnectivityTest::TestBrokenConnection() { void DataTest::SetUp() { server_ = ExampleTestServer(); - ASSERT_OK_AND_ASSIGN(auto location, Location::ForScheme(transport(), "localhost", 0)); + ASSERT_OK_AND_ASSIGN(auto location, Location::ForScheme(transport(), "127.0.0.1", 0)); FlightServerOptions options(location); ASSERT_OK(server_->Init(options)); @@ -129,7 +129,7 @@ void DataTest::TearDown() { } Status DataTest::ConnectClient() { ARROW_ASSIGN_OR_RAISE(auto location, - Location::ForScheme(transport(), "localhost", server_->port())); + Location::ForScheme(transport(), "127.0.0.1", server_->port())); ARROW_ASSIGN_OR_RAISE(client_, FlightClient::Connect(location)); return Status::OK(); } @@ -638,7 +638,7 @@ class DoPutTestServer : public FlightServerBase { }; void DoPutTest::SetUp() { - ASSERT_OK_AND_ASSIGN(auto location, Location::ForScheme(transport(), "localhost", 0)); + ASSERT_OK_AND_ASSIGN(auto location, Location::ForScheme(transport(), "127.0.0.1", 0)); ASSERT_OK(MakeServer( location, &server_, &client_, [](FlightServerOptions* options) { return Status::OK(); }, @@ -766,7 +766,7 @@ void DoPutTest::TestLargeBatch() { void DoPutTest::TestSizeLimit() { const int64_t size_limit = 4096; ASSERT_OK_AND_ASSIGN(auto location, - Location::ForScheme(transport(), "localhost", server_->port())); + Location::ForScheme(transport(), "127.0.0.1", server_->port())); auto client_options = FlightClientOptions::Defaults(); client_options.write_size_limit_bytes = size_limit; ASSERT_OK_AND_ASSIGN(auto client, FlightClient::Connect(location, client_options)); @@ -866,7 +866,7 @@ Status AppMetadataTestServer::DoPut(const ServerCallContext& context, } void AppMetadataTest::SetUp() { - ASSERT_OK_AND_ASSIGN(auto location, Location::ForScheme(transport(), "localhost", 0)); + ASSERT_OK_AND_ASSIGN(auto location, Location::ForScheme(transport(), "127.0.0.1", 0)); ASSERT_OK(MakeServer( location, &server_, &client_, [](FlightServerOptions* options) { return Status::OK(); }, @@ -1045,7 +1045,7 @@ class IpcOptionsTestServer : public FlightServerBase { }; void IpcOptionsTest::SetUp() { - ASSERT_OK_AND_ASSIGN(auto location, Location::ForScheme(transport(), "localhost", 0)); + ASSERT_OK_AND_ASSIGN(auto location, Location::ForScheme(transport(), "127.0.0.1", 0)); ASSERT_OK(MakeServer( location, &server_, &client_, [](FlightServerOptions* options) { return Status::OK(); }, @@ -1241,7 +1241,7 @@ void CudaDataTest::SetUp() { impl_->device = std::move(device); impl_->context = std::move(context); - ASSERT_OK_AND_ASSIGN(auto location, Location::ForScheme(transport(), "localhost", 0)); + ASSERT_OK_AND_ASSIGN(auto location, Location::ForScheme(transport(), "127.0.0.1", 0)); ASSERT_OK(MakeServer( location, &server_, &client_, [this](FlightServerOptions* options) { diff --git a/cpp/src/arrow/flight/test_util.h b/cpp/src/arrow/flight/test_util.h index 5320b958d58..d5b774b4a37 100644 --- a/cpp/src/arrow/flight/test_util.h +++ b/cpp/src/arrow/flight/test_util.h @@ -113,7 +113,7 @@ Status MakeServer(const Location& location, std::unique_ptr* s RETURN_NOT_OK(make_server_options(&server_options)); RETURN_NOT_OK((*server)->Init(server_options)); std::string uri = - location.scheme() + "://localhost:" + std::to_string((*server)->port()); + location.scheme() + "://127.0.0.1:" + std::to_string((*server)->port()); ARROW_ASSIGN_OR_RAISE(auto real_location, Location::Parse(uri)); FlightClientOptions client_options = FlightClientOptions::Defaults(); RETURN_NOT_OK(make_client_options(&client_options)); diff --git a/cpp/src/arrow/flight/transport/ucx/flight_transport_ucx_test.cc b/cpp/src/arrow/flight/transport/ucx/flight_transport_ucx_test.cc index bb9e15fe1bf..f159b11929f 100644 --- a/cpp/src/arrow/flight/transport/ucx/flight_transport_ucx_test.cc +++ b/cpp/src/arrow/flight/transport/ucx/flight_transport_ucx_test.cc @@ -276,7 +276,7 @@ class SimpleTestServer : public FlightServerBase { class TestUcx : public ::testing::Test { public: void SetUp() { - ASSERT_OK_AND_ASSIGN(auto location, Location::ForScheme("ucx", "localhost", 0)); + ASSERT_OK_AND_ASSIGN(auto location, Location::ForScheme("ucx", "127.0.0.1", 0)); ASSERT_OK(MakeServer( location, &server_, &client_, [](FlightServerOptions* options) { return Status::OK(); }, diff --git a/cpp/src/arrow/flight/transport/ucx/ucx_client.cc b/cpp/src/arrow/flight/transport/ucx/ucx_client.cc index d6b38973c43..ce42e5237a3 100644 --- a/cpp/src/arrow/flight/transport/ucx/ucx_client.cc +++ b/cpp/src/arrow/flight/transport/ucx/ucx_client.cc @@ -127,6 +127,7 @@ class ClientConnection { } driver_.reset(new UcpCallDriver(ucp_worker_, remote_endpoint_)); + ARROW_LOG(DEBUG) << "Connected to " << driver_->peer(); { // Set up Active Message (AM) handler @@ -528,6 +529,18 @@ class ARROW_FLIGHT_EXPORT UcxClientImpl status = ucp_config_read(nullptr, nullptr, &ucp_config); RETURN_NOT_OK(FromUcsStatus("ucp_config_read", status)); + // If location is IPv6, must adjust UCX config + // XXX: we assume locations always resolve to IPv6 or IPv4 but + // that is not necessarily true. + { + struct sockaddr_storage connect_addr; + RETURN_NOT_OK(UriToSockaddr(uri, &connect_addr)); + if (connect_addr.ss_family == AF_INET6) { + status = ucp_config_modify(ucp_config, "AF_PRIO", "inet6"); + RETURN_NOT_OK(FromUcsStatus("ucp_config_modify", status)); + } + } + std::memset(&ucp_params, 0, sizeof(ucp_params)); ucp_params.field_mask = UCP_PARAM_FIELD_FEATURES; ucp_params.features = UCP_FEATURE_AM | UCP_FEATURE_WAKEUP; diff --git a/cpp/src/arrow/flight/transport/ucx/ucx_server.cc b/cpp/src/arrow/flight/transport/ucx/ucx_server.cc index 8a053c89ff2..117cdc42049 100644 --- a/cpp/src/arrow/flight/transport/ucx/ucx_server.cc +++ b/cpp/src/arrow/flight/transport/ucx/ucx_server.cc @@ -219,8 +219,12 @@ class ARROW_FLIGHT_EXPORT UcxServerImpl } Status Init(const FlightServerOptions& options, const arrow::internal::Uri& uri) { + // TODO: this pool should be resized to match CPU cores ARROW_ASSIGN_OR_RAISE(rpc_pool_, arrow::internal::ThreadPool::Make(8)); + struct sockaddr_storage listen_addr; + ARROW_ASSIGN_OR_RAISE(auto addrlen, UriToSockaddr(uri, &listen_addr)); + // Init UCX { ucp_config_t* ucp_config; @@ -230,6 +234,12 @@ class ARROW_FLIGHT_EXPORT UcxServerImpl status = ucp_config_read(nullptr, nullptr, &ucp_config); RETURN_NOT_OK(FromUcsStatus("ucp_config_read", status)); + // If location is IPv6, must adjust UCX config + if (listen_addr.ss_family == AF_INET6) { + status = ucp_config_modify(ucp_config, "AF_PRIO", "inet6"); + RETURN_NOT_OK(FromUcsStatus("ucp_config_modify", status)); + } + // Allow application to override UCP config if (options.builder_hook) options.builder_hook(ucp_config); @@ -265,9 +275,6 @@ class ARROW_FLIGHT_EXPORT UcxServerImpl ucp_listener_params_t params; ucs_status_t status; - struct sockaddr_storage listen_addr; - ARROW_ASSIGN_OR_RAISE(auto addrlen, UriToSockaddr(uri, &listen_addr)); - params.field_mask = UCP_LISTENER_PARAM_FIELD_SOCK_ADDR | UCP_LISTENER_PARAM_FIELD_CONN_HANDLER; params.sockaddr.addr = reinterpret_cast(&listen_addr); @@ -296,7 +303,9 @@ class ARROW_FLIGHT_EXPORT UcxServerImpl raw_uri += ":"; raw_uri += std::to_string( ntohs(reinterpret_cast(&attr.sockaddr)->sin_port)); - FLIGHT_LOG(DEBUG) << "Listening on " << raw_uri; + std::string listen_str; + ARROW_UNUSED(SockaddrToString(attr.sockaddr).Value(&listen_str)); + FLIGHT_LOG(DEBUG) << "Listening on " << listen_str; RETURN_NOT_OK(Location::Parse(raw_uri, &location_)); } From 2d64cce78864b05d959839253eaaaa9ebed481ea Mon Sep 17 00:00:00 2001 From: David Li Date: Wed, 30 Mar 2022 09:48:36 -0400 Subject: [PATCH 09/16] ARROW-15706: [C++][FlightRPC] Address review feedback --- cpp/src/arrow/flight/flight_benchmark.cc | 6 +++--- cpp/src/arrow/flight/perf_server.cc | 14 ++++++------- .../ucx/flight_transport_ucx_test.cc | 14 ++++++------- cpp/src/arrow/flight/transport/ucx/ucx.cc | 2 ++ .../arrow/flight/transport/ucx/ucx_client.cc | 16 ++++++-------- .../arrow/flight/transport/ucx/ucx_server.cc | 21 +++++++++---------- .../flight/transport/ucx/util_internal.cc | 2 +- .../flight/transport/ucx/util_internal.h | 10 ++++----- 8 files changed, 39 insertions(+), 46 deletions(-) diff --git a/cpp/src/arrow/flight/flight_benchmark.cc b/cpp/src/arrow/flight/flight_benchmark.cc index b5594e2d056..fa0cc9a3d53 100644 --- a/cpp/src/arrow/flight/flight_benchmark.cc +++ b/cpp/src/arrow/flight/flight_benchmark.cc @@ -513,9 +513,9 @@ int main(int argc, char** argv) { << std::endl; return EXIT_FAILURE; } - ARROW_CHECK_OK(arrow::flight::Location::Parse( - "ucx://" + FLAGS_server_host + ":" + std::to_string(FLAGS_server_port), - &location)); + ARROW_CHECK_OK(arrow::flight::Location::Parse("ucx://" + FLAGS_server_host + ":" + + std::to_string(FLAGS_server_port)) + .Value(&location)); #else std::cerr << "Not built with transport: " << FLAGS_transport << std::endl; return EXIT_FAILURE; diff --git a/cpp/src/arrow/flight/perf_server.cc b/cpp/src/arrow/flight/perf_server.cc index 6bb30f62ddd..37e3ec4d771 100644 --- a/cpp/src/arrow/flight/perf_server.cc +++ b/cpp/src/arrow/flight/perf_server.cc @@ -105,7 +105,7 @@ class PerfDataStream : public FlightDataStream { if (records_sent_ >= total_records_) { // Signal that iteration is over payload.ipc_message.metadata = nullptr; - return Status::OK(); + return payload; } if (verify_) { @@ -290,12 +290,12 @@ int main(int argc, char** argv) { std::cerr << "Transport does not support TLS: " << FLAGS_transport << std::endl; return EXIT_FAILURE; } - ARROW_CHECK_OK(arrow::flight::Location::Parse( - "ucx://" + FLAGS_server_host + ":" + std::to_string(FLAGS_port), - &bind_location)); - ARROW_CHECK_OK(arrow::flight::Location::Parse( - "ucx://" + FLAGS_server_host + ":" + std::to_string(FLAGS_port), - &connect_location)); + ARROW_CHECK_OK(arrow::flight::Location::Parse("ucx://" + FLAGS_server_host + ":" + + std::to_string(FLAGS_port)) + .Value(&bind_location)); + ARROW_CHECK_OK(arrow::flight::Location::Parse("ucx://" + FLAGS_server_host + ":" + + std::to_string(FLAGS_port)) + .Value(&connect_location)); } else { std::cerr << "Transport does not support domain sockets: " << FLAGS_transport << std::endl; diff --git a/cpp/src/arrow/flight/transport/ucx/flight_transport_ucx_test.cc b/cpp/src/arrow/flight/transport/ucx/flight_transport_ucx_test.cc index f159b11929f..2795a2996f5 100644 --- a/cpp/src/arrow/flight/transport/ucx/flight_transport_ucx_test.cc +++ b/cpp/src/arrow/flight/transport/ucx/flight_transport_ucx_test.cc @@ -254,7 +254,7 @@ class SimpleTestServer : public FlightServerBase { return status_; } auto examples = ExampleFlightInfo(); - *info = std::unique_ptr(new FlightInfo(examples[0])); + info->reset(new FlightInfo(examples[0])); return Status::OK(); } @@ -309,13 +309,12 @@ TEST_F(TestUcx, SequentialClients) { Ticket ticket{"a"}; std::unique_ptr stream1, stream2; - std::shared_ptr
table1, table2; ASSERT_OK(client_->DoGet(ticket, &stream1)); - ASSERT_OK(stream1->ReadAll(&table1)); + ASSERT_OK_AND_ASSIGN(auto table1, stream1->ToTable()); - ASSERT_OK(client_->DoGet(ticket, &stream2)); - ASSERT_OK(stream2->ReadAll(&table2)); + ASSERT_OK(client2->DoGet(ticket, &stream2)); + ASSERT_OK_AND_ASSIGN(auto table2, stream2->ToTable()); AssertTablesEqual(*table1, *table2); } @@ -328,13 +327,12 @@ TEST_F(TestUcx, ConcurrentClients) { Ticket ticket{"a"}; std::unique_ptr stream1, stream2; - std::shared_ptr
table1, table2; ASSERT_OK(client_->DoGet(ticket, &stream1)); ASSERT_OK(client2->DoGet(ticket, &stream2)); - ASSERT_OK(stream1->ReadAll(&table1)); - ASSERT_OK(stream2->ReadAll(&table2)); + ASSERT_OK_AND_ASSIGN(auto table1, stream1->ToTable()); + ASSERT_OK_AND_ASSIGN(auto table2, stream2->ToTable()); AssertTablesEqual(*table1, *table2); } diff --git a/cpp/src/arrow/flight/transport/ucx/ucx.cc b/cpp/src/arrow/flight/transport/ucx/ucx.cc index 0b61dbb93e9..0e3daf60213 100644 --- a/cpp/src/arrow/flight/transport/ucx/ucx.cc +++ b/cpp/src/arrow/flight/transport/ucx/ucx.cc @@ -29,7 +29,9 @@ namespace flight { namespace transport { namespace ucx { +namespace { std::once_flag kInitializeOnce; +} void InitializeFlightUcx() { std::call_once(kInitializeOnce, []() { auto* registry = flight::internal::GetDefaultTransportRegistry(); diff --git a/cpp/src/arrow/flight/transport/ucx/ucx_client.cc b/cpp/src/arrow/flight/transport/ucx/ucx_client.cc index ce42e5237a3..8b01442a687 100644 --- a/cpp/src/arrow/flight/transport/ucx/ucx_client.cc +++ b/cpp/src/arrow/flight/transport/ucx/ucx_client.cc @@ -54,9 +54,8 @@ namespace flight { namespace transport { namespace ucx { -class UcxClientImpl; - namespace { +class UcxClientImpl; Status MergeStatuses(Status server_status, Status transport_status) { if (server_status.ok()) { @@ -293,10 +292,8 @@ class WriteClientStream : public UcxClientStream { while (true) { { std::unique_lock guard(driver_mutex_); - working_cv_.wait(guard, [this] { - return finished_ || incoming_.is_valid() || outgoing_.is_valid(); - }); - if (finished_) return; + working_cv_.wait(guard, + [this] { return incoming_.is_valid() || outgoing_.is_valid(); }); } while (true) { @@ -318,8 +315,8 @@ class WriteClientStream : public UcxClientStream { break; } driver_->MakeProgress(); - if (finished_) return; } + if (finished_) return; } } @@ -502,10 +499,8 @@ class ExchangeClientStream : public WriteClientStream { internal::FlightData next_data_; ReadState read_state_; }; -} // namespace -class ARROW_FLIGHT_EXPORT UcxClientImpl - : public arrow::flight::internal::ClientTransport { +class UcxClientImpl : public arrow::flight::internal::ClientTransport { public: UcxClientImpl() {} @@ -719,6 +714,7 @@ Status UcxClientStream::DoFinish() { } return MergeStatuses(server_status_, io_status_); } +} // namespace std::unique_ptr MakeUcxClientImpl() { return arrow::internal::make_unique(); diff --git a/cpp/src/arrow/flight/transport/ucx/ucx_server.cc b/cpp/src/arrow/flight/transport/ucx/ucx_server.cc index 117cdc42049..a4cc99dc675 100644 --- a/cpp/src/arrow/flight/transport/ucx/ucx_server.cc +++ b/cpp/src/arrow/flight/transport/ucx/ucx_server.cc @@ -153,8 +153,7 @@ class PutServerStream : public UcxServerStream { Status WritePutMetadata(const Buffer& payload) override { if (finished_) return Status::OK(); // Send synchronously (we don't control payload lifetime) - RETURN_NOT_OK(driver_->SendFrame(FrameType::kBuffer, payload.data(), payload.size())); - return Status::OK(); + return driver_->SendFrame(FrameType::kBuffer, payload.data(), payload.size()); } private: @@ -202,10 +201,8 @@ class ExchangeServerStream : public PutServerStream { return Status::NotImplemented("Not supported on this stream"); } }; -} // namespace -class ARROW_FLIGHT_EXPORT UcxServerImpl - : public arrow::flight::internal::ServerTransport { +class UcxServerImpl : public arrow::flight::internal::ServerTransport { public: using arrow::flight::internal::ServerTransport::ServerTransport; @@ -218,9 +215,10 @@ class ARROW_FLIGHT_EXPORT UcxServerImpl } } - Status Init(const FlightServerOptions& options, const arrow::internal::Uri& uri) { - // TODO: this pool should be resized to match CPU cores - ARROW_ASSIGN_OR_RAISE(rpc_pool_, arrow::internal::ThreadPool::Make(8)); + Status Init(const FlightServerOptions& options, + const arrow::internal::Uri& uri) override { + const auto num_threads = std::max(8, std::thread::hardware_concurrency()); + ARROW_ASSIGN_OR_RAISE(rpc_pool_, arrow::internal::ThreadPool::Make(num_threads)); struct sockaddr_storage listen_addr; ARROW_ASSIGN_OR_RAISE(auto addrlen, UriToSockaddr(uri, &listen_addr)); @@ -306,7 +304,7 @@ class ARROW_FLIGHT_EXPORT UcxServerImpl std::string listen_str; ARROW_UNUSED(SockaddrToString(attr.sockaddr).Value(&listen_str)); FLIGHT_LOG(DEBUG) << "Listening on " << listen_str; - RETURN_NOT_OK(Location::Parse(raw_uri, &location_)); + ARROW_ASSIGN_OR_RAISE(location_, Location::Parse(raw_uri)); } { @@ -354,7 +352,7 @@ class ARROW_FLIGHT_EXPORT UcxServerImpl } Status Wait() override { - std::unique_lock guard(join_mutex_); + std::lock_guard guard(join_mutex_); try { listener_thread_.join(); } catch (const std::system_error& e) { @@ -594,7 +592,7 @@ class ARROW_FLIGHT_EXPORT UcxServerImpl return worker; } - /// Callback handler. A new client has connected to the server. + // Callback handler. A new client has connected to the server. static void HandleIncomingConnection(ucp_conn_request_h connection_request, void* data) { UcxServerImpl* server = reinterpret_cast(data); @@ -631,6 +629,7 @@ class ARROW_FLIGHT_EXPORT UcxServerImpl std::mutex pending_connections_mutex_; std::queue pending_connections_; }; +} // namespace std::unique_ptr MakeUcxServerImpl( FlightServerBase* base, std::shared_ptr memory_manager) { diff --git a/cpp/src/arrow/flight/transport/ucx/util_internal.cc b/cpp/src/arrow/flight/transport/ucx/util_internal.cc index d1813988d34..5414ebb6013 100644 --- a/cpp/src/arrow/flight/transport/ucx/util_internal.cc +++ b/cpp/src/arrow/flight/transport/ucx/util_internal.cc @@ -69,7 +69,7 @@ arrow::Result UriToSockaddr(const arrow::internal::Uri& uri, continue; } - std::memcpy(addr, info->ai_addr, info->ai_addrlen); + std::memcpy(addr, cur_info->ai_addr, cur_info->ai_addrlen); if (cur_info->ai_family == AF_INET) { reinterpret_cast(addr)->sin_port = htons(uri.port()); } else if (cur_info->ai_family == AF_INET6) { diff --git a/cpp/src/arrow/flight/transport/ucx/util_internal.h b/cpp/src/arrow/flight/transport/ucx/util_internal.h index f88b1076e89..3985ec39e16 100644 --- a/cpp/src/arrow/flight/transport/ucx/util_internal.h +++ b/cpp/src/arrow/flight/transport/ucx/util_internal.h @@ -22,6 +22,8 @@ #include #include "arrow/flight/visibility.h" +#include "arrow/util/endian.h" +#include "arrow/util/ubsan.h" #include "arrow/util/uri.h" namespace arrow { @@ -30,15 +32,11 @@ namespace transport { namespace ucx { static inline void UInt32ToBytesBe(const uint32_t in, uint8_t* out) { - out[0] = static_cast((in >> 24) & 0xFF); - out[1] = static_cast((in >> 16) & 0xFF); - out[2] = static_cast((in >> 8) & 0xFF); - out[3] = static_cast(in & 0xFF); + util::SafeStore(out, bit_util::ToBigEndian(in)); } static inline uint32_t BytesToUInt32Be(const uint8_t* in) { - return static_cast(in[3]) | (static_cast(in[2]) << 8) | - (static_cast(in[1]) << 16) | (static_cast(in[0]) << 24); + return bit_util::FromBigEndian(util::SafeLoadAs(in)); } ARROW_FLIGHT_EXPORT From 6efafdeca29d8c5c49a16b096b41dadc107f53ec Mon Sep 17 00:00:00 2001 From: David Li Date: Wed, 30 Mar 2022 10:47:25 -0400 Subject: [PATCH 10/16] ARROW-15706: [C++][FlightRPC] Ignore spurious disconnect problems --- .../arrow/flight/transport/ucx/ucx_client.cc | 5 + .../flight/transport/ucx/util_internal.cc | 118 ++++++++++++------ .../flight/transport/ucx/util_internal.h | 14 +++ 3 files changed, 97 insertions(+), 40 deletions(-) diff --git a/cpp/src/arrow/flight/transport/ucx/ucx_client.cc b/cpp/src/arrow/flight/transport/ucx/ucx_client.cc index 8b01442a687..c00788037ba 100644 --- a/cpp/src/arrow/flight/transport/ucx/ucx_client.cc +++ b/cpp/src/arrow/flight/transport/ucx/ucx_client.cc @@ -149,6 +149,11 @@ class ClientConnection { if (!driver_) return Status::OK(); auto status = driver_->SendFrame(FrameType::kDisconnect, nullptr, 0); + const auto ucs_status = FlightUcxStatusDetail::Unwrap(status); + if (ucs_status == UCS_ERR_ENDPOINT_TIMEOUT || ucs_status == UCS_ERR_NOT_CONNECTED) { + // Ignore timeout, not connected + status = Status::OK(); + } status = MergeStatuses(std::move(status), driver_->Close()); driver_.reset(); diff --git a/cpp/src/arrow/flight/transport/ucx/util_internal.cc b/cpp/src/arrow/flight/transport/ucx/util_internal.cc index 5414ebb6013..dacfdfce25f 100644 --- a/cpp/src/arrow/flight/transport/ucx/util_internal.cc +++ b/cpp/src/arrow/flight/transport/ucx/util_internal.cc @@ -39,6 +39,13 @@ namespace flight { namespace transport { namespace ucx { +constexpr char FlightUcxStatusDetail::kTypeId[]; +std::string FlightUcxStatusDetail::ToString() const { return ucs_status_string(status_); } +ucs_status_t FlightUcxStatusDetail::Unwrap(const Status& status) { + if (!status.detail() || status.detail()->type_id() != kTypeId) return UCS_OK; + return dynamic_cast(status.detail().get())->status_; +} + arrow::Result UriToSockaddr(const arrow::internal::Uri& uri, struct sockaddr_storage* addr) { std::string host = uri.host(); @@ -125,123 +132,154 @@ Status FromUcsStatus(const std::string& context, ucs_status_t ucs_status) { return Status::OK(); case UCS_INPROGRESS: return Status::IOError(context, ": UCX error ", static_cast(ucs_status), - ": ", "UCS_INPROGRESS ", ucs_status_string(ucs_status)); + ": ", "UCS_INPROGRESS ", ucs_status_string(ucs_status)) + .WithDetail(std::make_shared(ucs_status)); case UCS_ERR_NO_MESSAGE: return Status::IOError(context, ": UCX error ", static_cast(ucs_status), - ": ", "UCS_ERR_NO_MESSAGE ", ucs_status_string(ucs_status)); + ": ", "UCS_ERR_NO_MESSAGE ", ucs_status_string(ucs_status)) + .WithDetail(std::make_shared(ucs_status)); case UCS_ERR_NO_RESOURCE: return Status::IOError(context, ": UCX error ", static_cast(ucs_status), - ": ", "UCS_ERR_NO_RESOURCE ", ucs_status_string(ucs_status)); + ": ", "UCS_ERR_NO_RESOURCE ", ucs_status_string(ucs_status)) + .WithDetail(std::make_shared(ucs_status)); case UCS_ERR_IO_ERROR: return Status::IOError(context, ": UCX error ", static_cast(ucs_status), - ": ", "UCS_ERR_IO_ERROR ", ucs_status_string(ucs_status)); + ": ", "UCS_ERR_IO_ERROR ", ucs_status_string(ucs_status)) + .WithDetail(std::make_shared(ucs_status)); case UCS_ERR_NO_MEMORY: return Status::OutOfMemory(context, ": UCX error ", static_cast(ucs_status), ": ", - "UCS_ERR_NO_MEMORY ", ucs_status_string(ucs_status)); + "UCS_ERR_NO_MEMORY ", ucs_status_string(ucs_status)) + .WithDetail(std::make_shared(ucs_status)); case UCS_ERR_INVALID_PARAM: return Status::Invalid(context, ": UCX error ", static_cast(ucs_status), ": ", "UCS_ERR_INVALID_PARAM ", - ucs_status_string(ucs_status)); + ucs_status_string(ucs_status)) + .WithDetail(std::make_shared(ucs_status)); case UCS_ERR_UNREACHABLE: return Status::IOError(context, ": UCX error ", static_cast(ucs_status), - ": ", "UCS_ERR_UNREACHABLE ", ucs_status_string(ucs_status)); + ": ", "UCS_ERR_UNREACHABLE ", ucs_status_string(ucs_status)) + .WithDetail(std::make_shared(ucs_status)); case UCS_ERR_INVALID_ADDR: return Status::Invalid(context, ": UCX error ", static_cast(ucs_status), - ": ", "UCS_ERR_INVALID_ADDR ", - ucs_status_string(ucs_status)); + ": ", "UCS_ERR_INVALID_ADDR ", ucs_status_string(ucs_status)) + .WithDetail(std::make_shared(ucs_status)); case UCS_ERR_NOT_IMPLEMENTED: return Status::NotImplemented( - context, ": UCX error ", static_cast(ucs_status), ": ", - "UCS_ERR_NOT_IMPLEMENTED ", ucs_status_string(ucs_status)); + context, ": UCX error ", static_cast(ucs_status), ": ", + "UCS_ERR_NOT_IMPLEMENTED ", ucs_status_string(ucs_status)) + .WithDetail(std::make_shared(ucs_status)); case UCS_ERR_MESSAGE_TRUNCATED: return Status::IOError(context, ": UCX error ", static_cast(ucs_status), ": ", "UCS_ERR_MESSAGE_TRUNCATED ", - ucs_status_string(ucs_status)); + ucs_status_string(ucs_status)) + .WithDetail(std::make_shared(ucs_status)); case UCS_ERR_NO_PROGRESS: return Status::IOError(context, ": UCX error ", static_cast(ucs_status), - ": ", "UCS_ERR_NO_PROGRESS ", ucs_status_string(ucs_status)); + ": ", "UCS_ERR_NO_PROGRESS ", ucs_status_string(ucs_status)) + .WithDetail(std::make_shared(ucs_status)); case UCS_ERR_BUFFER_TOO_SMALL: return Status::Invalid(context, ": UCX error ", static_cast(ucs_status), ": ", "UCS_ERR_BUFFER_TOO_SMALL ", - ucs_status_string(ucs_status)); + ucs_status_string(ucs_status)) + .WithDetail(std::make_shared(ucs_status)); case UCS_ERR_NO_ELEM: return Status::IOError(context, ": UCX error ", static_cast(ucs_status), - ": ", "UCS_ERR_NO_ELEM ", ucs_status_string(ucs_status)); + ": ", "UCS_ERR_NO_ELEM ", ucs_status_string(ucs_status)) + .WithDetail(std::make_shared(ucs_status)); case UCS_ERR_SOME_CONNECTS_FAILED: return Status::IOError(context, ": UCX error ", static_cast(ucs_status), ": ", "UCS_ERR_SOME_CONNECTS_FAILED ", - ucs_status_string(ucs_status)); + ucs_status_string(ucs_status)) + .WithDetail(std::make_shared(ucs_status)); case UCS_ERR_NO_DEVICE: return Status::IOError(context, ": UCX error ", static_cast(ucs_status), - ": ", "UCS_ERR_NO_DEVICE ", ucs_status_string(ucs_status)); + ": ", "UCS_ERR_NO_DEVICE ", ucs_status_string(ucs_status)) + .WithDetail(std::make_shared(ucs_status)); case UCS_ERR_BUSY: return Status::IOError(context, ": UCX error ", static_cast(ucs_status), - ": ", "UCS_ERR_BUSY ", ucs_status_string(ucs_status)); + ": ", "UCS_ERR_BUSY ", ucs_status_string(ucs_status)) + .WithDetail(std::make_shared(ucs_status)); case UCS_ERR_CANCELED: return Status::Cancelled(context, ": UCX error ", static_cast(ucs_status), - ": ", "UCS_ERR_CANCELED ", ucs_status_string(ucs_status)); + ": ", "UCS_ERR_CANCELED ", ucs_status_string(ucs_status)) + .WithDetail(std::make_shared(ucs_status)); case UCS_ERR_SHMEM_SEGMENT: return Status::IOError(context, ": UCX error ", static_cast(ucs_status), ": ", "UCS_ERR_SHMEM_SEGMENT ", - ucs_status_string(ucs_status)); + ucs_status_string(ucs_status)) + .WithDetail(std::make_shared(ucs_status)); case UCS_ERR_ALREADY_EXISTS: return Status::AlreadyExists( - context, ": UCX error ", static_cast(ucs_status), ": ", - "UCS_ERR_ALREADY_EXISTS ", ucs_status_string(ucs_status)); + context, ": UCX error ", static_cast(ucs_status), ": ", + "UCS_ERR_ALREADY_EXISTS ", ucs_status_string(ucs_status)) + .WithDetail(std::make_shared(ucs_status)); case UCS_ERR_OUT_OF_RANGE: return Status::IOError(context, ": UCX error ", static_cast(ucs_status), - ": ", "UCS_ERR_OUT_OF_RANGE ", - ucs_status_string(ucs_status)); + ": ", "UCS_ERR_OUT_OF_RANGE ", ucs_status_string(ucs_status)) + .WithDetail(std::make_shared(ucs_status)); case UCS_ERR_TIMED_OUT: return Status::Cancelled(context, ": UCX error ", static_cast(ucs_status), - ": ", "UCS_ERR_TIMED_OUT ", ucs_status_string(ucs_status)); + ": ", "UCS_ERR_TIMED_OUT ", ucs_status_string(ucs_status)) + .WithDetail(std::make_shared(ucs_status)); case UCS_ERR_EXCEEDS_LIMIT: return Status::IOError(context, ": UCX error ", static_cast(ucs_status), ": ", "UCS_ERR_EXCEEDS_LIMIT ", - ucs_status_string(ucs_status)); + ucs_status_string(ucs_status)) + .WithDetail(std::make_shared(ucs_status)); case UCS_ERR_UNSUPPORTED: - return Status::NotImplemented( - context, ": UCX error ", static_cast(ucs_status), ": ", - "UCS_ERR_UNSUPPORTED ", ucs_status_string(ucs_status)); + return Status::NotImplemented(context, ": UCX error ", + static_cast(ucs_status), ": ", + "UCS_ERR_UNSUPPORTED ", ucs_status_string(ucs_status)) + .WithDetail(std::make_shared(ucs_status)); case UCS_ERR_REJECTED: return Status::IOError(context, ": UCX error ", static_cast(ucs_status), - ": ", "UCS_ERR_REJECTED ", ucs_status_string(ucs_status)); + ": ", "UCS_ERR_REJECTED ", ucs_status_string(ucs_status)) + .WithDetail(std::make_shared(ucs_status)); case UCS_ERR_NOT_CONNECTED: return Status::IOError(context, ": UCX error ", static_cast(ucs_status), ": ", "UCS_ERR_NOT_CONNECTED ", - ucs_status_string(ucs_status)); + ucs_status_string(ucs_status)) + .WithDetail(std::make_shared(ucs_status)); case UCS_ERR_CONNECTION_RESET: return Status::IOError(context, ": UCX error ", static_cast(ucs_status), ": ", "UCS_ERR_CONNECTION_RESET ", - ucs_status_string(ucs_status)); + ucs_status_string(ucs_status)) + .WithDetail(std::make_shared(ucs_status)); case UCS_ERR_FIRST_LINK_FAILURE: return Status::IOError(context, ": UCX error ", static_cast(ucs_status), ": ", "UCS_ERR_FIRST_LINK_FAILURE ", - ucs_status_string(ucs_status)); + ucs_status_string(ucs_status)) + .WithDetail(std::make_shared(ucs_status)); case UCS_ERR_LAST_LINK_FAILURE: return Status::IOError(context, ": UCX error ", static_cast(ucs_status), ": ", "UCS_ERR_LAST_LINK_FAILURE ", - ucs_status_string(ucs_status)); + ucs_status_string(ucs_status)) + .WithDetail(std::make_shared(ucs_status)); case UCS_ERR_FIRST_ENDPOINT_FAILURE: return Status::IOError(context, ": UCX error ", static_cast(ucs_status), ": ", "UCS_ERR_FIRST_ENDPOINT_FAILURE ", - ucs_status_string(ucs_status)); + ucs_status_string(ucs_status)) + .WithDetail(std::make_shared(ucs_status)); case UCS_ERR_LAST_ENDPOINT_FAILURE: return Status::IOError(context, ": UCX error ", static_cast(ucs_status), ": ", "UCS_ERR_LAST_ENDPOINT_FAILURE ", - ucs_status_string(ucs_status)); + ucs_status_string(ucs_status)) + .WithDetail(std::make_shared(ucs_status)); case UCS_ERR_ENDPOINT_TIMEOUT: return Status::IOError(context, ": UCX error ", static_cast(ucs_status), ": ", "UCS_ERR_ENDPOINT_TIMEOUT ", - ucs_status_string(ucs_status)); + ucs_status_string(ucs_status)) + .WithDetail(std::make_shared(ucs_status)); case UCS_ERR_LAST: return Status::IOError(context, ": UCX error ", static_cast(ucs_status), - ": ", "UCS_ERR_LAST ", ucs_status_string(ucs_status)); + ": ", "UCS_ERR_LAST ", ucs_status_string(ucs_status)) + .WithDetail(std::make_shared(ucs_status)); default: return Status::UnknownError( - context, ": Unknown UCX error: ", static_cast(ucs_status), " ", - ucs_status_string(ucs_status)); + context, ": Unknown UCX error: ", static_cast(ucs_status), " ", + ucs_status_string(ucs_status)) + .WithDetail(std::make_shared(ucs_status)); } } diff --git a/cpp/src/arrow/flight/transport/ucx/util_internal.h b/cpp/src/arrow/flight/transport/ucx/util_internal.h index 3985ec39e16..d11ea2bfe3e 100644 --- a/cpp/src/arrow/flight/transport/ucx/util_internal.h +++ b/cpp/src/arrow/flight/transport/ucx/util_internal.h @@ -22,6 +22,7 @@ #include #include "arrow/flight/visibility.h" +#include "arrow/status.h" #include "arrow/util/endian.h" #include "arrow/util/ubsan.h" #include "arrow/util/uri.h" @@ -39,6 +40,19 @@ static inline uint32_t BytesToUInt32Be(const uint8_t* in) { return bit_util::FromBigEndian(util::SafeLoadAs(in)); } +class ARROW_FLIGHT_EXPORT FlightUcxStatusDetail : public StatusDetail { + public: + explicit FlightUcxStatusDetail(ucs_status_t status) : status_(status) {} + static constexpr char const kTypeId[] = "flight::transport::ucx::FlightUcxStatusDetail"; + + const char* type_id() const override { return kTypeId; } + std::string ToString() const override; + static ucs_status_t Unwrap(const Status& status); + + private: + ucs_status_t status_; +}; + ARROW_FLIGHT_EXPORT Status FromUcsStatus(const std::string& context, ucs_status_t ucs_status); From f8392a05aac53b248c9e31d953bbde10804a9b39 Mon Sep 17 00:00:00 2001 From: David Li Date: Thu, 31 Mar 2022 08:12:08 -0400 Subject: [PATCH 11/16] ARROW-15706: [C++][FlightRPC] Address feedback --- cpp/src/arrow/flight/transport/ucx/ucx_server.cc | 4 ++-- cpp/src/arrow/flight/transport/ucx/util_internal.cc | 2 +- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/cpp/src/arrow/flight/transport/ucx/ucx_server.cc b/cpp/src/arrow/flight/transport/ucx/ucx_server.cc index a4cc99dc675..d75c396d490 100644 --- a/cpp/src/arrow/flight/transport/ucx/ucx_server.cc +++ b/cpp/src/arrow/flight/transport/ucx/ucx_server.cc @@ -217,8 +217,8 @@ class UcxServerImpl : public arrow::flight::internal::ServerTransport { Status Init(const FlightServerOptions& options, const arrow::internal::Uri& uri) override { - const auto num_threads = std::max(8, std::thread::hardware_concurrency()); - ARROW_ASSIGN_OR_RAISE(rpc_pool_, arrow::internal::ThreadPool::Make(num_threads)); + const auto max_threads = std::max(8, std::thread::hardware_concurrency()); + ARROW_ASSIGN_OR_RAISE(rpc_pool_, arrow::internal::ThreadPool::Make(max_threads)); struct sockaddr_storage listen_addr; ARROW_ASSIGN_OR_RAISE(auto addrlen, UriToSockaddr(uri, &listen_addr)); diff --git a/cpp/src/arrow/flight/transport/ucx/util_internal.cc b/cpp/src/arrow/flight/transport/ucx/util_internal.cc index dacfdfce25f..ca4df21a055 100644 --- a/cpp/src/arrow/flight/transport/ucx/util_internal.cc +++ b/cpp/src/arrow/flight/transport/ucx/util_internal.cc @@ -82,7 +82,7 @@ arrow::Result UriToSockaddr(const arrow::internal::Uri& uri, } else if (cur_info->ai_family == AF_INET6) { reinterpret_cast(addr)->sin6_port = htons(uri.port()); } - size_t addrlen = info->ai_addrlen; + size_t addrlen = cur_info->ai_addrlen; freeaddrinfo(info); return addrlen; } From 5177eec3faa1ccad163b9924257f9eb73514d582 Mon Sep 17 00:00:00 2001 From: David Li Date: Fri, 1 Apr 2022 08:07:15 -0400 Subject: [PATCH 12/16] ARROW-15706: [C++][FlightRPC] Factor out disconnect logic --- cpp/src/arrow/flight/transport/ucx/ucx_client.cc | 3 +-- cpp/src/arrow/flight/transport/ucx/ucx_internal.cc | 4 +--- cpp/src/arrow/flight/transport/ucx/util_internal.h | 10 ++++++++++ 3 files changed, 12 insertions(+), 5 deletions(-) diff --git a/cpp/src/arrow/flight/transport/ucx/ucx_client.cc b/cpp/src/arrow/flight/transport/ucx/ucx_client.cc index c00788037ba..d040a9c220d 100644 --- a/cpp/src/arrow/flight/transport/ucx/ucx_client.cc +++ b/cpp/src/arrow/flight/transport/ucx/ucx_client.cc @@ -150,8 +150,7 @@ class ClientConnection { auto status = driver_->SendFrame(FrameType::kDisconnect, nullptr, 0); const auto ucs_status = FlightUcxStatusDetail::Unwrap(status); - if (ucs_status == UCS_ERR_ENDPOINT_TIMEOUT || ucs_status == UCS_ERR_NOT_CONNECTED) { - // Ignore timeout, not connected + if (IsIgnorableDisconnectError(ucs_status)) { status = Status::OK(); } status = MergeStatuses(std::move(status), driver_->Close()); diff --git a/cpp/src/arrow/flight/transport/ucx/ucx_internal.cc b/cpp/src/arrow/flight/transport/ucx/ucx_internal.cc index 3ed33513d01..d9c1ecb4a57 100644 --- a/cpp/src/arrow/flight/transport/ucx/ucx_internal.cc +++ b/cpp/src/arrow/flight/transport/ucx/ucx_internal.cc @@ -815,9 +815,7 @@ class UcpCallDriver::Impl { } endpoint_ = nullptr; - if (status != UCS_OK && status != UCS_ERR_ENDPOINT_TIMEOUT && - status != UCS_ERR_NOT_CONNECTED) { - // Ignore timeout, not connected + if (status != UCS_OK && !IsIgnorableDisconnectError(status)) { return FromUcsStatus(origin, status); } return Status::OK(); diff --git a/cpp/src/arrow/flight/transport/ucx/util_internal.h b/cpp/src/arrow/flight/transport/ucx/util_internal.h index d11ea2bfe3e..84e84ba0711 100644 --- a/cpp/src/arrow/flight/transport/ucx/util_internal.h +++ b/cpp/src/arrow/flight/transport/ucx/util_internal.h @@ -53,9 +53,19 @@ class ARROW_FLIGHT_EXPORT FlightUcxStatusDetail : public StatusDetail { ucs_status_t status_; }; +/// \brief Convert a UCS status to an Arrow Status. ARROW_FLIGHT_EXPORT Status FromUcsStatus(const std::string& context, ucs_status_t ucs_status); +/// \brief Check if a UCS error code can be ignored in the context of +/// a disconnect. +static inline bool IsIgnorableDisconnectError(ucs_status_t ucs_status) { + // Not connected, connection reset: we're already disconnected + // Timeout: most likely disconnected, but we can't tell from our end + return ucs_status == UCS_OK || ucs_status == UCS_ERR_ENDPOINT_TIMEOUT || + ucs_status == UCS_ERR_NOT_CONNECTED || ucs_status == UCS_ERR_CONNECTION_RESET; +} + /// \brief Helper to convert a Uri to a struct sockaddr (used in /// ucp_listener_params_t) /// From 9fbc756d59e878d5309ef65b253a5263595cbd26 Mon Sep 17 00:00:00 2001 From: David Li Date: Sat, 2 Apr 2022 11:57:20 -0400 Subject: [PATCH 13/16] ARROW-15706: [C++][FlightRPC] Document some design choices --- .../arrow/flight/transport/ucx/ucx_client.cc | 19 +++++++++++-------- .../flight/transport/ucx/ucx_internal.cc | 19 +++++++++++++------ .../arrow/flight/transport/ucx/ucx_server.cc | 5 +++++ 3 files changed, 29 insertions(+), 14 deletions(-) diff --git a/cpp/src/arrow/flight/transport/ucx/ucx_client.cc b/cpp/src/arrow/flight/transport/ucx/ucx_client.cc index d040a9c220d..3a216e0f726 100644 --- a/cpp/src/arrow/flight/transport/ucx/ucx_client.cc +++ b/cpp/src/arrow/flight/transport/ucx/ucx_client.cc @@ -260,7 +260,7 @@ class WriteClientStream : public UcxClientStream { if (finished_ || writes_done_) return Status::Invalid("Already done writing"); outgoing_ = driver_->SendFlightPayload(payload); working_cv_.notify_all(); - received_cv_.wait(guard, [this] { return outgoing_.is_finished(); }); + completed_cv_.wait(guard, [this] { return outgoing_.is_finished(); }); auto status = outgoing_.status(); outgoing_ = Future<>(); @@ -274,7 +274,7 @@ class WriteClientStream : public UcxClientStream { outgoing_ = driver_->SendFrameAsync(FrameType::kHeaders, std::move(headers).GetBuffer()); working_cv_.notify_all(); - received_cv_.wait(guard, [this] { return outgoing_.is_finished(); }); + completed_cv_.wait(guard, [this] { return outgoing_.is_finished(); }); writes_done_ = true; auto status = outgoing_.status(); @@ -292,6 +292,9 @@ class WriteClientStream : public UcxClientStream { // Ignore } } + // Flight's API allows concurrent reads/writes, but the UCX driver + // here is single-threaded, so push all UCX work onto a single + // worker thread void DriveWorker() { while (true) { { @@ -311,11 +314,11 @@ class WriteClientStream : public UcxClientStream { HandleIncomingMessage(*incoming_.result()); } incoming_ = Future>(); - received_cv_.notify_all(); + completed_cv_.notify_all(); break; } if (outgoing_.is_valid() && outgoing_.is_finished()) { - received_cv_.notify_all(); + completed_cv_.notify_all(); break; } driver_->MakeProgress(); @@ -328,7 +331,7 @@ class WriteClientStream : public UcxClientStream { std::mutex driver_mutex_; std::thread driver_thread_; - std::condition_variable received_cv_; + std::condition_variable completed_cv_; std::condition_variable working_cv_; Future> incoming_; Future<> outgoing_; @@ -348,7 +351,7 @@ class PutClientStream : public WriteClientStream { next_metadata_ = nullptr; incoming_ = driver_->ReadFrameAsync(); working_cv_.notify_all(); - received_cv_.wait(guard, [this] { return next_metadata_ != nullptr || finished_; }); + completed_cv_.wait(guard, [this] { return next_metadata_ != nullptr || finished_; }); if (finished_) { *out = nullptr; @@ -408,11 +411,11 @@ class ExchangeClientStream : public WriteClientStream { read_state_ = ReadState::kExpectHeader; incoming_ = driver_->ReadFrameAsync(); working_cv_.notify_all(); - received_cv_.wait(guard, [this] { return read_state_ != ReadState::kExpectHeader; }); + completed_cv_.wait(guard, [this] { return read_state_ != ReadState::kExpectHeader; }); if (read_state_ != ReadState::kFinished) { incoming_ = driver_->ReadFrameAsync(); working_cv_.notify_all(); - received_cv_.wait(guard, [this] { return read_state_ == ReadState::kFinished; }); + completed_cv_.wait(guard, [this] { return read_state_ == ReadState::kFinished; }); } if (finished_) { diff --git a/cpp/src/arrow/flight/transport/ucx/ucx_internal.cc b/cpp/src/arrow/flight/transport/ucx/ucx_internal.cc index d9c1ecb4a57..1b8eb4519f3 100644 --- a/cpp/src/arrow/flight/transport/ucx/ucx_internal.cc +++ b/cpp/src/arrow/flight/transport/ucx/ucx_internal.cc @@ -100,8 +100,7 @@ void TryUnmapBuffer(ucp_context_h context, ucp_mem_h memh_p) { class UcxDataBuffer : public Buffer { public: explicit UcxDataBuffer(std::shared_ptr worker, void* data, size_t size) - : Buffer(const_cast(reinterpret_cast(data)), - static_cast(size)), + : Buffer(reinterpret_cast(data), static_cast(size)), worker_(std::move(worker)) {} ~UcxDataBuffer() { @@ -525,13 +524,16 @@ class UcpCallDriver::Impl { std::unique_lock guard(frame_mutex_); if (ARROW_PREDICT_FALSE(!status_.ok())) return status_; + // Expected value of "counter" field in the frame header const uint32_t counter_value = next_counter_++; auto it = frames_.find(counter_value); if (it != frames_.end()) { + // Message already delivered, return it Future> fut = it->second; frames_.erase(it); return fut; } + // Message not yet delivered, insert a future and wait auto pair = frames_.insert({counter_value, Future>::Make()}); DCHECK(pair.second); return pair.first->second; @@ -828,9 +830,14 @@ class UcpCallDriver::Impl { if (ARROW_PREDICT_FALSE(!status_.ok())) return; auto pair = frames_.insert({frame->counter, frame}); if (!pair.second) { + // Not inserted, because ReadFrameAsync was called for this + // frame counter value and the client is already waiting on + // it. Complete the existing future. pair.first->second.MarkFinished(std::move(frame)); frames_.erase(pair.first); } + // Otherwise, we inserted the frame, meaning the client was not + // currently waiting for that frame counter value } void Push(Status status) { @@ -948,9 +955,9 @@ class UcpCallDriver::Impl { const ucp_am_recv_param_t* param) { DCHECK(param->recv_attr & UCP_AM_RECV_ATTR_FIELD_REPLY_EP); - if (data_length > static_cast(std::numeric_limits::max())) { - return Status::Invalid( - "Cannot allocate buffer greater than int64_t max, requested: ", data_length); + if (data_length > static_cast(std::numeric_limits::max())) { + return Status::Invalid("Cannot allocate buffer greater than 2 GiB, requested: ", + data_length); } ARROW_ASSIGN_OR_RAISE(auto frame, Frame::ParseHeader(header, header_length)); @@ -1021,7 +1028,7 @@ class UcpCallDriver::Impl { return UCS_OK; } else { // Data will be freed after callback returns - copy to buffer - if (frame->type != FrameType::kPayloadBody || memory_manager_->is_cpu()) { + if (memory_manager_->is_cpu() || frame->type != FrameType::kPayloadBody) { ARROW_ASSIGN_OR_RAISE(frame->buffer, AllocateBuffer(data_length, read_memory_pool_)); std::memcpy(frame->buffer->mutable_data(), data, data_length); diff --git a/cpp/src/arrow/flight/transport/ucx/ucx_server.cc b/cpp/src/arrow/flight/transport/ucx/ucx_server.cc index d75c396d490..5025759e88b 100644 --- a/cpp/src/arrow/flight/transport/ucx/ucx_server.cc +++ b/cpp/src/arrow/flight/transport/ucx/ucx_server.cc @@ -322,6 +322,7 @@ class UcxServerImpl : public arrow::flight::internal::ServerTransport { // Wait for current RPCs to finish listening_.store(false); + // Unstick the listener thread from ucp_worker_wait RETURN_NOT_OK( FromUcsStatus("ucp_worker_signal", ucp_worker_signal(worker_conn_->get()))); status &= Wait(); @@ -550,6 +551,10 @@ class UcxServerImpl : public arrow::flight::internal::ServerTransport { } } + // Check listening_ in case we're shutting down. It is possible + // that Shutdown() was called while we were in + // ucp_worker_progress above, in which case if we don't check + // listening_ here, we'll enter ucp_worker_wait and get stuck. if (!listening_.load()) break; auto status = ucp_worker_wait(worker_conn_->get()); if (status != UCS_OK) { From 68901d219f9e7048803a9bb6a050cafed3e9f4a0 Mon Sep 17 00:00:00 2001 From: David Li Date: Mon, 4 Apr 2022 10:32:36 -0400 Subject: [PATCH 14/16] ARROW-15706: [C++][FlightRPC] Remove incomplete backpressure --- .../flight/transport/ucx/ucx_internal.cc | 2 +- .../arrow/flight/transport/ucx/ucx_server.cc | 29 ++----------------- 2 files changed, 4 insertions(+), 27 deletions(-) diff --git a/cpp/src/arrow/flight/transport/ucx/ucx_internal.cc b/cpp/src/arrow/flight/transport/ucx/ucx_internal.cc index 1b8eb4519f3..215b4611cf3 100644 --- a/cpp/src/arrow/flight/transport/ucx/ucx_internal.cc +++ b/cpp/src/arrow/flight/transport/ucx/ucx_internal.cc @@ -967,7 +967,7 @@ class UcpCallDriver::Impl { } if ((param->recv_attr & UCP_AM_RECV_ATTR_FLAG_DATA) && - (frame->type != FrameType::kPayloadBody || memory_manager_->is_cpu())) { + (memory_manager_->is_cpu() || frame->type != FrameType::kPayloadBody)) { // Zero-copy path. UCX-allocated buffer must be freed later. // XXX: this buffer can NOT be freed until AFTER we return from diff --git a/cpp/src/arrow/flight/transport/ucx/ucx_server.cc b/cpp/src/arrow/flight/transport/ucx/ucx_server.cc index 5025759e88b..e77777d2be2 100644 --- a/cpp/src/arrow/flight/transport/ucx/ucx_server.cc +++ b/cpp/src/arrow/flight/transport/ucx/ucx_server.cc @@ -79,36 +79,18 @@ class UcxServerCallContext : public flight::ServerCallContext { class UcxServerStream : public internal::ServerDataStream { public: - // TODO(lidavidm): backpressure threshold should be dynamic (ideally - // auto-adjusted, or at least configurable) - constexpr static size_t kBackpressureThreshold = 8; - explicit UcxServerStream(UcpCallDriver* driver) : peer_(driver->peer()), driver_(driver), writes_done_(false) {} Status WritesDone() override { - RETURN_NOT_OK(CheckBackpressure(0)); writes_done_ = true; return Status::OK(); } protected: - Status CheckBackpressure(size_t limit = kBackpressureThreshold - 1) { - while (requests_.size() > limit) { - auto& next = requests_.front(); - while (!next.is_finished()) { - driver_->MakeProgress(); - } - RETURN_NOT_OK(next.status()); - requests_.pop(); - } - return Status::OK(); - } - std::string peer_; UcpCallDriver* driver_; bool writes_done_; - std::queue> requests_; }; class GetServerStream : public UcxServerStream { @@ -117,14 +99,11 @@ class GetServerStream : public UcxServerStream { arrow::Result WriteData(const FlightPayload& payload) override { if (writes_done_) return false; - RETURN_NOT_OK(CheckBackpressure()); Future<> pending_send = driver_->SendFlightPayload(payload); - if (!pending_send.is_finished()) { - requests_.push(std::move(pending_send)); - } else { - // Request completed instantly - RETURN_NOT_OK(pending_send.status()); + while (!pending_send.is_finished()) { + driver_->MakeProgress(); } + RETURN_NOT_OK(pending_send.status()); return true; } }; @@ -188,8 +167,6 @@ class ExchangeServerStream : public PutServerStream { arrow::Result WriteData(const FlightPayload& payload) override { if (writes_done_) return false; - // Don't use backpressure - the application may expect synchronous - // behavior (write a message, read the client response) Future<> pending_send = driver_->SendFlightPayload(payload); while (!pending_send.is_finished()) { driver_->MakeProgress(); From 35fe9806edda9ae302ce440b8319b2c6a195ce74 Mon Sep 17 00:00:00 2001 From: David Li Date: Tue, 5 Apr 2022 10:45:44 -0400 Subject: [PATCH 15/16] ARROW-15706: [C++][FlightRPC] Update for deprecations --- .../ucx/flight_transport_ucx_test.cc | 47 +++++++------------ 1 file changed, 18 insertions(+), 29 deletions(-) diff --git a/cpp/src/arrow/flight/transport/ucx/flight_transport_ucx_test.cc b/cpp/src/arrow/flight/transport/ucx/flight_transport_ucx_test.cc index 2795a2996f5..6a580af92fd 100644 --- a/cpp/src/arrow/flight/transport/ucx/flight_transport_ucx_test.cc +++ b/cpp/src/arrow/flight/transport/ucx/flight_transport_ucx_test.cc @@ -296,40 +296,36 @@ class TestUcx : public ::testing::Test { TEST_F(TestUcx, GetFlightInfo) { auto descriptor = FlightDescriptor::Path({"foo", "bar"}); std::unique_ptr info; - ASSERT_OK(client_->GetFlightInfo(descriptor, &info)); + ASSERT_OK_AND_ASSIGN(info, client_->GetFlightInfo(descriptor)); // Test that we can reuse the connection - ASSERT_OK(client_->GetFlightInfo(descriptor, &info)); + ASSERT_OK_AND_ASSIGN(info, client_->GetFlightInfo(descriptor)); } TEST_F(TestUcx, SequentialClients) { - std::unique_ptr client2; - ASSERT_OK(FlightClient::Connect(server_->location(), FlightClientOptions::Defaults(), - &client2)); + ASSERT_OK_AND_ASSIGN( + auto client2, + FlightClient::Connect(server_->location(), FlightClientOptions::Defaults())); Ticket ticket{"a"}; - std::unique_ptr stream1, stream2; - - ASSERT_OK(client_->DoGet(ticket, &stream1)); + ASSERT_OK_AND_ASSIGN(auto stream1, client_->DoGet(ticket)); ASSERT_OK_AND_ASSIGN(auto table1, stream1->ToTable()); - ASSERT_OK(client2->DoGet(ticket, &stream2)); + ASSERT_OK_AND_ASSIGN(auto stream2, client2->DoGet(ticket)); ASSERT_OK_AND_ASSIGN(auto table2, stream2->ToTable()); AssertTablesEqual(*table1, *table2); } TEST_F(TestUcx, ConcurrentClients) { - std::unique_ptr client2; - ASSERT_OK(FlightClient::Connect(server_->location(), FlightClientOptions::Defaults(), - &client2)); + ASSERT_OK_AND_ASSIGN( + auto client2, + FlightClient::Connect(server_->location(), FlightClientOptions::Defaults())); Ticket ticket{"a"}; - std::unique_ptr stream1, stream2; - - ASSERT_OK(client_->DoGet(ticket, &stream1)); - ASSERT_OK(client2->DoGet(ticket, &stream2)); + ASSERT_OK_AND_ASSIGN(auto stream1, client_->DoGet(ticket)); + ASSERT_OK_AND_ASSIGN(auto stream2, client2->DoGet(ticket)); ASSERT_OK_AND_ASSIGN(auto table1, stream1->ToTable()); ASSERT_OK_AND_ASSIGN(auto table2, stream2->ToTable()); @@ -339,15 +335,13 @@ TEST_F(TestUcx, ConcurrentClients) { TEST_F(TestUcx, Errors) { auto descriptor = FlightDescriptor::Path({"error", "bar"}); - std::unique_ptr info; - auto* server = reinterpret_cast(server_.get()); for (const auto code : kStatusCodes) { if (code == StatusCode::OK) continue; Status expected(code, "Error message"); server->set_error_status(expected); - Status actual = client_->GetFlightInfo(descriptor, &info); + Status actual = client_->GetFlightInfo(descriptor).status(); ASSERT_EQ(actual, expected); // Attach a generic status detail @@ -357,7 +351,7 @@ TEST_F(TestUcx, Errors) { Status expected(code, "foo", std::make_shared(FlightStatusCode::Internal, detail->ToString())); - Status actual = client_->GetFlightInfo(descriptor, &info); + Status actual = client_->GetFlightInfo(descriptor).status(); ASSERT_EQ(actual, expected); } @@ -366,17 +360,13 @@ TEST_F(TestUcx, Errors) { Status expected(code, "Error message", std::make_shared(flight_code, "extra")); server->set_error_status(expected); - Status actual = client_->GetFlightInfo(descriptor, &info); + Status actual = client_->GetFlightInfo(descriptor).status(); ASSERT_EQ(actual, expected); } } } TEST(TestUcxIpV6, DISABLED_IpV6Port) { - // TODO(lidavidm): while we can listen on IPv6 fine, we can't - // actually connect to it (ucp_conn_request_h appears to point to a - // port where nobody is listening) - // Also, disabled in CI as machines lack an IPv6 interface ASSERT_OK_AND_ASSIGN(auto location, Location::ForScheme("ucx", "[::1]", 0)); @@ -384,13 +374,12 @@ TEST(TestUcxIpV6, DISABLED_IpV6Port) { FlightServerOptions server_options(location); ASSERT_OK(server->Init(server_options)); - std::unique_ptr client; FlightClientOptions client_options = FlightClientOptions::Defaults(); - ASSERT_OK(FlightClient::Connect(server->location(), client_options, &client)); + ASSERT_OK_AND_ASSIGN(auto client, + FlightClient::Connect(server->location(), client_options)); auto descriptor = FlightDescriptor::Path({"foo", "bar"}); - std::unique_ptr info; - ASSERT_OK(client->GetFlightInfo(descriptor, &info)); + ASSERT_OK_AND_ASSIGN(auto info, client->GetFlightInfo(descriptor)); } } // namespace flight From 6d7b0de90b0a4f59cab6774c77fde46f1ae371b7 Mon Sep 17 00:00:00 2001 From: David Li Date: Tue, 5 Apr 2022 10:57:56 -0400 Subject: [PATCH 16/16] ARROW-15706: [C++][FlightRPC] Fill in TODOs --- cpp/src/arrow/flight/transport/ucx/ucx_client.cc | 2 +- cpp/src/arrow/flight/transport/ucx/ucx_internal.cc | 8 ++++---- cpp/src/arrow/flight/transport/ucx/ucx_server.cc | 4 ++-- 3 files changed, 7 insertions(+), 7 deletions(-) diff --git a/cpp/src/arrow/flight/transport/ucx/ucx_client.cc b/cpp/src/arrow/flight/transport/ucx/ucx_client.cc index 3a216e0f726..173132062e5 100644 --- a/cpp/src/arrow/flight/transport/ucx/ucx_client.cc +++ b/cpp/src/arrow/flight/transport/ucx/ucx_client.cc @@ -673,7 +673,7 @@ class UcxClientImpl : public arrow::flight::internal::ClientTransport { Status ReturnConnection(ClientConnection conn) { std::unique_lock connections_mutex_; - // TODO(lidavidm): for future improvement: reclaim clients + // TODO(ARROW-16127): for future improvement: reclaim clients // asynchronously in the background (try to avoid issues like // constantly opening/closing clients because the application is // just barely over the limit of open connections) diff --git a/cpp/src/arrow/flight/transport/ucx/ucx_internal.cc b/cpp/src/arrow/flight/transport/ucx/ucx_internal.cc index 215b4611cf3..ab4cc323f4c 100644 --- a/cpp/src/arrow/flight/transport/ucx/ucx_internal.cc +++ b/cpp/src/arrow/flight/transport/ucx/ucx_internal.cc @@ -684,7 +684,7 @@ class UcpCallDriver::Impl { if (kEnableContigSend && all_cpu) { // CONTIG - concatenate buffers into one before sending - // TODO(lidavidm): this needs to be pipelined since it can be expensive. + // TODO(ARROW-16126): this needs to be pipelined since it can be expensive. // Preliminary profiling shows ~5% overhead just from mapping the buffer // alone (on Infiniband; it seems to be trivial for shared memory) request_param.datatype = ucp_dt_make_contig(1); @@ -929,8 +929,8 @@ class UcpCallDriver::Impl { } else { pending_send->completed.MarkFinished(FromUcsStatus("ucp_am_send_nbx", status)); } - // TODO(lidavidm): delete should occur on a background thread if there's mapped - // buffers, since unmapping can be nontrivial and we don't want to block + // TODO(ARROW-16126): delete should occur on a background thread if there's + // mapped buffers, since unmapping can be nontrivial and we don't want to block // the thread doing UCX work. (Borrow the Rust transfer-and-drop pattern.) delete pending_send; ucp_request_free(request); @@ -989,7 +989,7 @@ class UcpCallDriver::Impl { // We want to map/pin/register the buffer for faster transfer // where possible. (It gets unmapped in ~PendingAmRecv.) - // TODO(lidavidm): This takes non-trivial time, so return + // TODO(ARROW-16126): This takes non-trivial time, so return // UCS_INPROGRESS, kick off the allocation in the background, // and recv the data later (is it allowed to call // ucp_am_recv_data_nbx asynchronously?). diff --git a/cpp/src/arrow/flight/transport/ucx/ucx_server.cc b/cpp/src/arrow/flight/transport/ucx/ucx_server.cc index e77777d2be2..74a9311d0c8 100644 --- a/cpp/src/arrow/flight/transport/ucx/ucx_server.cc +++ b/cpp/src/arrow/flight/transport/ucx/ucx_server.cc @@ -325,7 +325,7 @@ class UcxServerImpl : public arrow::flight::internal::ServerTransport { } Status Shutdown(const std::chrono::system_clock::time_point& deadline) override { - // TODO(lidavidm): implement shutdown with deadline + // TODO(ARROW-16125): implement shutdown with deadline return Shutdown(); } @@ -578,7 +578,7 @@ class UcxServerImpl : public arrow::flight::internal::ServerTransport { static void HandleIncomingConnection(ucp_conn_request_h connection_request, void* data) { UcxServerImpl* server = reinterpret_cast(data); - // TODO(lidavidm): enable shedding load above some threshold + // TODO(ARROW-16124): enable shedding load above some threshold // (which is a pitfall with gRPC/Java) server->EnqueueClient(connection_request); }