diff --git a/cpp/cmake_modules/DefineOptions.cmake b/cpp/cmake_modules/DefineOptions.cmake index 0a43ec18f60..fcfc4f3f8ae 100644 --- a/cpp/cmake_modules/DefineOptions.cmake +++ b/cpp/cmake_modules/DefineOptions.cmake @@ -234,6 +234,8 @@ if("${CMAKE_SOURCE_DIR}" STREQUAL "${CMAKE_CURRENT_SOURCE_DIR}") define_option(ARROW_FLIGHT_SQL "Build the Arrow Flight SQL extension" OFF) + define_option(ARROW_FLIGHT_DP_SHM "Build the Arrow Flight shared memory data plane" OFF) + define_option(ARROW_GANDIVA "Build the Gandiva libraries" OFF) define_option(ARROW_GCS diff --git a/cpp/src/arrow/flight/CMakeLists.txt b/cpp/src/arrow/flight/CMakeLists.txt index 2cf8c9913e5..5dcbf32631c 100644 --- a/cpp/src/arrow/flight/CMakeLists.txt +++ b/cpp/src/arrow/flight/CMakeLists.txt @@ -151,6 +151,13 @@ endif() # Restore the CXXFLAGS that were modified above set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS_BACKUP}") +# Data plane source files +# TODO(yibo): separate common (serialize.cc) and driver (shm.cc) +if(ARROW_FLIGHT_DP_SHM) + add_definitions(-DFLIGHT_DP_SHM) + set(DATAPLANE_SRCS data_plane/serialize.cc data_plane/shm.cc) +endif() + # Note, we do not compile the generated Protobuf sources directly, instead # compiling then via protocol_internal.cc which contains some gRPC template # overrides to enable Flight-specific optimizations. See comments in @@ -164,7 +171,8 @@ set(ARROW_FLIGHT_SRCS serialization_internal.cc server.cc server_auth.cc - types.cc) + types.cc + data_plane/types.cc) add_arrow_lib(arrow_flight CMAKE_PACKAGE_NAME @@ -175,6 +183,7 @@ add_arrow_lib(arrow_flight ARROW_FLIGHT_LIBRARIES SOURCES ${ARROW_FLIGHT_SRCS} + ${DATAPLANE_SRCS} PRECOMPILED_HEADERS "$<$:arrow/flight/pch.h>" DEPENDENCIES diff --git a/cpp/src/arrow/flight/client.cc b/cpp/src/arrow/flight/client.cc index f9728f849ad..fed06bd5fd3 100644 --- a/cpp/src/arrow/flight/client.cc +++ b/cpp/src/arrow/flight/client.cc @@ -59,6 +59,8 @@ #include "arrow/flight/serialization_internal.h" #include "arrow/flight/types.h" +#include "arrow/flight/data_plane/types.h" + namespace arrow { namespace flight { @@ -127,6 +129,93 @@ struct ClientRpc { } }; +namespace internal { + +template +Status FinishClient(GrpcClientStream* grpc_stream, DataClientStream* data_stream, + ClientRpc* rpc) { + Status st1, st2; + if (data_stream) { + st1 = data_stream->Finish(); + } + st2 = internal::FromGrpcStatus(grpc_stream->Finish(), &rpc->context); + return st2.ok() ? st1 : st2; +} + +void CancelClient(DataClientStream* data_stream, ClientRpc* rpc) { + rpc->context.TryCancel(); + if (data_stream) { + data_stream->TryCancel(); + } +} + +template +Status ReadClient(GrpcClientStream* grpc_stream, DataClientStream* data_stream, + FlightData* payload) { + if (data_stream) { + return data_stream->Read(payload); + } else { + const bool ok = internal::ReadPayload(grpc_stream, payload); + return ok ? Status::OK() : Status::IOError("gRPC client read failed"); + } +} + +template +Status WriteClient(GrpcClientStream* grpc_stream, DataClientStream* data_stream, + const FlightPayload& payload) { + if (data_stream) { + return data_stream->Write(payload); + } else { + return internal::WritePayload(payload, grpc_stream); + } +} + +template +Status WritesDoneClient(GrpcClientStream* grpc_stream, DataClientStream* data_stream) { + const bool ok = grpc_stream->WritesDone(); + Status st; + if (data_stream) { + st = data_stream->WritesDone(); + } + return ok ? st : Status::IOError("gRPC WritesDone failed"); +} + +struct ClientReader { + using GrpcClientStream = grpc::ClientReader; + std::shared_ptr grpc_stream; + std::shared_ptr data_stream; + + Status Read(FlightData* payload) { + return ReadClient(grpc_stream.get(), data_stream.get(), payload); + } + Status Finish(ClientRpc* rpc) { + return FinishClient(grpc_stream.get(), data_stream.get(), rpc); + } + void Cancel(ClientRpc* rpc) { CancelClient(data_stream.get(), rpc); } +}; + +// ReadT can be pb::FlightData or pb::PutResult +template +struct ClientReaderWriter { + using GrpcClientStream = grpc::ClientReaderWriter; + std::shared_ptr grpc_stream; + std::shared_ptr data_stream; + + Status Read(FlightData* payload) { + return ReadClient(grpc_stream.get(), data_stream.get(), payload); + } + Status Write(const FlightPayload& payload) { + return WriteClient(grpc_stream.get(), data_stream.get(), payload); + } + Status WritesDone() { return WritesDoneClient(grpc_stream.get(), data_stream.get()); } + Status Finish(ClientRpc* rpc) { + return FinishClient(grpc_stream.get(), data_stream.get(), rpc); + } + void Cancel(ClientRpc* rpc) { CancelClient(data_stream.get(), rpc); } +}; + +} // namespace internal + /// Helper that manages Finish() of a gRPC stream. /// /// When we encounter an error (e.g. could not decode an IPC message), @@ -139,15 +228,21 @@ struct ClientRpc { /// /// The template lets us abstract between DoGet/DoExchange and DoPut, /// which respectively read internal::FlightData and pb::PutResult. +// +// Stream contains two shared_ptrs to grpc stream and possibly an +// additional data stream. See internal::ClientReader for details. template class FinishableStream { public: - FinishableStream(std::shared_ptr rpc, std::shared_ptr stream) - : rpc_(rpc), stream_(stream), finished_(false), server_status_() {} + FinishableStream(std::shared_ptr rpc, Stream stream) + : rpc_(std::move(rpc)), + stream_(std::move(stream)), + finished_(false), + server_status_() {} virtual ~FinishableStream() = default; /// \brief Get the underlying stream. - std::shared_ptr stream() const { return stream_; } + Stream stream() const { return stream_; } /// \brief Finish the call, adding server context to the given status. virtual Status Finish(Status st) { @@ -162,11 +257,11 @@ class FinishableStream { // indicate that it is done writing, but not done reading, it // should use DoneWriting. ReadT message; - while (internal::ReadPayload(stream_.get(), &message)) { + while (internal::ReadPayload(stream_.grpc_stream.get(), &message)) { // Drain the read side to avoid gRPC hanging in Finish() } - server_status_ = internal::FromGrpcStatus(stream_->Finish(), &rpc_->context); + server_status_ = stream_.Finish(rpc_.get()); finished_ = true; return MergeStatus(std::move(st)); @@ -184,7 +279,7 @@ class FinishableStream { } std::shared_ptr rpc_; - std::shared_ptr stream_; + Stream stream_; bool finished_; Status server_status_; }; @@ -197,9 +292,8 @@ template class FinishableWritableStream : public FinishableStream { public: FinishableWritableStream(std::shared_ptr rpc, - std::shared_ptr read_mutex, - std::shared_ptr stream) - : FinishableStream(rpc, stream), + std::shared_ptr read_mutex, Stream stream) + : FinishableStream(std::move(rpc), std::move(stream)), finish_mutex_(), read_mutex_(read_mutex), done_writing_(false) {} @@ -213,7 +307,7 @@ class FinishableWritableStream : public FinishableStream { return Status::OK(); } done_writing_ = true; - if (!this->stream()->WritesDone()) { + if (!this->stream().WritesDone().ok()) { // Error happened, try to close the stream to get more detailed info return Finish(MakeFlightError(FlightStatusCode::Internal, "Could not flush pending record batches")); @@ -239,7 +333,7 @@ class FinishableWritableStream : public FinishableStream { // Try to flush pending writes. Don't use our WritesDone() to // avoid recursion. - bool finished_writes = done_writing_ || this->stream()->WritesDone(); + bool finished_writes = done_writing_ || this->stream().WritesDone().ok(); done_writing_ = true; st = FinishableStream::Finish(std::move(st)); @@ -436,8 +530,7 @@ class GrpcIpcMessageReader : public ipc::MessageReader { GrpcIpcMessageReader( std::shared_ptr rpc, std::shared_ptr read_mutex, std::shared_ptr> stream, - std::shared_ptr>> - peekable_reader, + std::shared_ptr> peekable_reader, std::shared_ptr* app_metadata) : rpc_(rpc), read_mutex_(read_mutex), @@ -476,8 +569,7 @@ class GrpcIpcMessageReader : public ipc::MessageReader { // side calls Finish(). Nullable as DoGet doesn't need this. std::shared_ptr read_mutex_; std::shared_ptr> stream_; - std::shared_ptr>> - peekable_reader_; + std::shared_ptr> peekable_reader_; // A reference to GrpcStreamReader.app_metadata_. That class // can't access the app metadata because when it Peek()s the stream, // it may be looking at a dictionary batch, not the record @@ -490,18 +582,19 @@ class GrpcIpcMessageReader : public ipc::MessageReader { /// The implementation of the public-facing API for reading from a /// FlightData stream template -class GrpcStreamReader : public FlightStreamReader { +class ClientStreamReader : public FlightStreamReader { public: - GrpcStreamReader(std::shared_ptr rpc, std::shared_ptr read_mutex, - const ipc::IpcReadOptions& options, StopToken stop_token, - std::shared_ptr> stream) + ClientStreamReader( + std::shared_ptr rpc, std::shared_ptr read_mutex, + const ipc::IpcReadOptions& options, StopToken stop_token, + std::shared_ptr> stream) : rpc_(rpc), read_mutex_(read_mutex), options_(options), stop_token_(std::move(stop_token)), stream_(stream), - peekable_reader_(new internal::PeekableFlightDataReader>( - stream->stream())), + peekable_reader_( + new internal::PeekableFlightDataReader(stream->stream())), app_metadata_(nullptr) {} Status EnsureDataStarted() { @@ -584,7 +677,7 @@ class GrpcStreamReader : public FlightStreamReader { return ReadAll(table, stop_token_); } using FlightStreamReader::ReadAll; - void Cancel() override { rpc_->context.TryCancel(); } + void Cancel() override { stream_->stream().Cancel(rpc_.get()); } private: std::unique_lock TakeGuard() { @@ -608,8 +701,7 @@ class GrpcStreamReader : public FlightStreamReader { ipc::IpcReadOptions options_; StopToken stop_token_; std::shared_ptr> stream_; - std::shared_ptr>> - peekable_reader_; + std::shared_ptr> peekable_reader_; std::shared_ptr batch_reader_; std::shared_ptr app_metadata_; }; @@ -628,16 +720,16 @@ template class DoPutPayloadWriter; template -class GrpcStreamWriter : public FlightStreamWriter { +class ClientStreamWriter : public FlightStreamWriter { public: - ~GrpcStreamWriter() override = default; + ~ClientStreamWriter() override = default; - using GrpcStream = grpc::ClientReaderWriter; + using ClientStream = internal::ClientReaderWriter; - explicit GrpcStreamWriter( + explicit ClientStreamWriter( const FlightDescriptor& descriptor, std::shared_ptr rpc, int64_t write_size_limit_bytes, const ipc::IpcWriteOptions& options, - std::shared_ptr> writer) + std::shared_ptr> writer) : app_metadata_(nullptr), batch_writer_(nullptr), writer_(std::move(writer)), @@ -651,7 +743,7 @@ class GrpcStreamWriter : public FlightStreamWriter { const FlightDescriptor& descriptor, std::shared_ptr schema, const ipc::IpcWriteOptions& options, std::shared_ptr rpc, int64_t write_size_limit_bytes, - std::shared_ptr> writer, + std::shared_ptr> writer, std::unique_ptr* out); Status CheckStarted() { @@ -688,7 +780,7 @@ class GrpcStreamWriter : public FlightStreamWriter { Status WriteMetadata(std::shared_ptr app_metadata) override { FlightPayload payload{}; payload.app_metadata = app_metadata; - auto status = internal::WritePayload(payload, writer_->stream().get()); + auto status = writer_->stream().Write(payload); if (status.IsIOError()) { return writer_->Finish(MakeFlightError(FlightStatusCode::Internal, "Could not write metadata to stream")); @@ -741,7 +833,7 @@ class GrpcStreamWriter : public FlightStreamWriter { friend class DoPutPayloadWriter; std::shared_ptr app_metadata_; std::unique_ptr batch_writer_; - std::shared_ptr> writer_; + std::shared_ptr> writer_; // Fields used to lazy-initialize the IpcPayloadWriter. They're // invalid once Begin() is called. @@ -757,13 +849,13 @@ class GrpcStreamWriter : public FlightStreamWriter { template class DoPutPayloadWriter : public ipc::internal::IpcPayloadWriter { public: - using GrpcStream = grpc::ClientReaderWriter; + using ClientStream = internal::ClientReaderWriter; DoPutPayloadWriter( const FlightDescriptor& descriptor, std::shared_ptr rpc, int64_t write_size_limit_bytes, - std::shared_ptr> writer, - GrpcStreamWriter* stream_writer) + std::shared_ptr> writer, + ClientStreamWriter* stream_writer) : descriptor_(descriptor), rpc_(rpc), write_size_limit_bytes_(write_size_limit_bytes), @@ -809,7 +901,7 @@ class DoPutPayloadWriter : public ipc::internal::IpcPayloadWriter { } } - auto status = internal::WritePayload(payload, writer_->stream().get()); + auto status = writer_->stream().Write(payload); if (status.IsIOError()) { return writer_->Finish(MakeFlightError(FlightStatusCode::Internal, "Could not write record batch to stream")); @@ -826,21 +918,21 @@ class DoPutPayloadWriter : public ipc::internal::IpcPayloadWriter { const FlightDescriptor descriptor_; std::shared_ptr rpc_; int64_t write_size_limit_bytes_; - std::shared_ptr> writer_; + std::shared_ptr> writer_; bool first_payload_; - GrpcStreamWriter* stream_writer_; + ClientStreamWriter* stream_writer_; }; template -Status GrpcStreamWriter::Open( +Status ClientStreamWriter::Open( const FlightDescriptor& descriptor, std::shared_ptr schema, // this schema is nullable const ipc::IpcWriteOptions& options, std::shared_ptr rpc, int64_t write_size_limit_bytes, - std::shared_ptr> writer, + std::shared_ptr> writer, std::unique_ptr* out) { - std::unique_ptr> instance( - new GrpcStreamWriter( + std::unique_ptr> instance( + new ClientStreamWriter( descriptor, std::move(rpc), write_size_limit_bytes, options, writer)); if (schema) { // The schema was provided (DoPut). Eagerly write the schema and @@ -852,7 +944,7 @@ Status GrpcStreamWriter::Open( // calls Begin() to send data, we'll send a redundant descriptor. FlightPayload payload{}; RETURN_NOT_OK(internal::ToPayload(descriptor, &payload.descriptor)); - auto status = internal::WritePayload(payload, instance->writer_->stream().get()); + auto status = instance->writer_->stream().Write(payload); if (status.IsIOError()) { return writer->Finish(MakeFlightError(FlightStatusCode::Internal, "Could not write descriptor to stream")); @@ -917,6 +1009,8 @@ constexpr char kDummyRootCert[] = class FlightClient::FlightClientImpl { public: Status Connect(const Location& location, const FlightClientOptions& options) { + ARROW_ASSIGN_OR_RAISE(data_plane_, internal::ClientDataPlane::Make(location)); + const std::string& scheme = location.scheme(); std::stringstream grpc_uri; @@ -1194,19 +1288,25 @@ class FlightClient::FlightClientImpl { Status DoGet(const FlightCallOptions& options, const Ticket& ticket, std::unique_ptr* out) { - using StreamReader = GrpcStreamReader>; + using ClientStream = internal::ClientReader; + using StreamReader = ClientStreamReader; pb::Ticket pb_ticket; internal::ToProto(ticket, &pb_ticket); auto rpc = std::make_shared(options); RETURN_NOT_OK(rpc->SetToken(auth_handler_.get())); - std::shared_ptr> stream = + + ARROW_ASSIGN_OR_RAISE(auto data_stream, data_plane_->DoGet(&rpc->context)); + + std::shared_ptr> grpc_stream = stub_->DoGet(&rpc->context, pb_ticket); - auto finishable_stream = std::make_shared< - FinishableStream, internal::FlightData>>( - rpc, stream); - *out = std::unique_ptr(new StreamReader( - rpc, nullptr, options.read_options, options.stop_token, finishable_stream)); + auto stream = ClientStream{std::move(grpc_stream), std::move(data_stream)}; + auto finishable_stream = + std::make_shared>( + rpc, std::move(stream)); + *out = std::unique_ptr( + new StreamReader(std::move(rpc), nullptr, options.read_options, + options.stop_token, std::move(finishable_stream))); // Eagerly read the schema return static_cast(out->get())->EnsureDataStarted(); } @@ -1215,51 +1315,62 @@ class FlightClient::FlightClientImpl { const std::shared_ptr& schema, std::unique_ptr* out, std::unique_ptr* reader) { - using GrpcStream = grpc::ClientReaderWriter; - using StreamWriter = GrpcStreamWriter; + using ClientStream = internal::ClientReaderWriter; + using StreamWriter = ClientStreamWriter; auto rpc = std::make_shared(options); RETURN_NOT_OK(rpc->SetToken(auth_handler_.get())); - std::shared_ptr stream = stub_->DoPut(&rpc->context); + + ARROW_ASSIGN_OR_RAISE(auto data_stream, data_plane_->DoPut(&rpc->context)); + + std::shared_ptr> grpc_stream = + stub_->DoPut(&rpc->context); + auto stream = ClientStream{grpc_stream, std::move(data_stream)}; // The writer drains the reader on close to avoid hanging inside // gRPC. Concurrent reads are unsafe, so a mutex protects this operation. - std::shared_ptr read_mutex = std::make_shared(); + auto read_mutex = std::make_shared(); auto finishable_stream = - std::make_shared>( - rpc, read_mutex, stream); - *reader = - std::unique_ptr(new GrpcMetadataReader(stream, read_mutex)); - return StreamWriter::Open(descriptor, schema, options.write_options, rpc, - write_size_limit_bytes_, finishable_stream, out); + std::make_shared>( + rpc, read_mutex, std::move(stream)); + *reader = std::unique_ptr( + new GrpcMetadataReader(grpc_stream, read_mutex)); + return StreamWriter::Open(descriptor, schema, options.write_options, std::move(rpc), + write_size_limit_bytes_, std::move(finishable_stream), out); } Status DoExchange(const FlightCallOptions& options, const FlightDescriptor& descriptor, std::unique_ptr* writer, std::unique_ptr* reader) { - using GrpcStream = grpc::ClientReaderWriter; - using StreamReader = GrpcStreamReader; - using StreamWriter = GrpcStreamWriter; + using ClientStream = internal::ClientReaderWriter; + using StreamReader = ClientStreamReader; + using StreamWriter = ClientStreamWriter; auto rpc = std::make_shared(options); RETURN_NOT_OK(rpc->SetToken(auth_handler_.get())); - std::shared_ptr> stream = - stub_->DoExchange(&rpc->context); + + ARROW_ASSIGN_OR_RAISE(auto data_stream, data_plane_->DoExchange(&rpc->context)); + + std::shared_ptr> + grpc_stream = stub_->DoExchange(&rpc->context); + auto stream = ClientStream{std::move(grpc_stream), std::move(data_stream)}; // The writer drains the reader on close to avoid hanging inside // gRPC. Concurrent reads are unsafe, so a mutex protects this operation. - std::shared_ptr read_mutex = std::make_shared(); + auto read_mutex = std::make_shared(); auto finishable_stream = - std::make_shared>( - rpc, read_mutex, stream); + std::make_shared>( + rpc, read_mutex, std::move(stream)); *reader = std::unique_ptr(new StreamReader( rpc, read_mutex, options.read_options, options.stop_token, finishable_stream)); // Do not eagerly read the schema. There may be metadata messages // before any data is sent, or data may not be sent at all. - return StreamWriter::Open(descriptor, nullptr, options.write_options, rpc, - write_size_limit_bytes_, finishable_stream, writer); + return StreamWriter::Open(descriptor, nullptr, options.write_options, std::move(rpc), + write_size_limit_bytes_, std::move(finishable_stream), + writer); } private: std::unique_ptr stub_; + std::unique_ptr data_plane_; std::shared_ptr auth_handler_; #if defined(GRPC_NAMESPACE_FOR_TLS_CREDENTIALS_OPTIONS) // Scope the TlsServerAuthorizationCheckConfig to be at the class instance level, since diff --git a/cpp/src/arrow/flight/data_plane/internal.h b/cpp/src/arrow/flight/data_plane/internal.h new file mode 100644 index 00000000000..96117f8709a --- /dev/null +++ b/cpp/src/arrow/flight/data_plane/internal.h @@ -0,0 +1,45 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#pragma once + +#include +#include + +#include "arrow/result.h" + +namespace arrow { +namespace flight { +namespace internal { + +class ClientDataPlane; +class ServerDataPlane; + +enum class StreamType { kGet, kPut, kExchange }; + +struct DataPlaneMaker { + arrow::Result> (*make_client)(const std::string&); + arrow::Result> (*make_server)(const std::string&); +}; + +// data plane makers are defined in data plane drivers +DataPlaneMaker GetShmDataPlaneMaker(); +DataPlaneMaker GetUcxDataPlaneMaker(); + +} // namespace internal +} // namespace flight +} // namespace arrow diff --git a/cpp/src/arrow/flight/data_plane/serialize.cc b/cpp/src/arrow/flight/data_plane/serialize.cc new file mode 100644 index 00000000000..2ee2b5cf608 --- /dev/null +++ b/cpp/src/arrow/flight/data_plane/serialize.cc @@ -0,0 +1,106 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#include "arrow/flight/data_plane/serialize.h" +#include "arrow/flight/customize_protobuf.h" +#include "arrow/flight/internal.h" +#include "arrow/flight/serialization_internal.h" +#include "arrow/result.h" +#include "arrow/util/logging.h" +#include "arrow/util/make_unique.h" + +#include +#include + +#include + +namespace arrow { +namespace flight { +namespace internal { + +namespace { + +void ReleaseBuffer(void* buffer_ptr) { + delete reinterpret_cast*>(buffer_ptr); +} + +} // namespace + +Status Deserialize(std::shared_ptr buffer, FlightData* data) { + // hold the buffer + std::shared_ptr* buffer_ptr = new std::shared_ptr(std::move(buffer)); + + const grpc::Slice slice((*buffer_ptr)->mutable_data(), + static_cast((*buffer_ptr)->size()), &ReleaseBuffer, + buffer_ptr); + grpc::ByteBuffer bbuf = grpc::ByteBuffer(&slice, 1); + + { + // make sure GrpcBuffer::Wrap goes the wanted path + auto grpc_bbuf = *reinterpret_cast(&bbuf); + DCHECK_EQ(grpc_bbuf->type, GRPC_BB_RAW); + DCHECK_EQ(grpc_bbuf->data.raw.compression, GRPC_COMPRESS_NONE); + DCHECK_EQ(grpc_bbuf->data.raw.slice_buffer.count, 1); + grpc_slice slice = grpc_bbuf->data.raw.slice_buffer.slices[0]; + DCHECK_NE(slice.refcount, 0); + } + + // buffer ownership is transferred to "data" on success + const Status st = FromGrpcStatus(FlightDataDeserialize(&bbuf, data)); + if (!st.ok()) { + delete buffer_ptr; + } + return st; +} + +SerializeSlice::SerializeSlice(grpc::Slice&& slice) { + slice_ = arrow::internal::make_unique(std::move(slice)); +} +SerializeSlice::SerializeSlice(SerializeSlice&&) = default; +SerializeSlice::~SerializeSlice() = default; + +const uint8_t* SerializeSlice::data() const { return slice_->begin(); } +int64_t SerializeSlice::size() const { return static_cast(slice_->size()); } + +arrow::Result> Serialize(const FlightPayload& payload, + int64_t* total_size) { + RETURN_NOT_OK(payload.Validate()); + + grpc::ByteBuffer bbuf; + bool owner; + RETURN_NOT_OK(FromGrpcStatus(FlightDataSerialize(payload, &bbuf, &owner))); + + if (total_size) { + *total_size = static_cast(bbuf.Length()); + } + + // ByteBuffer::Dump doesn't copy data buffer, IIUC + std::vector grpc_slices; + RETURN_NOT_OK(FromGrpcStatus(bbuf.Dump(&grpc_slices))); + + // move grpc slice life cycle to returned serialize slice + std::vector slices; + for (auto& grpc_slice : grpc_slices) { + SerializeSlice slice(std::move(grpc_slice)); + slices.emplace_back(std::move(slice)); + } + return slices; +} + +} // namespace internal +} // namespace flight +} // namespace arrow diff --git a/cpp/src/arrow/flight/data_plane/serialize.h b/cpp/src/arrow/flight/data_plane/serialize.h new file mode 100644 index 00000000000..9c9496bd8d9 --- /dev/null +++ b/cpp/src/arrow/flight/data_plane/serialize.h @@ -0,0 +1,92 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#pragma once + +#include "arrow/buffer.h" +#include "arrow/result.h" + +#include +#include + +namespace grpc { + +class Slice; + +}; // namespace grpc + +namespace arrow { +namespace flight { + +struct FlightPayload; + +namespace internal { + +struct FlightData; + +// Reader buffer management +// # data plane receives data and creates/stores to one reader buffer +// # Deserialize() creates a new shared_ptr to hold the buffer, creates a +// gprc slice per that buffer, and installs destroyer callback, which +// deletes the created shared_ptr when grpc slice is freed +// # Deserialize() calls FlightDataDeserialize() which transfers grpc slice +// lifecyce to Buffer object (FlightData->body) managed by the end consumer +// * see GrpcBuffer::slice_ in serialization_internal.cc +// # releasing the reader buffer +// # end consumer frees FlightData +// # frees Buffer(GrpcBuffer) object +// # frees grpc slice GrpcBuffer::slice_ +// # invokes destroyer to release reader buffer + +// Reader buffer must be a continuous memory block, see GrpcBuffer::Wrap +// - FlightDataDeserialize will copy and flatten non-continuous blocks anyway +// - it's necessary to get destroyer called + +Status Deserialize(std::shared_ptr buffer, FlightData* data); + +// Writer buffer management +// # FlightDataSerialize() holds buffer (FlightPayload.ipc_msg.body_buffer[i]) +// in returned grpc bbuf +// * see SliceFromBuffer in serialization_internal.cc +// # dump grpc bbuf to a vector of grpc slice, then move to SerializeSlice[] +// # data plane sends data per returned SerializeSlice[] +// # releasing the writer buffer +// # data plane frees vector +// # frees grpc slice SerializeSlice::slice_ +// # release writer buffer + +// a simple wrapper of grpc::Slice, the only purpose is to hide grpc +// from data plane implementation +class SerializeSlice { + public: + explicit SerializeSlice(grpc::Slice&& slice); + SerializeSlice(SerializeSlice&&); + ~SerializeSlice(); + + const uint8_t* data() const; + int64_t size() const; + + private: + std::unique_ptr slice_; +}; + +arrow::Result> Serialize(const FlightPayload& payload, + int64_t* total_size = NULLPTR); + +} // namespace internal +} // namespace flight +} // namespace arrow diff --git a/cpp/src/arrow/flight/data_plane/shm.cc b/cpp/src/arrow/flight/data_plane/shm.cc new file mode 100644 index 00000000000..9bbcbef28fe --- /dev/null +++ b/cpp/src/arrow/flight/data_plane/shm.cc @@ -0,0 +1,826 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +// shared memory data plane driver +// - client and server use fifo (named pipe) for messaging +// - client and server mmap one pre-allocated shared buffer for flight payload +// exchanging, with best performance +// - if no shared buffer available, an one-shot buffer is created on demand for +// each read/write operation, performance is probably poor + +// TODO(yibo): +// - implement back pressure (common for all data planes) +// - replace fifo with unix socket for ipc +// * current fifo based approach may leave garbage ipc files on crash +// with unix socket: unlink ipc files immediately after open, pass fd +// * fix a race issue (see ShmFifo:~ShmFifo()) +// * socket based messaging may be re-usable by other data planes +// - IS IT POSSIBLE to use existing grpc control path for data plane messaging +// - improve buffer cache management +// * finer de-/allocation, better resource usage, drop one-shot buffer +// * better to be re-usable for other data planes + +// XXX: performance depends heavily on if the buffer cache is used effectively +// - if payload size > kBufferSize, cache cannot be used, performance suffers +// - cache capacity (kBufferCount) limits pending buffers not consumed by the +// reader, performance may suffer if reader is slower than the writer as the +// buffer cache is used up quickly + +// default buffer cache +static constexpr int kBufferCount = 4; +static constexpr int kBufferSize = 256 * 1024; + +#include "arrow/buffer.h" +#include "arrow/flight/data_plane/internal.h" +#include "arrow/flight/data_plane/serialize.h" +#include "arrow/flight/data_plane/types.h" +#include "arrow/result.h" +#include "arrow/util/bit_util.h" +#include "arrow/util/counting_semaphore.h" +#include "arrow/util/logging.h" +#include "arrow/util/make_unique.h" + +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include +#include +#include +#include +#include +#include +#include +#include + +namespace arrow { +namespace flight { +namespace internal { + +namespace { + +// stream map keys transferred together with grpc client metadata +// must be unique, no capital letter +// - name prefix of fifo and shared memory +const char kIpcNamePrefix[] = "flight-dataplane-shm-ipc"; + +// message between client and server +struct ShmMsg { + static constexpr uint32_t kMagic = 0xFEEDCAFE; + static constexpr int kShmNameSize = 64; + // Get + // # server -> client: GetData + // # client -> server: GetAck/Nak + // # server -> client: GetDone (WritesDone) + // Put + // # client -> server: PutData + // # server -> client: PutAck/Nak + // # client -> server: PutDone (WritesDone) + enum Type { + Min, + GetData, + PutData, + GetAck, + PutAck, + GetNak, + PutNak, + GetDone, + PutDone, + GetInvalid, + PutInvalid, + Finish, + Max + }; + + ShmMsg() = default; + + ShmMsg(Type type, int64_t size, const char* shm_name) + : magic(kMagic), type(type), size(size) { + // shm_name strlen is verified in Make + const int n = snprintf(this->shm_name, kShmNameSize, "%s", shm_name); + DCHECK(n >= 0 && n < kShmNameSize); + } + + static arrow::Result Make(Type type, int64_t size = 0, + const std::string& shm_name = "") { + DCHECK(type > Type::Min && type < Type::Max); + if (shm_name.size() >= kShmNameSize) { + return Status::IOError("shared memory name length greater than ", + int(kShmNameSize)); + } + return ShmMsg(type, size, shm_name.c_str()); + } + + // passed across process, no pointer + uint32_t magic = kMagic; + Type type = Type::Min; + int64_t size = 0; + char shm_name[kShmNameSize]{}; +}; + +// shared memory buffer +class ShmBuffer : public MutableBuffer { + public: + // create and map a new shared memory with specified name and size, called by client + static arrow::Result> Create(const std::string& name, + int64_t size) { + DCHECK_GT(size, 0); + + int fd = shm_open(name.c_str(), O_CREAT | O_EXCL | O_RDWR, 0666); + if (fd == -1) { + return Status::IOError("create shm: ", strerror(errno)); + } + if (ftruncate(fd, size) == -1) { + const int saved_errno = errno; + close(fd); + shm_unlink(name.c_str()); + return Status::IOError("ftruncate: ", strerror(saved_errno)); + } + + void* data = + mmap(NULL, static_cast(size), PROT_READ | PROT_WRITE, MAP_SHARED, fd, 0); + const int saved_errno = errno; + close(fd); + if (data == MAP_FAILED) { + shm_unlink(name.c_str()); + return Status::IOError("mmap: ", strerror(saved_errno)); + } + + return std::make_shared(reinterpret_cast(data), size); + } + + // open and map an existing shared memory, called by server + static arrow::Result> Open(const std::string& name, + int64_t size) { + DCHECK_GT(size, 0); + + int fd = shm_open(name.c_str(), O_RDWR, 0666); + if (fd == -1) { + return Status::IOError("open shm: ", strerror(errno)); + } + + void* data = + mmap(NULL, static_cast(size), PROT_READ | PROT_WRITE, MAP_SHARED, fd, 0); + const int saved_errno = errno; + // memory is mapped, we can close shm fd and delete rpc file + close(fd); + shm_unlink(name.c_str()); + if (data == MAP_FAILED) { + return Status::IOError("mmap: ", strerror(saved_errno)); + } + + return std::make_shared(reinterpret_cast(data), size); + } + + ShmBuffer(uint8_t* data, int64_t size) : MutableBuffer(data, size) {} + + ~ShmBuffer() override { munmap(mutable_data(), static_cast(size())); } +}; + +// per stream buffer cache +class ShmCache { + class CachedBuffer : public MutableBuffer { + public: + CachedBuffer(std::shared_ptr parent, atomic_uchar* refcnt) + : MutableBuffer(parent->mutable_data(), parent->size()), + parent_(std::move(parent)), + refcnt_(refcnt) {} + + // decreases reference count on destruction + ~CachedBuffer() { + const uint8_t current = atomic_fetch_sub(refcnt_, 1); + DCHECK(current == 1 || current == 2); + } + + private: + // holds the parent buffer as refcnt_ locates there + std::shared_ptr parent_; + atomic_uchar* refcnt_; + }; + + public: + static arrow::Result> Create(const std::string& name_prefix, + int buffer_count, + int buffer_size) { + return CreateOrOpen(name_prefix, buffer_count, buffer_size, /*create=*/true); + } + + static arrow::Result> Open(const std::string& name_prefix, + int buffer_count, + int buffer_size) { + return CreateOrOpen(name_prefix, buffer_count, buffer_size, /*create=*/false); + } + + ShmCache(const std::string& name_prefix, int64_t buffer_count, int64_t buffer_size, + int64_t buffer_offset, std::shared_ptr buffer, atomic_uchar* refcnt) + : name_prefix_(name_prefix), + buffer_count_(buffer_count), + buffer_size_(buffer_size), + buffer_offset_(buffer_offset), + buffer_(std::move(buffer)), + refcnt_(refcnt) {} + + // called by writer + arrow::Result>> CreateBuffer( + int64_t size, bool client) { + if (size <= buffer_size_) { + // acquire free cached buffer + for (int i = 0; i < buffer_count_; ++i) { + uint8_t expected = 0; + if (atomic_compare_exchange_strong(&refcnt_[i], &expected, 1)) { + auto buffer = std::make_shared( + SliceMutableBuffer(buffer_, buffer_offset_ + buffer_size_ * i, size), + &refcnt_[i]); + // append buffer index to name + return std::make_pair(cached_buffer_prefix() + std::to_string(i), + std::move(buffer)); + } + } + } + // fallback to one-shot buffer if no cached buffer available + static std::atomic counter{0}; + const std::string name = + name_prefix_ + (client ? "c" : "s") + std::to_string(++counter); + ARROW_ASSIGN_OR_RAISE(auto buffer, ShmBuffer::Create(name, size)); + return std::make_pair(name, std::move(buffer)); + } + + // called by reader + arrow::Result> OpenBuffer(const std::string& shm_name, + int64_t size) { + const std::string name_prefix = cached_buffer_prefix(); + if (shm_name.find(name_prefix) == 0) { + const int i = std::stoi( + shm_name.substr(name_prefix.size(), shm_name.size() - name_prefix.size())); + if (i < 0 || i >= buffer_count_) { + return Status::IOError("invalid buffer index"); + } + // normally, writer is still holding the buffer, refcnt should be 1 + // but writer stream may have be destroyed which decreased refcnt to 0 + const uint8_t current = atomic_fetch_add(&refcnt_[i], 1); + DCHECK(current == 0 || current == 1); + return std::make_shared( + SliceMutableBuffer(buffer_, buffer_offset_ + buffer_size_ * i, size), + &refcnt_[i]); + } + return ShmBuffer::Open(shm_name, size); + } + + private: + static arrow::Result> CreateOrOpen( + const std::string& name_prefix, int64_t buffer_count, int64_t buffer_size, + bool create) { + if (buffer_count < 0 || buffer_size < 0) { + return Status::Invalid("invalid buffer count or size"); + } + if (buffer_count == 0 || buffer_size == 0) { + return arrow::internal::make_unique(name_prefix, 0, 0, 0, nullptr, + nullptr); + } + + // allocate all buffers at once, prepend refcnt[] array + buffer_size = bit_util::RoundUpToPowerOf2(buffer_size, 64); + const int64_t buffer_offset = bit_util::RoundUpToPowerOf2(buffer_count, 64); + const int64_t total_size = buffer_offset + buffer_size * buffer_count; + ARROW_ASSIGN_OR_RAISE(auto buffer, + create ? ShmBuffer::Create(name_prefix + "cc-shm", total_size) + : ShmBuffer::Open(name_prefix + "cc-shm", total_size)); + atomic_uchar* refcnt = reinterpret_cast(buffer->mutable_data()); + if (create) { + std::memset(refcnt, 0, buffer_count); + } + return arrow::internal::make_unique( + name_prefix, buffer_count, buffer_size, buffer_offset, std::move(buffer), refcnt); + } + + std::string cached_buffer_prefix() { return name_prefix_ + "cc-buf"; } + + const std::string name_prefix_; + // shared buffer + // - all cached buffers are allocate in one continuous buffer + // - prepends refcnt[] array with each element refers to one buffer + // * shared and accessed as atomic variables by client/server + // * acquiring steps: 0 - free, 1 - created by writer, 2 - opened by reader + // * releasing steps: both writer and reader decreased refcnt by 1 + // - memory layout + // +----------------+--------------+-----+------------------+ + // content | refcnt[count_] | buffer0 | ... | buffer[count_-1] | + // +----------------+--------------+-----+------------------+ + // size | buffer_count_ | buffer_size_ | ... | buffer_size_ | + // +----------------+--------------+-----+------------------+ + // ^ ^ + // | | + // 0 buffer_offset_ + const int64_t buffer_count_, buffer_size_, buffer_offset_; + std::shared_ptr buffer_; + atomic_uchar* refcnt_; + static_assert(sizeof(atomic_uchar) == 1, ""); +}; + +// fifo for client and server messages +class ShmFifo { + public: + // create and open two fifos for read/write, called by client + static arrow::Result> Create( + const std::string& fifo_name, std::unique_ptr&& shm_cache) { + const std::string reader_path = "/tmp/" + fifo_name + "s2c"; + const std::string writer_path = "/tmp/" + fifo_name + "c2s"; + if (mkfifo(reader_path.c_str(), 0666) == -1) { + return Status::IOError("mkfifo: ", strerror(errno)); + } + if (mkfifo(writer_path.c_str(), 0666) == -1) { + unlink(reader_path.c_str()); + return Status::IOError("mkfifo: ", strerror(errno)); + } + // fifo open hangs if RDONLY or WRONLY (until peer opens) + int reader_fd = open(reader_path.c_str(), O_RDWR); + int writer_fd = open(writer_path.c_str(), O_RDWR); + if (reader_fd == -1 || writer_fd == -1) { + unlink(reader_path.c_str()); + unlink(writer_path.c_str()); + if (reader_fd != -1) close(reader_fd); + if (writer_fd != -1) close(writer_fd); + return Status::IOError("create fifo"); + } + return arrow::internal::make_unique(std::move(shm_cache), reader_fd, + writer_fd); + } + + // open existing fifos, called by server + static arrow::Result> Open( + const std::string& fifo_name, std::unique_ptr&& shm_cache) { + const std::string reader_path = "/tmp/" + fifo_name + "c2s"; + const std::string writer_path = "/tmp/" + fifo_name + "s2c"; + int reader_fd = open(reader_path.c_str(), O_RDWR); + int writer_fd = open(writer_path.c_str(), O_RDWR); + // we've opened the fifos, unlink fifo files + unlink(reader_path.c_str()); + unlink(writer_path.c_str()); + if (reader_fd == -1 || writer_fd == -1) { + if (reader_fd != -1) close(reader_fd); + if (writer_fd != -1) close(writer_fd); + return Status::IOError("open fifos"); + } + return arrow::internal::make_unique(std::move(shm_cache), reader_fd, + writer_fd); + } + + ShmFifo(std::unique_ptr&& shm_cache, int reader_fd, int writer_fd) + : shm_cache_(std::move(shm_cache)), reader_fd_(reader_fd), writer_fd_(writer_fd) { + pipe(pipe_fds_); + reader_thread_ = std::thread([this] { ReaderThread(); }); + } + + ~ShmFifo() { + StopReaderThread(); + close(pipe_fds_[0]); + close(pipe_fds_[1]); + close(reader_fd_); + // XXX: if writer fifo closes immediately after writing finish message + // the reader fifo may not see the messag and timeout (both ends must + // be open for fifo to work properly) + // it won't happen after replacing fifo with unix socket + // below horrible code is a temporary workaround for this issue + int writer_fd = writer_fd_; + std::thread([writer_fd]() { + sleep(1); + close(writer_fd); + }).detach(); + } + + Status WriteData(ShmMsg::Type msg_type, std::shared_ptr buffer, + const std::string& shm_name) { + DCHECK(msg_type == ShmMsg::GetData || msg_type == ShmMsg::PutData); + if (peer_finished_) { + return Status::IOError("peer finished or cancelled"); + } + + const int64_t size = buffer->size(); + ARROW_ASSIGN_OR_RAISE(ShmMsg msg, ShmMsg::Make(msg_type, size, shm_name)); + + // hold write buffer in held_writers_ map to wait for peer response + // it must be done before writing os fifo, in case peer responses fast + { + const std::lock_guard lock(held_writers_mtx_); + DCHECK_EQ(held_writers_.find(shm_name), held_writers_.end()); + held_writers_[shm_name] = std::move(buffer); + } + + // write to os fifo + const Status st = OsWriteFifo(msg); + if (!st.ok()) { + // release the buffer on error + const std::lock_guard lock(held_writers_mtx_); + const auto n = held_writers_.erase(shm_name); + DCHECK_EQ(n, 1); + } + return st; + } + + Status WriteCtrl(ShmMsg::Type msg_type) { + ARROW_ASSIGN_OR_RAISE(ShmMsg msg, ShmMsg::Make(msg_type)); + return OsWriteFifo(msg); + } + + arrow::Result>> ReadMsg() { + RETURN_NOT_OK(reader_sem_.Acquire(1)); + const std::lock_guard lock(reader_queue_mtx_); + auto v = reader_queue_.front(); + reader_queue_.pop(); + return std::move(v); + } + + arrow::Result>> CreateBuffer( + int64_t size, bool client) { + return shm_cache_->CreateBuffer(size, client); + } + + private: + void ReaderThread() { + ShmMsg msg; + while (OsReadFifo(&msg).ok()) { + DCHECK_EQ(msg.magic, ShmMsg::kMagic); + + switch (msg.type) { + // data + case ShmMsg::GetData: + case ShmMsg::PutData: + // create buffer per received message, append to reader queue + { + ShmMsg::Type msg_type = msg.type; + std::shared_ptr buffer; + auto result = shm_cache_->OpenBuffer(msg.shm_name, msg.size); + if (result.ok()) { + buffer = result.ValueOrDie(); + } else { + msg_type = + msg.type == ShmMsg::GetData ? ShmMsg::GetInvalid : ShmMsg::PutInvalid; + } + { + std::lock_guard lock(reader_queue_mtx_); + reader_queue_.emplace(msg_type, std::move(buffer)); + } + ARROW_UNUSED(reader_sem_.Release(1)); + // send response so the writer can free its buffer + if (msg.type == ShmMsg::GetData) { + msg.type = result.ok() ? ShmMsg::GetAck : ShmMsg::GetNak; + } else { + msg.type = result.ok() ? ShmMsg::PutAck : ShmMsg::PutNak; + } + DCHECK_OK(OsWriteFifo(msg)); + } + break; + // writes done, error, finish + case ShmMsg::GetDone: + case ShmMsg::PutDone: + case ShmMsg::GetInvalid: + case ShmMsg::PutInvalid: + case ShmMsg::Finish: { + const std::lock_guard lock(reader_queue_mtx_); + reader_queue_.emplace(msg.type, std::shared_ptr()); + } + ARROW_UNUSED(reader_sem_.Release(1)); + if (msg.type == ShmMsg::Finish) { + peer_finished_ = true; + } + break; + // data response + case ShmMsg::GetAck: + case ShmMsg::GetNak: + case ShmMsg::PutAck: + case ShmMsg::PutNak: + // release according write buffer + { + const std::lock_guard lock(held_writers_mtx_); + const auto n = held_writers_.erase(msg.shm_name); + DCHECK_EQ(n, 1); + } + break; + default: + DCHECK(false); + break; + } + + std::memset(&msg, 0, sizeof(ShmMsg)); + } + } + + // make sure to write/read a full message per call + Status OsWriteFifo(const ShmMsg& msg) { + if (fifo_write_error_) { + return Status::IOError("fifo write error"); + } + const uint8_t* buf = reinterpret_cast(&msg); + size_t count = sizeof(ShmMsg); + while (count > 0) { + ssize_t ret = write(writer_fd_, buf, count); + if (ret == -1 && errno != EINTR) { + fifo_write_error_ = true; + return Status::IOError("write: ", strerror(errno)); + } + count -= ret; + buf += ret; + } + return Status::OK(); + } + + Status OsReadFifo(ShmMsg* msg) { + struct pollfd fds[2]; + fds[0].fd = reader_fd_; + fds[0].events = POLLIN; + fds[1].fd = pipe_fds_[0]; + fds[1].events = POLLIN; + + uint8_t* buf = reinterpret_cast(msg); + size_t count = sizeof(ShmMsg); + while (count > 0) { + // force checking stop token every 5 seconds, in case pipe method fails + if (poll(fds, 2, 5000) == -1) { + if (errno == EINTR) { + continue; + } + return Status::IOError("poll: ", strerror(errno)); + } + if ((fds[1].revents & POLLIN) || stop_) { + // exit thread if pipe received something or stop token is set + return Status::IOError("stop requested"); + } + if (fds[0].revents & (POLLERR | POLLHUP | POLLNVAL)) { + return Status::IOError("error polled"); + } + if (fds[0].revents & POLLIN) { + ssize_t ret = read(reader_fd_, buf, count); + count -= ret; + buf += ret; + } + } + return Status::OK(); + } + + void StopReaderThread() { + stop_ = true; + while (write(pipe_fds_[1], "s", 1) == -1 && errno == EINTR) { + } + reader_thread_.join(); + } + + std::unique_ptr shm_cache_; + // os fifo fd + int reader_fd_, writer_fd_; + bool fifo_write_error_{false}; + std::thread reader_thread_; + // self pipe to stop reader thread + int pipe_fds_[2]; + // force stoping reader thread in case pipe method fails + std::atomic stop_{false}; + // read message queue with timeout (XXX: suitable value?) + arrow::util::CountingSemaphore reader_sem_{/*initial=*/0, /*timeout_seconds=*/5}; + std::queue>> reader_queue_; + std::mutex reader_queue_mtx_; + // write buffers cannot be freed before peer response + std::unordered_map> held_writers_; + std::mutex held_writers_mtx_; + // received finish or cancel message, cannot write anymore + std::atomic peer_finished_{false}; +}; + +struct ShmStreamImpl { + ShmStreamImpl(bool client, StreamType stream_type, std::unique_ptr&& shm_fifo) + : client_(client), stream_type_(stream_type), shm_fifo_(std::move(shm_fifo)) {} + + Status Read(FlightData* data) { + DCHECK_NE(stream_type_, client_ ? StreamType::kPut : StreamType::kGet); + const std::string prefix = client_ ? "client: " : "server: "; + if (reads_done_) { + return Status::IOError(prefix, "reads done"); + } + ShmMsg::Type msg_type; + std::shared_ptr buffer; + ARROW_ASSIGN_OR_RAISE(std::tie(msg_type, buffer), shm_fifo_->ReadMsg()); + switch (msg_type) { + case ShmMsg::GetData: + case ShmMsg::PutData: + DCHECK_EQ(msg_type == ShmMsg::GetData, client_); + return Deserialize(std::move(buffer), data); + case ShmMsg::GetDone: + case ShmMsg::PutDone: + DCHECK_EQ(msg_type == ShmMsg::GetDone, client_); + reads_done_ = true; + return Status::IOError(prefix, "peer done writing"); + case ShmMsg::GetInvalid: + case ShmMsg::PutInvalid: + DCHECK_EQ(msg_type == ShmMsg::GetInvalid, client_); + return Status::Invalid(prefix, "recevied invalid payload"); + case ShmMsg::Finish: + DCHECK(!client_); + reads_done_ = writes_done_ = true; + return Status::IOError(prefix, "client finished"); + default: + DCHECK(false); + return Status::Invalid(prefix, "received invalid message"); + } + } + + Status Write(const FlightPayload& payload) { + DCHECK_NE(stream_type_, client_ ? StreamType::kGet : StreamType::kPut); + if (writes_done_) { + return Status::IOError(client_ ? "client" : "server", ": writes done"); + } + int64_t total_size; + auto result = Serialize(payload, &total_size); + if (!result.ok()) { + ARROW_UNUSED( + shm_fifo_->WriteCtrl(client_ ? ShmMsg::PutInvalid : ShmMsg::GetInvalid)); + return result.status(); + } + DCHECK_GT(total_size, 0); + const std::vector& slices = result.ValueOrDie(); + + std::string shm_name; + std::shared_ptr buffer; + ARROW_ASSIGN_OR_RAISE(std::tie(shm_name, buffer), + shm_fifo_->CreateBuffer(total_size, client_)); + CopySlicesToBuffer(slices, buffer.get()); + return shm_fifo_->WriteData(client_ ? ShmMsg::PutData : ShmMsg::GetData, + std::move(buffer), shm_name); + } + + Status WritesDone() { + DCHECK_NE(stream_type_, client_ ? StreamType::kGet : StreamType::kPut); + if (writes_done_) { + return Status::OK(); + } + writes_done_ = true; + return shm_fifo_->WriteCtrl(client_ ? ShmMsg::PutDone : ShmMsg::GetDone); + } + + static void CopySlicesToBuffer(const std::vector& slices, + Buffer* buffer) { + uint8_t* dest_ptr = buffer->mutable_data(); + for (const auto& slice : slices) { + DCHECK_LE(dest_ptr + slice.size(), buffer->data() + buffer->size()); + std::memcpy(dest_ptr, slice.data(), static_cast(slice.size())); + dest_ptr += slice.size(); + } + DCHECK_EQ(dest_ptr, buffer->data() + buffer->size()); + } + + const bool client_; + const StreamType stream_type_; + std::unique_ptr shm_fifo_; + std::atomic reads_done_{false}, writes_done_{false}; +}; + +class ShmClientStream : public DataClientStream { + public: + ShmClientStream(StreamType stream_type, std::unique_ptr&& shm_fifo) + : stream_(/*client=*/true, stream_type, std::move(shm_fifo)) {} + + ~ShmClientStream() { ARROW_UNUSED(Finish()); } + + Status Read(FlightData* data) override { return stream_.Read(data); } + Status Write(const FlightPayload& payload) override { return stream_.Write(payload); } + Status WritesDone() override { return stream_.WritesDone(); } + + Status Finish() override { + stream_.writes_done_ = stream_.reads_done_ = true; + return stream_.shm_fifo_->WriteCtrl(ShmMsg::Finish); + } + + void TryCancel() override { + ARROW_UNUSED(stream_.shm_fifo_->WriteCtrl(ShmMsg::Finish)); + } + + private: + ShmStreamImpl stream_; +}; + +class ShmServerStream : public DataServerStream { + public: + ShmServerStream(StreamType stream_type, std::unique_ptr&& shm_fifo) + : stream_(/*client=*/false, stream_type, std::move(shm_fifo)) {} + + ~ShmServerStream() { + if (stream_.stream_type_ != StreamType::kPut) { + ARROW_UNUSED(WritesDone()); + } + } + + Status Read(FlightData* data) override { return stream_.Read(data); } + Status Write(const FlightPayload& payload) override { return stream_.Write(payload); } + Status WritesDone() override { return stream_.WritesDone(); } + + private: + ShmStreamImpl stream_; +}; + +class ShmClientDataPlane : public ClientDataPlane { + private: + ResultClientStream DoGetImpl(StreamMap* map) override { + return Do(map, StreamType::kGet); + } + + ResultClientStream DoPutImpl(StreamMap* map) override { + return Do(map, StreamType::kPut); + } + + ResultClientStream DoExchangeImpl(StreamMap* map) override { + return Do(map, StreamType::kExchange); + } + + ResultClientStream Do(StreamMap* map, StreamType stream_type) { + const std::string name_prefix = GenerateNamePrefix(); + (*map)[kIpcNamePrefix] = name_prefix; + ARROW_ASSIGN_OR_RAISE(auto shm_cache, + ShmCache::Create(name_prefix, kBufferCount, kBufferSize)); + ARROW_ASSIGN_OR_RAISE(auto shm_fifo, + ShmFifo::Create(name_prefix, std::move(shm_cache))); + return arrow::internal::make_unique(stream_type, + std::move(shm_fifo)); + } + + // generate system unique name prefix for ipc objects + // prefix = "/flight-shm-{client pid}-{stream counter}-" + std::string GenerateNamePrefix() { + static std::atomic counter{0}; + std::stringstream name_prefix; + name_prefix << "/flight-shm-" << getpid() << '-' << ++counter << '-'; + return name_prefix.str(); + } +}; + +class ShmServerDataPlane : public ServerDataPlane { + private: + ResultServerStream DoGetImpl(const StreamMap& map) override { + return Do(map, StreamType::kGet); + } + + ResultServerStream DoPutImpl(const StreamMap& map) override { + return Do(map, StreamType::kPut); + } + + ResultServerStream DoExchangeImpl(const StreamMap& map) override { + return Do(map, StreamType::kExchange); + } + + std::vector stream_keys() override { return {kIpcNamePrefix}; } + + ResultServerStream Do(const StreamMap& map, StreamType stream_type) { + ARROW_ASSIGN_OR_RAISE(auto name_prefix, GetNamePrefix(map)); + ARROW_ASSIGN_OR_RAISE(auto shm_cache, + ShmCache::Open(name_prefix, kBufferCount, kBufferSize)); + ARROW_ASSIGN_OR_RAISE(auto shm_fifo, + ShmFifo::Open(name_prefix, std::move(shm_cache))); + return arrow::internal::make_unique(stream_type, + std::move(shm_fifo)); + } + + // extract ipc objecs name prefix from stream map set by client + arrow::Result GetNamePrefix(const StreamMap& map) { + auto it = map.find(kIpcNamePrefix); + if (it == map.end()) { + return Status::Invalid("key not found: ", kIpcNamePrefix); + } + return it->second; + } +}; + +arrow::Result> MakeShmClientDataPlane( + const std::string&) { + return arrow::internal::make_unique(); +} + +arrow::Result> MakeShmServerDataPlane( + const std::string&) { + return arrow::internal::make_unique(); +} + +} // namespace + +DataPlaneMaker GetShmDataPlaneMaker() { + return {MakeShmClientDataPlane, MakeShmServerDataPlane}; +} + +} // namespace internal +} // namespace flight +} // namespace arrow diff --git a/cpp/src/arrow/flight/data_plane/types.cc b/cpp/src/arrow/flight/data_plane/types.cc new file mode 100644 index 00000000000..7733bb1feb6 --- /dev/null +++ b/cpp/src/arrow/flight/data_plane/types.cc @@ -0,0 +1,127 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#include "arrow/flight/data_plane/types.h" +#include "arrow/flight/data_plane/internal.h" +#include "arrow/util/io_util.h" +#include "arrow/util/logging.h" +#include "arrow/util/make_unique.h" + +#ifdef GRPCPP_PP_INCLUDE +#include +#else +#include +#endif + +namespace arrow { +namespace flight { +namespace internal { + +namespace { + +// data plane registry (name -> data plane maker) +struct Registry { + std::map makers; + + // register all data planes on creation of registry singleton + Registry() { +#ifdef FLIGHT_DP_SHM + Register("shm", GetShmDataPlaneMaker()); +#endif + } + + static const Registry& instance() { + static const Registry registry; + return registry; + } + + void Register(const std::string& name, const DataPlaneMaker& maker) { + DCHECK_EQ(makers.find(name), makers.end()); + makers[name] = maker; + } + + arrow::Result GetDataPlaneMaker(const std::string& uri) const { + const std::string name = uri.substr(0, uri.find(':')); + auto it = makers.find(name); + if (it == makers.end()) { + return Status::Invalid("unknown data plane: ", name); + } + return it->second; + } +}; + +std::string GetGrpcMetadata(const grpc::ServerContext& context, const std::string& key) { + const auto client_metadata = context.client_metadata(); + const auto found = client_metadata.find(key); + std::string token; + if (found == client_metadata.end()) { + DCHECK(false); + token = ""; + } else { + token = std::string(found->second.data(), found->second.length()); + } + return token; +} + +// TODO(yibo): getting data plane uri from env var is bad, shall we extend +// location to support two uri (control, data)? or any better approach to +// negotiate data plane uri? +std::string DataUriFromLocation(const Location& /*location*/) { + auto result = arrow::internal::GetEnvVar("FLIGHT_DATAPLANE"); + if (result.ok()) { + return result.ValueOrDie(); + } + return ""; // empty uri -> default grpc data plane +} + +} // namespace + +arrow::Result> ClientDataPlane::Make( + const Location& location) { + const std::string uri = DataUriFromLocation(location); + if (uri.empty()) return arrow::internal::make_unique(); + ARROW_ASSIGN_OR_RAISE(auto maker, Registry::instance().GetDataPlaneMaker(uri)); + return maker.make_client(uri); +} + +arrow::Result> ServerDataPlane::Make( + const Location& location) { + const std::string uri = DataUriFromLocation(location); + if (uri.empty()) return arrow::internal::make_unique(); + ARROW_ASSIGN_OR_RAISE(auto maker, Registry::instance().GetDataPlaneMaker(uri)); + return maker.make_server(uri); +} + +void ClientDataPlane::AppendStreamMap( + grpc::ClientContext* context, const std::map& stream_map) { + for (const auto& kv : stream_map) { + context->AddMetadata(kv.first, kv.second); + } +} + +std::map ServerDataPlane::ExtractStreamMap( + const grpc::ServerContext& context) { + std::map stream_map; + for (const auto& key : stream_keys()) { + stream_map[key] = GetGrpcMetadata(context, key); + } + return stream_map; +} + +} // namespace internal +} // namespace flight +} // namespace arrow diff --git a/cpp/src/arrow/flight/data_plane/types.h b/cpp/src/arrow/flight/data_plane/types.h new file mode 100644 index 00000000000..aa2ced776c2 --- /dev/null +++ b/cpp/src/arrow/flight/data_plane/types.h @@ -0,0 +1,161 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#pragma once + +#include +#include +#include +#include + +#include "arrow/result.h" + +namespace grpc { + +class ClientContext; +class ServerContext; + +}; // namespace grpc + +namespace arrow { +namespace flight { + +struct FlightPayload; +struct Location; + +namespace internal { + +struct FlightData; + +// Data plane stream simulates grpc::Client|ServerReader[Writer] +// - DataClientStream is created before DataServerStream +// - all interfaces should return IOError on data plane related errors +// - WritesDone closes the writer side, the operation should be idempotent and +// after that, the writer should return IOError on "Write", and the reader +// should return IOError on "Read" (pending reads on peer is not discarded) +// - Finish closes the client, no more read/write can be performed, it should +// be idempotent and after that, both the client and server should return +// IOError on "Read" or "Write" (pending reads on server is not discarded) +// - TryCancel tells server to stop writing, should be idempotent +// - data race +// * a single data plane may create several client/server stream pairs which +// run in parallel +// * for a single stream, Read may run in parallel with Write (DoExchange), +// Read/Read and Write/Write normally run in sequence (DoGet/DoPut), but +// should behave corretly if multiple Reads or Writes run in parallel + +// NOTE: grpc defines separated classes for Reader/Writer/ReaderWriter, +// data plane implements all the read/write interfaces in one class + +struct DataClientStream { + virtual ~DataClientStream() = default; + + virtual Status Read(FlightData* data) = 0; + virtual Status Write(const FlightPayload& payload) = 0; + virtual Status WritesDone() = 0; + virtual Status Finish() = 0; + virtual void TryCancel() = 0; +}; + +struct DataServerStream { + virtual ~DataServerStream() = default; + + virtual Status Read(FlightData* data) = 0; + virtual Status Write(const FlightPayload& payload) = 0; + // grpc doesn't implement server writes done + virtual Status WritesDone() = 0; +}; + +// Data plane is initialized at client/server startup, it creates data streams +// to replace grpc streams for flight payload transmission (get/put/exchange). + +class ClientDataPlane { + public: + // client can send a {str:str} map to server together with grpc metadata + // this is useful to match client data stream with server data stream + using StreamMap = std::map; + using ResultClientStream = arrow::Result>; + + virtual ~ClientDataPlane() = default; + + static arrow::Result> Make(const Location& location); + + ResultClientStream DoGet(grpc::ClientContext* context) { + StreamMap stream_map; + ARROW_ASSIGN_OR_RAISE(auto data_stream, DoGetImpl(&stream_map)); + AppendStreamMap(context, stream_map); + return data_stream; + } + + ResultClientStream DoPut(grpc::ClientContext* context) { + StreamMap stream_map; + ARROW_ASSIGN_OR_RAISE(auto data_stream, DoPutImpl(&stream_map)); + AppendStreamMap(context, stream_map); + return data_stream; + } + + ResultClientStream DoExchange(grpc::ClientContext* context) { + StreamMap stream_map; + ARROW_ASSIGN_OR_RAISE(auto data_stream, DoExchangeImpl(&stream_map)); + AppendStreamMap(context, stream_map); + return data_stream; + } + + private: + // implement empty data plane in base class + virtual ResultClientStream DoGetImpl(StreamMap* stream_map) { return NULLPTR; } + virtual ResultClientStream DoPutImpl(StreamMap* stream_map) { return NULLPTR; } + virtual ResultClientStream DoExchangeImpl(StreamMap* stream_map) { return NULLPTR; } + + void AppendStreamMap(grpc::ClientContext* context, const StreamMap& stream_map); +}; + +class ServerDataPlane { + public: + using StreamMap = std::map; + using ResultServerStream = arrow::Result>; + + virtual ~ServerDataPlane() = default; + + static arrow::Result> Make(const Location& location); + + ResultServerStream DoGet(const grpc::ServerContext& context) { + return DoGetImpl(ExtractStreamMap(context)); + } + + ResultServerStream DoPut(const grpc::ServerContext& context) { + return DoPutImpl(ExtractStreamMap(context)); + } + + ResultServerStream DoExchange(const grpc::ServerContext& context) { + return DoExchangeImpl(ExtractStreamMap(context)); + } + + private: + virtual ResultServerStream DoGetImpl(const StreamMap& stream_map) { return NULLPTR; } + virtual ResultServerStream DoPutImpl(const StreamMap& stream_map) { return NULLPTR; } + virtual ResultServerStream DoExchangeImpl(const StreamMap& stream_map) { + return NULLPTR; + } + virtual std::vector stream_keys() { return {}; } + + StreamMap ExtractStreamMap(const grpc::ServerContext& context); +}; + +} // namespace internal +} // namespace flight +} // namespace arrow diff --git a/cpp/src/arrow/flight/perf_server.cc b/cpp/src/arrow/flight/perf_server.cc index 7efd034ad25..a8c9b0c812d 100644 --- a/cpp/src/arrow/flight/perf_server.cc +++ b/cpp/src/arrow/flight/perf_server.cc @@ -185,7 +185,7 @@ class FlightPerfServer : public FlightServerBase { std::unique_ptr* data_stream) override { perf::Token token; CHECK_PARSE(token.ParseFromString(request.ticket)); - return GetPerfBatches(token, perf_schema_, false, data_stream); + return GetPerfBatches(token, perf_schema_, /*use_verifier=*/false, data_stream); } Status DoPut(const ServerCallContext& context, diff --git a/cpp/src/arrow/flight/serialization_internal.h b/cpp/src/arrow/flight/serialization_internal.h index 5f7d0cc487c..de4e5469fa4 100644 --- a/cpp/src/arrow/flight/serialization_internal.h +++ b/cpp/src/arrow/flight/serialization_internal.h @@ -87,11 +87,11 @@ bool ReadPayload(grpc::ClientReaderWriter* reader // The Flight reader can then peek at the message to determine whether // it has application metadata or not, and pass the message to // RecordBatchStreamReader as appropriate. -template +template class PeekableFlightDataReader { public: - explicit PeekableFlightDataReader(ReaderPtr stream) - : stream_(stream), peek_(), finished_(false), valid_(false) {} + explicit PeekableFlightDataReader(Reader stream) + : stream_(std::move(stream)), peek_(), finished_(false), valid_(false) {} void Peek(internal::FlightData** out) { *out = nullptr; @@ -132,7 +132,7 @@ class PeekableFlightDataReader { return valid_; } - if (!internal::ReadPayload(&*stream_, &peek_)) { + if (!stream_.Read(&peek_).ok()) { finished_ = true; valid_ = false; } else { @@ -141,7 +141,7 @@ class PeekableFlightDataReader { return valid_; } - ReaderPtr stream_; + Reader stream_; internal::FlightData peek_; bool finished_; bool valid_; diff --git a/cpp/src/arrow/flight/server.cc b/cpp/src/arrow/flight/server.cc index 988bd690d51..517dcff1ecb 100644 --- a/cpp/src/arrow/flight/server.cc +++ b/cpp/src/arrow/flight/server.cc @@ -62,11 +62,7 @@ #include "arrow/flight/server_middleware.h" #include "arrow/flight/types.h" -using FlightService = arrow::flight::protocol::FlightService; -using ServerContext = grpc::ServerContext; - -template -using ServerWriter = grpc::ServerWriter; +#include "arrow/flight/data_plane/types.h" namespace arrow { namespace flight { @@ -96,6 +92,67 @@ namespace pb = arrow::flight::protocol; } \ } while (false) +namespace internal { + +template +Status ReadServer(GrpcServerStream* grpc_stream, DataServerStream* data_stream, + FlightData* payload) { + if (data_stream) { + return data_stream->Read(payload); + } else { + const bool ok = internal::ReadPayload(grpc_stream, payload); + return ok ? Status::OK() : Status::IOError("gRPC server read failed"); + } +} + +template +Status WriteServer(GrpcServerStream* grpc_stream, DataServerStream* data_stream, + const FlightPayload& payload) { + if (data_stream) { + return data_stream->Write(payload); + } else { + return internal::WritePayload(payload, grpc_stream); + } +} + +template +Status WritesDoneServer(GrpcServerStream*, DataServerStream* data_stream) { + if (data_stream) { + RETURN_NOT_OK(data_stream->WritesDone()); + } + // grpc server doesn't implement WritesDone + return Status::OK(); +} + +struct ServerWriter { + using GrpcServerStream = grpc::ServerWriter; + GrpcServerStream* grpc_stream; + DataServerStream* data_stream; + + Status Write(const FlightPayload& payload) { + return WriteServer(grpc_stream, data_stream, payload); + } + Status WritesDone() { return WritesDoneServer(grpc_stream, data_stream); } +}; + +// WriteT can be pb::FlightData or pb::PutResult +template +struct ServerReaderWriter { + using GrpcServerStream = grpc::ServerReaderWriter; + GrpcServerStream* grpc_stream; + DataServerStream* data_stream; + + Status Read(FlightData* payload) { + return ReadServer(grpc_stream, data_stream, payload); + } + Status Write(const FlightPayload& payload) { + return WriteServer(grpc_stream, data_stream, payload); + } + Status WritesDone() { return WritesDoneServer(grpc_stream, data_stream); } +}; + +} // namespace internal + namespace { // A MessageReader implementation that reads from a gRPC ServerReader. @@ -104,7 +161,7 @@ template class FlightIpcMessageReader : public ipc::MessageReader { public: explicit FlightIpcMessageReader( - std::shared_ptr> peekable_reader, + std::shared_ptr> peekable_reader, std::shared_ptr* app_metadata) : peekable_reader_(peekable_reader), app_metadata_(app_metadata) {} @@ -127,7 +184,7 @@ class FlightIpcMessageReader : public ipc::MessageReader { } protected: - std::shared_ptr> peekable_reader_; + std::shared_ptr> peekable_reader_; // A reference to FlightMessageReaderImpl.app_metadata_. That class // can't access the app metadata because when it Peek()s the stream, // it may be looking at a dictionary batch, not the record @@ -138,14 +195,13 @@ class FlightIpcMessageReader : public ipc::MessageReader { bool stream_finished_ = false; }; -template +template class FlightMessageReaderImpl : public FlightMessageReader { public: - using GrpcStream = grpc::ServerReaderWriter; + using ServerStream = internal::ServerReaderWriter; - explicit FlightMessageReaderImpl(GrpcStream* reader) - : reader_(reader), - peekable_reader_(new internal::PeekableFlightDataReader(reader)) {} + explicit FlightMessageReaderImpl(ServerStream reader) + : peekable_reader_(new internal::PeekableFlightDataReader(reader)) {} Status Init() { // Peek the first message to get the descriptor. @@ -209,7 +265,7 @@ class FlightMessageReaderImpl : public FlightMessageReader { return Status::IOError("Client never sent a data message"); } auto message_reader = std::unique_ptr( - new FlightIpcMessageReader(peekable_reader_, &app_metadata_)); + new FlightIpcMessageReader(peekable_reader_, &app_metadata_)); ARROW_ASSIGN_OR_RAISE( batch_reader_, ipc::RecordBatchStreamReader::Open(std::move(message_reader))); } @@ -217,8 +273,7 @@ class FlightMessageReaderImpl : public FlightMessageReader { } FlightDescriptor descriptor_; - GrpcStream* reader_; - std::shared_ptr> peekable_reader_; + std::shared_ptr> peekable_reader_; std::shared_ptr batch_reader_; std::shared_ptr app_metadata_; }; @@ -284,8 +339,9 @@ class GrpcServerAuthSender : public ServerAuthSender { /// stream for DoExchange. class DoExchangeMessageWriter : public FlightMessageWriter { public: - explicit DoExchangeMessageWriter( - grpc::ServerReaderWriter* stream) + using ServerStream = internal::ServerReaderWriter; + + explicit DoExchangeMessageWriter(ServerStream stream) : stream_(stream), ipc_options_(::arrow::ipc::IpcWriteOptions::Defaults()) {} Status Begin(const std::shared_ptr& schema, @@ -328,15 +384,16 @@ class DoExchangeMessageWriter : public FlightMessageWriter { } Status Close() override { - // It's fine to Close() without writing data - return Status::OK(); + // It's fine to Close() without writing data for grpc + // Data plane needs to tell the reader that we've done writing + return stream_.WritesDone(); } ipc::WriteStats stats() const override { return stats_; } private: Status WritePayload(const FlightPayload& payload) { - RETURN_NOT_OK(internal::WritePayload(payload, stream_)); + RETURN_NOT_OK(stream_.Write(payload)); ++stats_.num_messages; return Status::OK(); } @@ -365,7 +422,7 @@ class DoExchangeMessageWriter : public FlightMessageWriter { return Status::OK(); } - grpc::ServerReaderWriter* stream_; + ServerStream stream_; ::arrow::ipc::IpcWriteOptions ipc_options_; ipc::DictionaryFieldMapper mapper_; ipc::WriteStats stats_; @@ -410,7 +467,7 @@ class GrpcServerCallContext : public ServerCallContext { private: friend class FlightServiceImpl; - ServerContext* context_; + grpc::ServerContext* context_; std::string peer_; std::string peer_identity_; std::vector> middleware_; @@ -432,7 +489,7 @@ class GrpcAddCallHeaders : public AddCallHeaders { // This class glues an implementation of FlightServerBase together with the // gRPC service definition, so the latter is not exposed in the public API -class FlightServiceImpl : public FlightService::Service { +class FlightServiceImpl : public protocol::FlightService::Service { public: explicit FlightServiceImpl( std::shared_ptr auth_handler, @@ -442,7 +499,7 @@ class FlightServiceImpl : public FlightService::Service { : auth_handler_(auth_handler), middleware_(middleware), server_(server) {} template - grpc::Status WriteStream(Iterator* iterator, ServerWriter* writer) { + grpc::Status WriteStream(Iterator* iterator, grpc::ServerWriter* writer) { if (!iterator) { return grpc::Status(grpc::StatusCode::INTERNAL, "No items to iterate"); } @@ -467,7 +524,7 @@ class FlightServiceImpl : public FlightService::Service { template grpc::Status WriteStream(const std::vector& values, - ServerWriter* writer) { + grpc::ServerWriter* writer) { // Write flight info to stream until listing is exhausted for (const UserType& value : values) { ProtoType pb_value; @@ -482,7 +539,7 @@ class FlightServiceImpl : public FlightService::Service { } // Authenticate the client (if applicable) and construct the call context - grpc::Status CheckAuth(const FlightMethod& method, ServerContext* context, + grpc::Status CheckAuth(const FlightMethod& method, grpc::ServerContext* context, GrpcServerCallContext& flight_context) { if (!auth_handler_) { const auto auth_context = context->auth_context(); @@ -496,14 +553,7 @@ class FlightServiceImpl : public FlightService::Service { flight_context.peer_identity_ = ""; } } else { - const auto client_metadata = context->client_metadata(); - const auto auth_header = client_metadata.find(internal::kGrpcAuthHeader); - std::string token; - if (auth_header == client_metadata.end()) { - token = ""; - } else { - token = std::string(auth_header->second.data(), auth_header->second.length()); - } + const std::string token = GetMetadata(context, internal::kGrpcAuthHeader); GRPC_RETURN_NOT_OK(auth_handler_->IsValid(token, &flight_context.peer_identity_)); } @@ -511,7 +561,7 @@ class FlightServiceImpl : public FlightService::Service { } // Authenticate the client (if applicable) and construct the call context - grpc::Status MakeCallContext(const FlightMethod& method, ServerContext* context, + grpc::Status MakeCallContext(const FlightMethod& method, grpc::ServerContext* context, GrpcServerCallContext& flight_context) { // Run server middleware const CallInfo info{method}; @@ -542,7 +592,7 @@ class FlightServiceImpl : public FlightService::Service { } grpc::Status Handshake( - ServerContext* context, + grpc::ServerContext* context, grpc::ServerReaderWriter* stream) { GrpcServerCallContext flight_context(context); GRPC_RETURN_NOT_GRPC_OK( @@ -561,8 +611,8 @@ class FlightServiceImpl : public FlightService::Service { auth_handler_->Authenticate(&outgoing, &incoming)); } - grpc::Status ListFlights(ServerContext* context, const pb::Criteria* request, - ServerWriter* writer) { + grpc::Status ListFlights(grpc::ServerContext* context, const pb::Criteria* request, + grpc::ServerWriter* writer) { GrpcServerCallContext flight_context(context); GRPC_RETURN_NOT_GRPC_OK( CheckAuth(FlightMethod::ListFlights, context, flight_context)); @@ -584,7 +634,8 @@ class FlightServiceImpl : public FlightService::Service { WriteStream(listing.get(), writer)); } - grpc::Status GetFlightInfo(ServerContext* context, const pb::FlightDescriptor* request, + grpc::Status GetFlightInfo(grpc::ServerContext* context, + const pb::FlightDescriptor* request, pb::FlightInfo* response) { GrpcServerCallContext flight_context(context); GRPC_RETURN_NOT_GRPC_OK( @@ -609,7 +660,8 @@ class FlightServiceImpl : public FlightService::Service { RETURN_WITH_MIDDLEWARE(flight_context, grpc::Status::OK); } - grpc::Status GetSchema(ServerContext* context, const pb::FlightDescriptor* request, + grpc::Status GetSchema(grpc::ServerContext* context, + const pb::FlightDescriptor* request, pb::SchemaResult* response) { GrpcServerCallContext flight_context(context); GRPC_RETURN_NOT_GRPC_OK(CheckAuth(FlightMethod::GetSchema, context, flight_context)); @@ -633,9 +685,12 @@ class FlightServiceImpl : public FlightService::Service { RETURN_WITH_MIDDLEWARE(flight_context, grpc::Status::OK); } - grpc::Status DoGet(ServerContext* context, const pb::Ticket* request, - ServerWriter* writer) { + grpc::Status DoGet(grpc::ServerContext* context, const pb::Ticket* request, + grpc::ServerWriter* grpc_writer) { GrpcServerCallContext flight_context(context); + std::unique_ptr data_writer; + SERVICE_RETURN_NOT_OK(flight_context, + server_->data_plane()->DoGet(*context).Value(&data_writer)); GRPC_RETURN_NOT_GRPC_OK(CheckAuth(FlightMethod::DoGet, context, flight_context)); CHECK_ARG_NOT_NULL(flight_context, request, "ticket cannot be null"); @@ -655,7 +710,8 @@ class FlightServiceImpl : public FlightService::Service { // Write the schema as the first message in the stream FlightPayload schema_payload; SERVICE_RETURN_NOT_OK(flight_context, data_stream->GetSchemaPayload(&schema_payload)); - auto status = internal::WritePayload(schema_payload, writer); + auto writer = internal::ServerWriter{grpc_writer, data_writer.get()}; + auto status = writer.Write(schema_payload); if (status.IsIOError()) { // gRPC doesn't give any way for us to know why the message // could not be written. @@ -668,8 +724,11 @@ class FlightServiceImpl : public FlightService::Service { FlightPayload payload; SERVICE_RETURN_NOT_OK(flight_context, data_stream->Next(&payload)); // End of stream - if (payload.ipc_message.metadata == nullptr) break; - auto status = internal::WritePayload(payload, writer); + if (payload.ipc_message.metadata == nullptr) { + SERVICE_RETURN_NOT_OK(flight_context, writer.WritesDone()); + break; + } + auto status = writer.Write(payload); // Connection terminated if (status.IsIOError()) break; SERVICE_RETURN_NOT_OK(flight_context, status); @@ -677,26 +736,38 @@ class FlightServiceImpl : public FlightService::Service { RETURN_WITH_MIDDLEWARE(flight_context, grpc::Status::OK); } - grpc::Status DoPut(ServerContext* context, - grpc::ServerReaderWriter* reader) { + grpc::Status DoPut( + grpc::ServerContext* context, + grpc::ServerReaderWriter* grpc_reader) { GrpcServerCallContext flight_context(context); + std::unique_ptr data_reader; + SERVICE_RETURN_NOT_OK(flight_context, + server_->data_plane()->DoPut(*context).Value(&data_reader)); GRPC_RETURN_NOT_GRPC_OK(CheckAuth(FlightMethod::DoPut, context, flight_context)); + auto reader = + internal::ServerReaderWriter{grpc_reader, data_reader.get()}; auto message_reader = std::unique_ptr>( new FlightMessageReaderImpl(reader)); SERVICE_RETURN_NOT_OK(flight_context, message_reader->Init()); auto metadata_writer = - std::unique_ptr(new GrpcMetadataWriter(reader)); + std::unique_ptr(new GrpcMetadataWriter(grpc_reader)); RETURN_WITH_MIDDLEWARE(flight_context, server_->DoPut(flight_context, std::move(message_reader), std::move(metadata_writer))); } grpc::Status DoExchange( - ServerContext* context, - grpc::ServerReaderWriter* stream) { + grpc::ServerContext* context, + grpc::ServerReaderWriter* grpc_stream) { GrpcServerCallContext flight_context(context); + std::unique_ptr data_stream; + SERVICE_RETURN_NOT_OK( + flight_context, server_->data_plane()->DoExchange(*context).Value(&data_stream)); GRPC_RETURN_NOT_GRPC_OK(CheckAuth(FlightMethod::DoExchange, context, flight_context)); + + auto stream = + internal::ServerReaderWriter{grpc_stream, data_stream.get()}; auto message_reader = std::unique_ptr>( new FlightMessageReaderImpl(stream)); SERVICE_RETURN_NOT_OK(flight_context, message_reader->Init()); @@ -707,8 +778,8 @@ class FlightServiceImpl : public FlightService::Service { std::move(writer))); } - grpc::Status ListActions(ServerContext* context, const pb::Empty* request, - ServerWriter* writer) { + grpc::Status ListActions(grpc::ServerContext* context, const pb::Empty* request, + grpc::ServerWriter* writer) { GrpcServerCallContext flight_context(context); GRPC_RETURN_NOT_GRPC_OK( CheckAuth(FlightMethod::ListActions, context, flight_context)); @@ -718,8 +789,8 @@ class FlightServiceImpl : public FlightService::Service { RETURN_WITH_MIDDLEWARE(flight_context, WriteStream(types, writer)); } - grpc::Status DoAction(ServerContext* context, const pb::Action* request, - ServerWriter* writer) { + grpc::Status DoAction(grpc::ServerContext* context, const pb::Action* request, + grpc::ServerWriter* writer) { GrpcServerCallContext flight_context(context); GRPC_RETURN_NOT_GRPC_OK(CheckAuth(FlightMethod::DoAction, context, flight_context)); CHECK_ARG_NOT_NULL(flight_context, request, "Action cannot be null"); @@ -752,6 +823,18 @@ class FlightServiceImpl : public FlightService::Service { } private: + std::string GetMetadata(const grpc::ServerContext* context, const char* key) { + const auto client_metadata = context->client_metadata(); + const auto auth_header = client_metadata.find(key); + std::string token; + if (auth_header == client_metadata.end()) { + token = ""; + } else { + token = std::string(auth_header->second.data(), auth_header->second.length()); + } + return token; + } + std::shared_ptr auth_handler_; std::vector>> middleware_; @@ -844,6 +927,7 @@ class ServerSignalHandler { struct FlightServerBase::Impl { std::unique_ptr service_; std::unique_ptr server_; + std::unique_ptr data_plane_; int port_; // Signal handlers (on Windows) and the shutdown handler (other platforms) @@ -909,6 +993,9 @@ FlightServerBase::FlightServerBase() { impl_.reset(new Impl); } FlightServerBase::~FlightServerBase() {} Status FlightServerBase::Init(const FlightServerOptions& options) { + ARROW_ASSIGN_OR_RAISE(impl_->data_plane_, + internal::ServerDataPlane::Make(options.location)); + impl_->service_.reset( new FlightServiceImpl(options.auth_handler, options.middleware, this)); @@ -964,11 +1051,16 @@ Status FlightServerBase::Init(const FlightServerOptions& options) { if (!impl_->server_) { return Status::UnknownError("Server did not start properly"); } + return Status::OK(); } int FlightServerBase::port() const { return impl_->port_; } +internal::ServerDataPlane* FlightServerBase::data_plane() const { + return impl_->data_plane_.get(); +} + Status FlightServerBase::SetShutdownOnSignals(const std::vector sigs) { impl_->signals_ = sigs; impl_->old_signal_handlers_.clear(); diff --git a/cpp/src/arrow/flight/server.h b/cpp/src/arrow/flight/server.h index 218b640adc8..a38b3f90529 100644 --- a/cpp/src/arrow/flight/server.h +++ b/cpp/src/arrow/flight/server.h @@ -43,6 +43,10 @@ namespace flight { class ServerMiddleware; class ServerMiddlewareFactory; +namespace internal { +class ServerDataPlane; +} // namespace internal + /// \brief Interface that produces a sequence of IPC payloads to be sent in /// FlightData protobuf messages class ARROW_FLIGHT_EXPORT FlightDataStream { @@ -180,6 +184,10 @@ class ARROW_FLIGHT_EXPORT FlightServerBase { /// domain socket). int port() const; + /// \brief Get the data plane of the Flight server. + /// This method must only be called after Init(). + internal::ServerDataPlane* data_plane() const; + /// \brief Set the server to stop when receiving any of the given signal /// numbers. /// This method must be called before Serve().