From ed60fb22fe94395951dc0d9fa2f766f95f1f9782 Mon Sep 17 00:00:00 2001 From: Yibo Cai Date: Thu, 4 Nov 2021 07:21:52 +0000 Subject: [PATCH] [RFC] ARROW-15282: [C++][FlightRPC] Support non-grpc data planes This patch decouples flightrpc data plane from grpc so we can leverage optimized data transfer libraries. The basic idea is to replace grpc stream with a data plane stream for FlightData transmission in DoGet/DoPut/DoExchange. There's no big change to current flight client and server implementations. Added a wrapper to support both grpc stream and data plane stream. By default, grpc stream is used, which goes the current grpc based code path. If a data plane is enabled (currently through environment variable), flight payload will go through the data plane stream instead. See client.cc and server.cc to review the changes. **About data plane implementation** - data_plane/types.{h,cc} Defines client/server data plane and data plane stream interfaces. It's the only exported api to other component ({client,server}.cc). - data_plane/serialize.{h,cc} De-Serialize FlightData manually as we bypass grpc. Luckly, we already implemented related functions to support payload zero-copy. - shm.cc A shared memory driver to verify the data plane approach. The code may be a bit hard to read, it's better to focus on data plane interface implementations at first before dive deep into details like shared memory, ipc and buffer management related code. Please note there are still many caveats in current code, see TODO and XXX in shm.cc for details. **To evaluate this patch** I tested shared memory data plane on Linux (x86, Arm) and MacOS (Arm). Build with `-DARROW_FLIGHT_DP_SHM=ON` to enable the shared memory data plane. Set `FLIGHT_DATAPLANE=shm` environment variable to run unit tests and benchmarks with the shared memory data plane enabled. ``` Build: cmake -DARROW_FLIGHT_DP_SHM=ON -DARROW_FLIGHT=ON .... Test: FLIGHT_DATAPLANE=shm release/arrow-flight-test Bench: FLIGHT_DATAPLANE=shm release/arrow-flight-benchmark \ -num_streams=1|2|4 -num_threads=1|2|4 ``` Benchmark result (throughput, latency) on Xeon Gold 5218. Test case: DoGet, batch size = 128KiB | streams | grpc over unix socket | shared memory data plane | | ------- | --------------------- | ------------------------ | | 1 | 3324 MB/s, 35 us | 7045 MB/s, 16 us | | 2 | 6289 MB/s, 38 us | 13311 MB/s, 17 us | | 4 | 10037 MB/s, 44 us | 25012 MB/s, 17 us | --- cpp/cmake_modules/DefineOptions.cmake | 2 + cpp/src/arrow/flight/CMakeLists.txt | 11 +- cpp/src/arrow/flight/client.cc | 251 ++++-- cpp/src/arrow/flight/data_plane/internal.h | 45 + cpp/src/arrow/flight/data_plane/serialize.cc | 106 +++ cpp/src/arrow/flight/data_plane/serialize.h | 92 ++ cpp/src/arrow/flight/data_plane/shm.cc | 826 ++++++++++++++++++ cpp/src/arrow/flight/data_plane/types.cc | 127 +++ cpp/src/arrow/flight/data_plane/types.h | 161 ++++ cpp/src/arrow/flight/perf_server.cc | 2 +- cpp/src/arrow/flight/serialization_internal.h | 10 +- cpp/src/arrow/flight/server.cc | 200 +++-- cpp/src/arrow/flight/server.h | 8 + 13 files changed, 1710 insertions(+), 131 deletions(-) create mode 100644 cpp/src/arrow/flight/data_plane/internal.h create mode 100644 cpp/src/arrow/flight/data_plane/serialize.cc create mode 100644 cpp/src/arrow/flight/data_plane/serialize.h create mode 100644 cpp/src/arrow/flight/data_plane/shm.cc create mode 100644 cpp/src/arrow/flight/data_plane/types.cc create mode 100644 cpp/src/arrow/flight/data_plane/types.h diff --git a/cpp/cmake_modules/DefineOptions.cmake b/cpp/cmake_modules/DefineOptions.cmake index 0a43ec18f60..fcfc4f3f8ae 100644 --- a/cpp/cmake_modules/DefineOptions.cmake +++ b/cpp/cmake_modules/DefineOptions.cmake @@ -234,6 +234,8 @@ if("${CMAKE_SOURCE_DIR}" STREQUAL "${CMAKE_CURRENT_SOURCE_DIR}") define_option(ARROW_FLIGHT_SQL "Build the Arrow Flight SQL extension" OFF) + define_option(ARROW_FLIGHT_DP_SHM "Build the Arrow Flight shared memory data plane" OFF) + define_option(ARROW_GANDIVA "Build the Gandiva libraries" OFF) define_option(ARROW_GCS diff --git a/cpp/src/arrow/flight/CMakeLists.txt b/cpp/src/arrow/flight/CMakeLists.txt index 2cf8c9913e5..5dcbf32631c 100644 --- a/cpp/src/arrow/flight/CMakeLists.txt +++ b/cpp/src/arrow/flight/CMakeLists.txt @@ -151,6 +151,13 @@ endif() # Restore the CXXFLAGS that were modified above set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS_BACKUP}") +# Data plane source files +# TODO(yibo): separate common (serialize.cc) and driver (shm.cc) +if(ARROW_FLIGHT_DP_SHM) + add_definitions(-DFLIGHT_DP_SHM) + set(DATAPLANE_SRCS data_plane/serialize.cc data_plane/shm.cc) +endif() + # Note, we do not compile the generated Protobuf sources directly, instead # compiling then via protocol_internal.cc which contains some gRPC template # overrides to enable Flight-specific optimizations. See comments in @@ -164,7 +171,8 @@ set(ARROW_FLIGHT_SRCS serialization_internal.cc server.cc server_auth.cc - types.cc) + types.cc + data_plane/types.cc) add_arrow_lib(arrow_flight CMAKE_PACKAGE_NAME @@ -175,6 +183,7 @@ add_arrow_lib(arrow_flight ARROW_FLIGHT_LIBRARIES SOURCES ${ARROW_FLIGHT_SRCS} + ${DATAPLANE_SRCS} PRECOMPILED_HEADERS "$<$:arrow/flight/pch.h>" DEPENDENCIES diff --git a/cpp/src/arrow/flight/client.cc b/cpp/src/arrow/flight/client.cc index f9728f849ad..fed06bd5fd3 100644 --- a/cpp/src/arrow/flight/client.cc +++ b/cpp/src/arrow/flight/client.cc @@ -59,6 +59,8 @@ #include "arrow/flight/serialization_internal.h" #include "arrow/flight/types.h" +#include "arrow/flight/data_plane/types.h" + namespace arrow { namespace flight { @@ -127,6 +129,93 @@ struct ClientRpc { } }; +namespace internal { + +template +Status FinishClient(GrpcClientStream* grpc_stream, DataClientStream* data_stream, + ClientRpc* rpc) { + Status st1, st2; + if (data_stream) { + st1 = data_stream->Finish(); + } + st2 = internal::FromGrpcStatus(grpc_stream->Finish(), &rpc->context); + return st2.ok() ? st1 : st2; +} + +void CancelClient(DataClientStream* data_stream, ClientRpc* rpc) { + rpc->context.TryCancel(); + if (data_stream) { + data_stream->TryCancel(); + } +} + +template +Status ReadClient(GrpcClientStream* grpc_stream, DataClientStream* data_stream, + FlightData* payload) { + if (data_stream) { + return data_stream->Read(payload); + } else { + const bool ok = internal::ReadPayload(grpc_stream, payload); + return ok ? Status::OK() : Status::IOError("gRPC client read failed"); + } +} + +template +Status WriteClient(GrpcClientStream* grpc_stream, DataClientStream* data_stream, + const FlightPayload& payload) { + if (data_stream) { + return data_stream->Write(payload); + } else { + return internal::WritePayload(payload, grpc_stream); + } +} + +template +Status WritesDoneClient(GrpcClientStream* grpc_stream, DataClientStream* data_stream) { + const bool ok = grpc_stream->WritesDone(); + Status st; + if (data_stream) { + st = data_stream->WritesDone(); + } + return ok ? st : Status::IOError("gRPC WritesDone failed"); +} + +struct ClientReader { + using GrpcClientStream = grpc::ClientReader; + std::shared_ptr grpc_stream; + std::shared_ptr data_stream; + + Status Read(FlightData* payload) { + return ReadClient(grpc_stream.get(), data_stream.get(), payload); + } + Status Finish(ClientRpc* rpc) { + return FinishClient(grpc_stream.get(), data_stream.get(), rpc); + } + void Cancel(ClientRpc* rpc) { CancelClient(data_stream.get(), rpc); } +}; + +// ReadT can be pb::FlightData or pb::PutResult +template +struct ClientReaderWriter { + using GrpcClientStream = grpc::ClientReaderWriter; + std::shared_ptr grpc_stream; + std::shared_ptr data_stream; + + Status Read(FlightData* payload) { + return ReadClient(grpc_stream.get(), data_stream.get(), payload); + } + Status Write(const FlightPayload& payload) { + return WriteClient(grpc_stream.get(), data_stream.get(), payload); + } + Status WritesDone() { return WritesDoneClient(grpc_stream.get(), data_stream.get()); } + Status Finish(ClientRpc* rpc) { + return FinishClient(grpc_stream.get(), data_stream.get(), rpc); + } + void Cancel(ClientRpc* rpc) { CancelClient(data_stream.get(), rpc); } +}; + +} // namespace internal + /// Helper that manages Finish() of a gRPC stream. /// /// When we encounter an error (e.g. could not decode an IPC message), @@ -139,15 +228,21 @@ struct ClientRpc { /// /// The template lets us abstract between DoGet/DoExchange and DoPut, /// which respectively read internal::FlightData and pb::PutResult. +// +// Stream contains two shared_ptrs to grpc stream and possibly an +// additional data stream. See internal::ClientReader for details. template class FinishableStream { public: - FinishableStream(std::shared_ptr rpc, std::shared_ptr stream) - : rpc_(rpc), stream_(stream), finished_(false), server_status_() {} + FinishableStream(std::shared_ptr rpc, Stream stream) + : rpc_(std::move(rpc)), + stream_(std::move(stream)), + finished_(false), + server_status_() {} virtual ~FinishableStream() = default; /// \brief Get the underlying stream. - std::shared_ptr stream() const { return stream_; } + Stream stream() const { return stream_; } /// \brief Finish the call, adding server context to the given status. virtual Status Finish(Status st) { @@ -162,11 +257,11 @@ class FinishableStream { // indicate that it is done writing, but not done reading, it // should use DoneWriting. ReadT message; - while (internal::ReadPayload(stream_.get(), &message)) { + while (internal::ReadPayload(stream_.grpc_stream.get(), &message)) { // Drain the read side to avoid gRPC hanging in Finish() } - server_status_ = internal::FromGrpcStatus(stream_->Finish(), &rpc_->context); + server_status_ = stream_.Finish(rpc_.get()); finished_ = true; return MergeStatus(std::move(st)); @@ -184,7 +279,7 @@ class FinishableStream { } std::shared_ptr rpc_; - std::shared_ptr stream_; + Stream stream_; bool finished_; Status server_status_; }; @@ -197,9 +292,8 @@ template class FinishableWritableStream : public FinishableStream { public: FinishableWritableStream(std::shared_ptr rpc, - std::shared_ptr read_mutex, - std::shared_ptr stream) - : FinishableStream(rpc, stream), + std::shared_ptr read_mutex, Stream stream) + : FinishableStream(std::move(rpc), std::move(stream)), finish_mutex_(), read_mutex_(read_mutex), done_writing_(false) {} @@ -213,7 +307,7 @@ class FinishableWritableStream : public FinishableStream { return Status::OK(); } done_writing_ = true; - if (!this->stream()->WritesDone()) { + if (!this->stream().WritesDone().ok()) { // Error happened, try to close the stream to get more detailed info return Finish(MakeFlightError(FlightStatusCode::Internal, "Could not flush pending record batches")); @@ -239,7 +333,7 @@ class FinishableWritableStream : public FinishableStream { // Try to flush pending writes. Don't use our WritesDone() to // avoid recursion. - bool finished_writes = done_writing_ || this->stream()->WritesDone(); + bool finished_writes = done_writing_ || this->stream().WritesDone().ok(); done_writing_ = true; st = FinishableStream::Finish(std::move(st)); @@ -436,8 +530,7 @@ class GrpcIpcMessageReader : public ipc::MessageReader { GrpcIpcMessageReader( std::shared_ptr rpc, std::shared_ptr read_mutex, std::shared_ptr> stream, - std::shared_ptr>> - peekable_reader, + std::shared_ptr> peekable_reader, std::shared_ptr* app_metadata) : rpc_(rpc), read_mutex_(read_mutex), @@ -476,8 +569,7 @@ class GrpcIpcMessageReader : public ipc::MessageReader { // side calls Finish(). Nullable as DoGet doesn't need this. std::shared_ptr read_mutex_; std::shared_ptr> stream_; - std::shared_ptr>> - peekable_reader_; + std::shared_ptr> peekable_reader_; // A reference to GrpcStreamReader.app_metadata_. That class // can't access the app metadata because when it Peek()s the stream, // it may be looking at a dictionary batch, not the record @@ -490,18 +582,19 @@ class GrpcIpcMessageReader : public ipc::MessageReader { /// The implementation of the public-facing API for reading from a /// FlightData stream template -class GrpcStreamReader : public FlightStreamReader { +class ClientStreamReader : public FlightStreamReader { public: - GrpcStreamReader(std::shared_ptr rpc, std::shared_ptr read_mutex, - const ipc::IpcReadOptions& options, StopToken stop_token, - std::shared_ptr> stream) + ClientStreamReader( + std::shared_ptr rpc, std::shared_ptr read_mutex, + const ipc::IpcReadOptions& options, StopToken stop_token, + std::shared_ptr> stream) : rpc_(rpc), read_mutex_(read_mutex), options_(options), stop_token_(std::move(stop_token)), stream_(stream), - peekable_reader_(new internal::PeekableFlightDataReader>( - stream->stream())), + peekable_reader_( + new internal::PeekableFlightDataReader(stream->stream())), app_metadata_(nullptr) {} Status EnsureDataStarted() { @@ -584,7 +677,7 @@ class GrpcStreamReader : public FlightStreamReader { return ReadAll(table, stop_token_); } using FlightStreamReader::ReadAll; - void Cancel() override { rpc_->context.TryCancel(); } + void Cancel() override { stream_->stream().Cancel(rpc_.get()); } private: std::unique_lock TakeGuard() { @@ -608,8 +701,7 @@ class GrpcStreamReader : public FlightStreamReader { ipc::IpcReadOptions options_; StopToken stop_token_; std::shared_ptr> stream_; - std::shared_ptr>> - peekable_reader_; + std::shared_ptr> peekable_reader_; std::shared_ptr batch_reader_; std::shared_ptr app_metadata_; }; @@ -628,16 +720,16 @@ template class DoPutPayloadWriter; template -class GrpcStreamWriter : public FlightStreamWriter { +class ClientStreamWriter : public FlightStreamWriter { public: - ~GrpcStreamWriter() override = default; + ~ClientStreamWriter() override = default; - using GrpcStream = grpc::ClientReaderWriter; + using ClientStream = internal::ClientReaderWriter; - explicit GrpcStreamWriter( + explicit ClientStreamWriter( const FlightDescriptor& descriptor, std::shared_ptr rpc, int64_t write_size_limit_bytes, const ipc::IpcWriteOptions& options, - std::shared_ptr> writer) + std::shared_ptr> writer) : app_metadata_(nullptr), batch_writer_(nullptr), writer_(std::move(writer)), @@ -651,7 +743,7 @@ class GrpcStreamWriter : public FlightStreamWriter { const FlightDescriptor& descriptor, std::shared_ptr schema, const ipc::IpcWriteOptions& options, std::shared_ptr rpc, int64_t write_size_limit_bytes, - std::shared_ptr> writer, + std::shared_ptr> writer, std::unique_ptr* out); Status CheckStarted() { @@ -688,7 +780,7 @@ class GrpcStreamWriter : public FlightStreamWriter { Status WriteMetadata(std::shared_ptr app_metadata) override { FlightPayload payload{}; payload.app_metadata = app_metadata; - auto status = internal::WritePayload(payload, writer_->stream().get()); + auto status = writer_->stream().Write(payload); if (status.IsIOError()) { return writer_->Finish(MakeFlightError(FlightStatusCode::Internal, "Could not write metadata to stream")); @@ -741,7 +833,7 @@ class GrpcStreamWriter : public FlightStreamWriter { friend class DoPutPayloadWriter; std::shared_ptr app_metadata_; std::unique_ptr batch_writer_; - std::shared_ptr> writer_; + std::shared_ptr> writer_; // Fields used to lazy-initialize the IpcPayloadWriter. They're // invalid once Begin() is called. @@ -757,13 +849,13 @@ class GrpcStreamWriter : public FlightStreamWriter { template class DoPutPayloadWriter : public ipc::internal::IpcPayloadWriter { public: - using GrpcStream = grpc::ClientReaderWriter; + using ClientStream = internal::ClientReaderWriter; DoPutPayloadWriter( const FlightDescriptor& descriptor, std::shared_ptr rpc, int64_t write_size_limit_bytes, - std::shared_ptr> writer, - GrpcStreamWriter* stream_writer) + std::shared_ptr> writer, + ClientStreamWriter* stream_writer) : descriptor_(descriptor), rpc_(rpc), write_size_limit_bytes_(write_size_limit_bytes), @@ -809,7 +901,7 @@ class DoPutPayloadWriter : public ipc::internal::IpcPayloadWriter { } } - auto status = internal::WritePayload(payload, writer_->stream().get()); + auto status = writer_->stream().Write(payload); if (status.IsIOError()) { return writer_->Finish(MakeFlightError(FlightStatusCode::Internal, "Could not write record batch to stream")); @@ -826,21 +918,21 @@ class DoPutPayloadWriter : public ipc::internal::IpcPayloadWriter { const FlightDescriptor descriptor_; std::shared_ptr rpc_; int64_t write_size_limit_bytes_; - std::shared_ptr> writer_; + std::shared_ptr> writer_; bool first_payload_; - GrpcStreamWriter* stream_writer_; + ClientStreamWriter* stream_writer_; }; template -Status GrpcStreamWriter::Open( +Status ClientStreamWriter::Open( const FlightDescriptor& descriptor, std::shared_ptr schema, // this schema is nullable const ipc::IpcWriteOptions& options, std::shared_ptr rpc, int64_t write_size_limit_bytes, - std::shared_ptr> writer, + std::shared_ptr> writer, std::unique_ptr* out) { - std::unique_ptr> instance( - new GrpcStreamWriter( + std::unique_ptr> instance( + new ClientStreamWriter( descriptor, std::move(rpc), write_size_limit_bytes, options, writer)); if (schema) { // The schema was provided (DoPut). Eagerly write the schema and @@ -852,7 +944,7 @@ Status GrpcStreamWriter::Open( // calls Begin() to send data, we'll send a redundant descriptor. FlightPayload payload{}; RETURN_NOT_OK(internal::ToPayload(descriptor, &payload.descriptor)); - auto status = internal::WritePayload(payload, instance->writer_->stream().get()); + auto status = instance->writer_->stream().Write(payload); if (status.IsIOError()) { return writer->Finish(MakeFlightError(FlightStatusCode::Internal, "Could not write descriptor to stream")); @@ -917,6 +1009,8 @@ constexpr char kDummyRootCert[] = class FlightClient::FlightClientImpl { public: Status Connect(const Location& location, const FlightClientOptions& options) { + ARROW_ASSIGN_OR_RAISE(data_plane_, internal::ClientDataPlane::Make(location)); + const std::string& scheme = location.scheme(); std::stringstream grpc_uri; @@ -1194,19 +1288,25 @@ class FlightClient::FlightClientImpl { Status DoGet(const FlightCallOptions& options, const Ticket& ticket, std::unique_ptr* out) { - using StreamReader = GrpcStreamReader>; + using ClientStream = internal::ClientReader; + using StreamReader = ClientStreamReader; pb::Ticket pb_ticket; internal::ToProto(ticket, &pb_ticket); auto rpc = std::make_shared(options); RETURN_NOT_OK(rpc->SetToken(auth_handler_.get())); - std::shared_ptr> stream = + + ARROW_ASSIGN_OR_RAISE(auto data_stream, data_plane_->DoGet(&rpc->context)); + + std::shared_ptr> grpc_stream = stub_->DoGet(&rpc->context, pb_ticket); - auto finishable_stream = std::make_shared< - FinishableStream, internal::FlightData>>( - rpc, stream); - *out = std::unique_ptr(new StreamReader( - rpc, nullptr, options.read_options, options.stop_token, finishable_stream)); + auto stream = ClientStream{std::move(grpc_stream), std::move(data_stream)}; + auto finishable_stream = + std::make_shared>( + rpc, std::move(stream)); + *out = std::unique_ptr( + new StreamReader(std::move(rpc), nullptr, options.read_options, + options.stop_token, std::move(finishable_stream))); // Eagerly read the schema return static_cast(out->get())->EnsureDataStarted(); } @@ -1215,51 +1315,62 @@ class FlightClient::FlightClientImpl { const std::shared_ptr& schema, std::unique_ptr* out, std::unique_ptr* reader) { - using GrpcStream = grpc::ClientReaderWriter; - using StreamWriter = GrpcStreamWriter; + using ClientStream = internal::ClientReaderWriter; + using StreamWriter = ClientStreamWriter; auto rpc = std::make_shared(options); RETURN_NOT_OK(rpc->SetToken(auth_handler_.get())); - std::shared_ptr stream = stub_->DoPut(&rpc->context); + + ARROW_ASSIGN_OR_RAISE(auto data_stream, data_plane_->DoPut(&rpc->context)); + + std::shared_ptr> grpc_stream = + stub_->DoPut(&rpc->context); + auto stream = ClientStream{grpc_stream, std::move(data_stream)}; // The writer drains the reader on close to avoid hanging inside // gRPC. Concurrent reads are unsafe, so a mutex protects this operation. - std::shared_ptr read_mutex = std::make_shared(); + auto read_mutex = std::make_shared(); auto finishable_stream = - std::make_shared>( - rpc, read_mutex, stream); - *reader = - std::unique_ptr(new GrpcMetadataReader(stream, read_mutex)); - return StreamWriter::Open(descriptor, schema, options.write_options, rpc, - write_size_limit_bytes_, finishable_stream, out); + std::make_shared>( + rpc, read_mutex, std::move(stream)); + *reader = std::unique_ptr( + new GrpcMetadataReader(grpc_stream, read_mutex)); + return StreamWriter::Open(descriptor, schema, options.write_options, std::move(rpc), + write_size_limit_bytes_, std::move(finishable_stream), out); } Status DoExchange(const FlightCallOptions& options, const FlightDescriptor& descriptor, std::unique_ptr* writer, std::unique_ptr* reader) { - using GrpcStream = grpc::ClientReaderWriter; - using StreamReader = GrpcStreamReader; - using StreamWriter = GrpcStreamWriter; + using ClientStream = internal::ClientReaderWriter; + using StreamReader = ClientStreamReader; + using StreamWriter = ClientStreamWriter; auto rpc = std::make_shared(options); RETURN_NOT_OK(rpc->SetToken(auth_handler_.get())); - std::shared_ptr> stream = - stub_->DoExchange(&rpc->context); + + ARROW_ASSIGN_OR_RAISE(auto data_stream, data_plane_->DoExchange(&rpc->context)); + + std::shared_ptr> + grpc_stream = stub_->DoExchange(&rpc->context); + auto stream = ClientStream{std::move(grpc_stream), std::move(data_stream)}; // The writer drains the reader on close to avoid hanging inside // gRPC. Concurrent reads are unsafe, so a mutex protects this operation. - std::shared_ptr read_mutex = std::make_shared(); + auto read_mutex = std::make_shared(); auto finishable_stream = - std::make_shared>( - rpc, read_mutex, stream); + std::make_shared>( + rpc, read_mutex, std::move(stream)); *reader = std::unique_ptr(new StreamReader( rpc, read_mutex, options.read_options, options.stop_token, finishable_stream)); // Do not eagerly read the schema. There may be metadata messages // before any data is sent, or data may not be sent at all. - return StreamWriter::Open(descriptor, nullptr, options.write_options, rpc, - write_size_limit_bytes_, finishable_stream, writer); + return StreamWriter::Open(descriptor, nullptr, options.write_options, std::move(rpc), + write_size_limit_bytes_, std::move(finishable_stream), + writer); } private: std::unique_ptr stub_; + std::unique_ptr data_plane_; std::shared_ptr auth_handler_; #if defined(GRPC_NAMESPACE_FOR_TLS_CREDENTIALS_OPTIONS) // Scope the TlsServerAuthorizationCheckConfig to be at the class instance level, since diff --git a/cpp/src/arrow/flight/data_plane/internal.h b/cpp/src/arrow/flight/data_plane/internal.h new file mode 100644 index 00000000000..96117f8709a --- /dev/null +++ b/cpp/src/arrow/flight/data_plane/internal.h @@ -0,0 +1,45 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#pragma once + +#include +#include + +#include "arrow/result.h" + +namespace arrow { +namespace flight { +namespace internal { + +class ClientDataPlane; +class ServerDataPlane; + +enum class StreamType { kGet, kPut, kExchange }; + +struct DataPlaneMaker { + arrow::Result> (*make_client)(const std::string&); + arrow::Result> (*make_server)(const std::string&); +}; + +// data plane makers are defined in data plane drivers +DataPlaneMaker GetShmDataPlaneMaker(); +DataPlaneMaker GetUcxDataPlaneMaker(); + +} // namespace internal +} // namespace flight +} // namespace arrow diff --git a/cpp/src/arrow/flight/data_plane/serialize.cc b/cpp/src/arrow/flight/data_plane/serialize.cc new file mode 100644 index 00000000000..2ee2b5cf608 --- /dev/null +++ b/cpp/src/arrow/flight/data_plane/serialize.cc @@ -0,0 +1,106 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#include "arrow/flight/data_plane/serialize.h" +#include "arrow/flight/customize_protobuf.h" +#include "arrow/flight/internal.h" +#include "arrow/flight/serialization_internal.h" +#include "arrow/result.h" +#include "arrow/util/logging.h" +#include "arrow/util/make_unique.h" + +#include +#include + +#include + +namespace arrow { +namespace flight { +namespace internal { + +namespace { + +void ReleaseBuffer(void* buffer_ptr) { + delete reinterpret_cast*>(buffer_ptr); +} + +} // namespace + +Status Deserialize(std::shared_ptr buffer, FlightData* data) { + // hold the buffer + std::shared_ptr* buffer_ptr = new std::shared_ptr(std::move(buffer)); + + const grpc::Slice slice((*buffer_ptr)->mutable_data(), + static_cast((*buffer_ptr)->size()), &ReleaseBuffer, + buffer_ptr); + grpc::ByteBuffer bbuf = grpc::ByteBuffer(&slice, 1); + + { + // make sure GrpcBuffer::Wrap goes the wanted path + auto grpc_bbuf = *reinterpret_cast(&bbuf); + DCHECK_EQ(grpc_bbuf->type, GRPC_BB_RAW); + DCHECK_EQ(grpc_bbuf->data.raw.compression, GRPC_COMPRESS_NONE); + DCHECK_EQ(grpc_bbuf->data.raw.slice_buffer.count, 1); + grpc_slice slice = grpc_bbuf->data.raw.slice_buffer.slices[0]; + DCHECK_NE(slice.refcount, 0); + } + + // buffer ownership is transferred to "data" on success + const Status st = FromGrpcStatus(FlightDataDeserialize(&bbuf, data)); + if (!st.ok()) { + delete buffer_ptr; + } + return st; +} + +SerializeSlice::SerializeSlice(grpc::Slice&& slice) { + slice_ = arrow::internal::make_unique(std::move(slice)); +} +SerializeSlice::SerializeSlice(SerializeSlice&&) = default; +SerializeSlice::~SerializeSlice() = default; + +const uint8_t* SerializeSlice::data() const { return slice_->begin(); } +int64_t SerializeSlice::size() const { return static_cast(slice_->size()); } + +arrow::Result> Serialize(const FlightPayload& payload, + int64_t* total_size) { + RETURN_NOT_OK(payload.Validate()); + + grpc::ByteBuffer bbuf; + bool owner; + RETURN_NOT_OK(FromGrpcStatus(FlightDataSerialize(payload, &bbuf, &owner))); + + if (total_size) { + *total_size = static_cast(bbuf.Length()); + } + + // ByteBuffer::Dump doesn't copy data buffer, IIUC + std::vector grpc_slices; + RETURN_NOT_OK(FromGrpcStatus(bbuf.Dump(&grpc_slices))); + + // move grpc slice life cycle to returned serialize slice + std::vector slices; + for (auto& grpc_slice : grpc_slices) { + SerializeSlice slice(std::move(grpc_slice)); + slices.emplace_back(std::move(slice)); + } + return slices; +} + +} // namespace internal +} // namespace flight +} // namespace arrow diff --git a/cpp/src/arrow/flight/data_plane/serialize.h b/cpp/src/arrow/flight/data_plane/serialize.h new file mode 100644 index 00000000000..9c9496bd8d9 --- /dev/null +++ b/cpp/src/arrow/flight/data_plane/serialize.h @@ -0,0 +1,92 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#pragma once + +#include "arrow/buffer.h" +#include "arrow/result.h" + +#include +#include + +namespace grpc { + +class Slice; + +}; // namespace grpc + +namespace arrow { +namespace flight { + +struct FlightPayload; + +namespace internal { + +struct FlightData; + +// Reader buffer management +// # data plane receives data and creates/stores to one reader buffer +// # Deserialize() creates a new shared_ptr to hold the buffer, creates a +// gprc slice per that buffer, and installs destroyer callback, which +// deletes the created shared_ptr when grpc slice is freed +// # Deserialize() calls FlightDataDeserialize() which transfers grpc slice +// lifecyce to Buffer object (FlightData->body) managed by the end consumer +// * see GrpcBuffer::slice_ in serialization_internal.cc +// # releasing the reader buffer +// # end consumer frees FlightData +// # frees Buffer(GrpcBuffer) object +// # frees grpc slice GrpcBuffer::slice_ +// # invokes destroyer to release reader buffer + +// Reader buffer must be a continuous memory block, see GrpcBuffer::Wrap +// - FlightDataDeserialize will copy and flatten non-continuous blocks anyway +// - it's necessary to get destroyer called + +Status Deserialize(std::shared_ptr buffer, FlightData* data); + +// Writer buffer management +// # FlightDataSerialize() holds buffer (FlightPayload.ipc_msg.body_buffer[i]) +// in returned grpc bbuf +// * see SliceFromBuffer in serialization_internal.cc +// # dump grpc bbuf to a vector of grpc slice, then move to SerializeSlice[] +// # data plane sends data per returned SerializeSlice[] +// # releasing the writer buffer +// # data plane frees vector +// # frees grpc slice SerializeSlice::slice_ +// # release writer buffer + +// a simple wrapper of grpc::Slice, the only purpose is to hide grpc +// from data plane implementation +class SerializeSlice { + public: + explicit SerializeSlice(grpc::Slice&& slice); + SerializeSlice(SerializeSlice&&); + ~SerializeSlice(); + + const uint8_t* data() const; + int64_t size() const; + + private: + std::unique_ptr slice_; +}; + +arrow::Result> Serialize(const FlightPayload& payload, + int64_t* total_size = NULLPTR); + +} // namespace internal +} // namespace flight +} // namespace arrow diff --git a/cpp/src/arrow/flight/data_plane/shm.cc b/cpp/src/arrow/flight/data_plane/shm.cc new file mode 100644 index 00000000000..9bbcbef28fe --- /dev/null +++ b/cpp/src/arrow/flight/data_plane/shm.cc @@ -0,0 +1,826 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +// shared memory data plane driver +// - client and server use fifo (named pipe) for messaging +// - client and server mmap one pre-allocated shared buffer for flight payload +// exchanging, with best performance +// - if no shared buffer available, an one-shot buffer is created on demand for +// each read/write operation, performance is probably poor + +// TODO(yibo): +// - implement back pressure (common for all data planes) +// - replace fifo with unix socket for ipc +// * current fifo based approach may leave garbage ipc files on crash +// with unix socket: unlink ipc files immediately after open, pass fd +// * fix a race issue (see ShmFifo:~ShmFifo()) +// * socket based messaging may be re-usable by other data planes +// - IS IT POSSIBLE to use existing grpc control path for data plane messaging +// - improve buffer cache management +// * finer de-/allocation, better resource usage, drop one-shot buffer +// * better to be re-usable for other data planes + +// XXX: performance depends heavily on if the buffer cache is used effectively +// - if payload size > kBufferSize, cache cannot be used, performance suffers +// - cache capacity (kBufferCount) limits pending buffers not consumed by the +// reader, performance may suffer if reader is slower than the writer as the +// buffer cache is used up quickly + +// default buffer cache +static constexpr int kBufferCount = 4; +static constexpr int kBufferSize = 256 * 1024; + +#include "arrow/buffer.h" +#include "arrow/flight/data_plane/internal.h" +#include "arrow/flight/data_plane/serialize.h" +#include "arrow/flight/data_plane/types.h" +#include "arrow/result.h" +#include "arrow/util/bit_util.h" +#include "arrow/util/counting_semaphore.h" +#include "arrow/util/logging.h" +#include "arrow/util/make_unique.h" + +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include +#include +#include +#include +#include +#include +#include +#include + +namespace arrow { +namespace flight { +namespace internal { + +namespace { + +// stream map keys transferred together with grpc client metadata +// must be unique, no capital letter +// - name prefix of fifo and shared memory +const char kIpcNamePrefix[] = "flight-dataplane-shm-ipc"; + +// message between client and server +struct ShmMsg { + static constexpr uint32_t kMagic = 0xFEEDCAFE; + static constexpr int kShmNameSize = 64; + // Get + // # server -> client: GetData + // # client -> server: GetAck/Nak + // # server -> client: GetDone (WritesDone) + // Put + // # client -> server: PutData + // # server -> client: PutAck/Nak + // # client -> server: PutDone (WritesDone) + enum Type { + Min, + GetData, + PutData, + GetAck, + PutAck, + GetNak, + PutNak, + GetDone, + PutDone, + GetInvalid, + PutInvalid, + Finish, + Max + }; + + ShmMsg() = default; + + ShmMsg(Type type, int64_t size, const char* shm_name) + : magic(kMagic), type(type), size(size) { + // shm_name strlen is verified in Make + const int n = snprintf(this->shm_name, kShmNameSize, "%s", shm_name); + DCHECK(n >= 0 && n < kShmNameSize); + } + + static arrow::Result Make(Type type, int64_t size = 0, + const std::string& shm_name = "") { + DCHECK(type > Type::Min && type < Type::Max); + if (shm_name.size() >= kShmNameSize) { + return Status::IOError("shared memory name length greater than ", + int(kShmNameSize)); + } + return ShmMsg(type, size, shm_name.c_str()); + } + + // passed across process, no pointer + uint32_t magic = kMagic; + Type type = Type::Min; + int64_t size = 0; + char shm_name[kShmNameSize]{}; +}; + +// shared memory buffer +class ShmBuffer : public MutableBuffer { + public: + // create and map a new shared memory with specified name and size, called by client + static arrow::Result> Create(const std::string& name, + int64_t size) { + DCHECK_GT(size, 0); + + int fd = shm_open(name.c_str(), O_CREAT | O_EXCL | O_RDWR, 0666); + if (fd == -1) { + return Status::IOError("create shm: ", strerror(errno)); + } + if (ftruncate(fd, size) == -1) { + const int saved_errno = errno; + close(fd); + shm_unlink(name.c_str()); + return Status::IOError("ftruncate: ", strerror(saved_errno)); + } + + void* data = + mmap(NULL, static_cast(size), PROT_READ | PROT_WRITE, MAP_SHARED, fd, 0); + const int saved_errno = errno; + close(fd); + if (data == MAP_FAILED) { + shm_unlink(name.c_str()); + return Status::IOError("mmap: ", strerror(saved_errno)); + } + + return std::make_shared(reinterpret_cast(data), size); + } + + // open and map an existing shared memory, called by server + static arrow::Result> Open(const std::string& name, + int64_t size) { + DCHECK_GT(size, 0); + + int fd = shm_open(name.c_str(), O_RDWR, 0666); + if (fd == -1) { + return Status::IOError("open shm: ", strerror(errno)); + } + + void* data = + mmap(NULL, static_cast(size), PROT_READ | PROT_WRITE, MAP_SHARED, fd, 0); + const int saved_errno = errno; + // memory is mapped, we can close shm fd and delete rpc file + close(fd); + shm_unlink(name.c_str()); + if (data == MAP_FAILED) { + return Status::IOError("mmap: ", strerror(saved_errno)); + } + + return std::make_shared(reinterpret_cast(data), size); + } + + ShmBuffer(uint8_t* data, int64_t size) : MutableBuffer(data, size) {} + + ~ShmBuffer() override { munmap(mutable_data(), static_cast(size())); } +}; + +// per stream buffer cache +class ShmCache { + class CachedBuffer : public MutableBuffer { + public: + CachedBuffer(std::shared_ptr parent, atomic_uchar* refcnt) + : MutableBuffer(parent->mutable_data(), parent->size()), + parent_(std::move(parent)), + refcnt_(refcnt) {} + + // decreases reference count on destruction + ~CachedBuffer() { + const uint8_t current = atomic_fetch_sub(refcnt_, 1); + DCHECK(current == 1 || current == 2); + } + + private: + // holds the parent buffer as refcnt_ locates there + std::shared_ptr parent_; + atomic_uchar* refcnt_; + }; + + public: + static arrow::Result> Create(const std::string& name_prefix, + int buffer_count, + int buffer_size) { + return CreateOrOpen(name_prefix, buffer_count, buffer_size, /*create=*/true); + } + + static arrow::Result> Open(const std::string& name_prefix, + int buffer_count, + int buffer_size) { + return CreateOrOpen(name_prefix, buffer_count, buffer_size, /*create=*/false); + } + + ShmCache(const std::string& name_prefix, int64_t buffer_count, int64_t buffer_size, + int64_t buffer_offset, std::shared_ptr buffer, atomic_uchar* refcnt) + : name_prefix_(name_prefix), + buffer_count_(buffer_count), + buffer_size_(buffer_size), + buffer_offset_(buffer_offset), + buffer_(std::move(buffer)), + refcnt_(refcnt) {} + + // called by writer + arrow::Result>> CreateBuffer( + int64_t size, bool client) { + if (size <= buffer_size_) { + // acquire free cached buffer + for (int i = 0; i < buffer_count_; ++i) { + uint8_t expected = 0; + if (atomic_compare_exchange_strong(&refcnt_[i], &expected, 1)) { + auto buffer = std::make_shared( + SliceMutableBuffer(buffer_, buffer_offset_ + buffer_size_ * i, size), + &refcnt_[i]); + // append buffer index to name + return std::make_pair(cached_buffer_prefix() + std::to_string(i), + std::move(buffer)); + } + } + } + // fallback to one-shot buffer if no cached buffer available + static std::atomic counter{0}; + const std::string name = + name_prefix_ + (client ? "c" : "s") + std::to_string(++counter); + ARROW_ASSIGN_OR_RAISE(auto buffer, ShmBuffer::Create(name, size)); + return std::make_pair(name, std::move(buffer)); + } + + // called by reader + arrow::Result> OpenBuffer(const std::string& shm_name, + int64_t size) { + const std::string name_prefix = cached_buffer_prefix(); + if (shm_name.find(name_prefix) == 0) { + const int i = std::stoi( + shm_name.substr(name_prefix.size(), shm_name.size() - name_prefix.size())); + if (i < 0 || i >= buffer_count_) { + return Status::IOError("invalid buffer index"); + } + // normally, writer is still holding the buffer, refcnt should be 1 + // but writer stream may have be destroyed which decreased refcnt to 0 + const uint8_t current = atomic_fetch_add(&refcnt_[i], 1); + DCHECK(current == 0 || current == 1); + return std::make_shared( + SliceMutableBuffer(buffer_, buffer_offset_ + buffer_size_ * i, size), + &refcnt_[i]); + } + return ShmBuffer::Open(shm_name, size); + } + + private: + static arrow::Result> CreateOrOpen( + const std::string& name_prefix, int64_t buffer_count, int64_t buffer_size, + bool create) { + if (buffer_count < 0 || buffer_size < 0) { + return Status::Invalid("invalid buffer count or size"); + } + if (buffer_count == 0 || buffer_size == 0) { + return arrow::internal::make_unique(name_prefix, 0, 0, 0, nullptr, + nullptr); + } + + // allocate all buffers at once, prepend refcnt[] array + buffer_size = bit_util::RoundUpToPowerOf2(buffer_size, 64); + const int64_t buffer_offset = bit_util::RoundUpToPowerOf2(buffer_count, 64); + const int64_t total_size = buffer_offset + buffer_size * buffer_count; + ARROW_ASSIGN_OR_RAISE(auto buffer, + create ? ShmBuffer::Create(name_prefix + "cc-shm", total_size) + : ShmBuffer::Open(name_prefix + "cc-shm", total_size)); + atomic_uchar* refcnt = reinterpret_cast(buffer->mutable_data()); + if (create) { + std::memset(refcnt, 0, buffer_count); + } + return arrow::internal::make_unique( + name_prefix, buffer_count, buffer_size, buffer_offset, std::move(buffer), refcnt); + } + + std::string cached_buffer_prefix() { return name_prefix_ + "cc-buf"; } + + const std::string name_prefix_; + // shared buffer + // - all cached buffers are allocate in one continuous buffer + // - prepends refcnt[] array with each element refers to one buffer + // * shared and accessed as atomic variables by client/server + // * acquiring steps: 0 - free, 1 - created by writer, 2 - opened by reader + // * releasing steps: both writer and reader decreased refcnt by 1 + // - memory layout + // +----------------+--------------+-----+------------------+ + // content | refcnt[count_] | buffer0 | ... | buffer[count_-1] | + // +----------------+--------------+-----+------------------+ + // size | buffer_count_ | buffer_size_ | ... | buffer_size_ | + // +----------------+--------------+-----+------------------+ + // ^ ^ + // | | + // 0 buffer_offset_ + const int64_t buffer_count_, buffer_size_, buffer_offset_; + std::shared_ptr buffer_; + atomic_uchar* refcnt_; + static_assert(sizeof(atomic_uchar) == 1, ""); +}; + +// fifo for client and server messages +class ShmFifo { + public: + // create and open two fifos for read/write, called by client + static arrow::Result> Create( + const std::string& fifo_name, std::unique_ptr&& shm_cache) { + const std::string reader_path = "/tmp/" + fifo_name + "s2c"; + const std::string writer_path = "/tmp/" + fifo_name + "c2s"; + if (mkfifo(reader_path.c_str(), 0666) == -1) { + return Status::IOError("mkfifo: ", strerror(errno)); + } + if (mkfifo(writer_path.c_str(), 0666) == -1) { + unlink(reader_path.c_str()); + return Status::IOError("mkfifo: ", strerror(errno)); + } + // fifo open hangs if RDONLY or WRONLY (until peer opens) + int reader_fd = open(reader_path.c_str(), O_RDWR); + int writer_fd = open(writer_path.c_str(), O_RDWR); + if (reader_fd == -1 || writer_fd == -1) { + unlink(reader_path.c_str()); + unlink(writer_path.c_str()); + if (reader_fd != -1) close(reader_fd); + if (writer_fd != -1) close(writer_fd); + return Status::IOError("create fifo"); + } + return arrow::internal::make_unique(std::move(shm_cache), reader_fd, + writer_fd); + } + + // open existing fifos, called by server + static arrow::Result> Open( + const std::string& fifo_name, std::unique_ptr&& shm_cache) { + const std::string reader_path = "/tmp/" + fifo_name + "c2s"; + const std::string writer_path = "/tmp/" + fifo_name + "s2c"; + int reader_fd = open(reader_path.c_str(), O_RDWR); + int writer_fd = open(writer_path.c_str(), O_RDWR); + // we've opened the fifos, unlink fifo files + unlink(reader_path.c_str()); + unlink(writer_path.c_str()); + if (reader_fd == -1 || writer_fd == -1) { + if (reader_fd != -1) close(reader_fd); + if (writer_fd != -1) close(writer_fd); + return Status::IOError("open fifos"); + } + return arrow::internal::make_unique(std::move(shm_cache), reader_fd, + writer_fd); + } + + ShmFifo(std::unique_ptr&& shm_cache, int reader_fd, int writer_fd) + : shm_cache_(std::move(shm_cache)), reader_fd_(reader_fd), writer_fd_(writer_fd) { + pipe(pipe_fds_); + reader_thread_ = std::thread([this] { ReaderThread(); }); + } + + ~ShmFifo() { + StopReaderThread(); + close(pipe_fds_[0]); + close(pipe_fds_[1]); + close(reader_fd_); + // XXX: if writer fifo closes immediately after writing finish message + // the reader fifo may not see the messag and timeout (both ends must + // be open for fifo to work properly) + // it won't happen after replacing fifo with unix socket + // below horrible code is a temporary workaround for this issue + int writer_fd = writer_fd_; + std::thread([writer_fd]() { + sleep(1); + close(writer_fd); + }).detach(); + } + + Status WriteData(ShmMsg::Type msg_type, std::shared_ptr buffer, + const std::string& shm_name) { + DCHECK(msg_type == ShmMsg::GetData || msg_type == ShmMsg::PutData); + if (peer_finished_) { + return Status::IOError("peer finished or cancelled"); + } + + const int64_t size = buffer->size(); + ARROW_ASSIGN_OR_RAISE(ShmMsg msg, ShmMsg::Make(msg_type, size, shm_name)); + + // hold write buffer in held_writers_ map to wait for peer response + // it must be done before writing os fifo, in case peer responses fast + { + const std::lock_guard lock(held_writers_mtx_); + DCHECK_EQ(held_writers_.find(shm_name), held_writers_.end()); + held_writers_[shm_name] = std::move(buffer); + } + + // write to os fifo + const Status st = OsWriteFifo(msg); + if (!st.ok()) { + // release the buffer on error + const std::lock_guard lock(held_writers_mtx_); + const auto n = held_writers_.erase(shm_name); + DCHECK_EQ(n, 1); + } + return st; + } + + Status WriteCtrl(ShmMsg::Type msg_type) { + ARROW_ASSIGN_OR_RAISE(ShmMsg msg, ShmMsg::Make(msg_type)); + return OsWriteFifo(msg); + } + + arrow::Result>> ReadMsg() { + RETURN_NOT_OK(reader_sem_.Acquire(1)); + const std::lock_guard lock(reader_queue_mtx_); + auto v = reader_queue_.front(); + reader_queue_.pop(); + return std::move(v); + } + + arrow::Result>> CreateBuffer( + int64_t size, bool client) { + return shm_cache_->CreateBuffer(size, client); + } + + private: + void ReaderThread() { + ShmMsg msg; + while (OsReadFifo(&msg).ok()) { + DCHECK_EQ(msg.magic, ShmMsg::kMagic); + + switch (msg.type) { + // data + case ShmMsg::GetData: + case ShmMsg::PutData: + // create buffer per received message, append to reader queue + { + ShmMsg::Type msg_type = msg.type; + std::shared_ptr buffer; + auto result = shm_cache_->OpenBuffer(msg.shm_name, msg.size); + if (result.ok()) { + buffer = result.ValueOrDie(); + } else { + msg_type = + msg.type == ShmMsg::GetData ? ShmMsg::GetInvalid : ShmMsg::PutInvalid; + } + { + std::lock_guard lock(reader_queue_mtx_); + reader_queue_.emplace(msg_type, std::move(buffer)); + } + ARROW_UNUSED(reader_sem_.Release(1)); + // send response so the writer can free its buffer + if (msg.type == ShmMsg::GetData) { + msg.type = result.ok() ? ShmMsg::GetAck : ShmMsg::GetNak; + } else { + msg.type = result.ok() ? ShmMsg::PutAck : ShmMsg::PutNak; + } + DCHECK_OK(OsWriteFifo(msg)); + } + break; + // writes done, error, finish + case ShmMsg::GetDone: + case ShmMsg::PutDone: + case ShmMsg::GetInvalid: + case ShmMsg::PutInvalid: + case ShmMsg::Finish: { + const std::lock_guard lock(reader_queue_mtx_); + reader_queue_.emplace(msg.type, std::shared_ptr()); + } + ARROW_UNUSED(reader_sem_.Release(1)); + if (msg.type == ShmMsg::Finish) { + peer_finished_ = true; + } + break; + // data response + case ShmMsg::GetAck: + case ShmMsg::GetNak: + case ShmMsg::PutAck: + case ShmMsg::PutNak: + // release according write buffer + { + const std::lock_guard lock(held_writers_mtx_); + const auto n = held_writers_.erase(msg.shm_name); + DCHECK_EQ(n, 1); + } + break; + default: + DCHECK(false); + break; + } + + std::memset(&msg, 0, sizeof(ShmMsg)); + } + } + + // make sure to write/read a full message per call + Status OsWriteFifo(const ShmMsg& msg) { + if (fifo_write_error_) { + return Status::IOError("fifo write error"); + } + const uint8_t* buf = reinterpret_cast(&msg); + size_t count = sizeof(ShmMsg); + while (count > 0) { + ssize_t ret = write(writer_fd_, buf, count); + if (ret == -1 && errno != EINTR) { + fifo_write_error_ = true; + return Status::IOError("write: ", strerror(errno)); + } + count -= ret; + buf += ret; + } + return Status::OK(); + } + + Status OsReadFifo(ShmMsg* msg) { + struct pollfd fds[2]; + fds[0].fd = reader_fd_; + fds[0].events = POLLIN; + fds[1].fd = pipe_fds_[0]; + fds[1].events = POLLIN; + + uint8_t* buf = reinterpret_cast(msg); + size_t count = sizeof(ShmMsg); + while (count > 0) { + // force checking stop token every 5 seconds, in case pipe method fails + if (poll(fds, 2, 5000) == -1) { + if (errno == EINTR) { + continue; + } + return Status::IOError("poll: ", strerror(errno)); + } + if ((fds[1].revents & POLLIN) || stop_) { + // exit thread if pipe received something or stop token is set + return Status::IOError("stop requested"); + } + if (fds[0].revents & (POLLERR | POLLHUP | POLLNVAL)) { + return Status::IOError("error polled"); + } + if (fds[0].revents & POLLIN) { + ssize_t ret = read(reader_fd_, buf, count); + count -= ret; + buf += ret; + } + } + return Status::OK(); + } + + void StopReaderThread() { + stop_ = true; + while (write(pipe_fds_[1], "s", 1) == -1 && errno == EINTR) { + } + reader_thread_.join(); + } + + std::unique_ptr shm_cache_; + // os fifo fd + int reader_fd_, writer_fd_; + bool fifo_write_error_{false}; + std::thread reader_thread_; + // self pipe to stop reader thread + int pipe_fds_[2]; + // force stoping reader thread in case pipe method fails + std::atomic stop_{false}; + // read message queue with timeout (XXX: suitable value?) + arrow::util::CountingSemaphore reader_sem_{/*initial=*/0, /*timeout_seconds=*/5}; + std::queue>> reader_queue_; + std::mutex reader_queue_mtx_; + // write buffers cannot be freed before peer response + std::unordered_map> held_writers_; + std::mutex held_writers_mtx_; + // received finish or cancel message, cannot write anymore + std::atomic peer_finished_{false}; +}; + +struct ShmStreamImpl { + ShmStreamImpl(bool client, StreamType stream_type, std::unique_ptr&& shm_fifo) + : client_(client), stream_type_(stream_type), shm_fifo_(std::move(shm_fifo)) {} + + Status Read(FlightData* data) { + DCHECK_NE(stream_type_, client_ ? StreamType::kPut : StreamType::kGet); + const std::string prefix = client_ ? "client: " : "server: "; + if (reads_done_) { + return Status::IOError(prefix, "reads done"); + } + ShmMsg::Type msg_type; + std::shared_ptr buffer; + ARROW_ASSIGN_OR_RAISE(std::tie(msg_type, buffer), shm_fifo_->ReadMsg()); + switch (msg_type) { + case ShmMsg::GetData: + case ShmMsg::PutData: + DCHECK_EQ(msg_type == ShmMsg::GetData, client_); + return Deserialize(std::move(buffer), data); + case ShmMsg::GetDone: + case ShmMsg::PutDone: + DCHECK_EQ(msg_type == ShmMsg::GetDone, client_); + reads_done_ = true; + return Status::IOError(prefix, "peer done writing"); + case ShmMsg::GetInvalid: + case ShmMsg::PutInvalid: + DCHECK_EQ(msg_type == ShmMsg::GetInvalid, client_); + return Status::Invalid(prefix, "recevied invalid payload"); + case ShmMsg::Finish: + DCHECK(!client_); + reads_done_ = writes_done_ = true; + return Status::IOError(prefix, "client finished"); + default: + DCHECK(false); + return Status::Invalid(prefix, "received invalid message"); + } + } + + Status Write(const FlightPayload& payload) { + DCHECK_NE(stream_type_, client_ ? StreamType::kGet : StreamType::kPut); + if (writes_done_) { + return Status::IOError(client_ ? "client" : "server", ": writes done"); + } + int64_t total_size; + auto result = Serialize(payload, &total_size); + if (!result.ok()) { + ARROW_UNUSED( + shm_fifo_->WriteCtrl(client_ ? ShmMsg::PutInvalid : ShmMsg::GetInvalid)); + return result.status(); + } + DCHECK_GT(total_size, 0); + const std::vector& slices = result.ValueOrDie(); + + std::string shm_name; + std::shared_ptr buffer; + ARROW_ASSIGN_OR_RAISE(std::tie(shm_name, buffer), + shm_fifo_->CreateBuffer(total_size, client_)); + CopySlicesToBuffer(slices, buffer.get()); + return shm_fifo_->WriteData(client_ ? ShmMsg::PutData : ShmMsg::GetData, + std::move(buffer), shm_name); + } + + Status WritesDone() { + DCHECK_NE(stream_type_, client_ ? StreamType::kGet : StreamType::kPut); + if (writes_done_) { + return Status::OK(); + } + writes_done_ = true; + return shm_fifo_->WriteCtrl(client_ ? ShmMsg::PutDone : ShmMsg::GetDone); + } + + static void CopySlicesToBuffer(const std::vector& slices, + Buffer* buffer) { + uint8_t* dest_ptr = buffer->mutable_data(); + for (const auto& slice : slices) { + DCHECK_LE(dest_ptr + slice.size(), buffer->data() + buffer->size()); + std::memcpy(dest_ptr, slice.data(), static_cast(slice.size())); + dest_ptr += slice.size(); + } + DCHECK_EQ(dest_ptr, buffer->data() + buffer->size()); + } + + const bool client_; + const StreamType stream_type_; + std::unique_ptr shm_fifo_; + std::atomic reads_done_{false}, writes_done_{false}; +}; + +class ShmClientStream : public DataClientStream { + public: + ShmClientStream(StreamType stream_type, std::unique_ptr&& shm_fifo) + : stream_(/*client=*/true, stream_type, std::move(shm_fifo)) {} + + ~ShmClientStream() { ARROW_UNUSED(Finish()); } + + Status Read(FlightData* data) override { return stream_.Read(data); } + Status Write(const FlightPayload& payload) override { return stream_.Write(payload); } + Status WritesDone() override { return stream_.WritesDone(); } + + Status Finish() override { + stream_.writes_done_ = stream_.reads_done_ = true; + return stream_.shm_fifo_->WriteCtrl(ShmMsg::Finish); + } + + void TryCancel() override { + ARROW_UNUSED(stream_.shm_fifo_->WriteCtrl(ShmMsg::Finish)); + } + + private: + ShmStreamImpl stream_; +}; + +class ShmServerStream : public DataServerStream { + public: + ShmServerStream(StreamType stream_type, std::unique_ptr&& shm_fifo) + : stream_(/*client=*/false, stream_type, std::move(shm_fifo)) {} + + ~ShmServerStream() { + if (stream_.stream_type_ != StreamType::kPut) { + ARROW_UNUSED(WritesDone()); + } + } + + Status Read(FlightData* data) override { return stream_.Read(data); } + Status Write(const FlightPayload& payload) override { return stream_.Write(payload); } + Status WritesDone() override { return stream_.WritesDone(); } + + private: + ShmStreamImpl stream_; +}; + +class ShmClientDataPlane : public ClientDataPlane { + private: + ResultClientStream DoGetImpl(StreamMap* map) override { + return Do(map, StreamType::kGet); + } + + ResultClientStream DoPutImpl(StreamMap* map) override { + return Do(map, StreamType::kPut); + } + + ResultClientStream DoExchangeImpl(StreamMap* map) override { + return Do(map, StreamType::kExchange); + } + + ResultClientStream Do(StreamMap* map, StreamType stream_type) { + const std::string name_prefix = GenerateNamePrefix(); + (*map)[kIpcNamePrefix] = name_prefix; + ARROW_ASSIGN_OR_RAISE(auto shm_cache, + ShmCache::Create(name_prefix, kBufferCount, kBufferSize)); + ARROW_ASSIGN_OR_RAISE(auto shm_fifo, + ShmFifo::Create(name_prefix, std::move(shm_cache))); + return arrow::internal::make_unique(stream_type, + std::move(shm_fifo)); + } + + // generate system unique name prefix for ipc objects + // prefix = "/flight-shm-{client pid}-{stream counter}-" + std::string GenerateNamePrefix() { + static std::atomic counter{0}; + std::stringstream name_prefix; + name_prefix << "/flight-shm-" << getpid() << '-' << ++counter << '-'; + return name_prefix.str(); + } +}; + +class ShmServerDataPlane : public ServerDataPlane { + private: + ResultServerStream DoGetImpl(const StreamMap& map) override { + return Do(map, StreamType::kGet); + } + + ResultServerStream DoPutImpl(const StreamMap& map) override { + return Do(map, StreamType::kPut); + } + + ResultServerStream DoExchangeImpl(const StreamMap& map) override { + return Do(map, StreamType::kExchange); + } + + std::vector stream_keys() override { return {kIpcNamePrefix}; } + + ResultServerStream Do(const StreamMap& map, StreamType stream_type) { + ARROW_ASSIGN_OR_RAISE(auto name_prefix, GetNamePrefix(map)); + ARROW_ASSIGN_OR_RAISE(auto shm_cache, + ShmCache::Open(name_prefix, kBufferCount, kBufferSize)); + ARROW_ASSIGN_OR_RAISE(auto shm_fifo, + ShmFifo::Open(name_prefix, std::move(shm_cache))); + return arrow::internal::make_unique(stream_type, + std::move(shm_fifo)); + } + + // extract ipc objecs name prefix from stream map set by client + arrow::Result GetNamePrefix(const StreamMap& map) { + auto it = map.find(kIpcNamePrefix); + if (it == map.end()) { + return Status::Invalid("key not found: ", kIpcNamePrefix); + } + return it->second; + } +}; + +arrow::Result> MakeShmClientDataPlane( + const std::string&) { + return arrow::internal::make_unique(); +} + +arrow::Result> MakeShmServerDataPlane( + const std::string&) { + return arrow::internal::make_unique(); +} + +} // namespace + +DataPlaneMaker GetShmDataPlaneMaker() { + return {MakeShmClientDataPlane, MakeShmServerDataPlane}; +} + +} // namespace internal +} // namespace flight +} // namespace arrow diff --git a/cpp/src/arrow/flight/data_plane/types.cc b/cpp/src/arrow/flight/data_plane/types.cc new file mode 100644 index 00000000000..7733bb1feb6 --- /dev/null +++ b/cpp/src/arrow/flight/data_plane/types.cc @@ -0,0 +1,127 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#include "arrow/flight/data_plane/types.h" +#include "arrow/flight/data_plane/internal.h" +#include "arrow/util/io_util.h" +#include "arrow/util/logging.h" +#include "arrow/util/make_unique.h" + +#ifdef GRPCPP_PP_INCLUDE +#include +#else +#include +#endif + +namespace arrow { +namespace flight { +namespace internal { + +namespace { + +// data plane registry (name -> data plane maker) +struct Registry { + std::map makers; + + // register all data planes on creation of registry singleton + Registry() { +#ifdef FLIGHT_DP_SHM + Register("shm", GetShmDataPlaneMaker()); +#endif + } + + static const Registry& instance() { + static const Registry registry; + return registry; + } + + void Register(const std::string& name, const DataPlaneMaker& maker) { + DCHECK_EQ(makers.find(name), makers.end()); + makers[name] = maker; + } + + arrow::Result GetDataPlaneMaker(const std::string& uri) const { + const std::string name = uri.substr(0, uri.find(':')); + auto it = makers.find(name); + if (it == makers.end()) { + return Status::Invalid("unknown data plane: ", name); + } + return it->second; + } +}; + +std::string GetGrpcMetadata(const grpc::ServerContext& context, const std::string& key) { + const auto client_metadata = context.client_metadata(); + const auto found = client_metadata.find(key); + std::string token; + if (found == client_metadata.end()) { + DCHECK(false); + token = ""; + } else { + token = std::string(found->second.data(), found->second.length()); + } + return token; +} + +// TODO(yibo): getting data plane uri from env var is bad, shall we extend +// location to support two uri (control, data)? or any better approach to +// negotiate data plane uri? +std::string DataUriFromLocation(const Location& /*location*/) { + auto result = arrow::internal::GetEnvVar("FLIGHT_DATAPLANE"); + if (result.ok()) { + return result.ValueOrDie(); + } + return ""; // empty uri -> default grpc data plane +} + +} // namespace + +arrow::Result> ClientDataPlane::Make( + const Location& location) { + const std::string uri = DataUriFromLocation(location); + if (uri.empty()) return arrow::internal::make_unique(); + ARROW_ASSIGN_OR_RAISE(auto maker, Registry::instance().GetDataPlaneMaker(uri)); + return maker.make_client(uri); +} + +arrow::Result> ServerDataPlane::Make( + const Location& location) { + const std::string uri = DataUriFromLocation(location); + if (uri.empty()) return arrow::internal::make_unique(); + ARROW_ASSIGN_OR_RAISE(auto maker, Registry::instance().GetDataPlaneMaker(uri)); + return maker.make_server(uri); +} + +void ClientDataPlane::AppendStreamMap( + grpc::ClientContext* context, const std::map& stream_map) { + for (const auto& kv : stream_map) { + context->AddMetadata(kv.first, kv.second); + } +} + +std::map ServerDataPlane::ExtractStreamMap( + const grpc::ServerContext& context) { + std::map stream_map; + for (const auto& key : stream_keys()) { + stream_map[key] = GetGrpcMetadata(context, key); + } + return stream_map; +} + +} // namespace internal +} // namespace flight +} // namespace arrow diff --git a/cpp/src/arrow/flight/data_plane/types.h b/cpp/src/arrow/flight/data_plane/types.h new file mode 100644 index 00000000000..aa2ced776c2 --- /dev/null +++ b/cpp/src/arrow/flight/data_plane/types.h @@ -0,0 +1,161 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#pragma once + +#include +#include +#include +#include + +#include "arrow/result.h" + +namespace grpc { + +class ClientContext; +class ServerContext; + +}; // namespace grpc + +namespace arrow { +namespace flight { + +struct FlightPayload; +struct Location; + +namespace internal { + +struct FlightData; + +// Data plane stream simulates grpc::Client|ServerReader[Writer] +// - DataClientStream is created before DataServerStream +// - all interfaces should return IOError on data plane related errors +// - WritesDone closes the writer side, the operation should be idempotent and +// after that, the writer should return IOError on "Write", and the reader +// should return IOError on "Read" (pending reads on peer is not discarded) +// - Finish closes the client, no more read/write can be performed, it should +// be idempotent and after that, both the client and server should return +// IOError on "Read" or "Write" (pending reads on server is not discarded) +// - TryCancel tells server to stop writing, should be idempotent +// - data race +// * a single data plane may create several client/server stream pairs which +// run in parallel +// * for a single stream, Read may run in parallel with Write (DoExchange), +// Read/Read and Write/Write normally run in sequence (DoGet/DoPut), but +// should behave corretly if multiple Reads or Writes run in parallel + +// NOTE: grpc defines separated classes for Reader/Writer/ReaderWriter, +// data plane implements all the read/write interfaces in one class + +struct DataClientStream { + virtual ~DataClientStream() = default; + + virtual Status Read(FlightData* data) = 0; + virtual Status Write(const FlightPayload& payload) = 0; + virtual Status WritesDone() = 0; + virtual Status Finish() = 0; + virtual void TryCancel() = 0; +}; + +struct DataServerStream { + virtual ~DataServerStream() = default; + + virtual Status Read(FlightData* data) = 0; + virtual Status Write(const FlightPayload& payload) = 0; + // grpc doesn't implement server writes done + virtual Status WritesDone() = 0; +}; + +// Data plane is initialized at client/server startup, it creates data streams +// to replace grpc streams for flight payload transmission (get/put/exchange). + +class ClientDataPlane { + public: + // client can send a {str:str} map to server together with grpc metadata + // this is useful to match client data stream with server data stream + using StreamMap = std::map; + using ResultClientStream = arrow::Result>; + + virtual ~ClientDataPlane() = default; + + static arrow::Result> Make(const Location& location); + + ResultClientStream DoGet(grpc::ClientContext* context) { + StreamMap stream_map; + ARROW_ASSIGN_OR_RAISE(auto data_stream, DoGetImpl(&stream_map)); + AppendStreamMap(context, stream_map); + return data_stream; + } + + ResultClientStream DoPut(grpc::ClientContext* context) { + StreamMap stream_map; + ARROW_ASSIGN_OR_RAISE(auto data_stream, DoPutImpl(&stream_map)); + AppendStreamMap(context, stream_map); + return data_stream; + } + + ResultClientStream DoExchange(grpc::ClientContext* context) { + StreamMap stream_map; + ARROW_ASSIGN_OR_RAISE(auto data_stream, DoExchangeImpl(&stream_map)); + AppendStreamMap(context, stream_map); + return data_stream; + } + + private: + // implement empty data plane in base class + virtual ResultClientStream DoGetImpl(StreamMap* stream_map) { return NULLPTR; } + virtual ResultClientStream DoPutImpl(StreamMap* stream_map) { return NULLPTR; } + virtual ResultClientStream DoExchangeImpl(StreamMap* stream_map) { return NULLPTR; } + + void AppendStreamMap(grpc::ClientContext* context, const StreamMap& stream_map); +}; + +class ServerDataPlane { + public: + using StreamMap = std::map; + using ResultServerStream = arrow::Result>; + + virtual ~ServerDataPlane() = default; + + static arrow::Result> Make(const Location& location); + + ResultServerStream DoGet(const grpc::ServerContext& context) { + return DoGetImpl(ExtractStreamMap(context)); + } + + ResultServerStream DoPut(const grpc::ServerContext& context) { + return DoPutImpl(ExtractStreamMap(context)); + } + + ResultServerStream DoExchange(const grpc::ServerContext& context) { + return DoExchangeImpl(ExtractStreamMap(context)); + } + + private: + virtual ResultServerStream DoGetImpl(const StreamMap& stream_map) { return NULLPTR; } + virtual ResultServerStream DoPutImpl(const StreamMap& stream_map) { return NULLPTR; } + virtual ResultServerStream DoExchangeImpl(const StreamMap& stream_map) { + return NULLPTR; + } + virtual std::vector stream_keys() { return {}; } + + StreamMap ExtractStreamMap(const grpc::ServerContext& context); +}; + +} // namespace internal +} // namespace flight +} // namespace arrow diff --git a/cpp/src/arrow/flight/perf_server.cc b/cpp/src/arrow/flight/perf_server.cc index 7efd034ad25..a8c9b0c812d 100644 --- a/cpp/src/arrow/flight/perf_server.cc +++ b/cpp/src/arrow/flight/perf_server.cc @@ -185,7 +185,7 @@ class FlightPerfServer : public FlightServerBase { std::unique_ptr* data_stream) override { perf::Token token; CHECK_PARSE(token.ParseFromString(request.ticket)); - return GetPerfBatches(token, perf_schema_, false, data_stream); + return GetPerfBatches(token, perf_schema_, /*use_verifier=*/false, data_stream); } Status DoPut(const ServerCallContext& context, diff --git a/cpp/src/arrow/flight/serialization_internal.h b/cpp/src/arrow/flight/serialization_internal.h index 5f7d0cc487c..de4e5469fa4 100644 --- a/cpp/src/arrow/flight/serialization_internal.h +++ b/cpp/src/arrow/flight/serialization_internal.h @@ -87,11 +87,11 @@ bool ReadPayload(grpc::ClientReaderWriter* reader // The Flight reader can then peek at the message to determine whether // it has application metadata or not, and pass the message to // RecordBatchStreamReader as appropriate. -template +template class PeekableFlightDataReader { public: - explicit PeekableFlightDataReader(ReaderPtr stream) - : stream_(stream), peek_(), finished_(false), valid_(false) {} + explicit PeekableFlightDataReader(Reader stream) + : stream_(std::move(stream)), peek_(), finished_(false), valid_(false) {} void Peek(internal::FlightData** out) { *out = nullptr; @@ -132,7 +132,7 @@ class PeekableFlightDataReader { return valid_; } - if (!internal::ReadPayload(&*stream_, &peek_)) { + if (!stream_.Read(&peek_).ok()) { finished_ = true; valid_ = false; } else { @@ -141,7 +141,7 @@ class PeekableFlightDataReader { return valid_; } - ReaderPtr stream_; + Reader stream_; internal::FlightData peek_; bool finished_; bool valid_; diff --git a/cpp/src/arrow/flight/server.cc b/cpp/src/arrow/flight/server.cc index 988bd690d51..517dcff1ecb 100644 --- a/cpp/src/arrow/flight/server.cc +++ b/cpp/src/arrow/flight/server.cc @@ -62,11 +62,7 @@ #include "arrow/flight/server_middleware.h" #include "arrow/flight/types.h" -using FlightService = arrow::flight::protocol::FlightService; -using ServerContext = grpc::ServerContext; - -template -using ServerWriter = grpc::ServerWriter; +#include "arrow/flight/data_plane/types.h" namespace arrow { namespace flight { @@ -96,6 +92,67 @@ namespace pb = arrow::flight::protocol; } \ } while (false) +namespace internal { + +template +Status ReadServer(GrpcServerStream* grpc_stream, DataServerStream* data_stream, + FlightData* payload) { + if (data_stream) { + return data_stream->Read(payload); + } else { + const bool ok = internal::ReadPayload(grpc_stream, payload); + return ok ? Status::OK() : Status::IOError("gRPC server read failed"); + } +} + +template +Status WriteServer(GrpcServerStream* grpc_stream, DataServerStream* data_stream, + const FlightPayload& payload) { + if (data_stream) { + return data_stream->Write(payload); + } else { + return internal::WritePayload(payload, grpc_stream); + } +} + +template +Status WritesDoneServer(GrpcServerStream*, DataServerStream* data_stream) { + if (data_stream) { + RETURN_NOT_OK(data_stream->WritesDone()); + } + // grpc server doesn't implement WritesDone + return Status::OK(); +} + +struct ServerWriter { + using GrpcServerStream = grpc::ServerWriter; + GrpcServerStream* grpc_stream; + DataServerStream* data_stream; + + Status Write(const FlightPayload& payload) { + return WriteServer(grpc_stream, data_stream, payload); + } + Status WritesDone() { return WritesDoneServer(grpc_stream, data_stream); } +}; + +// WriteT can be pb::FlightData or pb::PutResult +template +struct ServerReaderWriter { + using GrpcServerStream = grpc::ServerReaderWriter; + GrpcServerStream* grpc_stream; + DataServerStream* data_stream; + + Status Read(FlightData* payload) { + return ReadServer(grpc_stream, data_stream, payload); + } + Status Write(const FlightPayload& payload) { + return WriteServer(grpc_stream, data_stream, payload); + } + Status WritesDone() { return WritesDoneServer(grpc_stream, data_stream); } +}; + +} // namespace internal + namespace { // A MessageReader implementation that reads from a gRPC ServerReader. @@ -104,7 +161,7 @@ template class FlightIpcMessageReader : public ipc::MessageReader { public: explicit FlightIpcMessageReader( - std::shared_ptr> peekable_reader, + std::shared_ptr> peekable_reader, std::shared_ptr* app_metadata) : peekable_reader_(peekable_reader), app_metadata_(app_metadata) {} @@ -127,7 +184,7 @@ class FlightIpcMessageReader : public ipc::MessageReader { } protected: - std::shared_ptr> peekable_reader_; + std::shared_ptr> peekable_reader_; // A reference to FlightMessageReaderImpl.app_metadata_. That class // can't access the app metadata because when it Peek()s the stream, // it may be looking at a dictionary batch, not the record @@ -138,14 +195,13 @@ class FlightIpcMessageReader : public ipc::MessageReader { bool stream_finished_ = false; }; -template +template class FlightMessageReaderImpl : public FlightMessageReader { public: - using GrpcStream = grpc::ServerReaderWriter; + using ServerStream = internal::ServerReaderWriter; - explicit FlightMessageReaderImpl(GrpcStream* reader) - : reader_(reader), - peekable_reader_(new internal::PeekableFlightDataReader(reader)) {} + explicit FlightMessageReaderImpl(ServerStream reader) + : peekable_reader_(new internal::PeekableFlightDataReader(reader)) {} Status Init() { // Peek the first message to get the descriptor. @@ -209,7 +265,7 @@ class FlightMessageReaderImpl : public FlightMessageReader { return Status::IOError("Client never sent a data message"); } auto message_reader = std::unique_ptr( - new FlightIpcMessageReader(peekable_reader_, &app_metadata_)); + new FlightIpcMessageReader(peekable_reader_, &app_metadata_)); ARROW_ASSIGN_OR_RAISE( batch_reader_, ipc::RecordBatchStreamReader::Open(std::move(message_reader))); } @@ -217,8 +273,7 @@ class FlightMessageReaderImpl : public FlightMessageReader { } FlightDescriptor descriptor_; - GrpcStream* reader_; - std::shared_ptr> peekable_reader_; + std::shared_ptr> peekable_reader_; std::shared_ptr batch_reader_; std::shared_ptr app_metadata_; }; @@ -284,8 +339,9 @@ class GrpcServerAuthSender : public ServerAuthSender { /// stream for DoExchange. class DoExchangeMessageWriter : public FlightMessageWriter { public: - explicit DoExchangeMessageWriter( - grpc::ServerReaderWriter* stream) + using ServerStream = internal::ServerReaderWriter; + + explicit DoExchangeMessageWriter(ServerStream stream) : stream_(stream), ipc_options_(::arrow::ipc::IpcWriteOptions::Defaults()) {} Status Begin(const std::shared_ptr& schema, @@ -328,15 +384,16 @@ class DoExchangeMessageWriter : public FlightMessageWriter { } Status Close() override { - // It's fine to Close() without writing data - return Status::OK(); + // It's fine to Close() without writing data for grpc + // Data plane needs to tell the reader that we've done writing + return stream_.WritesDone(); } ipc::WriteStats stats() const override { return stats_; } private: Status WritePayload(const FlightPayload& payload) { - RETURN_NOT_OK(internal::WritePayload(payload, stream_)); + RETURN_NOT_OK(stream_.Write(payload)); ++stats_.num_messages; return Status::OK(); } @@ -365,7 +422,7 @@ class DoExchangeMessageWriter : public FlightMessageWriter { return Status::OK(); } - grpc::ServerReaderWriter* stream_; + ServerStream stream_; ::arrow::ipc::IpcWriteOptions ipc_options_; ipc::DictionaryFieldMapper mapper_; ipc::WriteStats stats_; @@ -410,7 +467,7 @@ class GrpcServerCallContext : public ServerCallContext { private: friend class FlightServiceImpl; - ServerContext* context_; + grpc::ServerContext* context_; std::string peer_; std::string peer_identity_; std::vector> middleware_; @@ -432,7 +489,7 @@ class GrpcAddCallHeaders : public AddCallHeaders { // This class glues an implementation of FlightServerBase together with the // gRPC service definition, so the latter is not exposed in the public API -class FlightServiceImpl : public FlightService::Service { +class FlightServiceImpl : public protocol::FlightService::Service { public: explicit FlightServiceImpl( std::shared_ptr auth_handler, @@ -442,7 +499,7 @@ class FlightServiceImpl : public FlightService::Service { : auth_handler_(auth_handler), middleware_(middleware), server_(server) {} template - grpc::Status WriteStream(Iterator* iterator, ServerWriter* writer) { + grpc::Status WriteStream(Iterator* iterator, grpc::ServerWriter* writer) { if (!iterator) { return grpc::Status(grpc::StatusCode::INTERNAL, "No items to iterate"); } @@ -467,7 +524,7 @@ class FlightServiceImpl : public FlightService::Service { template grpc::Status WriteStream(const std::vector& values, - ServerWriter* writer) { + grpc::ServerWriter* writer) { // Write flight info to stream until listing is exhausted for (const UserType& value : values) { ProtoType pb_value; @@ -482,7 +539,7 @@ class FlightServiceImpl : public FlightService::Service { } // Authenticate the client (if applicable) and construct the call context - grpc::Status CheckAuth(const FlightMethod& method, ServerContext* context, + grpc::Status CheckAuth(const FlightMethod& method, grpc::ServerContext* context, GrpcServerCallContext& flight_context) { if (!auth_handler_) { const auto auth_context = context->auth_context(); @@ -496,14 +553,7 @@ class FlightServiceImpl : public FlightService::Service { flight_context.peer_identity_ = ""; } } else { - const auto client_metadata = context->client_metadata(); - const auto auth_header = client_metadata.find(internal::kGrpcAuthHeader); - std::string token; - if (auth_header == client_metadata.end()) { - token = ""; - } else { - token = std::string(auth_header->second.data(), auth_header->second.length()); - } + const std::string token = GetMetadata(context, internal::kGrpcAuthHeader); GRPC_RETURN_NOT_OK(auth_handler_->IsValid(token, &flight_context.peer_identity_)); } @@ -511,7 +561,7 @@ class FlightServiceImpl : public FlightService::Service { } // Authenticate the client (if applicable) and construct the call context - grpc::Status MakeCallContext(const FlightMethod& method, ServerContext* context, + grpc::Status MakeCallContext(const FlightMethod& method, grpc::ServerContext* context, GrpcServerCallContext& flight_context) { // Run server middleware const CallInfo info{method}; @@ -542,7 +592,7 @@ class FlightServiceImpl : public FlightService::Service { } grpc::Status Handshake( - ServerContext* context, + grpc::ServerContext* context, grpc::ServerReaderWriter* stream) { GrpcServerCallContext flight_context(context); GRPC_RETURN_NOT_GRPC_OK( @@ -561,8 +611,8 @@ class FlightServiceImpl : public FlightService::Service { auth_handler_->Authenticate(&outgoing, &incoming)); } - grpc::Status ListFlights(ServerContext* context, const pb::Criteria* request, - ServerWriter* writer) { + grpc::Status ListFlights(grpc::ServerContext* context, const pb::Criteria* request, + grpc::ServerWriter* writer) { GrpcServerCallContext flight_context(context); GRPC_RETURN_NOT_GRPC_OK( CheckAuth(FlightMethod::ListFlights, context, flight_context)); @@ -584,7 +634,8 @@ class FlightServiceImpl : public FlightService::Service { WriteStream(listing.get(), writer)); } - grpc::Status GetFlightInfo(ServerContext* context, const pb::FlightDescriptor* request, + grpc::Status GetFlightInfo(grpc::ServerContext* context, + const pb::FlightDescriptor* request, pb::FlightInfo* response) { GrpcServerCallContext flight_context(context); GRPC_RETURN_NOT_GRPC_OK( @@ -609,7 +660,8 @@ class FlightServiceImpl : public FlightService::Service { RETURN_WITH_MIDDLEWARE(flight_context, grpc::Status::OK); } - grpc::Status GetSchema(ServerContext* context, const pb::FlightDescriptor* request, + grpc::Status GetSchema(grpc::ServerContext* context, + const pb::FlightDescriptor* request, pb::SchemaResult* response) { GrpcServerCallContext flight_context(context); GRPC_RETURN_NOT_GRPC_OK(CheckAuth(FlightMethod::GetSchema, context, flight_context)); @@ -633,9 +685,12 @@ class FlightServiceImpl : public FlightService::Service { RETURN_WITH_MIDDLEWARE(flight_context, grpc::Status::OK); } - grpc::Status DoGet(ServerContext* context, const pb::Ticket* request, - ServerWriter* writer) { + grpc::Status DoGet(grpc::ServerContext* context, const pb::Ticket* request, + grpc::ServerWriter* grpc_writer) { GrpcServerCallContext flight_context(context); + std::unique_ptr data_writer; + SERVICE_RETURN_NOT_OK(flight_context, + server_->data_plane()->DoGet(*context).Value(&data_writer)); GRPC_RETURN_NOT_GRPC_OK(CheckAuth(FlightMethod::DoGet, context, flight_context)); CHECK_ARG_NOT_NULL(flight_context, request, "ticket cannot be null"); @@ -655,7 +710,8 @@ class FlightServiceImpl : public FlightService::Service { // Write the schema as the first message in the stream FlightPayload schema_payload; SERVICE_RETURN_NOT_OK(flight_context, data_stream->GetSchemaPayload(&schema_payload)); - auto status = internal::WritePayload(schema_payload, writer); + auto writer = internal::ServerWriter{grpc_writer, data_writer.get()}; + auto status = writer.Write(schema_payload); if (status.IsIOError()) { // gRPC doesn't give any way for us to know why the message // could not be written. @@ -668,8 +724,11 @@ class FlightServiceImpl : public FlightService::Service { FlightPayload payload; SERVICE_RETURN_NOT_OK(flight_context, data_stream->Next(&payload)); // End of stream - if (payload.ipc_message.metadata == nullptr) break; - auto status = internal::WritePayload(payload, writer); + if (payload.ipc_message.metadata == nullptr) { + SERVICE_RETURN_NOT_OK(flight_context, writer.WritesDone()); + break; + } + auto status = writer.Write(payload); // Connection terminated if (status.IsIOError()) break; SERVICE_RETURN_NOT_OK(flight_context, status); @@ -677,26 +736,38 @@ class FlightServiceImpl : public FlightService::Service { RETURN_WITH_MIDDLEWARE(flight_context, grpc::Status::OK); } - grpc::Status DoPut(ServerContext* context, - grpc::ServerReaderWriter* reader) { + grpc::Status DoPut( + grpc::ServerContext* context, + grpc::ServerReaderWriter* grpc_reader) { GrpcServerCallContext flight_context(context); + std::unique_ptr data_reader; + SERVICE_RETURN_NOT_OK(flight_context, + server_->data_plane()->DoPut(*context).Value(&data_reader)); GRPC_RETURN_NOT_GRPC_OK(CheckAuth(FlightMethod::DoPut, context, flight_context)); + auto reader = + internal::ServerReaderWriter{grpc_reader, data_reader.get()}; auto message_reader = std::unique_ptr>( new FlightMessageReaderImpl(reader)); SERVICE_RETURN_NOT_OK(flight_context, message_reader->Init()); auto metadata_writer = - std::unique_ptr(new GrpcMetadataWriter(reader)); + std::unique_ptr(new GrpcMetadataWriter(grpc_reader)); RETURN_WITH_MIDDLEWARE(flight_context, server_->DoPut(flight_context, std::move(message_reader), std::move(metadata_writer))); } grpc::Status DoExchange( - ServerContext* context, - grpc::ServerReaderWriter* stream) { + grpc::ServerContext* context, + grpc::ServerReaderWriter* grpc_stream) { GrpcServerCallContext flight_context(context); + std::unique_ptr data_stream; + SERVICE_RETURN_NOT_OK( + flight_context, server_->data_plane()->DoExchange(*context).Value(&data_stream)); GRPC_RETURN_NOT_GRPC_OK(CheckAuth(FlightMethod::DoExchange, context, flight_context)); + + auto stream = + internal::ServerReaderWriter{grpc_stream, data_stream.get()}; auto message_reader = std::unique_ptr>( new FlightMessageReaderImpl(stream)); SERVICE_RETURN_NOT_OK(flight_context, message_reader->Init()); @@ -707,8 +778,8 @@ class FlightServiceImpl : public FlightService::Service { std::move(writer))); } - grpc::Status ListActions(ServerContext* context, const pb::Empty* request, - ServerWriter* writer) { + grpc::Status ListActions(grpc::ServerContext* context, const pb::Empty* request, + grpc::ServerWriter* writer) { GrpcServerCallContext flight_context(context); GRPC_RETURN_NOT_GRPC_OK( CheckAuth(FlightMethod::ListActions, context, flight_context)); @@ -718,8 +789,8 @@ class FlightServiceImpl : public FlightService::Service { RETURN_WITH_MIDDLEWARE(flight_context, WriteStream(types, writer)); } - grpc::Status DoAction(ServerContext* context, const pb::Action* request, - ServerWriter* writer) { + grpc::Status DoAction(grpc::ServerContext* context, const pb::Action* request, + grpc::ServerWriter* writer) { GrpcServerCallContext flight_context(context); GRPC_RETURN_NOT_GRPC_OK(CheckAuth(FlightMethod::DoAction, context, flight_context)); CHECK_ARG_NOT_NULL(flight_context, request, "Action cannot be null"); @@ -752,6 +823,18 @@ class FlightServiceImpl : public FlightService::Service { } private: + std::string GetMetadata(const grpc::ServerContext* context, const char* key) { + const auto client_metadata = context->client_metadata(); + const auto auth_header = client_metadata.find(key); + std::string token; + if (auth_header == client_metadata.end()) { + token = ""; + } else { + token = std::string(auth_header->second.data(), auth_header->second.length()); + } + return token; + } + std::shared_ptr auth_handler_; std::vector>> middleware_; @@ -844,6 +927,7 @@ class ServerSignalHandler { struct FlightServerBase::Impl { std::unique_ptr service_; std::unique_ptr server_; + std::unique_ptr data_plane_; int port_; // Signal handlers (on Windows) and the shutdown handler (other platforms) @@ -909,6 +993,9 @@ FlightServerBase::FlightServerBase() { impl_.reset(new Impl); } FlightServerBase::~FlightServerBase() {} Status FlightServerBase::Init(const FlightServerOptions& options) { + ARROW_ASSIGN_OR_RAISE(impl_->data_plane_, + internal::ServerDataPlane::Make(options.location)); + impl_->service_.reset( new FlightServiceImpl(options.auth_handler, options.middleware, this)); @@ -964,11 +1051,16 @@ Status FlightServerBase::Init(const FlightServerOptions& options) { if (!impl_->server_) { return Status::UnknownError("Server did not start properly"); } + return Status::OK(); } int FlightServerBase::port() const { return impl_->port_; } +internal::ServerDataPlane* FlightServerBase::data_plane() const { + return impl_->data_plane_.get(); +} + Status FlightServerBase::SetShutdownOnSignals(const std::vector sigs) { impl_->signals_ = sigs; impl_->old_signal_handlers_.clear(); diff --git a/cpp/src/arrow/flight/server.h b/cpp/src/arrow/flight/server.h index 218b640adc8..a38b3f90529 100644 --- a/cpp/src/arrow/flight/server.h +++ b/cpp/src/arrow/flight/server.h @@ -43,6 +43,10 @@ namespace flight { class ServerMiddleware; class ServerMiddlewareFactory; +namespace internal { +class ServerDataPlane; +} // namespace internal + /// \brief Interface that produces a sequence of IPC payloads to be sent in /// FlightData protobuf messages class ARROW_FLIGHT_EXPORT FlightDataStream { @@ -180,6 +184,10 @@ class ARROW_FLIGHT_EXPORT FlightServerBase { /// domain socket). int port() const; + /// \brief Get the data plane of the Flight server. + /// This method must only be called after Init(). + internal::ServerDataPlane* data_plane() const; + /// \brief Set the server to stop when receiving any of the given signal /// numbers. /// This method must be called before Serve().