diff --git a/cpp/src/arrow/flight/CMakeLists.txt b/cpp/src/arrow/flight/CMakeLists.txt index cae36562be3..7447e675e08 100644 --- a/cpp/src/arrow/flight/CMakeLists.txt +++ b/cpp/src/arrow/flight/CMakeLists.txt @@ -41,21 +41,29 @@ if(NOT ARROW_GRPC_USE_SHARED) PARENT_SCOPE) endif() +set(ARROW_FLIGHT_TEST_INTERFACE_LIBS) if(ARROW_FLIGHT_TEST_LINKAGE STREQUAL "static") set(ARROW_FLIGHT_TEST_LINK_LIBS arrow_flight_static arrow_flight_testing_static ${ARROW_FLIGHT_STATIC_LINK_LIBS} ${ARROW_TEST_LINK_LIBS}) if(ARROW_CUDA) + list(APPEND ARROW_FLIGHT_TEST_INTERFACE_LIBS arrow_cuda_static) list(APPEND ARROW_FLIGHT_TEST_LINK_LIBS arrow_cuda_static) endif() else() set(ARROW_FLIGHT_TEST_LINK_LIBS arrow_flight_shared arrow_flight_testing_shared ${ARROW_TEST_LINK_LIBS}) if(ARROW_CUDA) + list(APPEND ARROW_FLIGHT_TEST_INTERFACE_LIBS arrow_cuda_shared) list(APPEND ARROW_FLIGHT_TEST_LINK_LIBS arrow_cuda_shared) endif() endif() +# Needed for Flight SQL and integration +set(ARROW_FLIGHT_TEST_LINK_LIBS + ${ARROW_FLIGHT_TEST_LINK_LIBS} + PARENT_SCOPE) + # TODO(wesm): Protobuf shared vs static linking set(FLIGHT_PROTO_PATH "${ARROW_SOURCE_DIR}/../format") @@ -162,21 +170,40 @@ endif() # Restore the CXXFLAGS that were modified above set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS_BACKUP}") -# Note, we do not compile the generated Protobuf sources directly, instead +# Note, we do not compile the generated gRPC sources directly, instead # compiling them via protocol_internal.cc which contains some gRPC template # overrides to enable Flight-specific optimizations. See comments in # protocol_internal.cc set(ARROW_FLIGHT_SRCS + "${CMAKE_CURRENT_BINARY_DIR}/Flight.pb.cc" client.cc client_cookie_middleware.cc - client_header_internal.cc - internal.cc - protocol_internal.cc + cookie_internal.cc serialization_internal.cc server.cc server_auth.cc + transport.cc + transport_server.cc + # Bundle the gRPC impl with libarrow_flight + transport/grpc/grpc_client.cc + transport/grpc/grpc_server.cc + transport/grpc/serialization_internal.cc + transport/grpc/protocol_grpc_internal.cc + transport/grpc/util_internal.cc types.cc) +if(MSVC) + # Protobuf generated files trigger spurious warnings on MSVC. + foreach(GENERATED_SOURCE "${CMAKE_CURRENT_BINARY_DIR}/Flight.pb.cc" + "${CMAKE_CURRENT_BINARY_DIR}/Flight.pb.h") + # Suppress missing dll-interface warning + set_source_files_properties("${GENERATED_SOURCE}" + PROPERTIES COMPILE_OPTIONS "/wd4251" + GENERATED TRUE + SKIP_UNITY_BUILD_INCLUSION TRUE) + endforeach() +endif() + add_arrow_lib(arrow_flight CMAKE_PACKAGE_NAME ArrowFlight @@ -226,10 +253,12 @@ if(ARROW_TESTING) ${BOOST_FILESYSTEM_LIBRARY} ${BOOST_SYSTEM_LIBRARY} GTest::gtest + ${ARROW_FLIGHT_TEST_INTERFACE_LIBS} STATIC_LINK_LIBS arrow_static arrow_flight_static - arrow_testing_static) + arrow_testing_static + ${ARROW_FLIGHT_TEST_INTERFACE_LIBS}) endif() foreach(LIB_TARGET ${ARROW_FLIGHT_TESTING_LIBRARIES}) @@ -263,14 +292,6 @@ if(ARROW_BUILD_TESTS OR ARROW_BUILD_BENCHMARKS) add_dependencies(arrow_flight flight-test-server) endif() -if(ARROW_CUDA AND ARROW_BUILD_TESTS) - add_arrow_test(flight_cuda_test - STATIC_LINK_LIBS - ${ARROW_FLIGHT_TEST_LINK_LIBS} - LABELS - "arrow_flight") -endif() - if(ARROW_BUILD_BENCHMARKS) # Perf server for benchmarks set(PERF_PROTO_GENERATED_FILES "${CMAKE_CURRENT_BINARY_DIR}/perf.pb.cc" diff --git a/cpp/src/arrow/flight/client.cc b/cpp/src/arrow/flight/client.cc index 39d6cebf41f..3067d28c9ed 100644 --- a/cpp/src/arrow/flight/client.cc +++ b/cpp/src/arrow/flight/client.cc @@ -20,51 +20,30 @@ // Platform-specific defines #include "arrow/flight/platform.h" -#include #include -#include #include #include -#include #include -#ifdef GRPCPP_PP_INCLUDE -#include -#if defined(GRPC_NAMESPACE_FOR_TLS_CREDENTIALS_OPTIONS) -#include -#endif -#else -#include -#endif - -#include - #include "arrow/buffer.h" +#include "arrow/ipc/options.h" #include "arrow/ipc/reader.h" #include "arrow/ipc/writer.h" -#include "arrow/record_batch.h" #include "arrow/result.h" #include "arrow/status.h" #include "arrow/table.h" -#include "arrow/type.h" #include "arrow/util/logging.h" -#include "arrow/util/uri.h" #include "arrow/flight/client_auth.h" -#include "arrow/flight/client_header_internal.h" -#include "arrow/flight/client_middleware.h" -#include "arrow/flight/internal.h" -#include "arrow/flight/middleware.h" -#include "arrow/flight/middleware_internal.h" #include "arrow/flight/serialization_internal.h" +#include "arrow/flight/transport.h" +#include "arrow/flight/transport/grpc/grpc_client.h" #include "arrow/flight/types.h" namespace arrow { namespace flight { -namespace pb = arrow::flight::protocol; - const char* kWriteSizeDetailTypeId = "flight::FlightWriteSizeStatusDetail"; FlightCallOptions::FlightCallOptions() @@ -101,349 +80,18 @@ Status FlightStreamReader::ReadAll(std::shared_ptr* table, return Table::FromRecordBatches(schema, std::move(batches)).Value(table); } -struct ClientRpc { - grpc::ClientContext context; - - explicit ClientRpc(const FlightCallOptions& options) { - if (options.timeout.count() >= 0) { - std::chrono::system_clock::time_point deadline = - std::chrono::time_point_cast( - std::chrono::system_clock::now() + options.timeout); - context.set_deadline(deadline); - } - for (auto header : options.headers) { - context.AddMetadata(header.first, header.second); - } - } - - /// \brief Add an auth token via an auth handler - Status SetToken(ClientAuthHandler* auth_handler) { - if (auth_handler) { - std::string token; - RETURN_NOT_OK(auth_handler->GetToken(&token)); - context.AddMetadata(internal::kGrpcAuthHeader, token); - } - return Status::OK(); - } -}; - -/// Helper that manages Finish() of a gRPC stream. -/// -/// When we encounter an error (e.g. could not decode an IPC message), -/// we want to provide both the client-side error context and any -/// available server-side context. This helper helps wrap up that -/// logic. -/// -/// This class protects the stream with a flag (so that Finish is -/// idempotent), and drains the read side (so that Finish won't hang). -/// -/// The template lets us abstract between DoGet/DoExchange and DoPut, -/// which respectively read internal::FlightData and pb::PutResult. -template -class FinishableStream { - public: - FinishableStream(std::shared_ptr rpc, std::shared_ptr stream) - : rpc_(rpc), stream_(stream), finished_(false), server_status_() {} - virtual ~FinishableStream() = default; - - /// \brief Get the underlying stream. - std::shared_ptr stream() const { return stream_; } - - /// \brief Finish the call, adding server context to the given status. - virtual Status Finish(Status st) { - if (finished_) { - return MergeStatus(std::move(st)); - } - - // Drain the read side, as otherwise gRPC Finish() will hang. We - // only call Finish() when the client closes the writer or the - // reader finishes, so it's OK to assume the client no longer - // wants to read and drain the read side. (If the client wants to - // indicate that it is done writing, but not done reading, it - // should use DoneWriting. - ReadT message; - while (internal::ReadPayload(stream_.get(), &message)) { - // Drain the read side to avoid gRPC hanging in Finish() - } - - server_status_ = internal::FromGrpcStatus(stream_->Finish(), &rpc_->context); - finished_ = true; - - return MergeStatus(std::move(st)); - } - - private: - Status MergeStatus(Status&& st) { - if (server_status_.ok()) { - return std::move(st); - } - return Status::FromDetailAndArgs( - server_status_.code(), server_status_.detail(), server_status_.message(), - ". Client context: ", st.ToString(), - ". gRPC client debug context: ", rpc_->context.debug_error_string()); - } - - std::shared_ptr rpc_; - std::shared_ptr stream_; - bool finished_; - Status server_status_; -}; - -/// Helper that manages \a Finish() of a read-write gRPC stream. +/// \brief An ipc::MessageReader adapting the Flight ClientDataStream interface. /// -/// This also calls \a WritesDone() and protects itself with a mutex -/// to enable sharing between the reader and writer. -template -class FinishableWritableStream : public FinishableStream { - public: - FinishableWritableStream(std::shared_ptr rpc, - std::shared_ptr read_mutex, - std::shared_ptr stream) - : FinishableStream(rpc, stream), - finish_mutex_(), - read_mutex_(read_mutex), - done_writing_(false) {} - virtual ~FinishableWritableStream() = default; - - /// \brief Indicate to gRPC that the write half of the stream is done. - Status DoneWriting() { - // This is only used by the writer side of a stream, so it need - // not be protected with a lock. - if (done_writing_) { - return Status::OK(); - } - done_writing_ = true; - if (!this->stream()->WritesDone()) { - // Error happened, try to close the stream to get more detailed info - return Finish(MakeFlightError(FlightStatusCode::Internal, - "Could not flush pending record batches")); - } - return Status::OK(); - } - - Status Finish(Status st) override { - // This may be used concurrently by reader/writer side of a - // stream, so it needs to be protected. - std::lock_guard guard(finish_mutex_); - - // Now that we're shared between a reader and writer, we need to - // protect ourselves from being called while there's an - // outstanding read. - std::unique_lock read_guard(*read_mutex_, std::try_to_lock); - if (!read_guard.owns_lock()) { - return MakeFlightError( - FlightStatusCode::Internal, - "Cannot close stream with pending read operation. Client context: " + - st.ToString()); - } - - // Try to flush pending writes. Don't use our WritesDone() to - // avoid recursion. - bool finished_writes = done_writing_ || this->stream()->WritesDone(); - done_writing_ = true; - - st = FinishableStream::Finish(std::move(st)); - - if (!finished_writes) { - return Status::FromDetailAndArgs( - st.code(), st.detail(), st.message(), - ". Additionally, could not finish writing record batches before closing"); - } - return st; - } - - private: - std::mutex finish_mutex_; - std::shared_ptr read_mutex_; - bool done_writing_; -}; - -class GrpcAddCallHeaders : public AddCallHeaders { - public: - explicit GrpcAddCallHeaders(std::multimap* metadata) - : metadata_(metadata) {} - ~GrpcAddCallHeaders() override = default; - - void AddHeader(const std::string& key, const std::string& value) override { - metadata_->insert(std::make_pair(key, value)); - } - - private: - std::multimap* metadata_; -}; - -class GrpcClientInterceptorAdapter : public grpc::experimental::Interceptor { - public: - explicit GrpcClientInterceptorAdapter( - std::vector> middleware) - : middleware_(std::move(middleware)), received_headers_(false) {} - - void Intercept(grpc::experimental::InterceptorBatchMethods* methods) { - using InterceptionHookPoints = grpc::experimental::InterceptionHookPoints; - if (methods->QueryInterceptionHookPoint( - InterceptionHookPoints::PRE_SEND_INITIAL_METADATA)) { - GrpcAddCallHeaders add_headers(methods->GetSendInitialMetadata()); - for (const auto& middleware : middleware_) { - middleware->SendingHeaders(&add_headers); - } - } - - if (methods->QueryInterceptionHookPoint( - InterceptionHookPoints::POST_RECV_INITIAL_METADATA)) { - if (!methods->GetRecvInitialMetadata()->empty()) { - ReceivedHeaders(*methods->GetRecvInitialMetadata()); - } - } - - if (methods->QueryInterceptionHookPoint(InterceptionHookPoints::POST_RECV_STATUS)) { - DCHECK_NE(nullptr, methods->GetRecvStatus()); - DCHECK_NE(nullptr, methods->GetRecvTrailingMetadata()); - ReceivedHeaders(*methods->GetRecvTrailingMetadata()); - const Status status = internal::FromGrpcStatus(*methods->GetRecvStatus()); - for (const auto& middleware : middleware_) { - middleware->CallCompleted(status); - } - } - - methods->Proceed(); - } - - private: - void ReceivedHeaders( - const std::multimap& metadata) { - if (received_headers_) { - return; - } - received_headers_ = true; - CallHeaders headers; - for (const auto& entry : metadata) { - headers.insert({util::string_view(entry.first.data(), entry.first.length()), - util::string_view(entry.second.data(), entry.second.length())}); - } - for (const auto& middleware : middleware_) { - middleware->ReceivedHeaders(headers); - } - } - - std::vector> middleware_; - // When communicating with a gRPC-Java server, the server may not - // send back headers if the call fails right away. Instead, the - // headers will be consolidated into the trailers. We don't want to - // call the client middleware callback twice, so instead track - // whether we saw headers - if not, then we need to check trailers. - bool received_headers_; -}; - -class GrpcClientInterceptorAdapterFactory - : public grpc::experimental::ClientInterceptorFactoryInterface { - public: - GrpcClientInterceptorAdapterFactory( - std::vector> middleware) - : middleware_(middleware) {} - - grpc::experimental::Interceptor* CreateClientInterceptor( - grpc::experimental::ClientRpcInfo* info) override { - std::vector> middleware; - - FlightMethod flight_method = FlightMethod::Invalid; - util::string_view method(info->method()); - if (method.ends_with("/Handshake")) { - flight_method = FlightMethod::Handshake; - } else if (method.ends_with("/ListFlights")) { - flight_method = FlightMethod::ListFlights; - } else if (method.ends_with("/GetFlightInfo")) { - flight_method = FlightMethod::GetFlightInfo; - } else if (method.ends_with("/GetSchema")) { - flight_method = FlightMethod::GetSchema; - } else if (method.ends_with("/DoGet")) { - flight_method = FlightMethod::DoGet; - } else if (method.ends_with("/DoPut")) { - flight_method = FlightMethod::DoPut; - } else if (method.ends_with("/DoExchange")) { - flight_method = FlightMethod::DoExchange; - } else if (method.ends_with("/DoAction")) { - flight_method = FlightMethod::DoAction; - } else if (method.ends_with("/ListActions")) { - flight_method = FlightMethod::ListActions; - } else { - DCHECK(false) << "Unknown Flight method: " << info->method(); - } - - const CallInfo flight_info{flight_method}; - for (auto& factory : middleware_) { - std::unique_ptr instance; - factory->StartCall(flight_info, &instance); - if (instance) { - middleware.push_back(std::move(instance)); - } - } - return new GrpcClientInterceptorAdapter(std::move(middleware)); - } - - private: - std::vector> middleware_; -}; - -class GrpcClientAuthSender : public ClientAuthSender { - public: - explicit GrpcClientAuthSender( - std::shared_ptr< - grpc::ClientReaderWriter> - stream) - : stream_(stream) {} - - Status Write(const std::string& token) override { - pb::HandshakeRequest response; - response.set_payload(token); - if (stream_->Write(response)) { - return Status::OK(); - } - return internal::FromGrpcStatus(stream_->Finish()); - } - - private: - std::shared_ptr> - stream_; -}; - -class GrpcClientAuthReader : public ClientAuthReader { - public: - explicit GrpcClientAuthReader( - std::shared_ptr< - grpc::ClientReaderWriter> - stream) - : stream_(stream) {} - - Status Read(std::string* token) override { - pb::HandshakeResponse request; - if (stream_->Read(&request)) { - *token = std::move(*request.mutable_payload()); - return Status::OK(); - } - return internal::FromGrpcStatus(stream_->Finish()); - } - - private: - std::shared_ptr> - stream_; -}; - -// An ipc::MessageReader that adapts any readable gRPC stream -// returning FlightData. -template -class GrpcIpcMessageReader : public ipc::MessageReader { +/// In order to support app_metadata and reuse the existing IPC +/// infrastructure, this takes a pointer to a buffer (provided by the +/// FlightStreamReader implementation) and upon reading a message, +/// updates that buffer with the one read from the server. +class IpcMessageReader : public ipc::MessageReader { public: - GrpcIpcMessageReader( - std::shared_ptr rpc, std::shared_ptr memory_manager, - std::shared_ptr read_mutex, - std::shared_ptr> stream, - std::shared_ptr>> - peekable_reader, - std::shared_ptr* app_metadata) - : rpc_(rpc), - memory_manager_(std::move(memory_manager)), - read_mutex_(read_mutex), - stream_(std::move(stream)), + IpcMessageReader(std::shared_ptr stream, + std::shared_ptr peekable_reader, + std::shared_ptr* app_metadata) + : stream_(std::move(stream)), peekable_reader_(peekable_reader), app_metadata_(app_metadata), stream_finished_(false) {} @@ -453,20 +101,11 @@ class GrpcIpcMessageReader : public ipc::MessageReader { return nullptr; } internal::FlightData* data; - { - auto guard = read_mutex_ ? std::unique_lock(*read_mutex_) - : std::unique_lock(); - peekable_reader_->Next(&data); - } + peekable_reader_->Next(&data); if (!data) { stream_finished_ = true; return stream_->Finish(Status::OK()); } - - if (ARROW_PREDICT_FALSE(!memory_manager_->is_cpu() && data->body)) { - ARROW_ASSIGN_OR_RAISE(data->body, Buffer::ViewOrCopy(data->body, memory_manager_)); - } - // Validate IPC message auto result = data->OpenMessage(); if (!result.ok()) { @@ -477,16 +116,10 @@ class GrpcIpcMessageReader : public ipc::MessageReader { } private: - // The RPC context lifetime must be coupled to the ClientReader - std::shared_ptr rpc_; + std::shared_ptr stream_; + std::shared_ptr peekable_reader_; std::shared_ptr memory_manager_; - // Guard reads with a mutex to prevent concurrent reads if the write - // side calls Finish(). Nullable as DoGet doesn't need this. - std::shared_ptr read_mutex_; - std::shared_ptr> stream_; - std::shared_ptr>> - peekable_reader_; - // A reference to GrpcStreamReader.app_metadata_. That class + // A reference to ClientStreamReader.app_metadata_. That class // can't access the app metadata because when it Peek()s the stream, // it may be looking at a dictionary batch, not the record // batch. Updating it here ensures the reader is always updated with @@ -495,34 +128,21 @@ class GrpcIpcMessageReader : public ipc::MessageReader { bool stream_finished_; }; -/// The implementation of the public-facing API for reading from a -/// FlightData stream -template -class GrpcStreamReader : public FlightStreamReader { +/// \brief A reader for any ClientDataStream. +class ClientStreamReader : public FlightStreamReader { public: - GrpcStreamReader(std::shared_ptr rpc, - std::shared_ptr memory_manager, - std::shared_ptr read_mutex, - const ipc::IpcReadOptions& options, StopToken stop_token, - std::shared_ptr> stream) - : rpc_(rpc), - memory_manager_(memory_manager ? std::move(memory_manager) - : CPUDevice::Instance()->default_memory_manager()), - read_mutex_(read_mutex), + ClientStreamReader(std::shared_ptr stream, + const ipc::IpcReadOptions& options, StopToken stop_token) + : stream_(std::move(stream)), options_(options), stop_token_(std::move(stop_token)), - stream_(stream), - peekable_reader_(new internal::PeekableFlightDataReader>( - stream->stream())), + peekable_reader_(new internal::PeekableFlightDataReader(stream_.get())), app_metadata_(nullptr) {} Status EnsureDataStarted() { if (!batch_reader_) { bool skipped_to_data = false; - { - auto guard = TakeGuard(); - skipped_to_data = peekable_reader_->SkipToData(); - } + skipped_to_data = peekable_reader_->SkipToData(); // peek() until we find the first data message; discard metadata if (!skipped_to_data) { return OverrideWithServerError(MakeFlightError( @@ -530,8 +150,7 @@ class GrpcStreamReader : public FlightStreamReader { } auto message_reader = std::unique_ptr( - new GrpcIpcMessageReader(rpc_, memory_manager_, read_mutex_, stream_, - peekable_reader_, &app_metadata_)); + new IpcMessageReader(stream_, peekable_reader_, &app_metadata_)); auto result = ipc::RecordBatchStreamReader::Open(std::move(message_reader), options_); RETURN_NOT_OK(OverrideWithServerError(std::move(result).Value(&batch_reader_))); @@ -544,10 +163,7 @@ class GrpcStreamReader : public FlightStreamReader { } Status Next(FlightStreamChunk* out) override { internal::FlightData* data; - { - auto guard = TakeGuard(); - peekable_reader_->Peek(&data); - } + peekable_reader_->Peek(&data); if (!data) { out->app_metadata = nullptr; out->data = nullptr; @@ -558,10 +174,7 @@ class GrpcStreamReader : public FlightStreamReader { // Metadata-only (data->metadata is the IPC header) out->app_metadata = data->app_metadata; out->data = nullptr; - { - auto guard = TakeGuard(); - peekable_reader_->Next(&data); - } + peekable_reader_->Next(&data); return Status::OK(); } @@ -570,7 +183,10 @@ class GrpcStreamReader : public FlightStreamReader { // Re-peek here since EnsureDataStarted() advances the stream return Next(out); } - RETURN_NOT_OK(batch_reader_->ReadNext(&out->data)); + auto status = batch_reader_->ReadNext(&out->data); + if (ARROW_PREDICT_FALSE(!status.ok())) { + return stream_->Finish(std::move(status)); + } out->app_metadata = std::move(app_metadata_); return Status::OK(); } @@ -596,14 +212,9 @@ class GrpcStreamReader : public FlightStreamReader { return ReadAll(table, stop_token_); } using FlightStreamReader::ReadAll; - void Cancel() override { rpc_->context.TryCancel(); } + void Cancel() override { stream_->TryCancel(); } private: - std::unique_lock TakeGuard() { - return read_mutex_ ? std::unique_lock(*read_mutex_) - : std::unique_lock(); - } - Status OverrideWithServerError(Status&& st) { if (st.ok()) { return std::move(st); @@ -611,183 +222,49 @@ class GrpcStreamReader : public FlightStreamReader { return stream_->Finish(std::move(st)); } - friend class GrpcIpcMessageReader; - std::shared_ptr rpc_; - std::shared_ptr memory_manager_; - // Guard reads with a lock to prevent Finish()/Close() from being - // called on the writer while the reader has a pending - // read. Nullable, as DoGet() doesn't need this. - std::shared_ptr read_mutex_; + std::shared_ptr stream_; ipc::IpcReadOptions options_; StopToken stop_token_; - std::shared_ptr> stream_; - std::shared_ptr>> - peekable_reader_; + std::shared_ptr peekable_reader_; std::shared_ptr batch_reader_; std::shared_ptr app_metadata_; }; -// The next two classes implement writing to a FlightData stream. -// Similarly to the read side, we want to reuse the implementation of -// RecordBatchWriter. As a result, these two classes are intertwined -// in order to pass application metadata "through" RecordBatchWriter. -// In order to get application-specific metadata to the -// IpcPayloadWriter, DoPutPayloadWriter takes a pointer to -// GrpcStreamWriter. GrpcStreamWriter updates a metadata field on -// write; DoPutPayloadWriter reads that metadata field to determine -// what to write. - -template -class DoPutPayloadWriter; - -template -class GrpcStreamWriter : public FlightStreamWriter { - public: - ~GrpcStreamWriter() override = default; - - using GrpcStream = grpc::ClientReaderWriter; - - explicit GrpcStreamWriter( - const FlightDescriptor& descriptor, std::shared_ptr rpc, - int64_t write_size_limit_bytes, const ipc::IpcWriteOptions& options, - std::shared_ptr> writer) - : app_metadata_(nullptr), - batch_writer_(nullptr), - writer_(std::move(writer)), - rpc_(std::move(rpc)), - write_size_limit_bytes_(write_size_limit_bytes), - options_(options), - descriptor_(descriptor), - writer_closed_(false) {} - - static Status Open( - const FlightDescriptor& descriptor, std::shared_ptr schema, - const ipc::IpcWriteOptions& options, std::shared_ptr rpc, - int64_t write_size_limit_bytes, - std::shared_ptr> writer, - std::unique_ptr* out); +FlightMetadataReader::~FlightMetadataReader() = default; - Status CheckStarted() { - if (!batch_writer_) { - return Status::Invalid("Writer not initialized. Call Begin() with a schema."); - } - return Status::OK(); - } +class ClientMetadataReader : public FlightMetadataReader { + public: + explicit ClientMetadataReader(std::shared_ptr stream) + : stream_(std::move(stream)) {} - Status Begin(const std::shared_ptr& schema, - const ipc::IpcWriteOptions& options) override { - if (batch_writer_) { - return Status::Invalid("This writer has already been started."); + Status ReadMetadata(std::shared_ptr* out) override { + if (!stream_->ReadPutMetadata(out)) { + return stream_->Finish(Status::OK()); } - std::unique_ptr payload_writer( - new DoPutPayloadWriter( - descriptor_, std::move(rpc_), write_size_limit_bytes_, writer_, this)); - // XXX: this does not actually write the message to the stream. - // See Close(). - ARROW_ASSIGN_OR_RAISE(batch_writer_, ipc::internal::OpenRecordBatchWriter( - std::move(payload_writer), schema, options)); return Status::OK(); } - Status Begin(const std::shared_ptr& schema) override { - return Begin(schema, options_); - } - - Status WriteRecordBatch(const RecordBatch& batch) override { - RETURN_NOT_OK(CheckStarted()); - return WriteWithMetadata(batch, nullptr); - } - - Status WriteMetadata(std::shared_ptr app_metadata) override { - FlightPayload payload{}; - payload.app_metadata = app_metadata; - auto status = internal::WritePayload(payload, writer_->stream().get()); - if (status.IsIOError()) { - return writer_->Finish(MakeFlightError(FlightStatusCode::Internal, - "Could not write metadata to stream")); - } - return status; - } - - Status WriteWithMetadata(const RecordBatch& batch, - std::shared_ptr app_metadata) override { - RETURN_NOT_OK(CheckStarted()); - app_metadata_ = app_metadata; - return batch_writer_->WriteRecordBatch(batch); - } - - Status DoneWriting() override { - // Do not CheckStarted - DoneWriting applies to data and metadata - if (batch_writer_) { - // Close the writer if we have one; this will force it to flush any - // remaining data, before we close the write side of the stream. - writer_closed_ = true; - Status st = batch_writer_->Close(); - if (!st.ok()) { - return writer_->Finish(std::move(st)); - } - } - return writer_->DoneWriting(); - } - - Status Close() override { - // Do not CheckStarted - Close applies to data and metadata - if (batch_writer_ && !writer_closed_) { - // This is important! Close() calls - // IpcPayloadWriter::CheckStarted() which will force the initial - // schema message to be written to the stream. This is required - // to unstick the server, else the client and the server end up - // waiting for each other. This happens if the client never - // wrote anything before calling Close(). - writer_closed_ = true; - return writer_->Finish(batch_writer_->Close()); - } - return writer_->Finish(Status::OK()); - } - - ipc::WriteStats stats() const override { - ARROW_CHECK_NE(batch_writer_, nullptr); - return batch_writer_->stats(); - } - private: - friend class DoPutPayloadWriter; - std::shared_ptr app_metadata_; - std::unique_ptr batch_writer_; - std::shared_ptr> writer_; - - // Fields used to lazy-initialize the IpcPayloadWriter. They're - // invalid once Begin() is called. - std::shared_ptr rpc_; - int64_t write_size_limit_bytes_; - ipc::IpcWriteOptions options_; - FlightDescriptor descriptor_; - bool writer_closed_; + std::shared_ptr stream_; }; -/// A IpcPayloadWriter implementation that writes to a gRPC stream of -/// FlightData messages. -template -class DoPutPayloadWriter : public ipc::internal::IpcPayloadWriter { +/// \brief An IpcPayloadWriter for any ClientDataStream. +/// +/// To support app_metadata and reuse the existing IPC infrastructure, +/// this takes a pointer to a buffer to be combined with the IPC +/// payload when writing a Flight payload. +class ClientPutPayloadWriter : public ipc::internal::IpcPayloadWriter { public: - using GrpcStream = grpc::ClientReaderWriter; - - DoPutPayloadWriter( - const FlightDescriptor& descriptor, std::shared_ptr rpc, - int64_t write_size_limit_bytes, - std::shared_ptr> writer, - GrpcStreamWriter* stream_writer) - : descriptor_(descriptor), - rpc_(rpc), + ClientPutPayloadWriter(std::shared_ptr stream, + FlightDescriptor descriptor, int64_t write_size_limit_bytes, + std::shared_ptr* app_metadata) + : descriptor_(std::move(descriptor)), write_size_limit_bytes_(write_size_limit_bytes), - writer_(std::move(writer)), - first_payload_(true), - stream_writer_(stream_writer) {} - - ~DoPutPayloadWriter() override = default; + stream_(std::move(stream)), + app_metadata_(app_metadata), + first_payload_(true) {} Status Start() override { return Status::OK(); } - Status WritePayload(const ipc::IpcPayload& ipc_payload) override { FlightPayload payload; payload.ipc_message = ipc_payload; @@ -800,9 +277,8 @@ class DoPutPayloadWriter : public ipc::internal::IpcPayloadWriter { // Write the descriptor to begin with RETURN_NOT_OK(internal::ToPayload(descriptor_, &payload.descriptor)); first_payload_ = false; - } else if (ipc_payload.type == ipc::MessageType::RECORD_BATCH && - stream_writer_->app_metadata_) { - payload.app_metadata = std::move(stream_writer_->app_metadata_); + } else if (ipc_payload.type == ipc::MessageType::RECORD_BATCH && *app_metadata_) { + payload.app_metadata = std::move(*app_metadata_); } if (write_size_limit_bytes_ > 0) { @@ -821,501 +297,168 @@ class DoPutPayloadWriter : public ipc::internal::IpcPayloadWriter { std::make_shared(write_size_limit_bytes_, size)); } } - - auto status = internal::WritePayload(payload, writer_->stream().get()); - if (status.IsIOError()) { - return writer_->Finish(MakeFlightError(FlightStatusCode::Internal, - "Could not write record batch to stream")); + ARROW_ASSIGN_OR_RAISE(auto success, stream_->WriteData(payload)); + if (!success) { + return MakeFlightError( + FlightStatusCode::Internal, + "Could not write record batch to stream (server disconnect?)"); } - return status; + return Status::OK(); } - Status Close() override { - // Closing is handled one layer up in GrpcStreamWriter::Close + // Closing is handled one layer up in ClientStreamWriter::Close return Status::OK(); } - protected: + private: const FlightDescriptor descriptor_; - std::shared_ptr rpc_; - int64_t write_size_limit_bytes_; - std::shared_ptr> writer_; + const int64_t write_size_limit_bytes_; + std::shared_ptr stream_; + std::shared_ptr* app_metadata_; bool first_payload_; - GrpcStreamWriter* stream_writer_; }; -template -Status GrpcStreamWriter::Open( - const FlightDescriptor& descriptor, - std::shared_ptr schema, // this schema is nullable - const ipc::IpcWriteOptions& options, std::shared_ptr rpc, - int64_t write_size_limit_bytes, - std::shared_ptr> writer, - std::unique_ptr* out) { - std::unique_ptr> instance( - new GrpcStreamWriter( - descriptor, std::move(rpc), write_size_limit_bytes, options, writer)); - if (schema) { - // The schema was provided (DoPut). Eagerly write the schema and - // descriptor together as the first message. - RETURN_NOT_OK(instance->Begin(schema, options)); - } else { - // The schema was not provided (DoExchange). Eagerly write just - // the descriptor as the first message. Note that if the client - // calls Begin() to send data, we'll send a redundant descriptor. - FlightPayload payload{}; - RETURN_NOT_OK(internal::ToPayload(descriptor, &payload.descriptor)); - auto status = internal::WritePayload(payload, instance->writer_->stream().get()); - if (status.IsIOError()) { - return writer->Finish(MakeFlightError(FlightStatusCode::Internal, - "Could not write descriptor to stream")); - } - RETURN_NOT_OK(status); - } - *out = std::move(instance); - return Status::OK(); -} - -FlightMetadataReader::~FlightMetadataReader() = default; - -class GrpcMetadataReader : public FlightMetadataReader { +class ClientStreamWriter : public FlightStreamWriter { public: - explicit GrpcMetadataReader( - std::shared_ptr> reader, - std::shared_ptr read_mutex) - : reader_(reader), read_mutex_(read_mutex) {} + explicit ClientStreamWriter(std::shared_ptr stream, + const ipc::IpcWriteOptions& options, + int64_t write_size_limit_bytes, FlightDescriptor descriptor) + : stream_(std::move(stream)), + batch_writer_(nullptr), + app_metadata_(nullptr), + writer_closed_(false), + closed_(false), + write_options_(options), + write_size_limit_bytes_(write_size_limit_bytes), + descriptor_(std::move(descriptor)) {} - Status ReadMetadata(std::shared_ptr* out) override { - std::lock_guard guard(*read_mutex_); - pb::PutResult message; - if (reader_->Read(&message)) { - *out = Buffer::FromString(std::move(*message.mutable_app_metadata())); - } else { - // Stream finished - *out = nullptr; + ~ClientStreamWriter() { + if (closed_) return; + // Implicitly Close() on destruction, though it's best if the + // application closes explicitly + auto status = Close(); + if (!status.ok()) { + ARROW_LOG(WARNING) << "Close() failed: " << status.ToString(); } - return Status::OK(); } - private: - std::shared_ptr> reader_; - std::shared_ptr read_mutex_; -}; - -namespace { -// Dummy self-signed certificate to be used because TlsCredentials -// requires root CA certs, even if you are skipping server -// verification. -#if defined(GRPC_NAMESPACE_FOR_TLS_CREDENTIALS_OPTIONS) -constexpr char kDummyRootCert[] = - "-----BEGIN CERTIFICATE-----\n" - "MIICwzCCAaugAwIBAgIJAM12DOkcaqrhMA0GCSqGSIb3DQEBBQUAMBQxEjAQBgNV\n" - "BAMTCWxvY2FsaG9zdDAeFw0yMDEwMDcwODIyNDFaFw0zMDEwMDUwODIyNDFaMBQx\n" - "EjAQBgNVBAMTCWxvY2FsaG9zdDCCASIwDQYJKoZIhvcNAQEBBQADggEPADCCAQoC\n" - "ggEBALjJ8KPEpF0P4GjMPrJhjIBHUL0AX9E4oWdgJRCSFkPKKEWzQabTQBikMOhI\n" - "W4VvBMaHEBuECE5OEyrzDRiAO354I4F4JbBfxMOY8NIW0uWD6THWm2KkCzoZRIPW\n" - "yZL6dN+mK6cEH+YvbNuy5ZQGNjGG43tyiXdOCAc4AI9POeTtjdMpbbpR2VY4Ad/E\n" - "oTEiS3gNnN7WIAdgMhCJxjzvPwKszV3f7pwuTHzFMsuHLKr6JeaVUYfbi4DxxC8Z\n" - "k6PF6dLlLf3ngTSLBJyaXP1BhKMvz0TaMK3F0y2OGwHM9J8np2zWjTlNVEzffQZx\n" - "SWMOQManlJGs60xYx9KCPJMZZsMCAwEAAaMYMBYwFAYDVR0RBA0wC4IJbG9jYWxo\n" - "b3N0MA0GCSqGSIb3DQEBBQUAA4IBAQC0LrmbcNKgO+D50d/wOc+vhi9K04EZh8bg\n" - "WYAK1kLOT4eShbzqWGV/1EggY4muQ6ypSELCLuSsg88kVtFQIeRilA6bHFqQSj6t\n" - "sqgh2cWsMwyllCtmX6Maf3CLb2ZdoJlqUwdiBdrbIbuyeAZj3QweCtLKGSQzGDyI\n" - "KH7G8nC5d0IoRPiCMB6RnMMKsrhviuCdWbAFHop7Ff36JaOJ8iRa2sSf2OXE8j/5\n" - "obCXCUvYHf4Zw27JcM2AnnQI9VJLnYxis83TysC5s2Z7t0OYNS9kFmtXQbUNlmpS\n" - "doQ/Eu47vWX7S0TXeGziGtbAOKxbHE0BGGPDOAB/jGW/JVbeTiXY\n" - "-----END CERTIFICATE-----\n"; -#endif -} // namespace -class FlightClient::FlightClientImpl { - public: - Status Connect(const Location& location, const FlightClientOptions& options) { - const std::string& scheme = location.scheme(); - - std::stringstream grpc_uri; - std::shared_ptr creds; - if (scheme == kSchemeGrpc || scheme == kSchemeGrpcTcp || scheme == kSchemeGrpcTls) { - grpc_uri << arrow::internal::UriEncodeHost(location.uri_->host()) << ':' - << location.uri_->port_text(); - - if (scheme == kSchemeGrpcTls) { - if (options.disable_server_verification) { -#if defined(GRPC_NAMESPACE_FOR_TLS_CREDENTIALS_OPTIONS) - namespace ge = GRPC_NAMESPACE_FOR_TLS_CREDENTIALS_OPTIONS; - -#if defined(GRPC_USE_CERTIFICATE_VERIFIER) - // gRPC >= 1.43 - class NoOpCertificateVerifier : public ge::ExternalCertificateVerifier { - public: - bool Verify(ge::TlsCustomVerificationCheckRequest*, - std::function, - grpc::Status* sync_status) override { - *sync_status = grpc::Status::OK; - return true; // Check done synchronously - } - void Cancel(ge::TlsCustomVerificationCheckRequest*) override {} - }; - auto cert_verifier = - ge::ExternalCertificateVerifier::Create(); - -#else // defined(GRPC_USE_CERTIFICATE_VERIFIER) - // gRPC < 1.43 - // A callback to supply to TlsCredentialsOptions that accepts any server - // arguments. - struct NoOpTlsAuthorizationCheck - : public ge::TlsServerAuthorizationCheckInterface { - int Schedule(ge::TlsServerAuthorizationCheckArg* arg) override { - arg->set_success(1); - arg->set_status(GRPC_STATUS_OK); - return 0; - } - }; - auto server_authorization_check = std::make_shared(); - noop_auth_check_ = std::make_shared( - server_authorization_check); -#endif // defined(GRPC_USE_CERTIFICATE_VERIFIER) - -#if defined(GRPC_USE_TLS_CHANNEL_CREDENTIALS_OPTIONS) - auto certificate_provider = - std::make_shared( - kDummyRootCert); -#if defined(GRPC_USE_TLS_CHANNEL_CREDENTIALS_OPTIONS_ROOT_CERTS) - grpc::experimental::TlsChannelCredentialsOptions tls_options( - certificate_provider); -#else // defined(GRPC_USE_TLS_CHANNEL_CREDENTIALS_OPTIONS_ROOT_CERTS) - // While gRPC >= 1.36 does not require a root cert (it has a default) - // in practice the path it hardcodes is broken. See grpc/grpc#21655. - grpc::experimental::TlsChannelCredentialsOptions tls_options; - tls_options.set_certificate_provider(certificate_provider); -#endif // defined(GRPC_USE_TLS_CHANNEL_CREDENTIALS_OPTIONS_ROOT_CERTS) - tls_options.watch_root_certs(); - tls_options.set_root_cert_name("dummy"); -#if defined(GRPC_USE_CERTIFICATE_VERIFIER) - tls_options.set_certificate_verifier(std::move(cert_verifier)); - tls_options.set_check_call_host(false); - tls_options.set_verify_server_certs(false); -#else // defined(GRPC_USE_CERTIFICATE_VERIFIER) - tls_options.set_server_verification_option( - grpc_tls_server_verification_option::GRPC_TLS_SKIP_ALL_SERVER_VERIFICATION); - tls_options.set_server_authorization_check_config(noop_auth_check_); -#endif // defined(GRPC_USE_CERTIFICATE_VERIFIER) -#elif defined(GRPC_NAMESPACE_FOR_TLS_CREDENTIALS_OPTIONS) - // continues defined(GRPC_USE_TLS_CHANNEL_CREDENTIALS_OPTIONS) - auto materials_config = std::make_shared(); - materials_config->set_pem_root_certs(kDummyRootCert); - ge::TlsCredentialsOptions tls_options( - GRPC_SSL_DONT_REQUEST_CLIENT_CERTIFICATE, - GRPC_TLS_SKIP_ALL_SERVER_VERIFICATION, materials_config, - std::shared_ptr(), noop_auth_check_); -#endif // defined(GRPC_USE_TLS_CHANNEL_CREDENTIALS_OPTIONS) - creds = ge::TlsCredentials(tls_options); -#else // defined(GRPC_NAMESPACE_FOR_TLS_CREDENTIALS_OPTIONS) - return Status::NotImplemented( - "Using encryption with server verification disabled is unsupported. " - "Please use a release of Arrow Flight built with gRPC 1.27 or higher."); -#endif // defined(GRPC_NAMESPACE_FOR_TLS_CREDENTIALS_OPTIONS) - } else { - grpc::SslCredentialsOptions ssl_options; - if (!options.tls_root_certs.empty()) { - ssl_options.pem_root_certs = options.tls_root_certs; - } - if (!options.cert_chain.empty()) { - ssl_options.pem_cert_chain = options.cert_chain; - } - if (!options.private_key.empty()) { - ssl_options.pem_private_key = options.private_key; - } - creds = grpc::SslCredentials(ssl_options); - } - } else { - creds = grpc::InsecureChannelCredentials(); - } - } else if (scheme == kSchemeGrpcUnix) { - grpc_uri << "unix://" << location.uri_->path(); - creds = grpc::InsecureChannelCredentials(); - } else { - return Status::NotImplemented("Flight scheme " + scheme + " is not supported."); + Status Begin(const std::shared_ptr& schema, + const ipc::IpcWriteOptions& options) override { + if (batch_writer_) { + return Status::Invalid("This writer has already been started."); } + std::unique_ptr payload_writer( + new ClientPutPayloadWriter(stream_, std::move(descriptor_), + write_size_limit_bytes_, &app_metadata_)); + // XXX: this does not actually write the message to the stream. + // See Close(). + ARROW_ASSIGN_OR_RAISE(batch_writer_, ipc::internal::OpenRecordBatchWriter( + std::move(payload_writer), schema, options)); + return Status::OK(); + } - grpc::ChannelArguments args; - // We can't set the same config value twice, so for values where - // we want to set defaults, keep them in a map and update them; - // then update them all at once - std::unordered_map default_args; - // Try to reconnect quickly at first, in case the server is still starting up - default_args[GRPC_ARG_INITIAL_RECONNECT_BACKOFF_MS] = 100; - // Receive messages of any size - default_args[GRPC_ARG_MAX_RECEIVE_MESSAGE_LENGTH] = -1; - // Setting this arg enables each client to open it's own TCP connection to server, - // not sharing one single connection, which becomes bottleneck under high load. - default_args[GRPC_ARG_USE_LOCAL_SUBCHANNEL_POOL] = 1; - - if (options.override_hostname != "") { - args.SetSslTargetNameOverride(options.override_hostname); - } + Status Begin(const std::shared_ptr& schema) override { + return Begin(schema, write_options_); + } - // Allow setting generic gRPC options. - for (const auto& arg : options.generic_options) { - if (util::holds_alternative(arg.second)) { - default_args[arg.first] = util::get(arg.second); - } else if (util::holds_alternative(arg.second)) { - args.SetString(arg.first, util::get(arg.second)); - } - // Otherwise unimplemented - } - for (const auto& pair : default_args) { - args.SetInt(pair.first, pair.second); + // Overload used by FlightClient::DoExchange + Status Begin() { + FlightPayload payload; + RETURN_NOT_OK(internal::ToPayload(descriptor_, &payload.descriptor)); + ARROW_ASSIGN_OR_RAISE(auto success, stream_->WriteData(payload)); + if (!success) { + return MakeFlightError( + FlightStatusCode::Internal, + "Could not write record batch to stream (server disconnect?)"); } - - std::vector> - interceptors; - interceptors.emplace_back( - new GrpcClientInterceptorAdapterFactory(std::move(options.middleware))); - - stub_ = pb::FlightService::NewStub( - grpc::experimental::CreateCustomChannelWithInterceptors( - grpc_uri.str(), creds, args, std::move(interceptors))); - - write_size_limit_bytes_ = options.write_size_limit_bytes; return Status::OK(); } - Status Authenticate(const FlightCallOptions& options, - std::unique_ptr auth_handler) { - auth_handler_ = std::move(auth_handler); - ClientRpc rpc(options); - std::shared_ptr> - stream = stub_->Handshake(&rpc.context); - GrpcClientAuthSender outgoing{stream}; - GrpcClientAuthReader incoming{stream}; - RETURN_NOT_OK(auth_handler_->Authenticate(&outgoing, &incoming)); - // Explicitly close our side of the connection - bool finished_writes = stream->WritesDone(); - RETURN_NOT_OK(internal::FromGrpcStatus(stream->Finish(), &rpc.context)); - if (!finished_writes) { - return MakeFlightError(FlightStatusCode::Internal, - "Could not finish writing before closing"); - } - return Status::OK(); + Status WriteRecordBatch(const RecordBatch& batch) override { + RETURN_NOT_OK(CheckStarted()); + return WriteWithMetadata(batch, nullptr); } - arrow::Result> AuthenticateBasicToken( - const FlightCallOptions& options, const std::string& username, - const std::string& password) { - // Add basic auth headers to outgoing headers. - ClientRpc rpc(options); - internal::AddBasicAuthHeaders(&rpc.context, username, password); - - std::shared_ptr> - stream = stub_->Handshake(&rpc.context); - GrpcClientAuthSender outgoing{stream}; - GrpcClientAuthReader incoming{stream}; - - // Explicitly close our side of the connection. - bool finished_writes = stream->WritesDone(); - RETURN_NOT_OK(internal::FromGrpcStatus(stream->Finish(), &rpc.context)); - if (!finished_writes) { + Status WriteMetadata(std::shared_ptr app_metadata) override { + FlightPayload payload; + payload.app_metadata = app_metadata; + ARROW_ASSIGN_OR_RAISE(auto success, stream_->WriteData(payload)); + if (!success) { return MakeFlightError(FlightStatusCode::Internal, - "Could not finish writing before closing"); + "Could not write metadata to stream (server disconnect?)"); } - - // Grab bearer token from incoming headers. - return internal::GetBearerTokenHeader(rpc.context); + return Status::OK(); } - Status ListFlights(const FlightCallOptions& options, const Criteria& criteria, - std::unique_ptr* listing) { - pb::Criteria pb_criteria; - RETURN_NOT_OK(internal::ToProto(criteria, &pb_criteria)); - - ClientRpc rpc(options); - RETURN_NOT_OK(rpc.SetToken(auth_handler_.get())); - std::unique_ptr> stream( - stub_->ListFlights(&rpc.context, pb_criteria)); - - std::vector flights; - - pb::FlightInfo pb_info; - while (!options.stop_token.IsStopRequested() && stream->Read(&pb_info)) { - FlightInfo::Data info_data; - RETURN_NOT_OK(internal::FromProto(pb_info, &info_data)); - flights.emplace_back(std::move(info_data)); - } - if (options.stop_token.IsStopRequested()) rpc.context.TryCancel(); - RETURN_NOT_OK(options.stop_token.Poll()); - listing->reset(new SimpleFlightListing(std::move(flights))); - return internal::FromGrpcStatus(stream->Finish(), &rpc.context); + Status WriteWithMetadata(const RecordBatch& batch, + std::shared_ptr app_metadata) override { + RETURN_NOT_OK(CheckStarted()); + app_metadata_ = app_metadata; + return batch_writer_->WriteRecordBatch(batch); } - Status DoAction(const FlightCallOptions& options, const Action& action, - std::unique_ptr* results) { - pb::Action pb_action; - RETURN_NOT_OK(internal::ToProto(action, &pb_action)); - - ClientRpc rpc(options); - RETURN_NOT_OK(rpc.SetToken(auth_handler_.get())); - std::unique_ptr> stream( - stub_->DoAction(&rpc.context, pb_action)); - - pb::Result pb_result; - - std::vector materialized_results; - while (!options.stop_token.IsStopRequested() && stream->Read(&pb_result)) { - Result result; - RETURN_NOT_OK(internal::FromProto(pb_result, &result)); - materialized_results.emplace_back(std::move(result)); + Status DoneWriting() override { + // Do not CheckStarted - DoneWriting applies to data and metadata + if (batch_writer_) { + // Close the writer if we have one; this will force it to flush any + // remaining data, before we close the write side of the stream. + writer_closed_ = true; + Status st = batch_writer_->Close(); + if (!st.ok()) { + return stream_->Finish(std::move(st)); + } } - if (options.stop_token.IsStopRequested()) rpc.context.TryCancel(); - RETURN_NOT_OK(options.stop_token.Poll()); - - *results = std::unique_ptr( - new SimpleResultStream(std::move(materialized_results))); - return internal::FromGrpcStatus(stream->Finish(), &rpc.context); + return stream_->WritesDone(); } - Status ListActions(const FlightCallOptions& options, std::vector* types) { - pb::Empty empty; - - ClientRpc rpc(options); - RETURN_NOT_OK(rpc.SetToken(auth_handler_.get())); - std::unique_ptr> stream( - stub_->ListActions(&rpc.context, empty)); - - pb::ActionType pb_type; - ActionType type; - while (!options.stop_token.IsStopRequested() && stream->Read(&pb_type)) { - RETURN_NOT_OK(internal::FromProto(pb_type, &type)); - types->emplace_back(std::move(type)); + Status Close() override { + // Do not CheckStarted - Close applies to data and metadata + if (!closed_) { + closed_ = true; + if (batch_writer_ && !writer_closed_) { + // This is important! Close() calls + // IpcPayloadWriter::CheckStarted() which will force the initial + // schema message to be written to the stream. This is required + // to unstick the server, else the client and the server end up + // waiting for each other. This happens if the client never + // wrote anything before calling Close(). + writer_closed_ = true; + final_status_ = stream_->Finish(batch_writer_->Close()); + } else { + final_status_ = stream_->Finish(Status::OK()); + } } - if (options.stop_token.IsStopRequested()) rpc.context.TryCancel(); - RETURN_NOT_OK(options.stop_token.Poll()); - return internal::FromGrpcStatus(stream->Finish(), &rpc.context); + return final_status_; } - Status GetFlightInfo(const FlightCallOptions& options, - const FlightDescriptor& descriptor, - std::unique_ptr* info) { - pb::FlightDescriptor pb_descriptor; - pb::FlightInfo pb_response; - - RETURN_NOT_OK(internal::ToProto(descriptor, &pb_descriptor)); - - ClientRpc rpc(options); - RETURN_NOT_OK(rpc.SetToken(auth_handler_.get())); - Status s = internal::FromGrpcStatus( - stub_->GetFlightInfo(&rpc.context, pb_descriptor, &pb_response), &rpc.context); - RETURN_NOT_OK(s); - - FlightInfo::Data info_data; - RETURN_NOT_OK(internal::FromProto(pb_response, &info_data)); - info->reset(new FlightInfo(std::move(info_data))); - return Status::OK(); + ipc::WriteStats stats() const override { + ARROW_CHECK_NE(batch_writer_, nullptr); + return batch_writer_->stats(); } - Status GetSchema(const FlightCallOptions& options, const FlightDescriptor& descriptor, - std::unique_ptr* schema_result) { - pb::FlightDescriptor pb_descriptor; - pb::SchemaResult pb_response; - - RETURN_NOT_OK(internal::ToProto(descriptor, &pb_descriptor)); - - ClientRpc rpc(options); - RETURN_NOT_OK(rpc.SetToken(auth_handler_.get())); - Status s = internal::FromGrpcStatus( - stub_->GetSchema(&rpc.context, pb_descriptor, &pb_response), &rpc.context); - RETURN_NOT_OK(s); - - std::string str; - RETURN_NOT_OK(internal::FromProto(pb_response, &str)); - schema_result->reset(new SchemaResult(str)); + private: + Status CheckStarted() { + if (!batch_writer_) { + return Status::Invalid("Writer not initialized. Call Begin() with a schema."); + } return Status::OK(); } - Status DoGet(const FlightCallOptions& options, const Ticket& ticket, - std::unique_ptr* out) { - using StreamReader = GrpcStreamReader>; - pb::Ticket pb_ticket; - internal::ToProto(ticket, &pb_ticket); - - auto rpc = std::make_shared(options); - RETURN_NOT_OK(rpc->SetToken(auth_handler_.get())); - std::shared_ptr> stream = - stub_->DoGet(&rpc->context, pb_ticket); - auto finishable_stream = std::make_shared< - FinishableStream, internal::FlightData>>( - rpc, stream); - *out = std::unique_ptr( - new StreamReader(rpc, options.memory_manager, nullptr, options.read_options, - options.stop_token, finishable_stream)); - // Eagerly read the schema - return static_cast(out->get())->EnsureDataStarted(); - } - - Status DoPut(const FlightCallOptions& options, const FlightDescriptor& descriptor, - const std::shared_ptr& schema, - std::unique_ptr* out, - std::unique_ptr* reader) { - using GrpcStream = grpc::ClientReaderWriter; - using StreamWriter = GrpcStreamWriter; - - auto rpc = std::make_shared(options); - RETURN_NOT_OK(rpc->SetToken(auth_handler_.get())); - std::shared_ptr stream = stub_->DoPut(&rpc->context); - // The writer drains the reader on close to avoid hanging inside - // gRPC. Concurrent reads are unsafe, so a mutex protects this operation. - std::shared_ptr read_mutex = std::make_shared(); - auto finishable_stream = - std::make_shared>( - rpc, read_mutex, stream); - *reader = - std::unique_ptr(new GrpcMetadataReader(stream, read_mutex)); - return StreamWriter::Open(descriptor, schema, options.write_options, rpc, - write_size_limit_bytes_, finishable_stream, out); - } - - Status DoExchange(const FlightCallOptions& options, const FlightDescriptor& descriptor, - std::unique_ptr* writer, - std::unique_ptr* reader) { - using GrpcStream = grpc::ClientReaderWriter; - using StreamReader = GrpcStreamReader; - using StreamWriter = GrpcStreamWriter; - - auto rpc = std::make_shared(options); - RETURN_NOT_OK(rpc->SetToken(auth_handler_.get())); - std::shared_ptr> stream = - stub_->DoExchange(&rpc->context); - // The writer drains the reader on close to avoid hanging inside - // gRPC. Concurrent reads are unsafe, so a mutex protects this operation. - std::shared_ptr read_mutex = std::make_shared(); - auto finishable_stream = - std::make_shared>( - rpc, read_mutex, stream); - *reader = std::unique_ptr( - new StreamReader(rpc, options.memory_manager, read_mutex, options.read_options, - options.stop_token, finishable_stream)); - // Do not eagerly read the schema. There may be metadata messages - // before any data is sent, or data may not be sent at all. - return StreamWriter::Open(descriptor, nullptr, options.write_options, rpc, - write_size_limit_bytes_, finishable_stream, writer); - } + std::shared_ptr stream_; + std::unique_ptr batch_writer_; + std::shared_ptr app_metadata_; + bool writer_closed_; + bool closed_; + // Close() is expected to be idempotent + Status final_status_; - private: - std::unique_ptr stub_; - std::shared_ptr auth_handler_; -#if defined(GRPC_NAMESPACE_FOR_TLS_CREDENTIALS_OPTIONS) && \ - !defined(GRPC_USE_CERTIFICATE_VERIFIER) - // Scope the TlsServerAuthorizationCheckConfig to be at the class instance level, since - // it gets created during Connect() and needs to persist to DoAction() calls. gRPC does - // not correctly increase the reference count of this object: - // https://github.com/grpc/grpc/issues/22287 - std::shared_ptr< - GRPC_NAMESPACE_FOR_TLS_CREDENTIALS_OPTIONS::TlsServerAuthorizationCheckConfig> - noop_auth_check_; -#endif + // Temporary state to construct the IPC payload writer + ipc::IpcWriteOptions write_options_; int64_t write_size_limit_bytes_; + FlightDescriptor descriptor_; }; -FlightClient::FlightClient() { impl_.reset(new FlightClientImpl); } +FlightClient::FlightClient() : closed_(false), write_size_limit_bytes_(0) {} FlightClient::~FlightClient() { auto st = Close(); @@ -1332,47 +475,53 @@ Status FlightClient::Connect(const Location& location, Status FlightClient::Connect(const Location& location, const FlightClientOptions& options, std::unique_ptr* client) { + flight::transport::grpc::InitializeFlightGrpcClient(); + client->reset(new FlightClient); - return (*client)->impl_->Connect(location, options); + (*client)->write_size_limit_bytes_ = options.write_size_limit_bytes; + const auto scheme = location.scheme(); + ARROW_ASSIGN_OR_RAISE((*client)->transport_, + internal::GetDefaultTransportRegistry()->MakeClient(scheme)); + return (*client)->transport_->Init(options, location, *location.uri_); } Status FlightClient::Authenticate(const FlightCallOptions& options, std::unique_ptr auth_handler) { RETURN_NOT_OK(CheckOpen()); - return impl_->Authenticate(options, std::move(auth_handler)); + return transport_->Authenticate(options, std::move(auth_handler)); } arrow::Result> FlightClient::AuthenticateBasicToken( const FlightCallOptions& options, const std::string& username, const std::string& password) { RETURN_NOT_OK(CheckOpen()); - return impl_->AuthenticateBasicToken(options, username, password); + return transport_->AuthenticateBasicToken(options, username, password); } Status FlightClient::DoAction(const FlightCallOptions& options, const Action& action, std::unique_ptr* results) { RETURN_NOT_OK(CheckOpen()); - return impl_->DoAction(options, action, results); + return transport_->DoAction(options, action, results); } Status FlightClient::ListActions(const FlightCallOptions& options, std::vector* actions) { RETURN_NOT_OK(CheckOpen()); - return impl_->ListActions(options, actions); + return transport_->ListActions(options, actions); } Status FlightClient::GetFlightInfo(const FlightCallOptions& options, const FlightDescriptor& descriptor, std::unique_ptr* info) { RETURN_NOT_OK(CheckOpen()); - return impl_->GetFlightInfo(options, descriptor, info); + return transport_->GetFlightInfo(options, descriptor, info); } Status FlightClient::GetSchema(const FlightCallOptions& options, const FlightDescriptor& descriptor, std::unique_ptr* schema_result) { RETURN_NOT_OK(CheckOpen()); - return impl_->GetSchema(options, descriptor, schema_result); + return transport_->GetSchema(options, descriptor, schema_result); } Status FlightClient::ListFlights(std::unique_ptr* listing) { @@ -1384,13 +533,18 @@ Status FlightClient::ListFlights(const FlightCallOptions& options, const Criteria& criteria, std::unique_ptr* listing) { RETURN_NOT_OK(CheckOpen()); - return impl_->ListFlights(options, criteria, listing); + return transport_->ListFlights(options, criteria, listing); } Status FlightClient::DoGet(const FlightCallOptions& options, const Ticket& ticket, std::unique_ptr* stream) { RETURN_NOT_OK(CheckOpen()); - return impl_->DoGet(options, ticket, stream); + std::unique_ptr remote_stream; + RETURN_NOT_OK(transport_->DoGet(options, ticket, &remote_stream)); + *stream = std::unique_ptr(new ClientStreamReader( + std::move(remote_stream), options.read_options, options.stop_token)); + // Eagerly read the schema + return static_cast(stream->get())->EnsureDataStarted(); } Status FlightClient::DoPut(const FlightCallOptions& options, @@ -1399,7 +553,16 @@ Status FlightClient::DoPut(const FlightCallOptions& options, std::unique_ptr* stream, std::unique_ptr* reader) { RETURN_NOT_OK(CheckOpen()); - return impl_->DoPut(options, descriptor, schema, stream, reader); + std::unique_ptr remote_stream; + RETURN_NOT_OK(transport_->DoPut(options, &remote_stream)); + std::shared_ptr shared_stream = std::move(remote_stream); + *reader = + std::unique_ptr(new ClientMetadataReader(shared_stream)); + *stream = std::unique_ptr( + new ClientStreamWriter(std::move(shared_stream), options.write_options, + write_size_limit_bytes_, descriptor)); + RETURN_NOT_OK((*stream)->Begin(schema, options.write_options)); + return Status::OK(); } Status FlightClient::DoExchange(const FlightCallOptions& options, @@ -1407,18 +570,30 @@ Status FlightClient::DoExchange(const FlightCallOptions& options, std::unique_ptr* writer, std::unique_ptr* reader) { RETURN_NOT_OK(CheckOpen()); - return impl_->DoExchange(options, descriptor, writer, reader); + std::unique_ptr remote_stream; + RETURN_NOT_OK(transport_->DoExchange(options, &remote_stream)); + std::shared_ptr shared_stream = std::move(remote_stream); + *reader = std::unique_ptr( + new ClientStreamReader(shared_stream, options.read_options, options.stop_token)); + auto stream_writer = std::unique_ptr( + new ClientStreamWriter(std::move(shared_stream), options.write_options, + write_size_limit_bytes_, descriptor)); + RETURN_NOT_OK(stream_writer->Begin()); + *writer = std::move(stream_writer); + return Status::OK(); } Status FlightClient::Close() { - // gRPC doesn't offer an explicit shutdown - impl_.reset(nullptr); - // TODO(ARROW-15473): if we track ongoing RPCs, we can cancel them first + if (!closed_) { + closed_ = true; + RETURN_NOT_OK(transport_->Close()); + transport_.reset(nullptr); + } return Status::OK(); } Status FlightClient::CheckOpen() const { - if (!impl_) { + if (closed_) { return Status::Invalid("FlightClient is closed"); } return Status::OK(); diff --git a/cpp/src/arrow/flight/client.h b/cpp/src/arrow/flight/client.h index 15c67051931..c5ed60a6c42 100644 --- a/cpp/src/arrow/flight/client.h +++ b/cpp/src/arrow/flight/client.h @@ -15,8 +15,8 @@ // specific language governing permissions and limitations // under the License. -/// \brief Implementation of Flight RPC client using gRPC. API should be -// considered experimental for now +/// \brief Implementation of Flight RPC client. API should be +/// considered experimental for now #pragma once @@ -34,6 +34,7 @@ #include "arrow/util/cancel.h" #include "arrow/util/variant.h" +#include "arrow/flight/type_fwd.h" #include "arrow/flight/types.h" // IWYU pragma: keep #include "arrow/flight/visibility.h" @@ -44,10 +45,6 @@ class Schema; namespace flight { -class ClientAuthHandler; -class ClientMiddleware; -class ClientMiddlewareFactory; - /// \brief A duration type for Flight call timeouts. typedef std::chrono::duration TimeoutDuration; @@ -176,7 +173,7 @@ class ARROW_FLIGHT_EXPORT FlightMetadataReader { virtual Status ReadMetadata(std::shared_ptr* out) = 0; }; -/// \brief Client class for Arrow Flight RPC services (gRPC-based). +/// \brief Client class for Arrow Flight RPC services. /// API experimental for now class ARROW_FLIGHT_EXPORT FlightClient { public: @@ -336,8 +333,9 @@ class ARROW_FLIGHT_EXPORT FlightClient { private: FlightClient(); Status CheckOpen() const; - class FlightClientImpl; - std::unique_ptr impl_; + std::unique_ptr transport_; + bool closed_; + int64_t write_size_limit_bytes_; }; } // namespace flight diff --git a/cpp/src/arrow/flight/client_cookie_middleware.cc b/cpp/src/arrow/flight/client_cookie_middleware.cc index 145705e9715..063c8c7f585 100644 --- a/cpp/src/arrow/flight/client_cookie_middleware.cc +++ b/cpp/src/arrow/flight/client_cookie_middleware.cc @@ -16,7 +16,7 @@ // under the License. #include "arrow/flight/client_cookie_middleware.h" -#include "arrow/flight/client_header_internal.h" +#include "arrow/flight/cookie_internal.h" #include "arrow/util/value_parsing.h" namespace arrow { diff --git a/cpp/src/arrow/flight/client_header_internal.cc b/cpp/src/arrow/flight/cookie_internal.cc similarity index 81% rename from cpp/src/arrow/flight/client_header_internal.cc rename to cpp/src/arrow/flight/cookie_internal.cc index f7dfd54b646..1a15da92676 100644 --- a/cpp/src/arrow/flight/client_header_internal.cc +++ b/cpp/src/arrow/flight/cookie_internal.cc @@ -18,7 +18,7 @@ // Interfaces for defining middleware for Flight clients. Currently // experimental. -#include "arrow/flight/client_header_internal.h" +#include "arrow/flight/cookie_internal.h" #include "arrow/flight/client.h" #include "arrow/flight/client_auth.h" #include "arrow/flight/platform.h" @@ -41,9 +41,6 @@ #include #include -const char kAuthHeader[] = "authorization"; -const char kBearerPrefix[] = "Bearer "; -const char kBasicPrefix[] = "Basic "; const char kCookieExpiresFormat[] = "%d %m %Y %H:%M:%S"; namespace arrow { @@ -66,7 +63,7 @@ size_t CaseInsensitiveHash::operator()(const std::string& key) const { return std::hash{}(upper_string); } -Cookie Cookie::parse(const arrow::util::string_view& cookie_header_value) { +Cookie Cookie::Parse(const arrow::util::string_view& cookie_header_value) { // Parse the cookie string. If the cookie has an expiration, record it. // If the cookie has a max-age, calculate the current time + max_age and set that as // the expiration. @@ -256,7 +253,7 @@ void CookieCache::UpdateCachedCookies(const CallHeaders& incoming_headers) { for (auto it = header_values.first; it != header_values.second; ++it) { const util::string_view& value = it->second; - Cookie cookie = Cookie::parse(value); + Cookie cookie = Cookie::Parse(value); // Cache cookies regardless of whether or not they are expired. The server may have // explicitly sent a Set-Cookie to expire a cached cookie. @@ -284,53 +281,6 @@ std::string CookieCache::GetValidCookiesAsString() { return cookie_string; } -/// \brief Add base64 encoded credentials to the outbound headers. -/// -/// \param context Context object to add the headers to. -/// \param username Username to format and encode. -/// \param password Password to format and encode. -void AddBasicAuthHeaders(grpc::ClientContext* context, const std::string& username, - const std::string& password) { - const std::string credentials = username + ":" + password; - context->AddMetadata(kAuthHeader, - kBasicPrefix + arrow::util::base64_encode(credentials)); -} - -/// \brief Get bearer token from inbound headers. -/// -/// \param context Incoming ClientContext that contains headers. -/// \return Arrow result with bearer token (empty if no bearer token found). -arrow::Result> GetBearerTokenHeader( - grpc::ClientContext& context) { - // Lambda function to compare characters without case sensitivity. - auto char_compare = [](const char& char1, const char& char2) { - return (::toupper(char1) == ::toupper(char2)); - }; - - // Get the auth token if it exists, this can be in the initial or the trailing metadata. - auto trailing_headers = context.GetServerTrailingMetadata(); - auto initial_headers = context.GetServerInitialMetadata(); - auto bearer_iter = trailing_headers.find(kAuthHeader); - if (bearer_iter == trailing_headers.end()) { - bearer_iter = initial_headers.find(kAuthHeader); - if (bearer_iter == initial_headers.end()) { - return std::make_pair("", ""); - } - } - - // Check if the value of the auth token starts with the bearer prefix and latch it. - std::string bearer_val(bearer_iter->second.data(), bearer_iter->second.size()); - if (bearer_val.size() > strlen(kBearerPrefix)) { - if (std::equal(bearer_val.begin(), bearer_val.begin() + strlen(kBearerPrefix), - kBearerPrefix, char_compare)) { - return std::make_pair(kAuthHeader, bearer_val); - } - } - - // The server is not required to provide a bearer token. - return std::make_pair("", ""); -} - } // namespace internal } // namespace flight } // namespace arrow diff --git a/cpp/src/arrow/flight/client_header_internal.h b/cpp/src/arrow/flight/cookie_internal.h similarity index 77% rename from cpp/src/arrow/flight/client_header_internal.h rename to cpp/src/arrow/flight/cookie_internal.h index dd4498e0315..6b3af516bb6 100644 --- a/cpp/src/arrow/flight/client_header_internal.h +++ b/cpp/src/arrow/flight/cookie_internal.h @@ -15,29 +15,20 @@ // specific language governing permissions and limitations // under the License. -// Interfaces for defining middleware for Flight clients. Currently -// experimental. +// Utilities for working with HTTP cookies. #pragma once -#include "arrow/flight/client_middleware.h" -#include "arrow/result.h" -#include "arrow/util/optional.h" - -#ifdef GRPCPP_PP_INCLUDE -#include -#if defined(GRPC_NAMESPACE_FOR_TLS_CREDENTIALS_OPTIONS) -#include -#endif -#else -#include -#endif - #include -#include #include #include #include +#include + +#include "arrow/flight/client_middleware.h" +#include "arrow/result.h" +#include "arrow/util/optional.h" +#include "arrow/util/string_view.h" namespace arrow { namespace flight { @@ -63,7 +54,7 @@ class ARROW_FLIGHT_EXPORT Cookie { /// \brief Parse function to parse a cookie header value and return a Cookie object. /// /// \return Cookie object based on cookie header value. - static Cookie parse(const arrow::util::string_view& cookie_header_value); + static Cookie Parse(const arrow::util::string_view& cookie_header_value); /// \brief Parse a cookie header string beginning at the given start_pos and identify /// the name and value of an attribute. @@ -130,22 +121,6 @@ class ARROW_FLIGHT_EXPORT CookieCache { cookies; }; -/// \brief Add basic authentication header key value pair to context. -/// -/// \param context grpc context variable to add header to. -/// \param username username to encode into header. -/// \param password password to to encode into header. -void ARROW_FLIGHT_EXPORT AddBasicAuthHeaders(grpc::ClientContext* context, - const std::string& username, - const std::string& password); - -/// \brief Get bearer token from incoming headers. -/// -/// \param context context that contains headers which hold the bearer token. -/// \return Bearer token, parsed from headers, empty if one is not present. -arrow::Result> ARROW_FLIGHT_EXPORT -GetBearerTokenHeader(grpc::ClientContext& context); - } // namespace internal } // namespace flight } // namespace arrow diff --git a/cpp/src/arrow/flight/flight_cuda_test.cc b/cpp/src/arrow/flight/flight_cuda_test.cc deleted file mode 100644 index b812efd4677..00000000000 --- a/cpp/src/arrow/flight/flight_cuda_test.cc +++ /dev/null @@ -1,229 +0,0 @@ -// Licensed to the Apache Software Foundation (ASF) under one -// or more contributor license agreements. See the NOTICE file -// distributed with this work for additional information -// regarding copyright ownership. The ASF licenses this file -// to you under the Apache License, Version 2.0 (the -// "License"); you may not use this file except in compliance -// with the License. You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, -// software distributed under the License is distributed on an -// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -// KIND, either express or implied. See the License for the -// specific language governing permissions and limitations -// under the License. - -#include - -#include - -#include "arrow/array.h" -#include "arrow/flight/client.h" -#include "arrow/flight/server.h" -#include "arrow/flight/test_util.h" -#include "arrow/gpu/cuda_api.h" -#include "arrow/table.h" -#include "arrow/testing/gtest_util.h" - -namespace arrow { -namespace flight { - -Status CheckBuffersOnDevice(const Array& array, const Device& device) { - if (array.num_fields() != 0) { - return Status::NotImplemented("Nested arrays"); - } - for (const auto& buffer : array.data()->buffers) { - if (!buffer) continue; - if (!buffer->device()->Equals(device)) { - return Status::Invalid("Expected buffer on device: ", device.ToString(), - ". Was allocated on device: ", buffer->device()->ToString()); - } - } - return Status::OK(); -} - -// Copy a record batch to host memory. -arrow::Result> CopyBatchToHost(const RecordBatch& batch) { - auto mm = CPUDevice::Instance()->default_memory_manager(); - ArrayVector arrays; - for (const auto& column : batch.columns()) { - std::shared_ptr data = column->data()->Copy(); - if (data->child_data.size() != 0) { - return Status::NotImplemented("Nested arrays"); - } - - for (size_t i = 0; i < data->buffers.size(); i++) { - const auto& buffer = data->buffers[i]; - if (!buffer || buffer->is_cpu()) continue; - ARROW_ASSIGN_OR_RAISE(data->buffers[i], Buffer::Copy(buffer, mm)); - } - arrays.push_back(MakeArray(data)); - } - return RecordBatch::Make(batch.schema(), batch.num_rows(), std::move(arrays)); -} - -class CudaTestServer : public FlightServerBase { - public: - explicit CudaTestServer(std::shared_ptr device) : device_(std::move(device)) {} - - Status DoGet(const ServerCallContext&, const Ticket&, - std::unique_ptr* data_stream) override { - RecordBatchVector batches; - RETURN_NOT_OK(ExampleIntBatches(&batches)); - ARROW_ASSIGN_OR_RAISE(auto batch_reader, RecordBatchReader::Make(batches)); - *data_stream = std::unique_ptr(new RecordBatchStream(batch_reader)); - return Status::OK(); - } - - Status DoPut(const ServerCallContext&, std::unique_ptr reader, - std::unique_ptr writer) override { - RecordBatchVector batches; - RETURN_NOT_OK(reader->ReadAll(&batches)); - for (const auto& batch : batches) { - for (const auto& column : batch->columns()) { - RETURN_NOT_OK(CheckBuffersOnDevice(*column, *device_)); - } - } - return Status::OK(); - } - - Status DoExchange(const ServerCallContext& context, - std::unique_ptr reader, - std::unique_ptr writer) override { - FlightStreamChunk chunk; - bool begun = false; - while (true) { - RETURN_NOT_OK(reader->Next(&chunk)); - if (!chunk.data) break; - if (!begun) { - begun = true; - RETURN_NOT_OK(writer->Begin(chunk.data->schema())); - } - for (const auto& column : chunk.data->columns()) { - RETURN_NOT_OK(CheckBuffersOnDevice(*column, *device_)); - } - RETURN_NOT_OK(writer->WriteRecordBatch(*chunk.data)); - } - return Status::OK(); - } - - private: - std::shared_ptr device_; -}; - -class TestCuda : public ::testing::Test { - public: - void SetUp() { - ASSERT_OK_AND_ASSIGN(manager_, cuda::CudaDeviceManager::Instance()); - ASSERT_OK_AND_ASSIGN(device_, manager_->GetDevice(0)); - ASSERT_OK_AND_ASSIGN(context_, device_->GetContext()); - - ASSERT_OK(MakeServer( - &server_, &client_, - [this](FlightServerOptions* options) { - options->memory_manager = device_->default_memory_manager(); - return Status::OK(); - }, - [](FlightClientOptions* options) { return Status::OK(); }, device_)); - } - void TearDown() { ASSERT_OK(server_->Shutdown()); } - - protected: - cuda::CudaDeviceManager* manager_; - std::shared_ptr device_; - std::shared_ptr context_; - - std::unique_ptr client_; - std::unique_ptr server_; -}; - -TEST_F(TestCuda, DoGet) { - // Check that we can allocate the results of DoGet with a custom - // memory manager. - FlightCallOptions options; - options.memory_manager = device_->default_memory_manager(); - - Ticket ticket{""}; - std::unique_ptr stream; - ASSERT_OK(client_->DoGet(options, ticket, &stream)); - std::shared_ptr
table; - ASSERT_OK(stream->ReadAll(&table)); - - for (const auto& column : table->columns()) { - for (const auto& chunk : column->chunks()) { - ASSERT_OK(CheckBuffersOnDevice(*chunk, *device_)); - } - } -} - -TEST_F(TestCuda, DoPut) { - // Check that we can send a record batch containing references to - // GPU buffers. - RecordBatchVector batches; - ASSERT_OK(ExampleIntBatches(&batches)); - - std::unique_ptr writer; - std::unique_ptr reader; - auto descriptor = FlightDescriptor::Path({""}); - ASSERT_OK(client_->DoPut(descriptor, batches[0]->schema(), &writer, &reader)); - - ipc::DictionaryMemo memo; - for (const auto& batch : batches) { - ASSERT_OK_AND_ASSIGN(auto buffer, cuda::SerializeRecordBatch(*batch, context_.get())); - ASSERT_OK_AND_ASSIGN(auto cuda_batch, - cuda::ReadRecordBatch(batch->schema(), &memo, buffer)); - - for (const auto& column : cuda_batch->columns()) { - ASSERT_OK(CheckBuffersOnDevice(*column, *device_)); - } - - ASSERT_OK(writer->WriteRecordBatch(*cuda_batch)); - } - ASSERT_OK(writer->Close()); -} - -TEST_F(TestCuda, DoExchange) { - // Check that we can send a record batch containing references to - // GPU buffers. - FlightCallOptions options; - options.memory_manager = device_->default_memory_manager(); - - RecordBatchVector batches; - ASSERT_OK(ExampleIntBatches(&batches)); - - std::unique_ptr writer; - std::unique_ptr reader; - auto descriptor = FlightDescriptor::Path({""}); - ASSERT_OK(client_->DoExchange(options, descriptor, &writer, &reader)); - ASSERT_OK(writer->Begin(batches[0]->schema())); - - ipc::DictionaryMemo write_memo; - ipc::DictionaryMemo read_memo; - for (const auto& batch : batches) { - ASSERT_OK_AND_ASSIGN(auto buffer, cuda::SerializeRecordBatch(*batch, context_.get())); - ASSERT_OK_AND_ASSIGN(auto cuda_batch, - cuda::ReadRecordBatch(batch->schema(), &write_memo, buffer)); - - for (const auto& column : cuda_batch->columns()) { - ASSERT_OK(CheckBuffersOnDevice(*column, *device_)); - } - - ASSERT_OK(writer->WriteRecordBatch(*cuda_batch)); - - FlightStreamChunk chunk; - ASSERT_OK(reader->Next(&chunk)); - for (const auto& column : chunk.data->columns()) { - ASSERT_OK(CheckBuffersOnDevice(*column, *device_)); - } - - // Bounce record batch back to host memory - ASSERT_OK_AND_ASSIGN(auto host_batch, CopyBatchToHost(*chunk.data)); - AssertBatchesEqual(*batch, *host_batch); - } - ASSERT_OK(writer->Close()); -} - -} // namespace flight -} // namespace arrow diff --git a/cpp/src/arrow/flight/flight_internals_test.cc b/cpp/src/arrow/flight/flight_internals_test.cc index 525e19499c3..e4babc03352 100644 --- a/cpp/src/arrow/flight/flight_internals_test.cc +++ b/cpp/src/arrow/flight/flight_internals_test.cc @@ -22,16 +22,16 @@ #include #include "arrow/flight/client_cookie_middleware.h" -#include "arrow/flight/client_header_internal.h" #include "arrow/flight/client_middleware.h" +#include "arrow/flight/cookie_internal.h" +#include "arrow/flight/serialization_internal.h" +#include "arrow/flight/test_util.h" +#include "arrow/flight/transport/grpc/util_internal.h" #include "arrow/flight/types.h" #include "arrow/status.h" #include "arrow/testing/gtest_util.h" #include "arrow/util/string.h" -#include "arrow/flight/internal.h" -#include "arrow/flight/test_util.h" - namespace arrow { namespace flight { @@ -156,21 +156,23 @@ TEST(FlightTypes, RoundtripStatus) { ASSERT_NE(nullptr, detail); ASSERT_EQ(FlightStatusCode::Unavailable, detail->code()); - Status status = internal::FromGrpcStatus( - internal::ToGrpcStatus(Status::NotImplemented("Sentinel"))); + Status status = flight::transport::grpc::FromGrpcStatus( + flight::transport::grpc::ToGrpcStatus(Status::NotImplemented("Sentinel"))); ASSERT_TRUE(status.IsNotImplemented()); ASSERT_THAT(status.message(), ::testing::HasSubstr("Sentinel")); - status = internal::FromGrpcStatus(internal::ToGrpcStatus(Status::Invalid("Sentinel"))); + status = flight::transport::grpc::FromGrpcStatus( + flight::transport::grpc::ToGrpcStatus(Status::Invalid("Sentinel"))); ASSERT_TRUE(status.IsInvalid()); ASSERT_THAT(status.message(), ::testing::HasSubstr("Sentinel")); - status = internal::FromGrpcStatus(internal::ToGrpcStatus(Status::KeyError("Sentinel"))); + status = flight::transport::grpc::FromGrpcStatus( + flight::transport::grpc::ToGrpcStatus(Status::KeyError("Sentinel"))); ASSERT_TRUE(status.IsKeyError()); ASSERT_THAT(status.message(), ::testing::HasSubstr("Sentinel")); - status = - internal::FromGrpcStatus(internal::ToGrpcStatus(Status::AlreadyExists("Sentinel"))); + status = flight::transport::grpc::FromGrpcStatus( + flight::transport::grpc::ToGrpcStatus(Status::AlreadyExists("Sentinel"))); ASSERT_TRUE(status.IsAlreadyExists()); ASSERT_THAT(status.message(), ::testing::HasSubstr("Sentinel")); } @@ -300,18 +302,18 @@ TEST_F(TestCookieMiddleware, Expires) { class TestCookieParsing : public ::testing::Test { public: void VerifyParseCookie(const std::string& cookie_str, bool expired) { - internal::Cookie cookie = internal::Cookie::parse(cookie_str); + internal::Cookie cookie = internal::Cookie::Parse(cookie_str); EXPECT_EQ(expired, cookie.IsExpired()); } void VerifyCookieName(const std::string& cookie_str, const std::string& name) { - internal::Cookie cookie = internal::Cookie::parse(cookie_str); + internal::Cookie cookie = internal::Cookie::Parse(cookie_str); EXPECT_EQ(name, cookie.GetName()); } void VerifyCookieString(const std::string& cookie_str, const std::string& cookie_as_string) { - internal::Cookie cookie = internal::Cookie::parse(cookie_str); + internal::Cookie cookie = internal::Cookie::Parse(cookie_str); EXPECT_EQ(cookie_as_string, cookie.AsCookieString()); } diff --git a/cpp/src/arrow/flight/flight_test.cc b/cpp/src/arrow/flight/flight_test.cc index 72644448f2e..812bd080f18 100644 --- a/cpp/src/arrow/flight/flight_test.cc +++ b/cpp/src/arrow/flight/flight_test.cc @@ -44,8 +44,15 @@ #error "gRPC headers should not be in public API" #endif -#include "arrow/flight/internal.h" -#include "arrow/flight/middleware_internal.h" +#ifdef GRPCPP_PP_INCLUDE +#include +#else +#include +#endif + +// Include before test_util.h (boost), contains Windows fixes +#include "arrow/flight/platform.h" +#include "arrow/flight/serialization_internal.h" #include "arrow/flight/test_definitions.h" #include "arrow/flight/test_util.h" @@ -100,17 +107,16 @@ class GrpcDoPutTest : public DoPutTest { std::string transport() const override { return "grpc"; } }; TEST_F(GrpcDoPutTest, TestInts) { TestInts(); } -TEST_F(GrpcDoPutTest, TestDoPutFloats) { TestDoPutFloats(); } -TEST_F(GrpcDoPutTest, TestDoPutEmptyBatch) { TestDoPutEmptyBatch(); } -TEST_F(GrpcDoPutTest, TestDoPutDicts) { TestDoPutDicts(); } -TEST_F(GrpcDoPutTest, TestDoPutLargeBatch) { TestDoPutLargeBatch(); } -TEST_F(GrpcDoPutTest, TestDoPutSizeLimit) { TestDoPutSizeLimit(); } +TEST_F(GrpcDoPutTest, TestFloats) { TestFloats(); } +TEST_F(GrpcDoPutTest, TestEmptyBatch) { TestEmptyBatch(); } +TEST_F(GrpcDoPutTest, TestDicts) { TestDicts(); } +TEST_F(GrpcDoPutTest, TestLargeBatch) { TestLargeBatch(); } +TEST_F(GrpcDoPutTest, TestSizeLimit) { TestSizeLimit(); } class GrpcAppMetadataTest : public AppMetadataTest { protected: std::string transport() const override { return "grpc"; } }; - TEST_F(GrpcAppMetadataTest, TestDoGet) { TestDoGet(); } TEST_F(GrpcAppMetadataTest, TestDoGetDictionaries) { TestDoGetDictionaries(); } TEST_F(GrpcAppMetadataTest, TestDoPut) { TestDoPut(); } @@ -121,7 +127,6 @@ class GrpcIpcOptionsTest : public IpcOptionsTest { protected: std::string transport() const override { return "grpc"; } }; - TEST_F(GrpcIpcOptionsTest, TestDoGetReadOptions) { TestDoGetReadOptions(); } TEST_F(GrpcIpcOptionsTest, TestDoPutWriteOptions) { TestDoPutWriteOptions(); } TEST_F(GrpcIpcOptionsTest, TestDoExchangeClientWriteOptions) { @@ -134,6 +139,14 @@ TEST_F(GrpcIpcOptionsTest, TestDoExchangeServerWriteOptions) { TestDoExchangeServerWriteOptions(); } +class GrpcCudaDataTest : public CudaDataTest { + protected: + std::string transport() const override { return "grpc"; } +}; +TEST_F(GrpcCudaDataTest, TestDoGet) { TestDoGet(); } +TEST_F(GrpcCudaDataTest, TestDoPut) { TestDoPut(); } +TEST_F(GrpcCudaDataTest, TestDoExchange) { TestDoExchange(); } + //------------------------------------------------------------ // Ad-hoc gRPC-specific tests @@ -1606,10 +1619,12 @@ TEST_F(TestCancel, DoExchange) { std::shared_ptr
table; EXPECT_RAISES_WITH_MESSAGE_THAT(Cancelled, ::testing::HasSubstr("StopSource"), stream->ReadAll(&table)); + ARROW_UNUSED(writer->Close()); ASSERT_OK(client_->DoExchange(FlightDescriptor::Command(""), &writer, &stream)); EXPECT_RAISES_WITH_MESSAGE_THAT(Cancelled, ::testing::HasSubstr("StopSource"), stream->ReadAll(&table, options.stop_token)); + ARROW_UNUSED(writer->Close()); } } // namespace flight diff --git a/cpp/src/arrow/flight/integration_tests/CMakeLists.txt b/cpp/src/arrow/flight/integration_tests/CMakeLists.txt index 3a878d7f305..c7a8c4fe459 100644 --- a/cpp/src/arrow/flight/integration_tests/CMakeLists.txt +++ b/cpp/src/arrow/flight/integration_tests/CMakeLists.txt @@ -23,14 +23,10 @@ if(ARROW_FLIGHT_TEST_LINKAGE STREQUAL "static") arrow_flight_testing_static arrow_flight_sql_static ${ARROW_FLIGHT_STATIC_LINK_LIBS} - ${ARROW_TEST_LINK_LIBS}) + ${ARROW_FLIGHT_TEST_LINK_LIBS}) else() - set(ARROW_FLIGHT_TEST_LINK_LIBS - arrow_flight_shared - arrow_flight_testing_shared - arrow_flight_sql_shared - ${ARROW_TEST_LINK_LIBS} - ${ARROW_FLIGHT_SQL_TEST_LINK_LIBS}) + set(ARROW_FLIGHT_TEST_LINK_LIBS arrow_flight_shared arrow_flight_testing_shared + arrow_flight_sql_shared ${ARROW_FLIGHT_TEST_LINK_LIBS}) endif() add_executable(flight-test-integration-server test_integration_server.cc diff --git a/cpp/src/arrow/flight/integration_tests/test_integration_server.cc b/cpp/src/arrow/flight/integration_tests/test_integration_server.cc index 92b2241a872..dad76c6914d 100644 --- a/cpp/src/arrow/flight/integration_tests/test_integration_server.cc +++ b/cpp/src/arrow/flight/integration_tests/test_integration_server.cc @@ -35,7 +35,7 @@ #include "arrow/util/logging.h" #include "arrow/flight/integration_tests/test_integration.h" -#include "arrow/flight/internal.h" +#include "arrow/flight/serialization_internal.h" #include "arrow/flight/server.h" #include "arrow/flight/server_auth.h" #include "arrow/flight/test_util.h" diff --git a/cpp/src/arrow/flight/internal.cc b/cpp/src/arrow/flight/internal.cc deleted file mode 100644 index f27de208ac3..00000000000 --- a/cpp/src/arrow/flight/internal.cc +++ /dev/null @@ -1,514 +0,0 @@ -// Licensed to the Apache Software Foundation (ASF) under one -// or more contributor license agreements. See the NOTICE file -// distributed with this work for additional information -// regarding copyright ownership. The ASF licenses this file -// to you under the Apache License, Version 2.0 (the -// "License"); you may not use this file except in compliance -// with the License. You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, -// software distributed under the License is distributed on an -// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -// KIND, either express or implied. See the License for the -// specific language governing permissions and limitations -// under the License. - -#include "arrow/flight/internal.h" - -#include -#include -#include -#include -#include -#include - -#include "arrow/flight/platform.h" -#include "arrow/flight/protocol_internal.h" -#include "arrow/flight/types.h" - -#ifdef GRPCPP_PP_INCLUDE -#include -#else -#include -#endif - -#include "arrow/buffer.h" -#include "arrow/io/memory.h" -#include "arrow/ipc/reader.h" -#include "arrow/ipc/writer.h" -#include "arrow/memory_pool.h" -#include "arrow/status.h" -#include "arrow/util/logging.h" -#include "arrow/util/string_builder.h" - -namespace arrow { -namespace flight { -namespace internal { - -const char* kGrpcAuthHeader = "auth-token-bin"; -const char* kGrpcStatusCodeHeader = "x-arrow-status"; -const char* kGrpcStatusMessageHeader = "x-arrow-status-message-bin"; -const char* kGrpcStatusDetailHeader = "x-arrow-status-detail-bin"; -const char* kBinaryErrorDetailsKey = "grpc-status-details-bin"; - -static Status StatusCodeFromString(const grpc::string_ref& code_ref, StatusCode* code) { - // Bounce through std::string to get a proper null-terminated C string - const auto code_int = std::atoi(std::string(code_ref.data(), code_ref.size()).c_str()); - switch (code_int) { - case static_cast(StatusCode::OutOfMemory): - case static_cast(StatusCode::KeyError): - case static_cast(StatusCode::TypeError): - case static_cast(StatusCode::Invalid): - case static_cast(StatusCode::IOError): - case static_cast(StatusCode::CapacityError): - case static_cast(StatusCode::IndexError): - case static_cast(StatusCode::UnknownError): - case static_cast(StatusCode::NotImplemented): - case static_cast(StatusCode::SerializationError): - case static_cast(StatusCode::RError): - case static_cast(StatusCode::CodeGenError): - case static_cast(StatusCode::ExpressionValidationError): - case static_cast(StatusCode::ExecutionError): - case static_cast(StatusCode::AlreadyExists): { - *code = static_cast(code_int); - return Status::OK(); - } - default: - // Code is invalid - return Status::UnknownError("Unknown Arrow status code", code_ref); - } -} - -/// Try to extract a status from gRPC trailers. -/// Return Status::OK if found, an error otherwise. -static Status FromGrpcContext(const grpc::ClientContext& ctx, Status* status, - std::shared_ptr flightStatusDetail) { - const std::multimap& trailers = - ctx.GetServerTrailingMetadata(); - const auto code_val = trailers.find(kGrpcStatusCodeHeader); - if (code_val == trailers.end()) { - return Status::IOError("Status code header not found"); - } - - const grpc::string_ref code_ref = code_val->second; - StatusCode code = {}; - RETURN_NOT_OK(StatusCodeFromString(code_ref, &code)); - - const auto message_val = trailers.find(kGrpcStatusMessageHeader); - if (message_val == trailers.end()) { - return Status::IOError("Status message header not found"); - } - - const grpc::string_ref message_ref = message_val->second; - std::string message = std::string(message_ref.data(), message_ref.size()); - const auto detail_val = trailers.find(kGrpcStatusDetailHeader); - if (detail_val != trailers.end()) { - const grpc::string_ref detail_ref = detail_val->second; - message += ". Detail: "; - message += std::string(detail_ref.data(), detail_ref.size()); - } - const auto grpc_detail_val = trailers.find(kBinaryErrorDetailsKey); - if (grpc_detail_val != trailers.end()) { - const grpc::string_ref detail_ref = grpc_detail_val->second; - std::string bin_detail = std::string(detail_ref.data(), detail_ref.size()); - if (!flightStatusDetail) { - flightStatusDetail = - std::make_shared(FlightStatusCode::Internal); - } - flightStatusDetail->set_extra_info(bin_detail); - } - *status = Status(code, message, flightStatusDetail); - return Status::OK(); -} - -/// Convert a gRPC status to an Arrow status, ignoring any -/// implementation-defined headers that encode further detail. -static Status FromGrpcCode(const grpc::Status& grpc_status) { - switch (grpc_status.error_code()) { - case grpc::StatusCode::OK: - return Status::OK(); - case grpc::StatusCode::CANCELLED: - return Status::IOError("gRPC cancelled call, with message: ", - grpc_status.error_message()) - .WithDetail(std::make_shared(FlightStatusCode::Cancelled)); - case grpc::StatusCode::UNKNOWN: { - std::stringstream ss; - ss << "Flight RPC failed with message: " << grpc_status.error_message(); - return Status::UnknownError(ss.str()).WithDetail( - std::make_shared(FlightStatusCode::Failed)); - } - case grpc::StatusCode::INVALID_ARGUMENT: - return Status::Invalid("gRPC returned invalid argument error, with message: ", - grpc_status.error_message()); - case grpc::StatusCode::DEADLINE_EXCEEDED: - return Status::IOError("gRPC returned deadline exceeded error, with message: ", - grpc_status.error_message()) - .WithDetail(std::make_shared(FlightStatusCode::TimedOut)); - case grpc::StatusCode::NOT_FOUND: - return Status::KeyError("gRPC returned not found error, with message: ", - grpc_status.error_message()); - case grpc::StatusCode::ALREADY_EXISTS: - return Status::AlreadyExists("gRPC returned already exists error, with message: ", - grpc_status.error_message()); - case grpc::StatusCode::PERMISSION_DENIED: - return Status::IOError("gRPC returned permission denied error, with message: ", - grpc_status.error_message()) - .WithDetail( - std::make_shared(FlightStatusCode::Unauthorized)); - case grpc::StatusCode::RESOURCE_EXHAUSTED: - return Status::Invalid("gRPC returned resource exhausted error, with message: ", - grpc_status.error_message()); - case grpc::StatusCode::FAILED_PRECONDITION: - return Status::Invalid("gRPC returned precondition failed error, with message: ", - grpc_status.error_message()); - case grpc::StatusCode::ABORTED: - return Status::IOError("gRPC returned aborted error, with message: ", - grpc_status.error_message()) - .WithDetail(std::make_shared(FlightStatusCode::Internal)); - case grpc::StatusCode::OUT_OF_RANGE: - return Status::Invalid("gRPC returned out-of-range error, with message: ", - grpc_status.error_message()); - case grpc::StatusCode::UNIMPLEMENTED: - return Status::NotImplemented("gRPC returned unimplemented error, with message: ", - grpc_status.error_message()); - case grpc::StatusCode::INTERNAL: - return Status::IOError("gRPC returned internal error, with message: ", - grpc_status.error_message()) - .WithDetail(std::make_shared(FlightStatusCode::Internal)); - case grpc::StatusCode::UNAVAILABLE: - return Status::IOError("gRPC returned unavailable error, with message: ", - grpc_status.error_message()) - .WithDetail( - std::make_shared(FlightStatusCode::Unavailable)); - case grpc::StatusCode::DATA_LOSS: - return Status::IOError("gRPC returned data loss error, with message: ", - grpc_status.error_message()) - .WithDetail(std::make_shared(FlightStatusCode::Internal)); - case grpc::StatusCode::UNAUTHENTICATED: - return Status::IOError("gRPC returned unauthenticated error, with message: ", - grpc_status.error_message()) - .WithDetail( - std::make_shared(FlightStatusCode::Unauthenticated)); - default: - return Status::UnknownError("gRPC failed with error code ", - grpc_status.error_code(), - " and message: ", grpc_status.error_message()); - } -} - -Status FromGrpcStatus(const grpc::Status& grpc_status, grpc::ClientContext* ctx) { - const Status status = FromGrpcCode(grpc_status); - - if (!status.ok() && ctx) { - Status arrow_status; - - if (!FromGrpcContext(*ctx, &arrow_status, FlightStatusDetail::UnwrapStatus(status)) - .ok()) { - // If we fail to decode a more detailed status from the headers, - // proceed normally - return status; - } - - return arrow_status; - } - return status; -} - -/// Convert an Arrow status to a gRPC status. -static grpc::Status ToRawGrpcStatus(const Status& arrow_status) { - if (arrow_status.ok()) { - return grpc::Status::OK; - } - - grpc::StatusCode grpc_code = grpc::StatusCode::UNKNOWN; - std::string message = arrow_status.message(); - if (arrow_status.detail()) { - message += ". Detail: "; - message += arrow_status.detail()->ToString(); - } - - std::shared_ptr flight_status = - FlightStatusDetail::UnwrapStatus(arrow_status); - if (flight_status) { - switch (flight_status->code()) { - case FlightStatusCode::Internal: - grpc_code = grpc::StatusCode::INTERNAL; - break; - case FlightStatusCode::TimedOut: - grpc_code = grpc::StatusCode::DEADLINE_EXCEEDED; - break; - case FlightStatusCode::Cancelled: - grpc_code = grpc::StatusCode::CANCELLED; - break; - case FlightStatusCode::Unauthenticated: - grpc_code = grpc::StatusCode::UNAUTHENTICATED; - break; - case FlightStatusCode::Unauthorized: - grpc_code = grpc::StatusCode::PERMISSION_DENIED; - break; - case FlightStatusCode::Unavailable: - grpc_code = grpc::StatusCode::UNAVAILABLE; - break; - default: - break; - } - } else if (arrow_status.IsNotImplemented()) { - grpc_code = grpc::StatusCode::UNIMPLEMENTED; - } else if (arrow_status.IsInvalid()) { - grpc_code = grpc::StatusCode::INVALID_ARGUMENT; - } else if (arrow_status.IsKeyError()) { - grpc_code = grpc::StatusCode::NOT_FOUND; - } else if (arrow_status.IsAlreadyExists()) { - grpc_code = grpc::StatusCode::ALREADY_EXISTS; - } - return grpc::Status(grpc_code, message); -} - -/// Convert an Arrow status to a gRPC status, and add extra headers to -/// the response to encode the original Arrow status. -grpc::Status ToGrpcStatus(const Status& arrow_status, grpc::ServerContext* ctx) { - grpc::Status status = ToRawGrpcStatus(arrow_status); - if (!status.ok() && ctx) { - const std::string code = std::to_string(static_cast(arrow_status.code())); - ctx->AddTrailingMetadata(internal::kGrpcStatusCodeHeader, code); - ctx->AddTrailingMetadata(internal::kGrpcStatusMessageHeader, arrow_status.message()); - if (arrow_status.detail()) { - const std::string detail_string = arrow_status.detail()->ToString(); - ctx->AddTrailingMetadata(internal::kGrpcStatusDetailHeader, detail_string); - } - auto fsd = FlightStatusDetail::UnwrapStatus(arrow_status); - if (fsd && !fsd->extra_info().empty()) { - ctx->AddTrailingMetadata(internal::kBinaryErrorDetailsKey, fsd->extra_info()); - } - } - - return status; -} - -// ActionType - -Status FromProto(const pb::ActionType& pb_type, ActionType* type) { - type->type = pb_type.type(); - type->description = pb_type.description(); - return Status::OK(); -} - -Status ToProto(const ActionType& type, pb::ActionType* pb_type) { - pb_type->set_type(type.type); - pb_type->set_description(type.description); - return Status::OK(); -} - -// Action - -Status FromProto(const pb::Action& pb_action, Action* action) { - action->type = pb_action.type(); - action->body = Buffer::FromString(pb_action.body()); - return Status::OK(); -} - -Status ToProto(const Action& action, pb::Action* pb_action) { - pb_action->set_type(action.type); - if (action.body) { - pb_action->set_body(action.body->ToString()); - } - return Status::OK(); -} - -// Result (of an Action) - -Status FromProto(const pb::Result& pb_result, Result* result) { - // ARROW-3250; can avoid copy. Can also write custom deserializer if it - // becomes an issue - result->body = Buffer::FromString(pb_result.body()); - return Status::OK(); -} - -Status ToProto(const Result& result, pb::Result* pb_result) { - pb_result->set_body(result.body->ToString()); - return Status::OK(); -} - -// Criteria - -Status FromProto(const pb::Criteria& pb_criteria, Criteria* criteria) { - criteria->expression = pb_criteria.expression(); - return Status::OK(); -} -Status ToProto(const Criteria& criteria, pb::Criteria* pb_criteria) { - pb_criteria->set_expression(criteria.expression); - return Status::OK(); -} - -// Location - -Status FromProto(const pb::Location& pb_location, Location* location) { - return Location::Parse(pb_location.uri(), location); -} - -void ToProto(const Location& location, pb::Location* pb_location) { - pb_location->set_uri(location.ToString()); -} - -Status ToProto(const BasicAuth& basic_auth, pb::BasicAuth* pb_basic_auth) { - pb_basic_auth->set_username(basic_auth.username); - pb_basic_auth->set_password(basic_auth.password); - return Status::OK(); -} - -// Ticket - -Status FromProto(const pb::Ticket& pb_ticket, Ticket* ticket) { - ticket->ticket = pb_ticket.ticket(); - return Status::OK(); -} - -void ToProto(const Ticket& ticket, pb::Ticket* pb_ticket) { - pb_ticket->set_ticket(ticket.ticket); -} - -// FlightData - -Status FromProto(const pb::FlightData& pb_data, FlightDescriptor* descriptor, - std::unique_ptr* message) { - RETURN_NOT_OK(internal::FromProto(pb_data.flight_descriptor(), descriptor)); - const std::string& header = pb_data.data_header(); - const std::string& body = pb_data.data_body(); - std::shared_ptr header_buf = Buffer::Wrap(header.data(), header.size()); - std::shared_ptr body_buf = Buffer::Wrap(body.data(), body.size()); - if (header_buf == nullptr || body_buf == nullptr) { - return Status::UnknownError("Could not create buffers from protobuf"); - } - return ipc::Message::Open(header_buf, body_buf).Value(message); -} - -// FlightEndpoint - -Status FromProto(const pb::FlightEndpoint& pb_endpoint, FlightEndpoint* endpoint) { - RETURN_NOT_OK(FromProto(pb_endpoint.ticket(), &endpoint->ticket)); - endpoint->locations.resize(pb_endpoint.location_size()); - for (int i = 0; i < pb_endpoint.location_size(); ++i) { - RETURN_NOT_OK(FromProto(pb_endpoint.location(i), &endpoint->locations[i])); - } - return Status::OK(); -} - -void ToProto(const FlightEndpoint& endpoint, pb::FlightEndpoint* pb_endpoint) { - ToProto(endpoint.ticket, pb_endpoint->mutable_ticket()); - pb_endpoint->clear_location(); - for (const Location& location : endpoint.locations) { - ToProto(location, pb_endpoint->add_location()); - } -} - -// FlightDescriptor - -Status FromProto(const pb::FlightDescriptor& pb_descriptor, - FlightDescriptor* descriptor) { - if (pb_descriptor.type() == pb::FlightDescriptor::PATH) { - descriptor->type = FlightDescriptor::PATH; - descriptor->path.reserve(pb_descriptor.path_size()); - for (int i = 0; i < pb_descriptor.path_size(); ++i) { - descriptor->path.emplace_back(pb_descriptor.path(i)); - } - } else if (pb_descriptor.type() == pb::FlightDescriptor::CMD) { - descriptor->type = FlightDescriptor::CMD; - descriptor->cmd = pb_descriptor.cmd(); - } else { - return Status::Invalid("Client sent UNKNOWN descriptor type"); - } - return Status::OK(); -} - -Status ToProto(const FlightDescriptor& descriptor, pb::FlightDescriptor* pb_descriptor) { - if (descriptor.type == FlightDescriptor::PATH) { - pb_descriptor->set_type(pb::FlightDescriptor::PATH); - for (const std::string& path : descriptor.path) { - pb_descriptor->add_path(path); - } - } else { - pb_descriptor->set_type(pb::FlightDescriptor::CMD); - pb_descriptor->set_cmd(descriptor.cmd); - } - return Status::OK(); -} - -// FlightInfo - -Status FromProto(const pb::FlightInfo& pb_info, FlightInfo::Data* info) { - RETURN_NOT_OK(FromProto(pb_info.flight_descriptor(), &info->descriptor)); - - info->schema = pb_info.schema(); - - info->endpoints.resize(pb_info.endpoint_size()); - for (int i = 0; i < pb_info.endpoint_size(); ++i) { - RETURN_NOT_OK(FromProto(pb_info.endpoint(i), &info->endpoints[i])); - } - - info->total_records = pb_info.total_records(); - info->total_bytes = pb_info.total_bytes(); - return Status::OK(); -} - -Status FromProto(const pb::BasicAuth& pb_basic_auth, BasicAuth* basic_auth) { - basic_auth->password = pb_basic_auth.password(); - basic_auth->username = pb_basic_auth.username(); - - return Status::OK(); -} - -Status FromProto(const pb::SchemaResult& pb_result, std::string* result) { - *result = pb_result.schema(); - return Status::OK(); -} - -Status SchemaToString(const Schema& schema, std::string* out) { - ipc::DictionaryMemo unused_dict_memo; - ARROW_ASSIGN_OR_RAISE(std::shared_ptr serialized_schema, - ipc::SerializeSchema(schema)); - *out = std::string(reinterpret_cast(serialized_schema->data()), - static_cast(serialized_schema->size())); - return Status::OK(); -} - -Status ToProto(const FlightInfo& info, pb::FlightInfo* pb_info) { - // clear any repeated fields - pb_info->clear_endpoint(); - - pb_info->set_schema(info.serialized_schema()); - - // descriptor - RETURN_NOT_OK(ToProto(info.descriptor(), pb_info->mutable_flight_descriptor())); - - // endpoints - for (const FlightEndpoint& endpoint : info.endpoints()) { - ToProto(endpoint, pb_info->add_endpoint()); - } - - pb_info->set_total_records(info.total_records()); - pb_info->set_total_bytes(info.total_bytes()); - return Status::OK(); -} - -Status ToProto(const SchemaResult& result, pb::SchemaResult* pb_result) { - pb_result->set_schema(result.serialized_schema()); - return Status::OK(); -} - -Status ToPayload(const FlightDescriptor& descr, std::shared_ptr* out) { - // TODO(lidavidm): make these use Result - std::string str_descr; - pb::FlightDescriptor pb_descr; - RETURN_NOT_OK(ToProto(descr, &pb_descr)); - if (!pb_descr.SerializeToString(&str_descr)) { - return Status::UnknownError("Failed to serialize Flight descriptor"); - } - *out = Buffer::FromString(std::move(str_descr)); - return Status::OK(); -} - -} // namespace internal -} // namespace flight -} // namespace arrow diff --git a/cpp/src/arrow/flight/internal.h b/cpp/src/arrow/flight/internal.h deleted file mode 100644 index c0964c68fdf..00000000000 --- a/cpp/src/arrow/flight/internal.h +++ /dev/null @@ -1,128 +0,0 @@ -// Licensed to the Apache Software Foundation (ASF) under one -// or more contributor license agreements. See the NOTICE file -// distributed with this work for additional information -// regarding copyright ownership. The ASF licenses this file -// to you under the Apache License, Version 2.0 (the -// "License"); you may not use this file except in compliance -// with the License. You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, -// software distributed under the License is distributed on an -// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -// KIND, either express or implied. See the License for the -// specific language governing permissions and limitations -// under the License. - -#pragma once - -#include -#include - -#include "arrow/flight/protocol_internal.h" // IWYU pragma: keep -#include "arrow/flight/types.h" -#include "arrow/util/macros.h" - -namespace grpc { - -class Status; - -} // namespace grpc - -namespace arrow { - -class Schema; -class Status; - -namespace pb = arrow::flight::protocol; - -namespace ipc { - -class Message; - -} // namespace ipc - -namespace flight { - -#define GRPC_RETURN_NOT_OK(expr) \ - do { \ - ::arrow::Status _s = (expr); \ - if (ARROW_PREDICT_FALSE(!_s.ok())) { \ - return ::arrow::flight::internal::ToGrpcStatus(_s); \ - } \ - } while (0) - -#define GRPC_RETURN_NOT_GRPC_OK(expr) \ - do { \ - ::grpc::Status _s = (expr); \ - if (ARROW_PREDICT_FALSE(!_s.ok())) { \ - return _s; \ - } \ - } while (0) - -namespace internal { - -/// The name of the header used to pass authentication tokens. -ARROW_FLIGHT_EXPORT -extern const char* kGrpcAuthHeader; - -/// The name of the header used to pass the exact Arrow status code. -ARROW_FLIGHT_EXPORT -extern const char* kGrpcStatusCodeHeader; - -/// The name of the header used to pass the exact Arrow status message. -ARROW_FLIGHT_EXPORT -extern const char* kGrpcStatusMessageHeader; - -/// The name of the header used to pass the exact Arrow status detail. -ARROW_FLIGHT_EXPORT -extern const char* kGrpcStatusDetailHeader; - -ARROW_FLIGHT_EXPORT -extern const char* kBinaryErrorDetailsKey; - -ARROW_FLIGHT_EXPORT -Status SchemaToString(const Schema& schema, std::string* out); - -/// Convert a gRPC status to an Arrow status. Optionally, provide a -/// ClientContext to recover the exact Arrow status if it was passed -/// over the wire. -ARROW_FLIGHT_EXPORT -Status FromGrpcStatus(const grpc::Status& grpc_status, - grpc::ClientContext* ctx = nullptr); - -ARROW_FLIGHT_EXPORT -grpc::Status ToGrpcStatus(const Status& arrow_status, grpc::ServerContext* ctx = nullptr); - -// These functions depend on protobuf types which are not exported in the Flight DLL. - -Status FromProto(const pb::ActionType& pb_type, ActionType* type); -Status FromProto(const pb::Action& pb_action, Action* action); -Status FromProto(const pb::Result& pb_result, Result* result); -Status FromProto(const pb::Criteria& pb_criteria, Criteria* criteria); -Status FromProto(const pb::Location& pb_location, Location* location); -Status FromProto(const pb::Ticket& pb_ticket, Ticket* ticket); -Status FromProto(const pb::FlightData& pb_data, FlightDescriptor* descriptor, - std::unique_ptr* message); -Status FromProto(const pb::FlightDescriptor& pb_descr, FlightDescriptor* descr); -Status FromProto(const pb::FlightEndpoint& pb_endpoint, FlightEndpoint* endpoint); -Status FromProto(const pb::FlightInfo& pb_info, FlightInfo::Data* info); -Status FromProto(const pb::SchemaResult& pb_result, std::string* result); -Status FromProto(const pb::BasicAuth& pb_basic_auth, BasicAuth* info); - -Status ToProto(const FlightDescriptor& descr, pb::FlightDescriptor* pb_descr); -Status ToProto(const FlightInfo& info, pb::FlightInfo* pb_info); -Status ToProto(const ActionType& type, pb::ActionType* pb_type); -Status ToProto(const Action& action, pb::Action* pb_action); -Status ToProto(const Result& result, pb::Result* pb_result); -Status ToProto(const Criteria& criteria, pb::Criteria* pb_criteria); -Status ToProto(const SchemaResult& result, pb::SchemaResult* pb_result); -void ToProto(const Ticket& ticket, pb::Ticket* pb_ticket); -Status ToProto(const BasicAuth& basic_auth, pb::BasicAuth* pb_basic_auth); - -Status ToPayload(const FlightDescriptor& descr, std::shared_ptr* out); - -} // namespace internal -} // namespace flight -} // namespace arrow diff --git a/cpp/src/arrow/flight/perf_server.cc b/cpp/src/arrow/flight/perf_server.cc index 89ffc046c82..ae89f88dde1 100644 --- a/cpp/src/arrow/flight/perf_server.cc +++ b/cpp/src/arrow/flight/perf_server.cc @@ -36,8 +36,8 @@ #include "arrow/util/logging.h" #include "arrow/flight/api.h" -#include "arrow/flight/internal.h" #include "arrow/flight/perf.pb.h" +#include "arrow/flight/protocol_internal.h" #include "arrow/flight/test_util.h" #ifdef ARROW_CUDA diff --git a/cpp/src/arrow/flight/platform.h b/cpp/src/arrow/flight/platform.h index 7f1b0954d84..8f8db2d2dc8 100644 --- a/cpp/src/arrow/flight/platform.h +++ b/cpp/src/arrow/flight/platform.h @@ -15,7 +15,7 @@ // specific language governing permissions and limitations // under the License. -// Internal header. Platform-specific definitions for gRPC. +// Internal header. Platform-specific definitions for Flight. #pragma once @@ -28,5 +28,4 @@ #endif // _MSC_VER -#include "arrow/util/config.h" // IWYU pragma: keep -#include "arrow/util/windows_compatibility.h" // IWYU pragma: keep +#include "arrow/util/config.h" // IWYU pragma: keep diff --git a/cpp/src/arrow/flight/protocol_internal.h b/cpp/src/arrow/flight/protocol_internal.h index 98bf9238809..60e87186488 100644 --- a/cpp/src/arrow/flight/protocol_internal.h +++ b/cpp/src/arrow/flight/protocol_internal.h @@ -19,10 +19,5 @@ // This addresses platform-specific defines, e.g. on Windows #include "arrow/flight/platform.h" // IWYU pragma: keep -// This header holds the Flight protobuf definitions. - -// Need to include this first to get our gRPC customizations -#include "arrow/flight/customize_protobuf.h" // IWYU pragma: export - -#include "arrow/flight/Flight.grpc.pb.h" // IWYU pragma: export -#include "arrow/flight/Flight.pb.h" // IWYU pragma: export +// This header holds the Flight Protobuf definitions. +#include "arrow/flight/Flight.pb.h" // IWYU pragma: export diff --git a/cpp/src/arrow/flight/serialization_internal.cc b/cpp/src/arrow/flight/serialization_internal.cc index cbea9566390..bbffc643466 100644 --- a/cpp/src/arrow/flight/serialization_internal.cc +++ b/cpp/src/arrow/flight/serialization_internal.cc @@ -17,473 +17,244 @@ #include "arrow/flight/serialization_internal.h" -#include -#include +#include #include -#include - -#include "arrow/flight/platform.h" - -#if defined(_MSC_VER) -#pragma warning(push) -#pragma warning(disable : 4267) -#endif - -#include -#include -#include - -#include -#ifdef GRPCPP_PP_INCLUDE -#include -#include -#else -#include -#include -#endif - -#if defined(_MSC_VER) -#pragma warning(pop) -#endif #include "arrow/buffer.h" -#include "arrow/device.h" -#include "arrow/flight/internal.h" -#include "arrow/flight/server.h" -#include "arrow/ipc/message.h" +#include "arrow/io/memory.h" +#include "arrow/ipc/reader.h" #include "arrow/ipc/writer.h" -#include "arrow/util/bit_util.h" -#include "arrow/util/logging.h" - -static constexpr int64_t kInt32Max = std::numeric_limits::max(); +#include "arrow/result.h" +#include "arrow/status.h" namespace arrow { namespace flight { namespace internal { -namespace pb = arrow::flight::protocol; +// ActionType -using arrow::ipc::IpcPayload; +Status FromProto(const pb::ActionType& pb_type, ActionType* type) { + type->type = pb_type.type(); + type->description = pb_type.description(); + return Status::OK(); +} -using google::protobuf::internal::WireFormatLite; -using google::protobuf::io::ArrayOutputStream; -using google::protobuf::io::CodedInputStream; -using google::protobuf::io::CodedOutputStream; +Status ToProto(const ActionType& type, pb::ActionType* pb_type) { + pb_type->set_type(type.type); + pb_type->set_description(type.description); + return Status::OK(); +} -using grpc::ByteBuffer; +// Action -bool ReadBytesZeroCopy(const std::shared_ptr& source_data, - CodedInputStream* input, std::shared_ptr* out) { - uint32_t length; - if (!input->ReadVarint32(&length)) { - return false; - } - auto buf = - SliceBuffer(source_data, input->CurrentPosition(), static_cast(length)); - *out = buf; - return input->Skip(static_cast(length)); +Status FromProto(const pb::Action& pb_action, Action* action) { + action->type = pb_action.type(); + action->body = Buffer::FromString(pb_action.body()); + return Status::OK(); } -// Internal wrapper for gRPC ByteBuffer so its memory can be exposed to Arrow -// consumers with zero-copy -class GrpcBuffer : public MutableBuffer { - public: - GrpcBuffer(grpc_slice slice, bool incref) - : MutableBuffer(GRPC_SLICE_START_PTR(slice), - static_cast(GRPC_SLICE_LENGTH(slice))), - slice_(incref ? grpc_slice_ref(slice) : slice) {} - - ~GrpcBuffer() override { - // Decref slice - grpc_slice_unref(slice_); +Status ToProto(const Action& action, pb::Action* pb_action) { + pb_action->set_type(action.type); + if (action.body) { + pb_action->set_body(action.body->ToString()); } + return Status::OK(); +} - static Status Wrap(ByteBuffer* cpp_buf, std::shared_ptr* out) { - // These types are guaranteed by static assertions in gRPC to have the same - // in-memory representation - - auto buffer = *reinterpret_cast(cpp_buf); - - // This part below is based on the Flatbuffers gRPC SerializationTraits in - // flatbuffers/grpc.h - - // Check if this is a single uncompressed slice. - if ((buffer->type == GRPC_BB_RAW) && - (buffer->data.raw.compression == GRPC_COMPRESS_NONE) && - (buffer->data.raw.slice_buffer.count == 1)) { - // If it is, then we can reference the `grpc_slice` directly. - grpc_slice slice = buffer->data.raw.slice_buffer.slices[0]; - - if (slice.refcount) { - // Increment reference count so this memory remains valid - *out = std::make_shared(slice, true); - } else { - // Small slices (less than GRPC_SLICE_INLINED_SIZE bytes) are - // inlined into the structure and must be copied. - const uint8_t length = slice.data.inlined.length; - ARROW_ASSIGN_OR_RAISE(*out, arrow::AllocateBuffer(length)); - std::memcpy((*out)->mutable_data(), slice.data.inlined.bytes, length); - } - } else { - // Otherwise, we need to use `grpc_byte_buffer_reader_readall` to read - // `buffer` into a single contiguous `grpc_slice`. The gRPC reader gives - // us back a new slice with the refcount already incremented. - grpc_byte_buffer_reader reader; - if (!grpc_byte_buffer_reader_init(&reader, buffer)) { - return Status::IOError("Internal gRPC error reading from ByteBuffer"); - } - grpc_slice slice = grpc_byte_buffer_reader_readall(&reader); - grpc_byte_buffer_reader_destroy(&reader); - - // Steal the slice reference - *out = std::make_shared(slice, false); - } - - return Status::OK(); - } +// Result (of an Action) - private: - grpc_slice slice_; -}; +Status FromProto(const pb::Result& pb_result, Result* result) { + // ARROW-3250; can avoid copy. Can also write custom deserializer if it + // becomes an issue + result->body = Buffer::FromString(pb_result.body()); + return Status::OK(); +} -// Destructor callback for grpc::Slice -static void ReleaseBuffer(void* buf_ptr) { - delete reinterpret_cast*>(buf_ptr); +Status ToProto(const Result& result, pb::Result* pb_result) { + pb_result->set_body(result.body->ToString()); + return Status::OK(); } -// Initialize gRPC Slice from arrow Buffer -arrow::Result SliceFromBuffer(const std::shared_ptr& buf) { - // Allocate persistent shared_ptr to control Buffer lifetime - std::shared_ptr* ptr = nullptr; - if (ARROW_PREDICT_TRUE(buf->is_cpu())) { - ptr = new std::shared_ptr(buf); - } else { - // Non-CPU buffer, must copy to CPU-accessible buffer first - ARROW_ASSIGN_OR_RAISE(auto cpu_buf, - Buffer::ViewOrCopy(buf, default_cpu_memory_manager())); - ptr = new std::shared_ptr(cpu_buf); - } - grpc::Slice slice(const_cast((*ptr)->data()), - static_cast((*ptr)->size()), &ReleaseBuffer, ptr); - // Make sure no copy was done (some grpc::Slice() constructors do an implicit memcpy) - DCHECK_EQ(slice.begin(), (*ptr)->data()); - return slice; +// Criteria + +Status FromProto(const pb::Criteria& pb_criteria, Criteria* criteria) { + criteria->expression = pb_criteria.expression(); + return Status::OK(); +} +Status ToProto(const Criteria& criteria, pb::Criteria* pb_criteria) { + pb_criteria->set_expression(criteria.expression); + return Status::OK(); } -static const uint8_t kPaddingBytes[8] = {0, 0, 0, 0, 0, 0, 0, 0}; +// Location -// Update the sizes of our Protobuf fields based on the given IPC payload. -grpc::Status IpcMessageHeaderSize(const arrow::ipc::IpcPayload& ipc_msg, bool has_body, - size_t* header_size, int32_t* metadata_size) { - DCHECK_LE(ipc_msg.metadata->size(), kInt32Max); - *metadata_size = static_cast(ipc_msg.metadata->size()); +Status FromProto(const pb::Location& pb_location, Location* location) { + return Location::Parse(pb_location.uri(), location); +} - // 1 byte for metadata tag - *header_size += 1 + WireFormatLite::LengthDelimitedSize(*metadata_size); +Status ToProto(const Location& location, pb::Location* pb_location) { + pb_location->set_uri(location.ToString()); + return Status::OK(); +} - // 2 bytes for body tag - if (has_body) { - // We write the body tag in the header but not the actual body data - *header_size += 2 + WireFormatLite::LengthDelimitedSize(ipc_msg.body_length) - - ipc_msg.body_length; - } +Status ToProto(const BasicAuth& basic_auth, pb::BasicAuth* pb_basic_auth) { + pb_basic_auth->set_username(basic_auth.username); + pb_basic_auth->set_password(basic_auth.password); + return Status::OK(); +} - return grpc::Status::OK; +// Ticket + +Status FromProto(const pb::Ticket& pb_ticket, Ticket* ticket) { + ticket->ticket = pb_ticket.ticket(); + return Status::OK(); } -grpc::Status FlightDataSerialize(const FlightPayload& msg, ByteBuffer* out, - bool* own_buffer) { - // Size of the IPC body (protobuf: data_body) - size_t body_size = 0; - // Size of the Protobuf "header" (everything except for the body) - size_t header_size = 0; - // Size of IPC header metadata (protobuf: data_header) - int32_t metadata_size = 0; - - // Write the descriptor if present - int32_t descriptor_size = 0; - if (msg.descriptor != nullptr) { - DCHECK_LE(msg.descriptor->size(), kInt32Max); - descriptor_size = static_cast(msg.descriptor->size()); - header_size += 1 + WireFormatLite::LengthDelimitedSize(descriptor_size); - } +Status ToProto(const Ticket& ticket, pb::Ticket* pb_ticket) { + pb_ticket->set_ticket(ticket.ticket); + return Status::OK(); +} - // App metadata tag if appropriate - int32_t app_metadata_size = 0; - if (msg.app_metadata && msg.app_metadata->size() > 0) { - DCHECK_LE(msg.app_metadata->size(), kInt32Max); - app_metadata_size = static_cast(msg.app_metadata->size()); - header_size += 1 + WireFormatLite::LengthDelimitedSize(app_metadata_size); +// FlightData + +Status FromProto(const pb::FlightData& pb_data, FlightDescriptor* descriptor, + std::unique_ptr* message) { + RETURN_NOT_OK(internal::FromProto(pb_data.flight_descriptor(), descriptor)); + const std::string& header = pb_data.data_header(); + const std::string& body = pb_data.data_body(); + std::shared_ptr header_buf = Buffer::Wrap(header.data(), header.size()); + std::shared_ptr body_buf = Buffer::Wrap(body.data(), body.size()); + if (header_buf == nullptr || body_buf == nullptr) { + return Status::UnknownError("Could not create buffers from protobuf"); } + return ipc::Message::Open(header_buf, body_buf).Value(message); +} - const arrow::ipc::IpcPayload& ipc_msg = msg.ipc_message; - // No data in this payload (metadata-only). - bool has_ipc = ipc_msg.type != ipc::MessageType::NONE; - bool has_body = has_ipc ? ipc::Message::HasBody(ipc_msg.type) : false; +// FlightEndpoint - if (has_ipc) { - DCHECK(has_body || ipc_msg.body_length == 0); - GRPC_RETURN_NOT_GRPC_OK( - IpcMessageHeaderSize(ipc_msg, has_body, &header_size, &metadata_size)); - body_size = static_cast(ipc_msg.body_length); +Status FromProto(const pb::FlightEndpoint& pb_endpoint, FlightEndpoint* endpoint) { + RETURN_NOT_OK(FromProto(pb_endpoint.ticket(), &endpoint->ticket)); + endpoint->locations.resize(pb_endpoint.location_size()); + for (int i = 0; i < pb_endpoint.location_size(); ++i) { + RETURN_NOT_OK(FromProto(pb_endpoint.location(i), &endpoint->locations[i])); } + return Status::OK(); +} - // TODO(wesm): messages over 2GB unlikely to be yet supported - // Validated in WritePayload since returning error here causes gRPC to fail an assertion - DCHECK_LE(body_size, kInt32Max); - - // Allocate and initialize slices - std::vector slices; - slices.emplace_back(header_size); - - // Force the header_stream to be destructed, which actually flushes - // the data into the slice. - { - ArrayOutputStream header_writer(const_cast(slices[0].begin()), - static_cast(slices[0].size())); - CodedOutputStream header_stream(&header_writer); - - // Write descriptor - if (msg.descriptor != nullptr) { - WireFormatLite::WriteTag(pb::FlightData::kFlightDescriptorFieldNumber, - WireFormatLite::WIRETYPE_LENGTH_DELIMITED, &header_stream); - header_stream.WriteVarint32(descriptor_size); - header_stream.WriteRawMaybeAliased(msg.descriptor->data(), - static_cast(msg.descriptor->size())); - } +Status ToProto(const FlightEndpoint& endpoint, pb::FlightEndpoint* pb_endpoint) { + RETURN_NOT_OK(ToProto(endpoint.ticket, pb_endpoint->mutable_ticket())); + pb_endpoint->clear_location(); + for (const Location& location : endpoint.locations) { + RETURN_NOT_OK(ToProto(location, pb_endpoint->add_location())); + } + return Status::OK(); +} - // Write header - if (has_ipc) { - WireFormatLite::WriteTag(pb::FlightData::kDataHeaderFieldNumber, - WireFormatLite::WIRETYPE_LENGTH_DELIMITED, &header_stream); - header_stream.WriteVarint32(metadata_size); - header_stream.WriteRawMaybeAliased(ipc_msg.metadata->data(), - static_cast(ipc_msg.metadata->size())); - } +// FlightDescriptor - // Write app metadata - if (app_metadata_size > 0) { - WireFormatLite::WriteTag(pb::FlightData::kAppMetadataFieldNumber, - WireFormatLite::WIRETYPE_LENGTH_DELIMITED, &header_stream); - header_stream.WriteVarint32(app_metadata_size); - header_stream.WriteRawMaybeAliased(msg.app_metadata->data(), - static_cast(msg.app_metadata->size())); +Status FromProto(const pb::FlightDescriptor& pb_descriptor, + FlightDescriptor* descriptor) { + if (pb_descriptor.type() == pb::FlightDescriptor::PATH) { + descriptor->type = FlightDescriptor::PATH; + descriptor->path.reserve(pb_descriptor.path_size()); + for (int i = 0; i < pb_descriptor.path_size(); ++i) { + descriptor->path.emplace_back(pb_descriptor.path(i)); } + } else if (pb_descriptor.type() == pb::FlightDescriptor::CMD) { + descriptor->type = FlightDescriptor::CMD; + descriptor->cmd = pb_descriptor.cmd(); + } else { + return Status::Invalid("Client sent UNKNOWN descriptor type"); + } + return Status::OK(); +} - if (has_body) { - // Write body tag - WireFormatLite::WriteTag(pb::FlightData::kDataBodyFieldNumber, - WireFormatLite::WIRETYPE_LENGTH_DELIMITED, &header_stream); - header_stream.WriteVarint32(static_cast(body_size)); - - // Enqueue body buffers for writing, without copying - for (const auto& buffer : ipc_msg.body_buffers) { - // Buffer may be null when the row length is zero, or when all - // entries are invalid. - if (!buffer) continue; - - grpc::Slice slice; - auto status = SliceFromBuffer(buffer).Value(&slice); - if (ARROW_PREDICT_FALSE(!status.ok())) { - // This will likely lead to abort as gRPC cannot recover from an error here - return ToGrpcStatus(status); - } - slices.push_back(std::move(slice)); - - // Write padding if not multiple of 8 - const auto remainder = static_cast( - bit_util::RoundUpToMultipleOf8(buffer->size()) - buffer->size()); - if (remainder) { - slices.push_back(grpc::Slice(kPaddingBytes, remainder)); - } - } +Status ToProto(const FlightDescriptor& descriptor, pb::FlightDescriptor* pb_descriptor) { + if (descriptor.type == FlightDescriptor::PATH) { + pb_descriptor->set_type(pb::FlightDescriptor::PATH); + for (const std::string& path : descriptor.path) { + pb_descriptor->add_path(path); } - - DCHECK_EQ(static_cast(header_size), header_stream.ByteCount()); + } else { + pb_descriptor->set_type(pb::FlightDescriptor::CMD); + pb_descriptor->set_cmd(descriptor.cmd); } - - // Hand off the slices to the returned ByteBuffer - *out = grpc::ByteBuffer(slices.data(), slices.size()); - *own_buffer = true; - return grpc::Status::OK; + return Status::OK(); } -// Read internal::FlightData from grpc::ByteBuffer containing FlightData -// protobuf without copying -grpc::Status FlightDataDeserialize(ByteBuffer* buffer, FlightData* out) { - if (!buffer) { - return grpc::Status(grpc::StatusCode::INTERNAL, "No payload"); - } +// FlightInfo - // Reset fields in case the caller reuses a single allocation - out->descriptor = nullptr; - out->app_metadata = nullptr; - out->metadata = nullptr; - out->body = nullptr; - - std::shared_ptr wrapped_buffer; - GRPC_RETURN_NOT_OK(GrpcBuffer::Wrap(buffer, &wrapped_buffer)); - - auto buffer_length = static_cast(wrapped_buffer->size()); - CodedInputStream pb_stream(wrapped_buffer->data(), buffer_length); - - pb_stream.SetTotalBytesLimit(buffer_length); - - // This is the bytes remaining when using CodedInputStream like this - while (pb_stream.BytesUntilTotalBytesLimit()) { - const uint32_t tag = pb_stream.ReadTag(); - const int field_number = WireFormatLite::GetTagFieldNumber(tag); - switch (field_number) { - case pb::FlightData::kFlightDescriptorFieldNumber: { - pb::FlightDescriptor pb_descriptor; - uint32_t length; - if (!pb_stream.ReadVarint32(&length)) { - return grpc::Status(grpc::StatusCode::INTERNAL, - "Unable to parse length of FlightDescriptor"); - } - // Can't use ParseFromCodedStream as this reads the entire - // rest of the stream into the descriptor command field. - std::string buffer; - pb_stream.ReadString(&buffer, length); - if (!pb_descriptor.ParseFromString(buffer)) { - return grpc::Status(grpc::StatusCode::INTERNAL, - "Unable to parse FlightDescriptor"); - } - arrow::flight::FlightDescriptor descriptor; - GRPC_RETURN_NOT_OK( - arrow::flight::internal::FromProto(pb_descriptor, &descriptor)); - out->descriptor.reset(new arrow::flight::FlightDescriptor(descriptor)); - } break; - case pb::FlightData::kDataHeaderFieldNumber: { - if (!ReadBytesZeroCopy(wrapped_buffer, &pb_stream, &out->metadata)) { - return grpc::Status(grpc::StatusCode::INTERNAL, - "Unable to read FlightData metadata"); - } - } break; - case pb::FlightData::kAppMetadataFieldNumber: { - if (!ReadBytesZeroCopy(wrapped_buffer, &pb_stream, &out->app_metadata)) { - return grpc::Status(grpc::StatusCode::INTERNAL, - "Unable to read FlightData application metadata"); - } - } break; - case pb::FlightData::kDataBodyFieldNumber: { - if (!ReadBytesZeroCopy(wrapped_buffer, &pb_stream, &out->body)) { - return grpc::Status(grpc::StatusCode::INTERNAL, - "Unable to read FlightData body"); - } - } break; - default: - DCHECK(false) << "cannot happen"; - } - } - buffer->Clear(); +Status FromProto(const pb::FlightInfo& pb_info, FlightInfo::Data* info) { + RETURN_NOT_OK(FromProto(pb_info.flight_descriptor(), &info->descriptor)); - // TODO(wesm): Where and when should we verify that the FlightData is not - // malformed? + info->schema = pb_info.schema(); - // Set the default value for an unspecified FlightData body. The other - // fields can be null if they're unspecified. - if (out->body == nullptr) { - out->body = std::make_shared(nullptr, 0); + info->endpoints.resize(pb_info.endpoint_size()); + for (int i = 0; i < pb_info.endpoint_size(); ++i) { + RETURN_NOT_OK(FromProto(pb_info.endpoint(i), &info->endpoints[i])); } - return grpc::Status::OK; -} - -::arrow::Result> FlightData::OpenMessage() { - return ipc::Message::Open(metadata, body); + info->total_records = pb_info.total_records(); + info->total_bytes = pb_info.total_bytes(); + return Status::OK(); } -// The pointer bitcast hack below causes legitimate warnings, silence them. -#ifndef _WIN32 -#pragma GCC diagnostic push -#pragma GCC diagnostic ignored "-Wstrict-aliasing" -#endif +Status FromProto(const pb::BasicAuth& pb_basic_auth, BasicAuth* basic_auth) { + basic_auth->password = pb_basic_auth.password(); + basic_auth->username = pb_basic_auth.username(); -// Pointer bitcast explanation: grpc::*Writer::Write() and grpc::*Reader::Read() -// both take a T* argument (here pb::FlightData*). But they don't do anything -// with that argument except pass it to SerializationTraits::Serialize() and -// SerializationTraits::Deserialize(). -// -// Since we control SerializationTraits, we can interpret the -// pointer argument whichever way we want, including cast it back to the original type. -// (see customize_protobuf.h). - -Status WritePayload(const FlightPayload& payload, - grpc::ClientReaderWriter* writer) { - RETURN_NOT_OK(payload.Validate()); - // Pretend to be pb::FlightData and intercept in SerializationTraits - if (!writer->Write(*reinterpret_cast(&payload), - grpc::WriteOptions())) { - return Status::IOError("Could not write payload to stream"); - } return Status::OK(); } -Status WritePayload(const FlightPayload& payload, - grpc::ClientReaderWriter* writer) { - RETURN_NOT_OK(payload.Validate()); - // Pretend to be pb::FlightData and intercept in SerializationTraits - if (!writer->Write(*reinterpret_cast(&payload), - grpc::WriteOptions())) { - return Status::IOError("Could not write payload to stream"); - } +Status FromProto(const pb::SchemaResult& pb_result, std::string* result) { + *result = pb_result.schema(); return Status::OK(); } -Status WritePayload(const FlightPayload& payload, - grpc::ServerReaderWriter* writer) { - RETURN_NOT_OK(payload.Validate()); - // Pretend to be pb::FlightData and intercept in SerializationTraits - if (!writer->Write(*reinterpret_cast(&payload), - grpc::WriteOptions())) { - return Status::IOError("Could not write payload to stream"); - } +Status SchemaToString(const Schema& schema, std::string* out) { + ipc::DictionaryMemo unused_dict_memo; + ARROW_ASSIGN_OR_RAISE(std::shared_ptr serialized_schema, + ipc::SerializeSchema(schema)); + *out = std::string(reinterpret_cast(serialized_schema->data()), + static_cast(serialized_schema->size())); return Status::OK(); } -Status WritePayload(const FlightPayload& payload, - grpc::ServerWriter* writer) { - RETURN_NOT_OK(payload.Validate()); - // Pretend to be pb::FlightData and intercept in SerializationTraits - if (!writer->Write(*reinterpret_cast(&payload), - grpc::WriteOptions())) { - return Status::IOError("Could not write payload to stream"); - } - return Status::OK(); -} +Status ToProto(const FlightInfo& info, pb::FlightInfo* pb_info) { + // clear any repeated fields + pb_info->clear_endpoint(); -bool ReadPayload(grpc::ClientReader* reader, FlightData* data) { - // Pretend to be pb::FlightData and intercept in SerializationTraits - return reader->Read(reinterpret_cast(data)); -} + pb_info->set_schema(info.serialized_schema()); -bool ReadPayload(grpc::ClientReaderWriter* reader, - FlightData* data) { - // Pretend to be pb::FlightData and intercept in SerializationTraits - return reader->Read(reinterpret_cast(data)); -} + // descriptor + RETURN_NOT_OK(ToProto(info.descriptor(), pb_info->mutable_flight_descriptor())); -bool ReadPayload(grpc::ServerReaderWriter* reader, - FlightData* data) { - // Pretend to be pb::FlightData and intercept in SerializationTraits - return reader->Read(reinterpret_cast(data)); -} + // endpoints + for (const FlightEndpoint& endpoint : info.endpoints()) { + RETURN_NOT_OK(ToProto(endpoint, pb_info->add_endpoint())); + } -bool ReadPayload(grpc::ServerReaderWriter* reader, - FlightData* data) { - // Pretend to be pb::FlightData and intercept in SerializationTraits - return reader->Read(reinterpret_cast(data)); + pb_info->set_total_records(info.total_records()); + pb_info->set_total_bytes(info.total_bytes()); + return Status::OK(); } -bool ReadPayload(grpc::ClientReaderWriter* reader, - pb::PutResult* data) { - return reader->Read(data); +Status ToProto(const SchemaResult& result, pb::SchemaResult* pb_result) { + pb_result->set_schema(result.serialized_schema()); + return Status::OK(); } -#ifndef _WIN32 -#pragma GCC diagnostic pop -#endif +Status ToPayload(const FlightDescriptor& descr, std::shared_ptr* out) { + // TODO(ARROW-15612): make these use Result + std::string str_descr; + pb::FlightDescriptor pb_descr; + RETURN_NOT_OK(ToProto(descr, &pb_descr)); + if (!pb_descr.SerializeToString(&str_descr)) { + return Status::UnknownError("Failed to serialize Flight descriptor"); + } + *out = Buffer::FromString(std::move(str_descr)); + return Status::OK(); +} } // namespace internal } // namespace flight diff --git a/cpp/src/arrow/flight/serialization_internal.h b/cpp/src/arrow/flight/serialization_internal.h index 5f7d0cc487c..c27bc79b315 100644 --- a/cpp/src/arrow/flight/serialization_internal.h +++ b/cpp/src/arrow/flight/serialization_internal.h @@ -15,82 +15,73 @@ // specific language governing permissions and limitations // under the License. -// (De)serialization utilities that hook into gRPC, efficiently -// handling Arrow-encoded data in a gRPC call. +// Generic Flight I/O utilities. #pragma once -#include - -#include "arrow/flight/internal.h" +#include "arrow/flight/protocol_internal.h" // IWYU pragma: keep +#include "arrow/flight/transport.h" #include "arrow/flight/types.h" -#include "arrow/ipc/message.h" -#include "arrow/result.h" +#include "arrow/util/macros.h" namespace arrow { -class Buffer; +class Schema; +class Status; + +namespace ipc { +class Message; +} // namespace ipc namespace flight { +namespace pb = arrow::flight::protocol; namespace internal { -/// Internal, not user-visible type used for memory-efficient reads from gRPC -/// stream -struct FlightData { - /// Used only for puts, may be null - std::unique_ptr descriptor; - - /// Non-length-prefixed Message header as described in format/Message.fbs - std::shared_ptr metadata; - - /// Application-defined metadata - std::shared_ptr app_metadata; - - /// Message body - std::shared_ptr body; - - /// Open IPC message from the metadata and body - ::arrow::Result> OpenMessage(); -}; - -/// Write Flight message on gRPC stream with zero-copy optimizations. -// Returns Invalid if the payload is ill-formed -// Returns IOError if gRPC did not write the message (note this is not -// necessarily an error - the client may simply have gone away) -Status WritePayload(const FlightPayload& payload, - grpc::ClientReaderWriter* writer); -Status WritePayload(const FlightPayload& payload, - grpc::ClientReaderWriter* writer); -Status WritePayload(const FlightPayload& payload, - grpc::ServerReaderWriter* writer); -Status WritePayload(const FlightPayload& payload, - grpc::ServerWriter* writer); - -/// Read Flight message from gRPC stream with zero-copy optimizations. -/// True is returned on success, false if stream ended. -bool ReadPayload(grpc::ClientReader* reader, FlightData* data); -bool ReadPayload(grpc::ClientReaderWriter* reader, - FlightData* data); -bool ReadPayload(grpc::ServerReaderWriter* reader, - FlightData* data); -bool ReadPayload(grpc::ServerReaderWriter* reader, - FlightData* data); -// Overload to make genericity easier in DoPutPayloadWriter -bool ReadPayload(grpc::ClientReaderWriter* reader, - pb::PutResult* data); +/// \brief The header used for transmitting authentication/authorization data. +static constexpr char kAuthHeader[] = "authorization"; + +ARROW_FLIGHT_EXPORT +Status SchemaToString(const Schema& schema, std::string* out); + +// These functions depend on protobuf types which are not exported in the Flight DLL. + +Status FromProto(const pb::ActionType& pb_type, ActionType* type); +Status FromProto(const pb::Action& pb_action, Action* action); +Status FromProto(const pb::Result& pb_result, Result* result); +Status FromProto(const pb::Criteria& pb_criteria, Criteria* criteria); +Status FromProto(const pb::Location& pb_location, Location* location); +Status FromProto(const pb::Ticket& pb_ticket, Ticket* ticket); +Status FromProto(const pb::FlightData& pb_data, FlightDescriptor* descriptor, + std::unique_ptr* message); +Status FromProto(const pb::FlightDescriptor& pb_descr, FlightDescriptor* descr); +Status FromProto(const pb::FlightEndpoint& pb_endpoint, FlightEndpoint* endpoint); +Status FromProto(const pb::FlightInfo& pb_info, FlightInfo::Data* info); +Status FromProto(const pb::SchemaResult& pb_result, std::string* result); +Status FromProto(const pb::BasicAuth& pb_basic_auth, BasicAuth* info); + +Status ToProto(const FlightDescriptor& descr, pb::FlightDescriptor* pb_descr); +Status ToProto(const FlightInfo& info, pb::FlightInfo* pb_info); +Status ToProto(const ActionType& type, pb::ActionType* pb_type); +Status ToProto(const Action& action, pb::Action* pb_action); +Status ToProto(const Result& result, pb::Result* pb_result); +Status ToProto(const Criteria& criteria, pb::Criteria* pb_criteria); +Status ToProto(const SchemaResult& result, pb::SchemaResult* pb_result); +Status ToProto(const Ticket& ticket, pb::Ticket* pb_ticket); +Status ToProto(const BasicAuth& basic_auth, pb::BasicAuth* pb_basic_auth); + +Status ToPayload(const FlightDescriptor& descr, std::shared_ptr* out); // We want to reuse RecordBatchStreamReader's implementation while // (1) Adapting it to the Flight message format // (2) Allowing pure-metadata messages before data is sent // (3) Reusing the reader implementation between DoGet and DoExchange. -// To do this, we wrap the gRPC reader in a peekable iterator. -// The Flight reader can then peek at the message to determine whether -// it has application metadata or not, and pass the message to -// RecordBatchStreamReader as appropriate. -template +// To do this, we wrap the transport-level reader in a peekable +// iterator. The Flight reader can then peek at the message to +// determine whether it has application metadata or not, and pass the +// message to RecordBatchStreamReader as appropriate. class PeekableFlightDataReader { public: - explicit PeekableFlightDataReader(ReaderPtr stream) + explicit PeekableFlightDataReader(TransportDataStream* stream) : stream_(stream), peek_(), finished_(false), valid_(false) {} void Peek(internal::FlightData** out) { @@ -132,7 +123,7 @@ class PeekableFlightDataReader { return valid_; } - if (!internal::ReadPayload(&*stream_, &peek_)) { + if (!stream_->ReadData(&peek_)) { finished_ = true; valid_ = false; } else { @@ -141,7 +132,7 @@ class PeekableFlightDataReader { return valid_; } - ReaderPtr stream_; + internal::TransportDataStream* stream_; internal::FlightData peek_; bool finished_; bool valid_; diff --git a/cpp/src/arrow/flight/server.cc b/cpp/src/arrow/flight/server.cc index 1c613b3c7c0..dff8d075610 100644 --- a/cpp/src/arrow/flight/server.cc +++ b/cpp/src/arrow/flight/server.cc @@ -15,12 +15,17 @@ // specific language governing permissions and limitations // under the License. +// Flight server lifecycle implementation on top of the transport +// interface + // Platform-specific defines #include "arrow/flight/platform.h" #include "arrow/flight/server.h" #ifdef _WIN32 +#include "arrow/util/windows_compatibility.h" + #include #else #include @@ -31,758 +36,23 @@ #include #include #include -#include -#include #include -#include #include -#ifdef GRPCPP_PP_INCLUDE -#include -#else -#include -#endif - -#include "arrow/buffer.h" -#include "arrow/ipc/dictionary.h" -#include "arrow/ipc/options.h" -#include "arrow/ipc/reader.h" -#include "arrow/ipc/writer.h" -#include "arrow/memory_pool.h" -#include "arrow/record_batch.h" +#include "arrow/device.h" +#include "arrow/flight/transport.h" +#include "arrow/flight/transport/grpc/grpc_server.h" +#include "arrow/flight/transport_server.h" +#include "arrow/flight/types.h" #include "arrow/status.h" #include "arrow/util/io_util.h" #include "arrow/util/logging.h" +#include "arrow/util/string_view.h" #include "arrow/util/uri.h" -#include "arrow/flight/internal.h" -#include "arrow/flight/middleware.h" -#include "arrow/flight/middleware_internal.h" -#include "arrow/flight/serialization_internal.h" -#include "arrow/flight/server_auth.h" -#include "arrow/flight/server_middleware.h" -#include "arrow/flight/types.h" - -using FlightService = arrow::flight::protocol::FlightService; -using ServerContext = grpc::ServerContext; - -template -using ServerWriter = grpc::ServerWriter; - namespace arrow { namespace flight { - -namespace pb = arrow::flight::protocol; - -// Macro that runs interceptors before returning the given status -#define RETURN_WITH_MIDDLEWARE(CONTEXT, STATUS) \ - do { \ - const auto& __s = (STATUS); \ - return CONTEXT.FinishRequest(__s); \ - } while (false) - -#define CHECK_ARG_NOT_NULL(CONTEXT, VAL, MESSAGE) \ - if (VAL == nullptr) { \ - RETURN_WITH_MIDDLEWARE(CONTEXT, \ - grpc::Status(grpc::StatusCode::INVALID_ARGUMENT, MESSAGE)); \ - } - -// Same as RETURN_NOT_OK, but accepts either Arrow or gRPC status, and -// will run interceptors -#define SERVICE_RETURN_NOT_OK(CONTEXT, expr) \ - do { \ - const auto& _s = (expr); \ - if (ARROW_PREDICT_FALSE(!_s.ok())) { \ - return CONTEXT.FinishRequest(_s); \ - } \ - } while (false) - namespace { - -// A MessageReader implementation that reads from a gRPC ServerReader. -// Templated to be generic over DoPut/DoExchange. -template -class FlightIpcMessageReader : public ipc::MessageReader { - public: - explicit FlightIpcMessageReader( - std::shared_ptr> peekable_reader, - std::shared_ptr memory_manager, - std::shared_ptr* app_metadata) - : peekable_reader_(peekable_reader), - memory_manager_(std::move(memory_manager)), - app_metadata_(app_metadata) {} - - ::arrow::Result> ReadNextMessage() override { - if (stream_finished_) { - return nullptr; - } - internal::FlightData* data; - peekable_reader_->Next(&data); - if (!data) { - stream_finished_ = true; - if (first_message_) { - return Status::Invalid( - "Client provided malformed message or did not provide message"); - } - return nullptr; - } - if (ARROW_PREDICT_FALSE(!memory_manager_->is_cpu())) { - ARROW_ASSIGN_OR_RAISE(data->body, Buffer::ViewOrCopy(data->body, memory_manager_)); - } - *app_metadata_ = std::move(data->app_metadata); - return data->OpenMessage(); - } - - protected: - std::shared_ptr> peekable_reader_; - std::shared_ptr memory_manager_; - // A reference to FlightMessageReaderImpl.app_metadata_. That class - // can't access the app metadata because when it Peek()s the stream, - // it may be looking at a dictionary batch, not the record - // batch. Updating it here ensures the reader is always updated with - // the last metadata message read. - std::shared_ptr* app_metadata_; - bool first_message_ = true; - bool stream_finished_ = false; -}; - -template -class FlightMessageReaderImpl : public FlightMessageReader { - public: - using GrpcStream = grpc::ServerReaderWriter; - - explicit FlightMessageReaderImpl(GrpcStream* reader, - std::shared_ptr memory_manager) - : reader_(reader), - memory_manager_(std::move(memory_manager)), - peekable_reader_(new internal::PeekableFlightDataReader(reader)) {} - - Status Init() { - // Peek the first message to get the descriptor. - internal::FlightData* data; - peekable_reader_->Peek(&data); - if (!data) { - return Status::IOError("Stream finished before first message sent"); - } - if (!data->descriptor) { - return Status::IOError("Descriptor missing on first message"); - } - descriptor_ = *data->descriptor.get(); // Copy - // If there's a schema (=DoPut), also Open(). - if (data->metadata) { - return EnsureDataStarted(); - } - peekable_reader_->Next(&data); - return Status::OK(); - } - - const FlightDescriptor& descriptor() const override { return descriptor_; } - - arrow::Result> GetSchema() override { - RETURN_NOT_OK(EnsureDataStarted()); - return batch_reader_->schema(); - } - - Status Next(FlightStreamChunk* out) override { - internal::FlightData* data; - peekable_reader_->Peek(&data); - if (!data) { - out->app_metadata = nullptr; - out->data = nullptr; - return Status::OK(); - } - - if (!data->metadata) { - // Metadata-only (data->metadata is the IPC header) - out->app_metadata = data->app_metadata; - out->data = nullptr; - peekable_reader_->Next(&data); - return Status::OK(); - } - - if (!batch_reader_) { - RETURN_NOT_OK(EnsureDataStarted()); - // re-peek here since EnsureDataStarted() advances the stream - return Next(out); - } - RETURN_NOT_OK(batch_reader_->ReadNext(&out->data)); - out->app_metadata = std::move(app_metadata_); - return Status::OK(); - } - - private: - /// Ensure we are set up to read data. - Status EnsureDataStarted() { - if (!batch_reader_) { - // peek() until we find the first data message; discard metadata - if (!peekable_reader_->SkipToData()) { - return Status::IOError("Client never sent a data message"); - } - auto message_reader = - std::unique_ptr(new FlightIpcMessageReader( - peekable_reader_, memory_manager_, &app_metadata_)); - ARROW_ASSIGN_OR_RAISE( - batch_reader_, ipc::RecordBatchStreamReader::Open(std::move(message_reader))); - } - return Status::OK(); - } - - FlightDescriptor descriptor_; - GrpcStream* reader_; - std::shared_ptr memory_manager_; - std::shared_ptr> peekable_reader_; - std::shared_ptr batch_reader_; - std::shared_ptr app_metadata_; -}; - -class GrpcMetadataWriter : public FlightMetadataWriter { - public: - explicit GrpcMetadataWriter( - grpc::ServerReaderWriter* writer) - : writer_(writer) {} - - Status WriteMetadata(const Buffer& buffer) override { - pb::PutResult message{}; - message.set_app_metadata(buffer.data(), buffer.size()); - if (writer_->Write(message)) { - return Status::OK(); - } - return Status::IOError("Unknown error writing metadata."); - } - - private: - grpc::ServerReaderWriter* writer_; -}; - -class GrpcServerAuthReader : public ServerAuthReader { - public: - explicit GrpcServerAuthReader( - grpc::ServerReaderWriter* stream) - : stream_(stream) {} - - Status Read(std::string* token) override { - pb::HandshakeRequest request; - if (stream_->Read(&request)) { - *token = std::move(*request.mutable_payload()); - return Status::OK(); - } - return Status::IOError("Stream is closed."); - } - - private: - grpc::ServerReaderWriter* stream_; -}; - -class GrpcServerAuthSender : public ServerAuthSender { - public: - explicit GrpcServerAuthSender( - grpc::ServerReaderWriter* stream) - : stream_(stream) {} - - Status Write(const std::string& token) override { - pb::HandshakeResponse response; - response.set_payload(token); - if (stream_->Write(response)) { - return Status::OK(); - } - return Status::IOError("Stream was closed."); - } - - private: - grpc::ServerReaderWriter* stream_; -}; - -/// The implementation of the write side of a bidirectional FlightData -/// stream for DoExchange. -class DoExchangeMessageWriter : public FlightMessageWriter { - public: - explicit DoExchangeMessageWriter( - grpc::ServerReaderWriter* stream) - : stream_(stream), ipc_options_(::arrow::ipc::IpcWriteOptions::Defaults()) {} - - Status Begin(const std::shared_ptr& schema, - const ipc::IpcWriteOptions& options) override { - if (started_) { - return Status::Invalid("This writer has already been started."); - } - started_ = true; - ipc_options_ = options; - - RETURN_NOT_OK(mapper_.AddSchemaFields(*schema)); - FlightPayload schema_payload; - RETURN_NOT_OK(ipc::GetSchemaPayload(*schema, ipc_options_, mapper_, - &schema_payload.ipc_message)); - return WritePayload(schema_payload); - } - - Status WriteRecordBatch(const RecordBatch& batch) override { - return WriteWithMetadata(batch, nullptr); - } - - Status WriteMetadata(std::shared_ptr app_metadata) override { - FlightPayload payload{}; - payload.app_metadata = app_metadata; - return WritePayload(payload); - } - - Status WriteWithMetadata(const RecordBatch& batch, - std::shared_ptr app_metadata) override { - RETURN_NOT_OK(CheckStarted()); - RETURN_NOT_OK(EnsureDictionariesWritten(batch)); - FlightPayload payload{}; - if (app_metadata) { - payload.app_metadata = app_metadata; - } - RETURN_NOT_OK(ipc::GetRecordBatchPayload(batch, ipc_options_, &payload.ipc_message)); - RETURN_NOT_OK(WritePayload(payload)); - ++stats_.num_record_batches; - return Status::OK(); - } - - Status Close() override { - // It's fine to Close() without writing data - return Status::OK(); - } - - ipc::WriteStats stats() const override { return stats_; } - - private: - Status WritePayload(const FlightPayload& payload) { - RETURN_NOT_OK(internal::WritePayload(payload, stream_)); - ++stats_.num_messages; - return Status::OK(); - } - - Status CheckStarted() { - if (!started_) { - return Status::Invalid("This writer is not started. Call Begin() with a schema"); - } - return Status::OK(); - } - - Status EnsureDictionariesWritten(const RecordBatch& batch) { - if (dictionaries_written_) { - return Status::OK(); - } - dictionaries_written_ = true; - ARROW_ASSIGN_OR_RAISE(const auto dictionaries, - ipc::CollectDictionaries(batch, mapper_)); - for (const auto& pair : dictionaries) { - FlightPayload payload{}; - RETURN_NOT_OK(ipc::GetDictionaryPayload(pair.first, pair.second, ipc_options_, - &payload.ipc_message)); - RETURN_NOT_OK(WritePayload(payload)); - ++stats_.num_dictionary_batches; - } - return Status::OK(); - } - - grpc::ServerReaderWriter* stream_; - ::arrow::ipc::IpcWriteOptions ipc_options_; - ipc::DictionaryFieldMapper mapper_; - ipc::WriteStats stats_; - bool started_ = false; - bool dictionaries_written_ = false; -}; - -class FlightServiceImpl; -class GrpcServerCallContext : public ServerCallContext { - explicit GrpcServerCallContext(grpc::ServerContext* context) - : context_(context), peer_(context_->peer()) {} - - const std::string& peer_identity() const override { return peer_identity_; } - const std::string& peer() const override { return peer_; } - bool is_cancelled() const override { return context_->IsCancelled(); } - - // Helper method that runs interceptors given the result of an RPC, - // then returns the final gRPC status to send to the client - grpc::Status FinishRequest(const grpc::Status& status) { - // Don't double-convert status - return the original one here - FinishRequest(internal::FromGrpcStatus(status)); - return status; - } - - grpc::Status FinishRequest(const arrow::Status& status) { - for (const auto& instance : middleware_) { - instance->CallCompleted(status); - } - - // Set custom headers to map the exact Arrow status for clients - // who want it. - return internal::ToGrpcStatus(status, context_); - } - - ServerMiddleware* GetMiddleware(const std::string& key) const override { - const auto& instance = middleware_map_.find(key); - if (instance == middleware_map_.end()) { - return nullptr; - } - return instance->second.get(); - } - - private: - friend class FlightServiceImpl; - ServerContext* context_; - std::string peer_; - std::string peer_identity_; - std::vector> middleware_; - std::unordered_map> middleware_map_; -}; - -class GrpcAddCallHeaders : public AddCallHeaders { - public: - explicit GrpcAddCallHeaders(grpc::ServerContext* context) : context_(context) {} - ~GrpcAddCallHeaders() override = default; - - void AddHeader(const std::string& key, const std::string& value) override { - context_->AddInitialMetadata(key, value); - } - - private: - grpc::ServerContext* context_; -}; - -// This class glues an implementation of FlightServerBase together with the -// gRPC service definition, so the latter is not exposed in the public API -class FlightServiceImpl : public FlightService::Service { - public: - explicit FlightServiceImpl( - std::shared_ptr auth_handler, - std::shared_ptr memory_manager, - std::vector>> - middleware, - FlightServerBase* server) - : auth_handler_(auth_handler), - memory_manager_(std::move(memory_manager)), - middleware_(middleware), - server_(server) {} - - template - grpc::Status WriteStream(Iterator* iterator, ServerWriter* writer) { - if (!iterator) { - return grpc::Status(grpc::StatusCode::INTERNAL, "No items to iterate"); - } - // Write flight info to stream until listing is exhausted - while (true) { - ProtoType pb_value; - std::unique_ptr value; - GRPC_RETURN_NOT_OK(iterator->Next(&value)); - if (!value) { - break; - } - GRPC_RETURN_NOT_OK(internal::ToProto(*value, &pb_value)); - - // Blocking write - if (!writer->Write(pb_value)) { - // Write returns false if the stream is closed - break; - } - } - return grpc::Status::OK; - } - - template - grpc::Status WriteStream(const std::vector& values, - ServerWriter* writer) { - // Write flight info to stream until listing is exhausted - for (const UserType& value : values) { - ProtoType pb_value; - GRPC_RETURN_NOT_OK(internal::ToProto(value, &pb_value)); - // Blocking write - if (!writer->Write(pb_value)) { - // Write returns false if the stream is closed - break; - } - } - return grpc::Status::OK; - } - - // Authenticate the client (if applicable) and construct the call context - grpc::Status CheckAuth(const FlightMethod& method, ServerContext* context, - GrpcServerCallContext& flight_context) { - if (!auth_handler_) { - const auto auth_context = context->auth_context(); - if (auth_context && auth_context->IsPeerAuthenticated()) { - auto peer_identity = auth_context->GetPeerIdentity(); - flight_context.peer_identity_ = - peer_identity.empty() - ? "" - : std::string(peer_identity.front().begin(), peer_identity.front().end()); - } else { - flight_context.peer_identity_ = ""; - } - } else { - const auto client_metadata = context->client_metadata(); - const auto auth_header = client_metadata.find(internal::kGrpcAuthHeader); - std::string token; - if (auth_header == client_metadata.end()) { - token = ""; - } else { - token = std::string(auth_header->second.data(), auth_header->second.length()); - } - GRPC_RETURN_NOT_OK(auth_handler_->IsValid(token, &flight_context.peer_identity_)); - } - - return MakeCallContext(method, context, flight_context); - } - - // Authenticate the client (if applicable) and construct the call context - grpc::Status MakeCallContext(const FlightMethod& method, ServerContext* context, - GrpcServerCallContext& flight_context) { - // Run server middleware - const CallInfo info{method}; - CallHeaders incoming_headers; - for (const auto& entry : context->client_metadata()) { - incoming_headers.insert( - {util::string_view(entry.first.data(), entry.first.length()), - util::string_view(entry.second.data(), entry.second.length())}); - } - - GrpcAddCallHeaders outgoing_headers(context); - for (const auto& factory : middleware_) { - std::shared_ptr instance; - Status result = factory.second->StartCall(info, incoming_headers, &instance); - if (!result.ok()) { - // Interceptor rejected call, end the request on all existing - // interceptors - return flight_context.FinishRequest(result); - } - if (instance != nullptr) { - flight_context.middleware_.push_back(instance); - flight_context.middleware_map_.insert({factory.first, instance}); - instance->SendingHeaders(&outgoing_headers); - } - } - - return grpc::Status::OK; - } - - grpc::Status Handshake( - ServerContext* context, - grpc::ServerReaderWriter* stream) { - GrpcServerCallContext flight_context(context); - GRPC_RETURN_NOT_GRPC_OK( - MakeCallContext(FlightMethod::Handshake, context, flight_context)); - - if (!auth_handler_) { - RETURN_WITH_MIDDLEWARE( - flight_context, - grpc::Status( - grpc::StatusCode::UNIMPLEMENTED, - "This service does not have an authentication mechanism enabled.")); - } - GrpcServerAuthSender outgoing{stream}; - GrpcServerAuthReader incoming{stream}; - RETURN_WITH_MIDDLEWARE(flight_context, - auth_handler_->Authenticate(&outgoing, &incoming)); - } - - grpc::Status ListFlights(ServerContext* context, const pb::Criteria* request, - ServerWriter* writer) { - GrpcServerCallContext flight_context(context); - GRPC_RETURN_NOT_GRPC_OK( - CheckAuth(FlightMethod::ListFlights, context, flight_context)); - - // Retrieve the listing from the implementation - std::unique_ptr listing; - - Criteria criteria; - if (request) { - SERVICE_RETURN_NOT_OK(flight_context, internal::FromProto(*request, &criteria)); - } - SERVICE_RETURN_NOT_OK(flight_context, - server_->ListFlights(flight_context, &criteria, &listing)); - if (!listing) { - // Treat null listing as no flights available - RETURN_WITH_MIDDLEWARE(flight_context, grpc::Status::OK); - } - RETURN_WITH_MIDDLEWARE(flight_context, - WriteStream(listing.get(), writer)); - } - - grpc::Status GetFlightInfo(ServerContext* context, const pb::FlightDescriptor* request, - pb::FlightInfo* response) { - GrpcServerCallContext flight_context(context); - GRPC_RETURN_NOT_GRPC_OK( - CheckAuth(FlightMethod::GetFlightInfo, context, flight_context)); - - CHECK_ARG_NOT_NULL(flight_context, request, "FlightDescriptor cannot be null"); - - FlightDescriptor descr; - SERVICE_RETURN_NOT_OK(flight_context, internal::FromProto(*request, &descr)); - - std::unique_ptr info; - SERVICE_RETURN_NOT_OK(flight_context, - server_->GetFlightInfo(flight_context, descr, &info)); - - if (!info) { - // Treat null listing as no flights available - RETURN_WITH_MIDDLEWARE( - flight_context, grpc::Status(grpc::StatusCode::NOT_FOUND, "Flight not found")); - } - - SERVICE_RETURN_NOT_OK(flight_context, internal::ToProto(*info, response)); - RETURN_WITH_MIDDLEWARE(flight_context, grpc::Status::OK); - } - - grpc::Status GetSchema(ServerContext* context, const pb::FlightDescriptor* request, - pb::SchemaResult* response) { - GrpcServerCallContext flight_context(context); - GRPC_RETURN_NOT_GRPC_OK(CheckAuth(FlightMethod::GetSchema, context, flight_context)); - - CHECK_ARG_NOT_NULL(flight_context, request, "FlightDescriptor cannot be null"); - - FlightDescriptor descr; - SERVICE_RETURN_NOT_OK(flight_context, internal::FromProto(*request, &descr)); - - std::unique_ptr result; - SERVICE_RETURN_NOT_OK(flight_context, - server_->GetSchema(flight_context, descr, &result)); - - if (!result) { - // Treat null listing as no flights available - RETURN_WITH_MIDDLEWARE( - flight_context, grpc::Status(grpc::StatusCode::NOT_FOUND, "Flight not found")); - } - - SERVICE_RETURN_NOT_OK(flight_context, internal::ToProto(*result, response)); - RETURN_WITH_MIDDLEWARE(flight_context, grpc::Status::OK); - } - - grpc::Status DoGet(ServerContext* context, const pb::Ticket* request, - ServerWriter* writer) { - GrpcServerCallContext flight_context(context); - GRPC_RETURN_NOT_GRPC_OK(CheckAuth(FlightMethod::DoGet, context, flight_context)); - - CHECK_ARG_NOT_NULL(flight_context, request, "ticket cannot be null"); - - Ticket ticket; - SERVICE_RETURN_NOT_OK(flight_context, internal::FromProto(*request, &ticket)); - - std::unique_ptr data_stream; - SERVICE_RETURN_NOT_OK(flight_context, - server_->DoGet(flight_context, ticket, &data_stream)); - - if (!data_stream) { - RETURN_WITH_MIDDLEWARE(flight_context, grpc::Status(grpc::StatusCode::NOT_FOUND, - "No data in this flight")); - } - - // Write the schema as the first message in the stream - FlightPayload schema_payload; - SERVICE_RETURN_NOT_OK(flight_context, data_stream->GetSchemaPayload(&schema_payload)); - auto status = internal::WritePayload(schema_payload, writer); - if (status.IsIOError()) { - // gRPC doesn't give any way for us to know why the message - // could not be written. - RETURN_WITH_MIDDLEWARE(flight_context, grpc::Status::OK); - } - SERVICE_RETURN_NOT_OK(flight_context, status); - - // Consume data stream and write out payloads - while (true) { - FlightPayload payload; - SERVICE_RETURN_NOT_OK(flight_context, data_stream->Next(&payload)); - // End of stream - if (payload.ipc_message.metadata == nullptr) break; - auto status = internal::WritePayload(payload, writer); - // Connection terminated - if (status.IsIOError()) break; - SERVICE_RETURN_NOT_OK(flight_context, status); - } - RETURN_WITH_MIDDLEWARE(flight_context, grpc::Status::OK); - } - - grpc::Status DoPut(ServerContext* context, - grpc::ServerReaderWriter* reader) { - GrpcServerCallContext flight_context(context); - GRPC_RETURN_NOT_GRPC_OK(CheckAuth(FlightMethod::DoPut, context, flight_context)); - - auto message_reader = std::unique_ptr>( - new FlightMessageReaderImpl(reader, memory_manager_)); - SERVICE_RETURN_NOT_OK(flight_context, message_reader->Init()); - auto metadata_writer = - std::unique_ptr(new GrpcMetadataWriter(reader)); - RETURN_WITH_MIDDLEWARE(flight_context, - server_->DoPut(flight_context, std::move(message_reader), - std::move(metadata_writer))); - } - - grpc::Status DoExchange( - ServerContext* context, - grpc::ServerReaderWriter* stream) { - GrpcServerCallContext flight_context(context); - GRPC_RETURN_NOT_GRPC_OK(CheckAuth(FlightMethod::DoExchange, context, flight_context)); - auto message_reader = std::unique_ptr>( - new FlightMessageReaderImpl(stream, memory_manager_)); - SERVICE_RETURN_NOT_OK(flight_context, message_reader->Init()); - auto writer = - std::unique_ptr(new DoExchangeMessageWriter(stream)); - RETURN_WITH_MIDDLEWARE(flight_context, - server_->DoExchange(flight_context, std::move(message_reader), - std::move(writer))); - } - - grpc::Status ListActions(ServerContext* context, const pb::Empty* request, - ServerWriter* writer) { - GrpcServerCallContext flight_context(context); - GRPC_RETURN_NOT_GRPC_OK( - CheckAuth(FlightMethod::ListActions, context, flight_context)); - // Retrieve the listing from the implementation - std::vector types; - SERVICE_RETURN_NOT_OK(flight_context, server_->ListActions(flight_context, &types)); - RETURN_WITH_MIDDLEWARE(flight_context, WriteStream(types, writer)); - } - - grpc::Status DoAction(ServerContext* context, const pb::Action* request, - ServerWriter* writer) { - GrpcServerCallContext flight_context(context); - GRPC_RETURN_NOT_GRPC_OK(CheckAuth(FlightMethod::DoAction, context, flight_context)); - CHECK_ARG_NOT_NULL(flight_context, request, "Action cannot be null"); - Action action; - SERVICE_RETURN_NOT_OK(flight_context, internal::FromProto(*request, &action)); - - std::unique_ptr results; - SERVICE_RETURN_NOT_OK(flight_context, - server_->DoAction(flight_context, action, &results)); - - if (!results) { - RETURN_WITH_MIDDLEWARE(flight_context, grpc::Status::CANCELLED); - } - - while (true) { - std::unique_ptr result; - SERVICE_RETURN_NOT_OK(flight_context, results->Next(&result)); - if (!result) { - // No more results - break; - } - pb::Result pb_result; - SERVICE_RETURN_NOT_OK(flight_context, internal::ToProto(*result, &pb_result)); - if (!writer->Write(pb_result)) { - // Stream may be closed - break; - } - } - RETURN_WITH_MIDDLEWARE(flight_context, grpc::Status::OK); - } - - private: - std::shared_ptr auth_handler_; - std::shared_ptr memory_manager_; - std::vector>> - middleware_; - FlightServerBase* server_; -}; - -} // namespace - -FlightMetadataWriter::~FlightMetadataWriter() = default; - -// -// gRPC server lifecycle -// - #if (ATOMIC_INT_LOCK_FREE != 2 || ATOMIC_POINTER_LOCK_FREE != 2) #error "atomic ints and atomic pointers not always lock-free!" #endif @@ -799,7 +69,7 @@ using ::arrow::internal::SignalHandler; #endif /// RAII guard that manages a self-pipe and a thread that listens on -/// the self-pipe, shutting down the gRPC server when a signal handler +/// the self-pipe, shutting down the server when a signal handler /// writes to the pipe. class ServerSignalHandler { public: @@ -857,18 +127,19 @@ class ServerSignalHandler { arrow::internal::Pipe self_pipe_; std::thread handle_signals_; }; +} // namespace +/// Server implementation. Manages the lifecycle of the "real" server +/// (ServerTransport) and contains struct FlightServerBase::Impl { - std::unique_ptr service_; - std::unique_ptr server_; - int port_; + std::unique_ptr transport_; // Signal handlers (on Windows) and the shutdown handler (other platforms) // are executed in a separate thread, so getting the current thread instance // wouldn't make sense. This means only a single instance can receive signals. static std::atomic running_instance_; // We'll use the self-pipe trick to notify a thread from the signal - // handler. The thread will then shut down the gRPC server. + // handler. The thread will then shut down the server. int self_pipe_wfd_; // Signal handling @@ -903,7 +174,10 @@ struct FlightServerBase::Impl { } auto instance = running_instance_.load(); if (instance != nullptr) { - instance->server_->Shutdown(); + auto status = instance->transport_->Shutdown(); + if (!status.ok()) { + ARROW_LOG(WARNING) << "Error shutting down server: " << status.ToString(); + } } } }; @@ -927,65 +201,18 @@ FlightServerBase::FlightServerBase() { impl_.reset(new Impl); } FlightServerBase::~FlightServerBase() {} Status FlightServerBase::Init(const FlightServerOptions& options) { - impl_->service_.reset(new FlightServiceImpl( - options.auth_handler, options.memory_manager, options.middleware, this)); - - grpc::ServerBuilder builder; - // Allow uploading messages of any length - builder.SetMaxReceiveMessageSize(-1); - - const Location& location = options.location; - const std::string scheme = location.scheme(); - if (scheme == kSchemeGrpc || scheme == kSchemeGrpcTcp || scheme == kSchemeGrpcTls) { - std::stringstream address; - address << arrow::internal::UriEncodeHost(location.uri_->host()) << ':' - << location.uri_->port_text(); - - std::shared_ptr creds; - if (scheme == kSchemeGrpcTls) { - grpc::SslServerCredentialsOptions ssl_options; - for (const auto& pair : options.tls_certificates) { - ssl_options.pem_key_cert_pairs.push_back({pair.pem_key, pair.pem_cert}); - } - if (options.verify_client) { - ssl_options.client_certificate_request = - GRPC_SSL_REQUEST_AND_REQUIRE_CLIENT_CERTIFICATE_AND_VERIFY; - } - if (!options.root_certificates.empty()) { - ssl_options.pem_root_certs = options.root_certificates; - } - creds = grpc::SslServerCredentials(ssl_options); - } else { - creds = grpc::InsecureServerCredentials(); - } - - builder.AddListeningPort(address.str(), creds, &impl_->port_); - } else if (scheme == kSchemeGrpcUnix) { - std::stringstream address; - address << "unix:" << location.uri_->path(); - builder.AddListeningPort(address.str(), grpc::InsecureServerCredentials()); - } else { - return Status::NotImplemented("Scheme is not supported: " + scheme); - } - - builder.RegisterService(impl_->service_.get()); - - // Disable SO_REUSEPORT - it makes debugging/testing a pain as - // leftover processes can handle requests on accident - builder.AddChannelArgument(GRPC_ARG_ALLOW_REUSEPORT, 0); - - if (options.builder_hook) { - options.builder_hook(&builder); - } + flight::transport::grpc::InitializeFlightGrpcServer(); - impl_->server_ = builder.BuildAndStart(); - if (!impl_->server_) { - return Status::UnknownError("Server did not start properly"); - } - return Status::OK(); + const auto scheme = options.location.scheme(); + ARROW_ASSIGN_OR_RAISE(impl_->transport_, + internal::GetDefaultTransportRegistry()->MakeServer( + scheme, this, options.memory_manager)); + return impl_->transport_->Init(options, *options.location.uri_); } -int FlightServerBase::port() const { return impl_->port_; } +int FlightServerBase::port() const { return location().uri_->port(); } + +Location FlightServerBase::location() const { return impl_->transport_->location(); } Status FlightServerBase::SetShutdownOnSignals(const std::vector sigs) { impl_->signals_ = sigs; @@ -994,7 +221,7 @@ Status FlightServerBase::SetShutdownOnSignals(const std::vector sigs) { } Status FlightServerBase::Serve() { - if (!impl_->server_) { + if (!impl_->transport_) { return Status::UnknownError("Server did not start properly"); } impl_->got_signal_ = 0; @@ -1012,7 +239,7 @@ Status FlightServerBase::Serve() { impl_->old_signal_handlers_.push_back(std::move(old_handler)); } - impl_->server_->Wait(); + RETURN_NOT_OK(impl_->transport_->Wait()); impl_->running_instance_ = nullptr; // Restore signal handlers @@ -1026,21 +253,19 @@ Status FlightServerBase::Serve() { int FlightServerBase::GotSignal() const { return impl_->got_signal_; } Status FlightServerBase::Shutdown(const std::chrono::system_clock::time_point* deadline) { - auto server = impl_->server_.get(); + auto server = impl_->transport_.get(); if (!server) { return Status::Invalid("Shutdown() on uninitialized FlightServerBase"); } impl_->running_instance_ = nullptr; - if (deadline == nullptr) { - impl_->server_->Shutdown(); - } else { - impl_->server_->Shutdown(*deadline); + if (deadline) { + return impl_->transport_->Shutdown(*deadline); } - return Status::OK(); + return impl_->transport_->Shutdown(); } Status FlightServerBase::Wait() { - impl_->server_->Wait(); + RETURN_NOT_OK(impl_->transport_->Wait()); impl_->running_instance_ = nullptr; return Status::OK(); } @@ -1138,7 +363,7 @@ class RecordBatchStream::RecordBatchStreamImpl { RETURN_NOT_OK(reader_->ReadNext(¤t_batch_)); - // TODO(wesm): Delta dictionaries + // TODO(ARROW-10787): Delta dictionaries if (!current_batch_) { // Signal that iteration is over payload->ipc_message.metadata = nullptr; @@ -1167,6 +392,8 @@ class RecordBatchStream::RecordBatchStreamImpl { int dictionary_index_ = 0; }; +FlightMetadataWriter::~FlightMetadataWriter() = default; + FlightDataStream::~FlightDataStream() {} RecordBatchStream::RecordBatchStream(const std::shared_ptr& reader, diff --git a/cpp/src/arrow/flight/server.h b/cpp/src/arrow/flight/server.h index 2f76735ada8..df17f2cc197 100644 --- a/cpp/src/arrow/flight/server.h +++ b/cpp/src/arrow/flight/server.h @@ -28,6 +28,7 @@ #include #include "arrow/flight/server_auth.h" +#include "arrow/flight/type_fwd.h" #include "arrow/flight/types.h" // IWYU pragma: keep #include "arrow/flight/visibility.h" // IWYU pragma: keep #include "arrow/ipc/dictionary.h" @@ -41,9 +42,6 @@ class Status; namespace flight { -class ServerMiddleware; -class ServerMiddlewareFactory; - /// \brief Interface that produces a sequence of IPC payloads to be sent in /// FlightData protobuf messages class ARROW_FLIGHT_EXPORT FlightDataStream { @@ -61,7 +59,7 @@ class ARROW_FLIGHT_EXPORT FlightDataStream { }; /// \brief A basic implementation of FlightDataStream that will provide -/// a sequence of FlightData messages to be written to a gRPC stream +/// a sequence of FlightData messages to be written to a stream class ARROW_FLIGHT_EXPORT RecordBatchStream : public FlightDataStream { public: /// \param[in] reader produces a sequence of record batches @@ -184,6 +182,10 @@ class ARROW_FLIGHT_EXPORT FlightServerBase { /// domain socket). int port() const; + /// \brief Get the address that the Flight server is listening on. + /// This method must only be called after Init(). + Location location() const; + /// \brief Set the server to stop when receiving any of the given signal /// numbers. /// This method must be called before Serve(). diff --git a/cpp/src/arrow/flight/sql/CMakeLists.txt b/cpp/src/arrow/flight/sql/CMakeLists.txt index 8c84ce8da0f..3c28e68597a 100644 --- a/cpp/src/arrow/flight/sql/CMakeLists.txt +++ b/cpp/src/arrow/flight/sql/CMakeLists.txt @@ -58,12 +58,11 @@ add_arrow_lib(arrow_flight_sql arrow_flight_static) if(ARROW_FLIGHT_TEST_LINKAGE STREQUAL "static") - set(ARROW_FLIGHT_SQL_TEST_LINK_LIBS - arrow_flight_sql_static arrow_flight_testing_static - ${ARROW_FLIGHT_STATIC_LINK_LIBS} ${ARROW_TEST_LINK_LIBS}) + set(ARROW_FLIGHT_SQL_TEST_LINK_LIBS arrow_flight_sql_static + ${ARROW_FLIGHT_TEST_LINK_LIBS}) else() - set(ARROW_FLIGHT_SQL_TEST_LINK_LIBS arrow_flight_sql_shared arrow_flight_testing_shared - ${ARROW_TEST_LINK_LIBS}) + set(ARROW_FLIGHT_SQL_TEST_LINK_LIBS arrow_flight_sql_shared + ${ARROW_FLIGHT_TEST_LINK_LIBS}) endif() # Build test server for unit tests diff --git a/cpp/src/arrow/flight/sql/client.cc b/cpp/src/arrow/flight/sql/client.cc index e72935c144b..199eebb66b3 100644 --- a/cpp/src/arrow/flight/sql/client.cc +++ b/cpp/src/arrow/flight/sql/client.cc @@ -92,7 +92,7 @@ arrow::Result FlightSqlClient::ExecuteUpdate(const FlightCallOptions& o std::unique_ptr writer; std::unique_ptr reader; - ARROW_RETURN_NOT_OK(DoPut(options, descriptor, NULLPTR, &writer, &reader)); + ARROW_RETURN_NOT_OK(DoPut(options, descriptor, arrow::schema({}), &writer, &reader)); std::shared_ptr metadata; diff --git a/cpp/src/arrow/flight/test_definitions.cc b/cpp/src/arrow/flight/test_definitions.cc index ad40fb6a6cd..099769ded9d 100644 --- a/cpp/src/arrow/flight/test_definitions.cc +++ b/cpp/src/arrow/flight/test_definitions.cc @@ -21,9 +21,17 @@ #include "arrow/array/array_base.h" #include "arrow/array/array_dict.h" +#include "arrow/array/util.h" #include "arrow/flight/api.h" #include "arrow/flight/test_util.h" +#include "arrow/table.h" #include "arrow/testing/generator.h" +#include "arrow/util/config.h" +#include "arrow/util/logging.h" + +#if defined(ARROW_CUDA) +#include "arrow/gpu/cuda_api.h" +#endif namespace arrow { namespace flight { @@ -237,6 +245,7 @@ void DataTest::TestOverflowServerBatch() { EXPECT_RAISES_WITH_MESSAGE_THAT( Invalid, ::testing::HasSubstr("Cannot send record batches exceeding 2GiB yet"), reader->ReadAll(&batches)); + ARROW_UNUSED(writer->Close()); } } void DataTest::TestOverflowClientBatch() { @@ -477,17 +486,19 @@ void DataTest::TestDoExchangeError() { FlightStreamChunk chunk; EXPECT_RAISES_WITH_MESSAGE_THAT( NotImplemented, ::testing::HasSubstr("Expected error"), reader->Next(&chunk)); + ARROW_UNUSED(writer->Close()); } { ASSERT_OK(client_->DoExchange(descr, &writer, &reader)); EXPECT_RAISES_WITH_MESSAGE_THAT( NotImplemented, ::testing::HasSubstr("Expected error"), reader->GetSchema()); + ARROW_UNUSED(writer->Close()); } // writer->Begin isn't tested here because, as noted in client.cc, // OpenRecordBatchWriter lazily writes the initial message - hence - // Begin() won't fail. Additionally, it appears gRPC may buffer - // writes - a write won't immediately fail even when the server - // immediately fails. + // Begin() won't fail. Additionally, transports are allowed to + // buffer writes - a write won't immediately fail even if the server + // would immediately return an error. } void DataTest::TestIssue5095() { // Make sure the server-side error message is reflected to the @@ -513,7 +524,17 @@ class DoPutTestServer : public FlightServerBase { std::unique_ptr reader, std::unique_ptr writer) override { descriptor_ = reader->descriptor(); - return reader->ReadAll(&batches_); + int counter = 0; + while (true) { + FlightStreamChunk chunk; + RETURN_NOT_OK(reader->Next(&chunk)); + if (!chunk.data) break; + batches_.push_back(std::move(chunk.data)); + auto buffer = Buffer::FromString(std::to_string(counter)); + RETURN_NOT_OK(writer->WriteMetadata(*buffer)); + counter++; + } + return Status::OK(); } protected: @@ -524,13 +545,16 @@ class DoPutTestServer : public FlightServerBase { }; void DoPutTest::SetUp() { + ASSERT_OK_AND_ASSIGN(auto location, Location::ForScheme(transport(), "localhost", 0)); ASSERT_OK(MakeServer( - &server_, &client_, [](FlightServerOptions* options) { return Status::OK(); }, + location, &server_, &client_, + [](FlightServerOptions* options) { return Status::OK(); }, [](FlightClientOptions* options) { return Status::OK(); })); } void DoPutTest::TearDown() { ASSERT_OK(client_->Close()); ASSERT_OK(server_->Shutdown()); + reinterpret_cast(server_.get())->batches_.clear(); } void DoPutTest::CheckBatches(const FlightDescriptor& expected_descriptor, const RecordBatchVector& expected_batches) { @@ -547,10 +571,21 @@ void DoPutTest::CheckDoPut(const FlightDescriptor& descr, std::unique_ptr stream; std::unique_ptr reader; ASSERT_OK(client_->DoPut(descr, schema, &stream, &reader)); + + // Ensure that the reader can be used independently of the writer + auto* reader_ref = reader.get(); + std::thread reader_thread([reader_ref, &batches]() { + for (size_t i = 0; i < batches.size(); i++) { + std::shared_ptr out; + ASSERT_OK(reader_ref->ReadMetadata(&out)); + } + }); + for (const auto& batch : batches) { ASSERT_OK(stream->WriteRecordBatch(*batch)); } ASSERT_OK(stream->DoneWriting()); + reader_thread.join(); ASSERT_OK(stream->Close()); CheckBatches(descr, batches); @@ -579,7 +614,7 @@ void DoPutTest::TestInts() { CheckDoPut(descr, schema, batches); } -void DoPutTest::TestDoPutFloats() { +void DoPutTest::TestFloats() { auto descr = FlightDescriptor::Path({"floats"}); RecordBatchVector batches; auto a0 = ArrayFromJSON(float32(), "[0, 1.2, -3.4, 5.6, null]"); @@ -590,7 +625,7 @@ void DoPutTest::TestDoPutFloats() { CheckDoPut(descr, schema, batches); } -void DoPutTest::TestDoPutEmptyBatch() { +void DoPutTest::TestEmptyBatch() { // Sending and receiving a 0-sized batch shouldn't fail auto descr = FlightDescriptor::Path({"ints"}); RecordBatchVector batches; @@ -601,7 +636,7 @@ void DoPutTest::TestDoPutEmptyBatch() { CheckDoPut(descr, schema, batches); } -void DoPutTest::TestDoPutDicts() { +void DoPutTest::TestDicts() { auto descr = FlightDescriptor::Path({"dicts"}); RecordBatchVector batches; auto dict_values = ArrayFromJSON(utf8(), "[\"foo\", \"bar\", \"quux\"]"); @@ -619,7 +654,7 @@ void DoPutTest::TestDoPutDicts() { // Ensure the server is configured to allow large messages by default // Tests a 32 MiB batch -void DoPutTest::TestDoPutLargeBatch() { +void DoPutTest::TestLargeBatch() { auto descr = FlightDescriptor::Path({"large-batches"}); auto schema = ExampleLargeSchema(); RecordBatchVector batches; @@ -627,10 +662,10 @@ void DoPutTest::TestDoPutLargeBatch() { CheckDoPut(descr, schema, batches); } -void DoPutTest::TestDoPutSizeLimit() { +void DoPutTest::TestSizeLimit() { const int64_t size_limit = 4096; - Location location; - ASSERT_OK(Location::ForGrpcTcp("localhost", server_->port(), &location)); + ASSERT_OK_AND_ASSIGN(auto location, + Location::ForScheme(transport(), "localhost", server_->port())); auto client_options = FlightClientOptions::Defaults(); client_options.write_size_limit_bytes = size_limit; std::unique_ptr client; @@ -709,8 +744,10 @@ Status AppMetadataTestServer::DoPut(const ServerCallContext& context, } void AppMetadataTest::SetUp() { + ASSERT_OK_AND_ASSIGN(auto location, Location::ForScheme(transport(), "localhost", 0)); ASSERT_OK(MakeServer( - &server_, &client_, [](FlightServerOptions* options) { return Status::OK(); }, + location, &server_, &client_, + [](FlightServerOptions* options) { return Status::OK(); }, [](FlightClientOptions* options) { return Status::OK(); })); } void AppMetadataTest::TearDown() { @@ -777,9 +814,10 @@ void AppMetadataTest::TestDoPut() { ASSERT_OK(writer->WriteWithMetadata(*expected_batches[i], Buffer::FromString(std::to_string(i)))); } - // This eventually calls grpc::ClientReaderWriter::Finish which can - // hang if there are unread messages. So make sure our wrapper - // around this doesn't hang (because it drains any unread messages) + // Transports may behave unpredictably if streams are not + // drained. So explicitly close to see if the transport misbehaves + // (e.g. gRPC will hang if the Flight transport layer doesn't drain + // messages) ASSERT_OK(writer->Close()); } // Test DoPut() with dictionaries. This tests a corner case in the @@ -890,8 +928,10 @@ class IpcOptionsTestServer : public FlightServerBase { }; void IpcOptionsTest::SetUp() { + ASSERT_OK_AND_ASSIGN(auto location, Location::ForScheme(transport(), "localhost", 0)); ASSERT_OK(MakeServer( - &server_, &client_, [](FlightServerOptions* options) { return Status::OK(); }, + location, &server_, &client_, + [](FlightServerOptions* options) { return Status::OK(); }, [](FlightClientOptions* options) { return Status::OK(); })); } void IpcOptionsTest::TearDown() { @@ -976,5 +1016,217 @@ void IpcOptionsTest::TestDoExchangeServerWriteOptions() { ASSERT_RAISES(Invalid, writer->Close()); } +//------------------------------------------------------------ +// Test CUDA memory in data plane methods + +#if defined(ARROW_CUDA) + +Status CheckBuffersOnDevice(const Array& array, const Device& device) { + if (array.num_fields() != 0) { + return Status::NotImplemented("Nested arrays"); + } + for (const auto& buffer : array.data()->buffers) { + if (!buffer) continue; + if (!buffer->device()->Equals(device)) { + return Status::Invalid("Expected buffer on device: ", device.ToString(), + ". Was allocated on device: ", buffer->device()->ToString()); + } + } + return Status::OK(); +} + +// Copy a record batch to host memory. +arrow::Result> CopyBatchToHost(const RecordBatch& batch) { + auto mm = CPUDevice::Instance()->default_memory_manager(); + ArrayVector arrays; + for (const auto& column : batch.columns()) { + std::shared_ptr data = column->data()->Copy(); + if (data->child_data.size() != 0) { + return Status::NotImplemented("Nested arrays"); + } + + for (size_t i = 0; i < data->buffers.size(); i++) { + const auto& buffer = data->buffers[i]; + if (!buffer || buffer->is_cpu()) continue; + ARROW_ASSIGN_OR_RAISE(data->buffers[i], Buffer::Copy(buffer, mm)); + } + arrays.push_back(MakeArray(data)); + } + return RecordBatch::Make(batch.schema(), batch.num_rows(), std::move(arrays)); +} + +class CudaTestServer : public FlightServerBase { + public: + explicit CudaTestServer(std::shared_ptr device) : device_(std::move(device)) {} + + Status DoGet(const ServerCallContext&, const Ticket&, + std::unique_ptr* data_stream) override { + RecordBatchVector batches; + RETURN_NOT_OK(ExampleIntBatches(&batches)); + ARROW_ASSIGN_OR_RAISE(auto batch_reader, RecordBatchReader::Make(batches)); + *data_stream = std::unique_ptr(new RecordBatchStream(batch_reader)); + return Status::OK(); + } + + Status DoPut(const ServerCallContext&, std::unique_ptr reader, + std::unique_ptr writer) override { + RecordBatchVector batches; + RETURN_NOT_OK(reader->ReadAll(&batches)); + for (const auto& batch : batches) { + for (const auto& column : batch->columns()) { + RETURN_NOT_OK(CheckBuffersOnDevice(*column, *device_)); + } + } + return Status::OK(); + } + + Status DoExchange(const ServerCallContext& context, + std::unique_ptr reader, + std::unique_ptr writer) override { + FlightStreamChunk chunk; + bool begun = false; + while (true) { + RETURN_NOT_OK(reader->Next(&chunk)); + if (!chunk.data) break; + if (!begun) { + begun = true; + RETURN_NOT_OK(writer->Begin(chunk.data->schema())); + } + for (const auto& column : chunk.data->columns()) { + RETURN_NOT_OK(CheckBuffersOnDevice(*column, *device_)); + } + RETURN_NOT_OK(writer->WriteRecordBatch(*chunk.data)); + } + return Status::OK(); + } + + private: + std::shared_ptr device_; +}; + +// Store CUDA objects without exposing them in the public header +class CudaDataTest::Impl { + public: + cuda::CudaDeviceManager* manager; + std::shared_ptr device; + std::shared_ptr context; +}; + +void CudaDataTest::SetUp() { + ASSERT_OK_AND_ASSIGN(auto manager, cuda::CudaDeviceManager::Instance()); + ASSERT_OK_AND_ASSIGN(auto device, manager->GetDevice(0)); + ASSERT_OK_AND_ASSIGN(auto context, device->GetContext()); + impl_.reset(new Impl()); + impl_->manager = manager; + impl_->device = std::move(device); + impl_->context = std::move(context); + + ASSERT_OK_AND_ASSIGN(auto location, Location::ForScheme(transport(), "localhost", 0)); + ASSERT_OK(MakeServer( + location, &server_, &client_, + [this](FlightServerOptions* options) { + options->memory_manager = impl_->device->default_memory_manager(); + return Status::OK(); + }, + [](FlightClientOptions* options) { return Status::OK(); }, impl_->device)); +} +void CudaDataTest::TearDown() { + ASSERT_OK(client_->Close()); + ASSERT_OK(server_->Shutdown()); +} +void CudaDataTest::TestDoGet() { + // Check that we can allocate the results of DoGet with a custom + // memory manager. + FlightCallOptions options; + options.memory_manager = impl_->device->default_memory_manager(); + + Ticket ticket{""}; + std::unique_ptr stream; + ASSERT_OK(client_->DoGet(options, ticket, &stream)); + std::shared_ptr
table; + ASSERT_OK(stream->ReadAll(&table)); + + for (const auto& column : table->columns()) { + for (const auto& chunk : column->chunks()) { + ASSERT_OK(CheckBuffersOnDevice(*chunk, *impl_->device)); + } + } +} +void CudaDataTest::TestDoPut() { + RecordBatchVector batches; + ASSERT_OK(ExampleIntBatches(&batches)); + + std::unique_ptr writer; + std::unique_ptr reader; + auto descriptor = FlightDescriptor::Path({""}); + ASSERT_OK(client_->DoPut(descriptor, batches[0]->schema(), &writer, &reader)); + + ipc::DictionaryMemo memo; + for (const auto& batch : batches) { + ASSERT_OK_AND_ASSIGN(auto buffer, + cuda::SerializeRecordBatch(*batch, impl_->context.get())); + ASSERT_OK_AND_ASSIGN(auto cuda_batch, + cuda::ReadRecordBatch(batch->schema(), &memo, buffer)); + + for (const auto& column : cuda_batch->columns()) { + ASSERT_OK(CheckBuffersOnDevice(*column, *impl_->device)); + } + + ASSERT_OK(writer->WriteRecordBatch(*cuda_batch)); + } + ASSERT_OK(writer->Close()); +} +void CudaDataTest::TestDoExchange() { + FlightCallOptions options; + options.memory_manager = impl_->device->default_memory_manager(); + + RecordBatchVector batches; + ASSERT_OK(ExampleIntBatches(&batches)); + + std::unique_ptr writer; + std::unique_ptr reader; + auto descriptor = FlightDescriptor::Path({""}); + ASSERT_OK(client_->DoExchange(options, descriptor, &writer, &reader)); + ASSERT_OK(writer->Begin(batches[0]->schema())); + + ipc::DictionaryMemo write_memo; + ipc::DictionaryMemo read_memo; + for (const auto& batch : batches) { + ASSERT_OK_AND_ASSIGN(auto buffer, + cuda::SerializeRecordBatch(*batch, impl_->context.get())); + ASSERT_OK_AND_ASSIGN(auto cuda_batch, + cuda::ReadRecordBatch(batch->schema(), &write_memo, buffer)); + + for (const auto& column : cuda_batch->columns()) { + ASSERT_OK(CheckBuffersOnDevice(*column, *impl_->device)); + } + + ASSERT_OK(writer->WriteRecordBatch(*cuda_batch)); + + FlightStreamChunk chunk; + ASSERT_OK(reader->Next(&chunk)); + for (const auto& column : chunk.data->columns()) { + ASSERT_OK(CheckBuffersOnDevice(*column, *impl_->device)); + } + + // Bounce record batch back to host memory + ASSERT_OK_AND_ASSIGN(auto host_batch, CopyBatchToHost(*chunk.data)); + AssertBatchesEqual(*batch, *host_batch); + } + ASSERT_OK(writer->Close()); +} + +#else + +void CudaDataTest::SetUp() {} +void CudaDataTest::TearDown() {} +void CudaDataTest::TestDoGet() { GTEST_SKIP() << "Arrow was built without ARROW_CUDA"; } +void CudaDataTest::TestDoPut() { GTEST_SKIP() << "Arrow was built without ARROW_CUDA"; } +void CudaDataTest::TestDoExchange() { + GTEST_SKIP() << "Arrow was built without ARROW_CUDA"; +} + +#endif + } // namespace flight } // namespace arrow diff --git a/cpp/src/arrow/flight/test_definitions.h b/cpp/src/arrow/flight/test_definitions.h index ff107f802e3..1e256456557 100644 --- a/cpp/src/arrow/flight/test_definitions.h +++ b/cpp/src/arrow/flight/test_definitions.h @@ -86,6 +86,7 @@ class ARROW_FLIGHT_EXPORT DataTest : public FlightTest { std::unique_ptr server_; }; +/// \brief Specific tests of DoPut. class ARROW_FLIGHT_EXPORT DoPutTest : public FlightTest { public: void SetUp(); @@ -97,11 +98,11 @@ class ARROW_FLIGHT_EXPORT DoPutTest : public FlightTest { // Test methods void TestInts(); - void TestDoPutFloats(); - void TestDoPutEmptyBatch(); - void TestDoPutDicts(); - void TestDoPutLargeBatch(); - void TestDoPutSizeLimit(); + void TestFloats(); + void TestEmptyBatch(); + void TestDicts(); + void TestLargeBatch(); + void TestSizeLimit(); private: std::unique_ptr client_; @@ -120,6 +121,7 @@ class ARROW_FLIGHT_EXPORT AppMetadataTestServer : public FlightServerBase { std::unique_ptr writer) override; }; +/// \brief Tests of app_metadata in data plane methods. class ARROW_FLIGHT_EXPORT AppMetadataTest : public FlightTest { public: void SetUp(); @@ -137,6 +139,7 @@ class ARROW_FLIGHT_EXPORT AppMetadataTest : public FlightTest { std::unique_ptr server_; }; +/// \brief Tests of IPC options in data plane methods. class ARROW_FLIGHT_EXPORT IpcOptionsTest : public FlightTest { public: void SetUp(); @@ -154,5 +157,25 @@ class ARROW_FLIGHT_EXPORT IpcOptionsTest : public FlightTest { std::unique_ptr server_; }; +/// \brief Tests of data plane methods with CUDA memory. +/// +/// If not built with ARROW_CUDA, tests are no-ops. +class ARROW_FLIGHT_EXPORT CudaDataTest : public FlightTest { + public: + void SetUp() override; + void TearDown() override; + + // Test methods + void TestDoGet(); + void TestDoPut(); + void TestDoExchange(); + + private: + class Impl; + std::unique_ptr client_; + std::unique_ptr server_; + std::shared_ptr impl_; +}; + } // namespace flight } // namespace arrow diff --git a/cpp/src/arrow/flight/test_util.cc b/cpp/src/arrow/flight/test_util.cc index 4405c24eea5..490f53bfb2f 100644 --- a/cpp/src/arrow/flight/test_util.cc +++ b/cpp/src/arrow/flight/test_util.cc @@ -15,7 +15,7 @@ // specific language governing permissions and limitations // under the License. -#include "arrow/flight/platform.h" +#include "arrow/flight/test_util.h" #ifdef __APPLE__ #include @@ -26,13 +26,15 @@ #include #include +// We need Windows fixes before including Boost +#include "arrow/util/windows_compatibility.h" + #include // We need BOOST_USE_WINDOWS_H definition with MinGW when we use // boost/process.hpp. See ARROW_BOOST_PROCESS_COMPILE_DEFINITIONS in // cpp/cmake_modules/BuildUtils.cmake for details. -#include - #include +#include #include "arrow/array.h" #include "arrow/array/builder_primitive.h" @@ -43,8 +45,7 @@ #include "arrow/util/logging.h" #include "arrow/flight/api.h" -#include "arrow/flight/internal.h" -#include "arrow/flight/test_util.h" +#include "arrow/flight/serialization_internal.h" namespace arrow { namespace flight { @@ -189,7 +190,7 @@ class FlightTestServer : public FlightServerBase { Status GetFlightInfo(const ServerCallContext& context, const FlightDescriptor& request, std::unique_ptr* out) override { - // Test that Arrow-C++ status codes can make it through gRPC + // Test that Arrow-C++ status codes make it through the transport if (request.type == FlightDescriptor::DescriptorType::CMD && request.cmd == "status-outofmemory") { return Status::OutOfMemory("Sentinel"); diff --git a/cpp/src/arrow/flight/test_util.h b/cpp/src/arrow/flight/test_util.h index 75a534d56b3..385eb58fa16 100644 --- a/cpp/src/arrow/flight/test_util.h +++ b/cpp/src/arrow/flight/test_util.h @@ -30,6 +30,7 @@ #include "arrow/status.h" #include "arrow/testing/gtest_util.h" #include "arrow/testing/util.h" +#include "arrow/util/make_unique.h" #include "arrow/flight/client.h" #include "arrow/flight/client_auth.h" @@ -37,7 +38,6 @@ #include "arrow/flight/server_auth.h" #include "arrow/flight/types.h" #include "arrow/flight/visibility.h" -#include "arrow/util/make_unique.h" namespace boost { namespace process { @@ -104,28 +104,43 @@ std::unique_ptr ExampleTestServer(); // Helper to initialize a server and matching client with callbacks to // populate options. template -Status MakeServer(std::unique_ptr* server, +Status MakeServer(const Location& location, std::unique_ptr* server, std::unique_ptr* client, std::function make_server_options, std::function make_client_options, Args&&... server_args) { - Location location; - RETURN_NOT_OK(Location::ForGrpcTcp("localhost", 0, &location)); *server = arrow::internal::make_unique(std::forward(server_args)...); FlightServerOptions server_options(location); RETURN_NOT_OK(make_server_options(&server_options)); RETURN_NOT_OK((*server)->Init(server_options)); Location real_location; - RETURN_NOT_OK(Location::ForGrpcTcp("localhost", (*server)->port(), &real_location)); + std::string uri = + location.scheme() + "://localhost:" + std::to_string((*server)->port()); + RETURN_NOT_OK(Location::Parse(uri, &real_location)); FlightClientOptions client_options = FlightClientOptions::Defaults(); RETURN_NOT_OK(make_client_options(&client_options)); return FlightClient::Connect(real_location, client_options, client); } +// Helper to initialize a server and matching client with callbacks to +// populate options. +template +Status MakeServer(std::unique_ptr* server, + std::unique_ptr* client, + std::function make_server_options, + std::function make_client_options, + Args&&... server_args) { + Location location; + RETURN_NOT_OK(Location::ForGrpcTcp("localhost", 0, &location)); + return MakeServer(location, server, client, std::move(make_server_options), + std::move(make_client_options), + std::forward(server_args)...); +} + // ---------------------------------------------------------------------- // A FlightDataStream that numbers the record batches /// \brief A basic implementation of FlightDataStream that will provide -/// a sequence of FlightData messages to be written to a gRPC stream +/// a sequence of FlightData messages to be written to a stream class ARROW_FLIGHT_EXPORT NumberingStream : public FlightDataStream { public: explicit NumberingStream(std::unique_ptr stream); diff --git a/cpp/src/arrow/flight/transport.cc b/cpp/src/arrow/flight/transport.cc new file mode 100644 index 00000000000..7a2429d0e39 --- /dev/null +++ b/cpp/src/arrow/flight/transport.cc @@ -0,0 +1,164 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#include "arrow/flight/transport.h" + +#include + +#include "arrow/flight/client_auth.h" +#include "arrow/flight/transport_server.h" +#include "arrow/ipc/message.h" +#include "arrow/result.h" +#include "arrow/status.h" +#include "arrow/util/make_unique.h" + +namespace arrow { +namespace flight { +namespace internal { + +::arrow::Result> FlightData::OpenMessage() { + return ipc::Message::Open(metadata, body); +} + +bool TransportDataStream::ReadData(internal::FlightData*) { return false; } +arrow::Result TransportDataStream::WriteData(const FlightPayload&) { + return Status::NotImplemented("Writing data for this stream"); +} +Status TransportDataStream::WritesDone() { return Status::OK(); } +bool ClientDataStream::ReadPutMetadata(std::shared_ptr*) { return false; } +Status ClientDataStream::Finish(Status st) { + auto server_status = DoFinish(); + if (server_status.ok()) return st; + + return Status::FromDetailAndArgs(server_status.code(), server_status.detail(), + server_status.message(), + ". Client context: ", st.ToString()); +} + +Status ClientTransport::Authenticate(const FlightCallOptions& options, + std::unique_ptr auth_handler) { + return Status::NotImplemented("Authenticate for this transport"); +} +arrow::Result> +ClientTransport::AuthenticateBasicToken(const FlightCallOptions& options, + const std::string& username, + const std::string& password) { + return Status::NotImplemented("AuthenticateBasicToken for this transport"); +} +Status ClientTransport::DoAction(const FlightCallOptions& options, const Action& action, + std::unique_ptr* results) { + return Status::NotImplemented("DoAction for this transport"); +} +Status ClientTransport::ListActions(const FlightCallOptions& options, + std::vector* actions) { + return Status::NotImplemented("ListActions for this transport"); +} +Status ClientTransport::GetFlightInfo(const FlightCallOptions& options, + const FlightDescriptor& descriptor, + std::unique_ptr* info) { + return Status::NotImplemented("GetFlightInfo for this transport"); +} +Status ClientTransport::GetSchema(const FlightCallOptions& options, + const FlightDescriptor& descriptor, + std::unique_ptr* schema_result) { + return Status::NotImplemented("GetSchema for this transport"); +} +Status ClientTransport::ListFlights(const FlightCallOptions& options, + const Criteria& criteria, + std::unique_ptr* listing) { + return Status::NotImplemented("ListFlights for this transport"); +} +Status ClientTransport::DoGet(const FlightCallOptions& options, const Ticket& ticket, + std::unique_ptr* stream) { + return Status::NotImplemented("DoGet for this transport"); +} +Status ClientTransport::DoPut(const FlightCallOptions& options, + std::unique_ptr* stream) { + return Status::NotImplemented("DoPut for this transport"); +} +Status ClientTransport::DoExchange(const FlightCallOptions& options, + std::unique_ptr* stream) { + return Status::NotImplemented("DoExchange for this transport"); +} + +class TransportRegistry::Impl final { + public: + arrow::Result> MakeClient( + const std::string& scheme) const { + auto it = client_factories_.find(scheme); + if (it == client_factories_.end()) { + return Status::KeyError("No client transport implementation for ", scheme); + } + return it->second(); + } + arrow::Result> MakeServer( + const std::string& scheme, FlightServerBase* base, + std::shared_ptr memory_manager) const { + auto it = server_factories_.find(scheme); + if (it == server_factories_.end()) { + return Status::KeyError("No server transport implementation for ", scheme); + } + return it->second(base, std::move(memory_manager)); + } + Status RegisterClient(const std::string& scheme, ClientFactory factory) { + auto it = client_factories_.insert({scheme, std::move(factory)}); + if (!it.second) { + return Status::Invalid("Client transport already registered for ", scheme); + } + return Status::OK(); + } + Status RegisterServer(const std::string& scheme, ServerFactory factory) { + auto it = server_factories_.insert({scheme, std::move(factory)}); + if (!it.second) { + return Status::Invalid("Server transport already registered for ", scheme); + } + return Status::OK(); + } + + private: + std::unordered_map client_factories_; + std::unordered_map server_factories_; +}; + +TransportRegistry::TransportRegistry() { impl_ = arrow::internal::make_unique(); } +TransportRegistry::~TransportRegistry() = default; +arrow::Result> TransportRegistry::MakeClient( + const std::string& scheme) const { + return impl_->MakeClient(scheme); +} +arrow::Result> TransportRegistry::MakeServer( + const std::string& scheme, FlightServerBase* base, + std::shared_ptr memory_manager) const { + return impl_->MakeServer(scheme, base, std::move(memory_manager)); +} +Status TransportRegistry::RegisterClient(const std::string& scheme, + ClientFactory factory) { + return impl_->RegisterClient(scheme, std::move(factory)); +} +Status TransportRegistry::RegisterServer(const std::string& scheme, + ServerFactory factory) { + return impl_->RegisterServer(scheme, std::move(factory)); +} + +TransportRegistry* GetDefaultTransportRegistry() { + static TransportRegistry kRegistry; + return &kRegistry; +} + +} // namespace internal +} // namespace flight +} // namespace arrow diff --git a/cpp/src/arrow/flight/transport.h b/cpp/src/arrow/flight/transport.h new file mode 100644 index 00000000000..085cfa99473 --- /dev/null +++ b/cpp/src/arrow/flight/transport.h @@ -0,0 +1,226 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +/// \file +/// Internal (but not private) interface for implementing +/// alternate network transports in Flight. +/// +/// \warning EXPERIMENTAL. Subject to change. +/// +/// To implement a transport, implement ServerTransport and +/// ClientTransport, and register the desired URI schemes with +/// TransportRegistry. Flight takes care of most of the per-RPC +/// details; transports only handle connections and providing a I/O +/// stream implementation (TransportDataStream). +/// +/// On the server side: +/// +/// 1. Applications subclass FlightServerBase and override RPC handlers. +/// 2. FlightServerBase::Init will look up and create a ServerTransport +/// based on the scheme of the Location given to it. +/// 3. The ServerTransport will start the actual server. (For instance, +/// for gRPC, it creates a gRPC server and registers a gRPC service.) +/// That server will handle connections. +/// 4. The transport should forward incoming calls to the server to the RPC +/// handlers defined on ServerTransport, which implements the actual +/// RPC handler using the interfaces here. Any I/O the RPC handler needs +/// to do is managed by transport-specific implementations of +/// TransportDataStream. +/// 5. ServerTransport calls FlightServerBase for the actual application +/// logic. +/// +/// On the client side: +/// +/// 1. Applications create a FlightClient with a Location. +/// 2. FlightClient will look up and create a ClientTransport based on +/// the scheme of the Location given to it. +/// 3. When calling a method on FlightClient, FlightClient will delegate to +/// the ClientTransport. There is some indirection, e.g. for DoGet, +/// FlightClient only requests that the ClientTransport start the +/// call and provide it with an I/O stream. The "Flight implementation" +/// itself still lives in FlightClient. + +#pragma once + +#include +#include +#include +#include +#include + +#include "arrow/flight/type_fwd.h" +#include "arrow/flight/visibility.h" +#include "arrow/type_fwd.h" + +namespace arrow { +namespace ipc { +class Message; +} +namespace flight { +namespace internal { + +/// Internal, not user-visible type used for memory-efficient reads +struct FlightData { + /// Used only for puts, may be null + std::unique_ptr descriptor; + + /// Non-length-prefixed Message header as described in format/Message.fbs + std::shared_ptr metadata; + + /// Application-defined metadata + std::shared_ptr app_metadata; + + /// Message body + std::shared_ptr body; + + /// Open IPC message from the metadata and body + ::arrow::Result> OpenMessage(); +}; + +/// \brief A transport-specific interface for reading/writing Arrow data. +/// +/// New transports will implement this to read/write IPC payloads to +/// the underlying stream. +class ARROW_FLIGHT_EXPORT TransportDataStream { + public: + virtual ~TransportDataStream() = default; + /// \brief Attempt to read the next FlightData message. + /// + /// \return success true if data was populated, false if there was + /// an error. For clients, the error can be retrieved from + /// Finish(Status). + virtual bool ReadData(FlightData* data); + /// \brief Attempt to write a FlightPayload. + /// + /// \param[in] payload The data to write. + /// \return true if the message was accepted by the transport, false + /// if not (e.g. due to client/server disconnect), Status if there + /// was an error (e.g. with the payload itself). + virtual arrow::Result WriteData(const FlightPayload& payload); + /// \brief Indicate that there are no more writes on this stream. + /// + /// This is only a hint for the underlying transport and may not + /// actually do anything. + virtual Status WritesDone(); +}; + +/// \brief A transport-specific interface for reading/writing Arrow +/// data for a client. +class ARROW_FLIGHT_EXPORT ClientDataStream : public TransportDataStream { + public: + /// \brief Attempt to read a non-data message. + /// + /// Only implemented for DoPut; mutually exclusive with + /// ReadData(FlightData*). + virtual bool ReadPutMetadata(std::shared_ptr* out); + /// \brief Attempt to cancel the call. + /// + /// This is only a hint and may not take effect immediately. The + /// client should still finish the call with Finish(Status) as usual. + virtual void TryCancel() {} + /// \brief Finish the call, reporting the server-sent status and/or + /// any client-side errors as appropriate. + /// + /// Implies WritesDone() and DoFinish(). + /// + /// \param[in] st A client-side status to combine with the + /// server-side error. That is, if an error occurs on the + /// client-side, call Finish(Status) to finish the server-side + /// call, get the server-side status, and merge the statuses + /// together so context is not lost. + Status Finish(Status st); + + protected: + /// \brief End the call, returning the final server status. + /// + /// For implementors: should imply WritesDone() (even if it does not + /// directly call it). + /// + /// Implies WritesDone(). + virtual Status DoFinish() = 0; +}; + +/// An implementation of a Flight client for a particular transport. +/// +/// Transports should override the methods they are capable of +/// supporting. The default method implementations return an error. +class ARROW_FLIGHT_EXPORT ClientTransport { + public: + virtual ~ClientTransport() = default; + + /// Initialize the client. + virtual Status Init(const FlightClientOptions& options, const Location& location, + const arrow::internal::Uri& uri) = 0; + /// Close the client. Once this returns, the client is no longer usable. + virtual Status Close() = 0; + + virtual Status Authenticate(const FlightCallOptions& options, + std::unique_ptr auth_handler); + virtual arrow::Result> AuthenticateBasicToken( + const FlightCallOptions& options, const std::string& username, + const std::string& password); + virtual Status DoAction(const FlightCallOptions& options, const Action& action, + std::unique_ptr* results); + virtual Status ListActions(const FlightCallOptions& options, + std::vector* actions); + virtual Status GetFlightInfo(const FlightCallOptions& options, + const FlightDescriptor& descriptor, + std::unique_ptr* info); + virtual Status GetSchema(const FlightCallOptions& options, + const FlightDescriptor& descriptor, + std::unique_ptr* schema_result); + virtual Status ListFlights(const FlightCallOptions& options, const Criteria& criteria, + std::unique_ptr* listing); + virtual Status DoGet(const FlightCallOptions& options, const Ticket& ticket, + std::unique_ptr* stream); + virtual Status DoPut(const FlightCallOptions& options, + std::unique_ptr* stream); + virtual Status DoExchange(const FlightCallOptions& options, + std::unique_ptr* stream); +}; + +/// A registry of transport implementations. +class ARROW_FLIGHT_EXPORT TransportRegistry { + public: + using ClientFactory = std::function>()>; + using ServerFactory = std::function>( + FlightServerBase*, std::shared_ptr memory_manager)>; + + TransportRegistry(); + ~TransportRegistry(); + + arrow::Result> MakeClient( + const std::string& scheme) const; + arrow::Result> MakeServer( + const std::string& scheme, FlightServerBase* base, + std::shared_ptr memory_manager) const; + + Status RegisterClient(const std::string& scheme, ClientFactory factory); + Status RegisterServer(const std::string& scheme, ServerFactory factory); + + private: + class Impl; + std::unique_ptr impl_; +}; + +/// \brief Get the registry of transport implementations. +ARROW_FLIGHT_EXPORT +TransportRegistry* GetDefaultTransportRegistry(); + +} // namespace internal +} // namespace flight +} // namespace arrow diff --git a/cpp/src/arrow/flight/customize_protobuf.h b/cpp/src/arrow/flight/transport/grpc/customize_grpc.h similarity index 80% rename from cpp/src/arrow/flight/customize_protobuf.h rename to cpp/src/arrow/flight/transport/grpc/customize_grpc.h index 1508af254dd..1085a946966 100644 --- a/cpp/src/arrow/flight/customize_protobuf.h +++ b/cpp/src/arrow/flight/transport/grpc/customize_grpc.h @@ -21,6 +21,7 @@ #include #include "arrow/flight/platform.h" +#include "arrow/flight/type_fwd.h" #include "arrow/util/config.h" // Silence protobuf warnings @@ -55,29 +56,26 @@ class ByteBuffer; namespace arrow { namespace flight { -struct FlightPayload; +namespace protocol { -namespace internal { +class FlightData; -struct FlightData; +} // namespace protocol -// Those two functions are defined in serialization-internal.cc +namespace transport { +namespace grpc { +// Those two functions are defined in serialization_internal.cc // Write FlightData to a grpc::ByteBuffer without extra copying -grpc::Status FlightDataSerialize(const FlightPayload& msg, grpc::ByteBuffer* out, - bool* own_buffer); +::grpc::Status FlightDataSerialize(const arrow::flight::FlightPayload& msg, + ::grpc::ByteBuffer* out, bool* own_buffer); // Read internal::FlightData from grpc::ByteBuffer containing FlightData // protobuf without copying -grpc::Status FlightDataDeserialize(grpc::ByteBuffer* buffer, FlightData* out); - -} // namespace internal - -namespace protocol { - -class FlightData; - -} // namespace protocol +::grpc::Status FlightDataDeserialize(::grpc::ByteBuffer* buffer, + arrow::flight::internal::FlightData* out); +} // namespace grpc +} // namespace transport } // namespace flight } // namespace arrow @@ -95,12 +93,12 @@ class SerializationTraits { // In the functions below, we cast back the Message argument to its real // type (see ReadPayload() and WritePayload() for the initial cast). static Status Serialize(const MessageType& msg, ByteBuffer* bb, bool* own_buffer) { - return arrow::flight::internal::FlightDataSerialize( + return arrow::flight::transport::grpc::FlightDataSerialize( *reinterpret_cast(&msg), bb, own_buffer); } static Status Deserialize(ByteBuffer* buffer, MessageType* msg) { - return arrow::flight::internal::FlightDataDeserialize( + return arrow::flight::transport::grpc::FlightDataDeserialize( buffer, reinterpret_cast(msg)); } }; diff --git a/cpp/src/arrow/flight/transport/grpc/grpc_client.cc b/cpp/src/arrow/flight/transport/grpc/grpc_client.cc new file mode 100644 index 00000000000..9af58ba4d28 --- /dev/null +++ b/cpp/src/arrow/flight/transport/grpc/grpc_client.cc @@ -0,0 +1,920 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#include "arrow/flight/transport/grpc/grpc_client.h" + +#include +#include +#include +#include +#include +#include +#include + +#include "arrow/util/config.h" +#ifdef GRPCPP_PP_INCLUDE +#include +#if defined(GRPC_NAMESPACE_FOR_TLS_CREDENTIALS_OPTIONS) +#include +#endif +#else +#include +#endif + +#include + +#include "arrow/buffer.h" +#include "arrow/device.h" +#include "arrow/result.h" +#include "arrow/status.h" +#include "arrow/util/base64.h" +#include "arrow/util/logging.h" +#include "arrow/util/uri.h" + +#include "arrow/flight/client.h" +#include "arrow/flight/client_auth.h" +#include "arrow/flight/client_middleware.h" +#include "arrow/flight/cookie_internal.h" +#include "arrow/flight/middleware.h" +#include "arrow/flight/serialization_internal.h" +#include "arrow/flight/transport.h" +#include "arrow/flight/transport/grpc/serialization_internal.h" +#include "arrow/flight/transport/grpc/util_internal.h" +#include "arrow/flight/types.h" + +namespace arrow { + +namespace flight { +namespace transport { +namespace grpc { + +namespace { +namespace pb = arrow::flight::protocol; + +struct ClientRpc { + ::grpc::ClientContext context; + + explicit ClientRpc(const FlightCallOptions& options) { + if (options.timeout.count() >= 0) { + std::chrono::system_clock::time_point deadline = + std::chrono::time_point_cast( + std::chrono::system_clock::now() + options.timeout); + context.set_deadline(deadline); + } + for (auto header : options.headers) { + context.AddMetadata(header.first, header.second); + } + } + + /// \brief Add an auth token via an auth handler + Status SetToken(ClientAuthHandler* auth_handler) { + if (auth_handler) { + std::string token; + RETURN_NOT_OK(auth_handler->GetToken(&token)); + context.AddMetadata(kGrpcAuthHeader, token); + } + return Status::OK(); + } +}; + +class GrpcAddClientHeaders : public AddCallHeaders { + public: + explicit GrpcAddClientHeaders(std::multimap<::grpc::string, ::grpc::string>* metadata) + : metadata_(metadata) {} + ~GrpcAddClientHeaders() override = default; + + void AddHeader(const std::string& key, const std::string& value) override { + metadata_->insert(std::make_pair(key, value)); + } + + private: + std::multimap<::grpc::string, ::grpc::string>* metadata_; +}; + +class GrpcClientInterceptorAdapter : public ::grpc::experimental::Interceptor { + public: + explicit GrpcClientInterceptorAdapter( + std::vector> middleware) + : middleware_(std::move(middleware)), received_headers_(false) {} + + void Intercept(::grpc::experimental::InterceptorBatchMethods* methods) { + using InterceptionHookPoints = ::grpc::experimental::InterceptionHookPoints; + if (methods->QueryInterceptionHookPoint( + InterceptionHookPoints::PRE_SEND_INITIAL_METADATA)) { + GrpcAddClientHeaders add_headers(methods->GetSendInitialMetadata()); + for (const auto& middleware : middleware_) { + middleware->SendingHeaders(&add_headers); + } + } + + if (methods->QueryInterceptionHookPoint( + InterceptionHookPoints::POST_RECV_INITIAL_METADATA)) { + if (!methods->GetRecvInitialMetadata()->empty()) { + ReceivedHeaders(*methods->GetRecvInitialMetadata()); + } + } + + if (methods->QueryInterceptionHookPoint(InterceptionHookPoints::POST_RECV_STATUS)) { + DCHECK_NE(nullptr, methods->GetRecvStatus()); + DCHECK_NE(nullptr, methods->GetRecvTrailingMetadata()); + ReceivedHeaders(*methods->GetRecvTrailingMetadata()); + const Status status = FromGrpcStatus(*methods->GetRecvStatus()); + for (const auto& middleware : middleware_) { + middleware->CallCompleted(status); + } + } + + methods->Proceed(); + } + + private: + void ReceivedHeaders( + const std::multimap<::grpc::string_ref, ::grpc::string_ref>& metadata) { + if (received_headers_) { + return; + } + received_headers_ = true; + CallHeaders headers; + for (const auto& entry : metadata) { + headers.insert({util::string_view(entry.first.data(), entry.first.length()), + util::string_view(entry.second.data(), entry.second.length())}); + } + for (const auto& middleware : middleware_) { + middleware->ReceivedHeaders(headers); + } + } + + std::vector> middleware_; + // When communicating with a gRPC-Java server, the server may not + // send back headers if the call fails right away. Instead, the + // headers will be consolidated into the trailers. We don't want to + // call the client middleware callback twice, so instead track + // whether we saw headers - if not, then we need to check trailers. + bool received_headers_; +}; + +class GrpcClientInterceptorAdapterFactory + : public ::grpc::experimental::ClientInterceptorFactoryInterface { + public: + GrpcClientInterceptorAdapterFactory( + std::vector> middleware) + : middleware_(middleware) {} + + ::grpc::experimental::Interceptor* CreateClientInterceptor( + ::grpc::experimental::ClientRpcInfo* info) override { + std::vector> middleware; + + FlightMethod flight_method = FlightMethod::Invalid; + util::string_view method(info->method()); + if (method.ends_with("/Handshake")) { + flight_method = FlightMethod::Handshake; + } else if (method.ends_with("/ListFlights")) { + flight_method = FlightMethod::ListFlights; + } else if (method.ends_with("/GetFlightInfo")) { + flight_method = FlightMethod::GetFlightInfo; + } else if (method.ends_with("/GetSchema")) { + flight_method = FlightMethod::GetSchema; + } else if (method.ends_with("/DoGet")) { + flight_method = FlightMethod::DoGet; + } else if (method.ends_with("/DoPut")) { + flight_method = FlightMethod::DoPut; + } else if (method.ends_with("/DoExchange")) { + flight_method = FlightMethod::DoExchange; + } else if (method.ends_with("/DoAction")) { + flight_method = FlightMethod::DoAction; + } else if (method.ends_with("/ListActions")) { + flight_method = FlightMethod::ListActions; + } else { + ARROW_LOG(WARNING) << "Unknown Flight method: " << info->method(); + flight_method = FlightMethod::Invalid; + } + + const CallInfo flight_info{flight_method}; + for (auto& factory : middleware_) { + std::unique_ptr instance; + factory->StartCall(flight_info, &instance); + if (instance) { + middleware.push_back(std::move(instance)); + } + } + return new GrpcClientInterceptorAdapter(std::move(middleware)); + } + + private: + std::vector> middleware_; +}; + +class GrpcClientAuthSender : public ClientAuthSender { + public: + explicit GrpcClientAuthSender( + std::shared_ptr< + ::grpc::ClientReaderWriter> + stream) + : stream_(stream) {} + + Status Write(const std::string& token) override { + pb::HandshakeRequest response; + response.set_payload(token); + if (stream_->Write(response)) { + return Status::OK(); + } + return FromGrpcStatus(stream_->Finish()); + } + + private: + std::shared_ptr<::grpc::ClientReaderWriter> + stream_; +}; + +class GrpcClientAuthReader : public ClientAuthReader { + public: + explicit GrpcClientAuthReader( + std::shared_ptr< + ::grpc::ClientReaderWriter> + stream) + : stream_(stream) {} + + Status Read(std::string* token) override { + pb::HandshakeResponse request; + if (stream_->Read(&request)) { + *token = std::move(*request.mutable_payload()); + return Status::OK(); + } + return FromGrpcStatus(stream_->Finish()); + } + + private: + std::shared_ptr<::grpc::ClientReaderWriter> + stream_; +}; + +/// \brief The base of the ClientDataStream implementation for gRPC. +template +class FinishableDataStream : public internal::ClientDataStream { + public: + FinishableDataStream(std::shared_ptr rpc, std::shared_ptr stream, + std::shared_ptr memory_manager) + : rpc_(std::move(rpc)), + stream_(std::move(stream)), + memory_manager_(memory_manager ? std::move(memory_manager) + : CPUDevice::Instance()->default_memory_manager()), + finished_(false) {} + + void TryCancel() override { rpc_->context.TryCancel(); } + + protected: + Status DoFinish() override { + if (finished_) { + return server_status_; + } + + // Drain the read side, as otherwise gRPC Finish() will hang. We + // only call Finish() when the client closes the writer or the + // reader finishes, so it's OK to assume the client no longer + // wants to read and drain the read side. (If the client wants to + // indicate that it is done writing, but not done reading, it + // should use DoneWriting. + ReadPayloadType message; + while (ReadPayload(stream_.get(), &message)) { + // Drain the read side to avoid gRPC hanging in Finish() + } + + server_status_ = FromGrpcStatus(stream_->Finish(), &rpc_->context); + if (!server_status_.ok()) { + server_status_ = Status::FromDetailAndArgs( + server_status_.code(), server_status_.detail(), server_status_.message(), + ". gRPC client debug context: ", rpc_->context.debug_error_string()); + } + if (!transport_status_.ok()) { + if (server_status_.ok()) { + server_status_ = transport_status_; + } else { + server_status_ = Status::FromDetailAndArgs( + server_status_.code(), server_status_.detail(), server_status_.message(), + ". gRPC client debug context: ", rpc_->context.debug_error_string(), + ". Additional context: ", transport_status_.ToString()); + } + } + finished_ = true; + + return server_status_; + } + + std::shared_ptr rpc_; + std::shared_ptr stream_; + std::shared_ptr memory_manager_; + bool finished_; + Status server_status_; + // A transport-side error that needs to get combined with the server status + Status transport_status_; +}; + +/// \brief A ClientDataStream implementation for gRPC that manages a +/// mutex to protect from concurrent reads/writes, and drains the +/// read side on finish. +template +class WritableDataStream : public FinishableDataStream { + public: + using Base = FinishableDataStream; + WritableDataStream(std::shared_ptr rpc, std::shared_ptr stream, + std::shared_ptr memory_manager) + : Base(std::move(rpc), std::move(stream), std::move(memory_manager)), + read_mutex_(), + finish_mutex_(), + done_writing_(false) {} + + Status WritesDone() override { + // This is only used by the writer side of a stream, so it need + // not be protected with a lock. + if (done_writing_) { + return Status::OK(); + } + done_writing_ = true; + if (!stream_->WritesDone()) { + // Error happened, try to close the stream to get more detailed info + return this->Finish(MakeFlightError(FlightStatusCode::Internal, + "Could not flush pending record batches")); + } + return Status::OK(); + } + + protected: + Status DoFinish() override { + // This may be used concurrently by reader/writer side of a + // stream, so it needs to be protected. + std::lock_guard guard(finish_mutex_); + + // Now that we're shared between a reader and writer, we need to + // protect ourselves from being called while there's an + // outstanding read. + std::unique_lock read_guard(read_mutex_, std::try_to_lock); + if (!read_guard.owns_lock()) { + return MakeFlightError(FlightStatusCode::Internal, + "Cannot close stream with pending read operation."); + } + + // Try to flush pending writes. Don't use our WritesDone() to + // avoid recursion. + bool finished_writes = done_writing_ || stream_->WritesDone(); + done_writing_ = true; + + Status st = Base::DoFinish(); + if (!finished_writes) { + return Status::FromDetailAndArgs( + st.code(), st.detail(), st.message(), + ". Additionally, could not finish writing record batches before closing"); + } + return st; + } + + using Base::stream_; + std::mutex read_mutex_; + std::mutex finish_mutex_; + bool done_writing_; +}; + +class GrpcClientGetStream + : public FinishableDataStream<::grpc::ClientReader, + internal::FlightData> { + public: + using FinishableDataStream::FinishableDataStream; + + bool ReadData(internal::FlightData* data) override { + bool success = ReadPayload(stream_.get(), data); + if (ARROW_PREDICT_FALSE(!success)) return false; + if (data->body) { + auto status = Buffer::ViewOrCopy(data->body, memory_manager_).Value(&data->body); + if (!status.ok()) { + transport_status_ = std::move(status); + return false; + } + } + return true; + } + Status WritesDone() override { return Status::NotImplemented("NYI"); } +}; + +class GrpcClientPutStream + : public WritableDataStream<::grpc::ClientReaderWriter, + pb::PutResult> { + public: + using Stream = ::grpc::ClientReaderWriter; + GrpcClientPutStream(std::shared_ptr rpc, std::shared_ptr stream, + std::shared_ptr memory_manager) + : WritableDataStream(std::move(rpc), std::move(stream), std::move(memory_manager)) { + } + + bool ReadPutMetadata(std::shared_ptr* out) override { + std::lock_guard guard(read_mutex_); + pb::PutResult message; + if (stream_->Read(&message)) { + *out = Buffer::FromString(std::move(*message.mutable_app_metadata())); + } else { + // Stream finished + *out = nullptr; + } + return true; + } + arrow::Result WriteData(const FlightPayload& payload) override { + return WritePayload(payload, this->stream_.get()); + } +}; + +class GrpcClientExchangeStream + : public WritableDataStream< + ::grpc::ClientReaderWriter, + internal::FlightData> { + public: + using Stream = ::grpc::ClientReaderWriter; + GrpcClientExchangeStream(std::shared_ptr rpc, std::shared_ptr stream, + std::shared_ptr memory_manager) + : WritableDataStream(std::move(rpc), std::move(stream), std::move(memory_manager)) { + } + + bool ReadData(internal::FlightData* data) override { + std::lock_guard guard(read_mutex_); + bool success = ReadPayload(stream_.get(), data); + if (ARROW_PREDICT_FALSE(!success)) return false; + if (data->body) { + auto status = Buffer::ViewOrCopy(data->body, memory_manager_).Value(&data->body); + if (!status.ok()) { + transport_status_ = std::move(status); + return false; + } + } + return true; + } + arrow::Result WriteData(const FlightPayload& payload) override { + return WritePayload(payload, this->stream_.get()); + } +}; + +static constexpr char kBearerPrefix[] = "Bearer "; +static constexpr char kBasicPrefix[] = "Basic "; + +/// \brief Add base64 encoded credentials to the outbound headers. +/// +/// \param context Context object to add the headers to. +/// \param username Username to format and encode. +/// \param password Password to format and encode. +void AddBasicAuthHeaders(::grpc::ClientContext* context, const std::string& username, + const std::string& password) { + const std::string credentials = username + ":" + password; + context->AddMetadata(internal::kAuthHeader, + kBasicPrefix + arrow::util::base64_encode(credentials)); +} + +/// \brief Get bearer token from inbound headers. +/// +/// \param context Incoming ClientContext that contains headers. +/// \return Arrow result with bearer token (empty if no bearer token found). +arrow::Result> GetBearerTokenHeader( + ::grpc::ClientContext& context) { + // Lambda function to compare characters without case sensitivity. + auto char_compare = [](const char& char1, const char& char2) { + return (::toupper(char1) == ::toupper(char2)); + }; + + // Get the auth token if it exists, this can be in the initial or the trailing metadata. + auto trailing_headers = context.GetServerTrailingMetadata(); + auto initial_headers = context.GetServerInitialMetadata(); + auto bearer_iter = trailing_headers.find(internal::kAuthHeader); + if (bearer_iter == trailing_headers.end()) { + bearer_iter = initial_headers.find(internal::kAuthHeader); + if (bearer_iter == initial_headers.end()) { + return std::make_pair("", ""); + } + } + + // Check if the value of the auth token starts with the bearer prefix and latch it. + std::string bearer_val(bearer_iter->second.data(), bearer_iter->second.size()); + if (bearer_val.size() > strlen(kBearerPrefix)) { + if (std::equal(bearer_val.begin(), bearer_val.begin() + strlen(kBearerPrefix), + kBearerPrefix, char_compare)) { + return std::make_pair(internal::kAuthHeader, bearer_val); + } + } + + // The server is not required to provide a bearer token. + return std::make_pair("", ""); +} + +// Dummy self-signed certificate to be used because TlsCredentials +// requires root CA certs, even if you are skipping server +// verification. +#if defined(GRPC_NAMESPACE_FOR_TLS_CREDENTIALS_OPTIONS) +constexpr char kDummyRootCert[] = + "-----BEGIN CERTIFICATE-----\n" + "MIICwzCCAaugAwIBAgIJAM12DOkcaqrhMA0GCSqGSIb3DQEBBQUAMBQxEjAQBgNV\n" + "BAMTCWxvY2FsaG9zdDAeFw0yMDEwMDcwODIyNDFaFw0zMDEwMDUwODIyNDFaMBQx\n" + "EjAQBgNVBAMTCWxvY2FsaG9zdDCCASIwDQYJKoZIhvcNAQEBBQADggEPADCCAQoC\n" + "ggEBALjJ8KPEpF0P4GjMPrJhjIBHUL0AX9E4oWdgJRCSFkPKKEWzQabTQBikMOhI\n" + "W4VvBMaHEBuECE5OEyrzDRiAO354I4F4JbBfxMOY8NIW0uWD6THWm2KkCzoZRIPW\n" + "yZL6dN+mK6cEH+YvbNuy5ZQGNjGG43tyiXdOCAc4AI9POeTtjdMpbbpR2VY4Ad/E\n" + "oTEiS3gNnN7WIAdgMhCJxjzvPwKszV3f7pwuTHzFMsuHLKr6JeaVUYfbi4DxxC8Z\n" + "k6PF6dLlLf3ngTSLBJyaXP1BhKMvz0TaMK3F0y2OGwHM9J8np2zWjTlNVEzffQZx\n" + "SWMOQManlJGs60xYx9KCPJMZZsMCAwEAAaMYMBYwFAYDVR0RBA0wC4IJbG9jYWxo\n" + "b3N0MA0GCSqGSIb3DQEBBQUAA4IBAQC0LrmbcNKgO+D50d/wOc+vhi9K04EZh8bg\n" + "WYAK1kLOT4eShbzqWGV/1EggY4muQ6ypSELCLuSsg88kVtFQIeRilA6bHFqQSj6t\n" + "sqgh2cWsMwyllCtmX6Maf3CLb2ZdoJlqUwdiBdrbIbuyeAZj3QweCtLKGSQzGDyI\n" + "KH7G8nC5d0IoRPiCMB6RnMMKsrhviuCdWbAFHop7Ff36JaOJ8iRa2sSf2OXE8j/5\n" + "obCXCUvYHf4Zw27JcM2AnnQI9VJLnYxis83TysC5s2Z7t0OYNS9kFmtXQbUNlmpS\n" + "doQ/Eu47vWX7S0TXeGziGtbAOKxbHE0BGGPDOAB/jGW/JVbeTiXY\n" + "-----END CERTIFICATE-----\n"; +#endif + +class GrpcClientImpl : public internal::ClientTransport { + public: + static arrow::Result> Make() { + return std::unique_ptr(new GrpcClientImpl()); + } + + Status Init(const FlightClientOptions& options, const Location& location, + const arrow::internal::Uri& uri) override { + const std::string& scheme = location.scheme(); + + std::stringstream grpc_uri; + std::shared_ptr<::grpc::ChannelCredentials> creds; + if (scheme == kSchemeGrpc || scheme == kSchemeGrpcTcp || scheme == kSchemeGrpcTls) { + grpc_uri << arrow::internal::UriEncodeHost(uri.host()) << ':' << uri.port_text(); + + if (scheme == kSchemeGrpcTls) { + if (options.disable_server_verification) { +#if defined(GRPC_NAMESPACE_FOR_TLS_CREDENTIALS_OPTIONS) + namespace ge = ::GRPC_NAMESPACE_FOR_TLS_CREDENTIALS_OPTIONS; + +#if defined(GRPC_USE_CERTIFICATE_VERIFIER) + // gRPC >= 1.43 + class NoOpCertificateVerifier : public ge::ExternalCertificateVerifier { + public: + bool Verify(ge::TlsCustomVerificationCheckRequest*, + std::function, + ::grpc::Status* sync_status) override { + *sync_status = ::grpc::Status::OK; + return true; // Check done synchronously + } + void Cancel(ge::TlsCustomVerificationCheckRequest*) override {} + }; + auto cert_verifier = + ge::ExternalCertificateVerifier::Create(); + +#else // defined(GRPC_USE_CERTIFICATE_VERIFIER) + // gRPC < 1.43 + // A callback to supply to TlsCredentialsOptions that accepts any server + // arguments. + struct NoOpTlsAuthorizationCheck + : public ge::TlsServerAuthorizationCheckInterface { + int Schedule(ge::TlsServerAuthorizationCheckArg* arg) override { + arg->set_success(1); + arg->set_status(GRPC_STATUS_OK); + return 0; + } + }; + auto server_authorization_check = std::make_shared(); + noop_auth_check_ = std::make_shared( + server_authorization_check); +#endif // defined(GRPC_USE_CERTIFICATE_VERIFIER) + +#if defined(GRPC_USE_TLS_CHANNEL_CREDENTIALS_OPTIONS) + auto certificate_provider = + std::make_shared<::grpc::experimental::StaticDataCertificateProvider>( + kDummyRootCert); +#if defined(GRPC_USE_TLS_CHANNEL_CREDENTIALS_OPTIONS_ROOT_CERTS) + ::grpc::experimental::TlsChannelCredentialsOptions tls_options( + certificate_provider); +#else // defined(GRPC_USE_TLS_CHANNEL_CREDENTIALS_OPTIONS_ROOT_CERTS) + // While gRPC >= 1.36 does not require a root cert (it has a default) + // in practice the path it hardcodes is broken. See grpc/grpc#21655. + ::grpc::experimental::TlsChannelCredentialsOptions tls_options; + tls_options.set_certificate_provider(certificate_provider); +#endif // defined(GRPC_USE_TLS_CHANNEL_CREDENTIALS_OPTIONS_ROOT_CERTS) + tls_options.watch_root_certs(); + tls_options.set_root_cert_name("dummy"); +#if defined(GRPC_USE_CERTIFICATE_VERIFIER) + tls_options.set_certificate_verifier(std::move(cert_verifier)); + tls_options.set_check_call_host(false); + tls_options.set_verify_server_certs(false); +#else // defined(GRPC_USE_CERTIFICATE_VERIFIER) + tls_options.set_server_verification_option( + grpc_tls_server_verification_option::GRPC_TLS_SKIP_ALL_SERVER_VERIFICATION); + tls_options.set_server_authorization_check_config(noop_auth_check_); +#endif // defined(GRPC_USE_CERTIFICATE_VERIFIER) +#elif defined(GRPC_NAMESPACE_FOR_TLS_CREDENTIALS_OPTIONS) + // continues defined(GRPC_USE_TLS_CHANNEL_CREDENTIALS_OPTIONS) + auto materials_config = std::make_shared(); + materials_config->set_pem_root_certs(kDummyRootCert); + ge::TlsCredentialsOptions tls_options( + GRPC_SSL_DONT_REQUEST_CLIENT_CERTIFICATE, + GRPC_TLS_SKIP_ALL_SERVER_VERIFICATION, materials_config, + std::shared_ptr(), noop_auth_check_); +#endif // defined(GRPC_USE_TLS_CHANNEL_CREDENTIALS_OPTIONS) + creds = ge::TlsCredentials(tls_options); +#else // defined(GRPC_NAMESPACE_FOR_TLS_CREDENTIALS_OPTIONS) + return Status::NotImplemented( + "Using encryption with server verification disabled is unsupported. " + "Please use a release of Arrow Flight built with gRPC 1.27 or higher."); +#endif // defined(GRPC_NAMESPACE_FOR_TLS_CREDENTIALS_OPTIONS) + } else { + ::grpc::SslCredentialsOptions ssl_options; + if (!options.tls_root_certs.empty()) { + ssl_options.pem_root_certs = options.tls_root_certs; + } + if (!options.cert_chain.empty()) { + ssl_options.pem_cert_chain = options.cert_chain; + } + if (!options.private_key.empty()) { + ssl_options.pem_private_key = options.private_key; + } + creds = ::grpc::SslCredentials(ssl_options); + } + } else { + creds = ::grpc::InsecureChannelCredentials(); + } + } else if (scheme == kSchemeGrpcUnix) { + grpc_uri << "unix://" << uri.path(); + creds = ::grpc::InsecureChannelCredentials(); + } else { + return Status::NotImplemented("Flight scheme ", scheme, + " is not supported by the gRPC transport"); + } + + ::grpc::ChannelArguments args; + // We can't set the same config value twice, so for values where + // we want to set defaults, keep them in a map and update them; + // then update them all at once + std::unordered_map default_args; + // Try to reconnect quickly at first, in case the server is still starting up + default_args[GRPC_ARG_INITIAL_RECONNECT_BACKOFF_MS] = 100; + // Receive messages of any size + default_args[GRPC_ARG_MAX_RECEIVE_MESSAGE_LENGTH] = -1; + // Setting this arg enables each client to open it's own TCP connection to server, + // not sharing one single connection, which becomes bottleneck under high load. + default_args[GRPC_ARG_USE_LOCAL_SUBCHANNEL_POOL] = 1; + + if (options.override_hostname != "") { + args.SetSslTargetNameOverride(options.override_hostname); + } + + // Allow setting generic gRPC options. + for (const auto& arg : options.generic_options) { + if (util::holds_alternative(arg.second)) { + default_args[arg.first] = util::get(arg.second); + } else if (util::holds_alternative(arg.second)) { + args.SetString(arg.first, util::get(arg.second)); + } + // Otherwise unimplemented + } + for (const auto& pair : default_args) { + args.SetInt(pair.first, pair.second); + } + + std::vector> + interceptors; + interceptors.emplace_back( + new GrpcClientInterceptorAdapterFactory(std::move(options.middleware))); + + stub_ = pb::FlightService::NewStub( + ::grpc::experimental::CreateCustomChannelWithInterceptors( + grpc_uri.str(), creds, args, std::move(interceptors))); + return Status::OK(); + } + + Status Close() override { + // TODO(ARROW-15473): if we track ongoing RPCs, we can cancel them first + // gRPC does not offer a real Close(). We could reset() the gRPC + // client but that can cause gRPC to hang in shutdown + // (ARROW-15793). + return Status::OK(); + } + + Status Authenticate(const FlightCallOptions& options, + std::unique_ptr auth_handler) override { + auth_handler_ = std::move(auth_handler); + ClientRpc rpc(options); + std::shared_ptr< + ::grpc::ClientReaderWriter> + stream = stub_->Handshake(&rpc.context); + GrpcClientAuthSender outgoing{stream}; + GrpcClientAuthReader incoming{stream}; + RETURN_NOT_OK(auth_handler_->Authenticate(&outgoing, &incoming)); + // Explicitly close our side of the connection + bool finished_writes = stream->WritesDone(); + RETURN_NOT_OK(FromGrpcStatus(stream->Finish(), &rpc.context)); + if (!finished_writes) { + return MakeFlightError(FlightStatusCode::Internal, + "Could not finish writing before closing"); + } + return Status::OK(); + } + + arrow::Result> AuthenticateBasicToken( + const FlightCallOptions& options, const std::string& username, + const std::string& password) override { + // Add basic auth headers to outgoing headers. + ClientRpc rpc(options); + AddBasicAuthHeaders(&rpc.context, username, password); + std::shared_ptr< + ::grpc::ClientReaderWriter> + stream = stub_->Handshake(&rpc.context); + // Explicitly close our side of the connection. + bool finished_writes = stream->WritesDone(); + RETURN_NOT_OK(FromGrpcStatus(stream->Finish(), &rpc.context)); + if (!finished_writes) { + return MakeFlightError(FlightStatusCode::Internal, + "Could not finish writing before closing"); + } + // Grab bearer token from incoming headers. + return GetBearerTokenHeader(rpc.context); + } + + Status ListFlights(const FlightCallOptions& options, const Criteria& criteria, + std::unique_ptr* listing) override { + pb::Criteria pb_criteria; + RETURN_NOT_OK(internal::ToProto(criteria, &pb_criteria)); + + ClientRpc rpc(options); + RETURN_NOT_OK(rpc.SetToken(auth_handler_.get())); + std::unique_ptr<::grpc::ClientReader> stream( + stub_->ListFlights(&rpc.context, pb_criteria)); + + std::vector flights; + + pb::FlightInfo pb_info; + while (!options.stop_token.IsStopRequested() && stream->Read(&pb_info)) { + FlightInfo::Data info_data; + RETURN_NOT_OK(internal::FromProto(pb_info, &info_data)); + flights.emplace_back(std::move(info_data)); + } + if (options.stop_token.IsStopRequested()) rpc.context.TryCancel(); + RETURN_NOT_OK(options.stop_token.Poll()); + listing->reset(new SimpleFlightListing(std::move(flights))); + return FromGrpcStatus(stream->Finish(), &rpc.context); + } + + Status DoAction(const FlightCallOptions& options, const Action& action, + std::unique_ptr* results) override { + pb::Action pb_action; + RETURN_NOT_OK(internal::ToProto(action, &pb_action)); + + ClientRpc rpc(options); + RETURN_NOT_OK(rpc.SetToken(auth_handler_.get())); + std::unique_ptr<::grpc::ClientReader> stream( + stub_->DoAction(&rpc.context, pb_action)); + + pb::Result pb_result; + + std::vector materialized_results; + while (!options.stop_token.IsStopRequested() && stream->Read(&pb_result)) { + Result result; + RETURN_NOT_OK(internal::FromProto(pb_result, &result)); + materialized_results.emplace_back(std::move(result)); + } + if (options.stop_token.IsStopRequested()) rpc.context.TryCancel(); + RETURN_NOT_OK(options.stop_token.Poll()); + + *results = std::unique_ptr( + new SimpleResultStream(std::move(materialized_results))); + return FromGrpcStatus(stream->Finish(), &rpc.context); + } + + Status ListActions(const FlightCallOptions& options, + std::vector* types) override { + pb::Empty empty; + + ClientRpc rpc(options); + RETURN_NOT_OK(rpc.SetToken(auth_handler_.get())); + std::unique_ptr<::grpc::ClientReader> stream( + stub_->ListActions(&rpc.context, empty)); + + pb::ActionType pb_type; + ActionType type; + while (!options.stop_token.IsStopRequested() && stream->Read(&pb_type)) { + RETURN_NOT_OK(internal::FromProto(pb_type, &type)); + types->emplace_back(std::move(type)); + } + if (options.stop_token.IsStopRequested()) rpc.context.TryCancel(); + RETURN_NOT_OK(options.stop_token.Poll()); + return FromGrpcStatus(stream->Finish(), &rpc.context); + } + + Status GetFlightInfo(const FlightCallOptions& options, + const FlightDescriptor& descriptor, + std::unique_ptr* info) override { + pb::FlightDescriptor pb_descriptor; + pb::FlightInfo pb_response; + + RETURN_NOT_OK(internal::ToProto(descriptor, &pb_descriptor)); + + ClientRpc rpc(options); + RETURN_NOT_OK(rpc.SetToken(auth_handler_.get())); + Status s = FromGrpcStatus( + stub_->GetFlightInfo(&rpc.context, pb_descriptor, &pb_response), &rpc.context); + RETURN_NOT_OK(s); + + FlightInfo::Data info_data; + RETURN_NOT_OK(internal::FromProto(pb_response, &info_data)); + info->reset(new FlightInfo(std::move(info_data))); + return Status::OK(); + } + + Status GetSchema(const FlightCallOptions& options, const FlightDescriptor& descriptor, + std::unique_ptr* schema_result) override { + pb::FlightDescriptor pb_descriptor; + pb::SchemaResult pb_response; + + RETURN_NOT_OK(internal::ToProto(descriptor, &pb_descriptor)); + + ClientRpc rpc(options); + RETURN_NOT_OK(rpc.SetToken(auth_handler_.get())); + Status s = FromGrpcStatus(stub_->GetSchema(&rpc.context, pb_descriptor, &pb_response), + &rpc.context); + RETURN_NOT_OK(s); + + std::string str; + RETURN_NOT_OK(internal::FromProto(pb_response, &str)); + schema_result->reset(new SchemaResult(str)); + return Status::OK(); + } + + Status DoGet(const FlightCallOptions& options, const Ticket& ticket, + std::unique_ptr* out) override { + pb::Ticket pb_ticket; + RETURN_NOT_OK(internal::ToProto(ticket, &pb_ticket)); + + auto rpc = std::make_shared(options); + RETURN_NOT_OK(rpc->SetToken(auth_handler_.get())); + std::shared_ptr<::grpc::ClientReader> stream = + stub_->DoGet(&rpc->context, pb_ticket); + *out = std::unique_ptr(new GrpcClientGetStream( + std::move(rpc), std::move(stream), options.memory_manager)); + return Status::OK(); + } + + Status DoPut(const FlightCallOptions& options, + std::unique_ptr* out) override { + using GrpcStream = ::grpc::ClientReaderWriter; + + auto rpc = std::make_shared(options); + RETURN_NOT_OK(rpc->SetToken(auth_handler_.get())); + std::shared_ptr stream = stub_->DoPut(&rpc->context); + *out = std::unique_ptr(new GrpcClientPutStream( + std::move(rpc), std::move(stream), options.memory_manager)); + return Status::OK(); + } + + Status DoExchange(const FlightCallOptions& options, + std::unique_ptr* out) override { + using GrpcStream = ::grpc::ClientReaderWriter; + + auto rpc = std::make_shared(options); + RETURN_NOT_OK(rpc->SetToken(auth_handler_.get())); + std::shared_ptr stream = stub_->DoExchange(&rpc->context); + *out = std::unique_ptr(new GrpcClientExchangeStream( + std::move(rpc), std::move(stream), options.memory_manager)); + return Status::OK(); + } + + private: + std::unique_ptr stub_; + std::shared_ptr auth_handler_; +#if defined(GRPC_NAMESPACE_FOR_TLS_CREDENTIALS_OPTIONS) && \ + !defined(GRPC_USE_CERTIFICATE_VERIFIER) + // Scope the TlsServerAuthorizationCheckConfig to be at the class instance level, since + // it gets created during Connect() and needs to persist to DoAction() calls. gRPC does + // not correctly increase the reference count of this object: + // https://github.com/grpc/grpc/issues/22287 + std::shared_ptr< + ::GRPC_NAMESPACE_FOR_TLS_CREDENTIALS_OPTIONS::TlsServerAuthorizationCheckConfig> + noop_auth_check_; +#endif +}; +std::once_flag kGrpcClientTransportInitialized; +} // namespace + +void InitializeFlightGrpcClient() { + std::call_once(kGrpcClientTransportInitialized, []() { + auto* registry = flight::internal::GetDefaultTransportRegistry(); + for (const auto& transport : {"grpc", "grpc+tls", "grpc+tcp", "grpc+unix"}) { + ARROW_CHECK_OK(registry->RegisterClient(transport, GrpcClientImpl::Make)); + } + }); +} + +} // namespace grpc +} // namespace transport +} // namespace flight +} // namespace arrow diff --git a/cpp/src/arrow/flight/middleware_internal.h b/cpp/src/arrow/flight/transport/grpc/grpc_client.h similarity index 68% rename from cpp/src/arrow/flight/middleware_internal.h rename to cpp/src/arrow/flight/transport/grpc/grpc_client.h index 8ee76476a46..6a75c9d57ab 100644 --- a/cpp/src/arrow/flight/middleware_internal.h +++ b/cpp/src/arrow/flight/transport/grpc/grpc_client.h @@ -15,32 +15,22 @@ // specific language governing permissions and limitations // under the License. -// Interfaces for defining middleware for Flight clients and -// servers. Currently experimental. +// gRPC-based transport for Flight. #pragma once -#include "arrow/flight/platform.h" -#include "arrow/flight/visibility.h" // IWYU pragma: keep - -#include -#include -#include - -#ifdef GRPCPP_PP_INCLUDE -#include -#else -#include -#endif - -#include "arrow/flight/middleware.h" +#include "arrow/flight/visibility.h" namespace arrow { - namespace flight { +namespace transport { +namespace grpc { -namespace internal {} // namespace internal +/// \brief Register the gRPC transport implementation. Idempotent. +ARROW_FLIGHT_EXPORT +void InitializeFlightGrpcClient(); +} // namespace grpc +} // namespace transport } // namespace flight - } // namespace arrow diff --git a/cpp/src/arrow/flight/transport/grpc/grpc_server.cc b/cpp/src/arrow/flight/transport/grpc/grpc_server.cc new file mode 100644 index 00000000000..5a2901c1d54 --- /dev/null +++ b/cpp/src/arrow/flight/transport/grpc/grpc_server.cc @@ -0,0 +1,635 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +// gRPC transport implementation for Arrow Flight + +#include "arrow/flight/transport/grpc/grpc_server.h" + +#include +#include +#include +#include +#include + +#include "arrow/util/config.h" +#ifdef GRPCPP_PP_INCLUDE +#include +#else +#include +#endif + +#include "arrow/buffer.h" +#include "arrow/flight/serialization_internal.h" +#include "arrow/flight/server.h" +#include "arrow/flight/server_middleware.h" +#include "arrow/flight/transport.h" +#include "arrow/flight/transport/grpc/serialization_internal.h" +#include "arrow/flight/transport/grpc/util_internal.h" +#include "arrow/flight/transport_server.h" +#include "arrow/flight/types.h" +#include "arrow/util/logging.h" +#include "arrow/util/uri.h" + +namespace arrow { +namespace flight { +namespace transport { +namespace grpc { + +namespace pb = arrow::flight::protocol; +using FlightService = pb::FlightService; +using ServerContext = ::grpc::ServerContext; +template +using ServerWriter = ::grpc::ServerWriter; + +// Macro that runs interceptors before returning the given status +#define RETURN_WITH_MIDDLEWARE(CONTEXT, STATUS) \ + do { \ + const auto& __s = (STATUS); \ + return CONTEXT.FinishRequest(__s); \ + } while (false) +#define CHECK_ARG_NOT_NULL(CONTEXT, VAL, MESSAGE) \ + if (VAL == nullptr) { \ + RETURN_WITH_MIDDLEWARE( \ + CONTEXT, ::grpc::Status(::grpc::StatusCode::INVALID_ARGUMENT, MESSAGE)); \ + } +// Same as RETURN_NOT_OK, but accepts either Arrow or gRPC status, and +// will run interceptors +#define SERVICE_RETURN_NOT_OK(CONTEXT, expr) \ + do { \ + const auto& _s = (expr); \ + if (ARROW_PREDICT_FALSE(!_s.ok())) { \ + return CONTEXT.FinishRequest(_s); \ + } \ + } while (false) + +namespace { +class GrpcServerAuthReader : public ServerAuthReader { + public: + explicit GrpcServerAuthReader( + ::grpc::ServerReaderWriter* stream) + : stream_(stream) {} + + Status Read(std::string* token) override { + pb::HandshakeRequest request; + if (stream_->Read(&request)) { + *token = std::move(*request.mutable_payload()); + return Status::OK(); + } + return Status::IOError("Stream is closed."); + } + + private: + ::grpc::ServerReaderWriter* stream_; +}; + +class GrpcServerAuthSender : public ServerAuthSender { + public: + explicit GrpcServerAuthSender( + ::grpc::ServerReaderWriter* stream) + : stream_(stream) {} + + Status Write(const std::string& token) override { + pb::HandshakeResponse response; + response.set_payload(token); + if (stream_->Write(response)) { + return Status::OK(); + } + return Status::IOError("Stream was closed."); + } + + private: + ::grpc::ServerReaderWriter* stream_; +}; + +class GrpcServerCallContext : public ServerCallContext { + explicit GrpcServerCallContext(::grpc::ServerContext* context) + : context_(context), peer_(context_->peer()) {} + + const std::string& peer_identity() const override { return peer_identity_; } + const std::string& peer() const override { return peer_; } + bool is_cancelled() const override { return context_->IsCancelled(); } + + // Helper method that runs interceptors given the result of an RPC, + // then returns the final gRPC status to send to the client + ::grpc::Status FinishRequest(const ::grpc::Status& status) { + // Don't double-convert status - return the original one here + FinishRequest(FromGrpcStatus(status)); + return status; + } + + ::grpc::Status FinishRequest(const arrow::Status& status) { + for (const auto& instance : middleware_) { + instance->CallCompleted(status); + } + + // Set custom headers to map the exact Arrow status for clients + // who want it. + return ToGrpcStatus(status, context_); + } + + ServerMiddleware* GetMiddleware(const std::string& key) const override { + const auto& instance = middleware_map_.find(key); + if (instance == middleware_map_.end()) { + return nullptr; + } + return instance->second.get(); + } + + private: + friend class GrpcServiceHandler; + ServerContext* context_; + std::string peer_; + std::string peer_identity_; + std::vector> middleware_; + std::unordered_map> middleware_map_; +}; + +class GrpcAddServerHeaders : public AddCallHeaders { + public: + explicit GrpcAddServerHeaders(::grpc::ServerContext* context) : context_(context) {} + ~GrpcAddServerHeaders() override = default; + + void AddHeader(const std::string& key, const std::string& value) override { + context_->AddInitialMetadata(key, value); + } + + private: + ::grpc::ServerContext* context_; +}; + +// A ServerDataStream for streaming data to the client. +class GetDataStream : public internal::ServerDataStream { + public: + explicit GetDataStream(ServerWriter* writer) : writer_(writer) {} + + arrow::Result WriteData(const FlightPayload& payload) override { + return WritePayload(payload, writer_); + } + + private: + ServerWriter* writer_; +}; + +// A ServerDataStream for reading data from the client. +class PutDataStream final : public internal::ServerDataStream { + public: + explicit PutDataStream( + ::grpc::ServerReaderWriter* stream) + : stream_(stream) {} + + bool ReadData(internal::FlightData* data) override { + return ReadPayload(&*stream_, data); + } + Status WritePutMetadata(const Buffer& metadata) override { + pb::PutResult message{}; + message.set_app_metadata(metadata.data(), metadata.size()); + if (stream_->Write(message)) { + return Status::OK(); + } + return Status::IOError("Unknown error writing metadata."); + } + + private: + ::grpc::ServerReaderWriter* stream_; +}; + +// A ServerDataStream for a bidirectional data exchange. +class ExchangeDataStream final : public internal::ServerDataStream { + public: + explicit ExchangeDataStream( + ::grpc::ServerReaderWriter* stream) + : stream_(stream) {} + + bool ReadData(internal::FlightData* data) override { + return ReadPayload(&*stream_, data); + } + arrow::Result WriteData(const FlightPayload& payload) override { + return WritePayload(payload, stream_); + } + + private: + ::grpc::ServerReaderWriter* stream_; +}; + +// The gRPC service implementation, which forwards calls to the Flight +// service and bridges between the Flight transport API and gRPC. +class GrpcServiceHandler final : public FlightService::Service { + public: + GrpcServiceHandler( + std::shared_ptr auth_handler, + std::vector>> + middleware, + internal::ServerTransport* impl) + : auth_handler_(auth_handler), middleware_(middleware), impl_(impl) {} + + template + ::grpc::Status WriteStream(Iterator* iterator, ServerWriter* writer) { + if (!iterator) { + return ::grpc::Status(::grpc::StatusCode::INTERNAL, "No items to iterate"); + } + // Write flight info to stream until listing is exhausted + while (true) { + ProtoType pb_value; + std::unique_ptr value; + GRPC_RETURN_NOT_OK(iterator->Next(&value)); + if (!value) { + break; + } + GRPC_RETURN_NOT_OK(internal::ToProto(*value, &pb_value)); + + // Blocking write + if (!writer->Write(pb_value)) { + // Write returns false if the stream is closed + break; + } + } + return ::grpc::Status::OK; + } + + template + ::grpc::Status WriteStream(const std::vector& values, + ServerWriter* writer) { + // Write flight info to stream until listing is exhausted + for (const UserType& value : values) { + ProtoType pb_value; + GRPC_RETURN_NOT_OK(internal::ToProto(value, &pb_value)); + // Blocking write + if (!writer->Write(pb_value)) { + // Write returns false if the stream is closed + break; + } + } + return ::grpc::Status::OK; + } + + // Authenticate the client (if applicable) and construct the call context + ::grpc::Status CheckAuth(const FlightMethod& method, ServerContext* context, + GrpcServerCallContext& flight_context) { + if (!auth_handler_) { + const auto auth_context = context->auth_context(); + if (auth_context && auth_context->IsPeerAuthenticated()) { + auto peer_identity = auth_context->GetPeerIdentity(); + flight_context.peer_identity_ = + peer_identity.empty() + ? "" + : std::string(peer_identity.front().begin(), peer_identity.front().end()); + } else { + flight_context.peer_identity_ = ""; + } + } else { + const auto client_metadata = context->client_metadata(); + const auto auth_header = client_metadata.find(kGrpcAuthHeader); + std::string token; + if (auth_header == client_metadata.end()) { + token = ""; + } else { + token = std::string(auth_header->second.data(), auth_header->second.length()); + } + GRPC_RETURN_NOT_OK(auth_handler_->IsValid(token, &flight_context.peer_identity_)); + } + + return MakeCallContext(method, context, flight_context); + } + + // Authenticate the client (if applicable) and construct the call context + ::grpc::Status MakeCallContext(const FlightMethod& method, ServerContext* context, + GrpcServerCallContext& flight_context) { + // Run server middleware + const CallInfo info{method}; + CallHeaders incoming_headers; + for (const auto& entry : context->client_metadata()) { + incoming_headers.insert( + {util::string_view(entry.first.data(), entry.first.length()), + util::string_view(entry.second.data(), entry.second.length())}); + } + + GrpcAddServerHeaders outgoing_headers(context); + for (const auto& factory : middleware_) { + std::shared_ptr instance; + Status result = factory.second->StartCall(info, incoming_headers, &instance); + if (!result.ok()) { + // Interceptor rejected call, end the request on all existing + // interceptors + return flight_context.FinishRequest(result); + } + if (instance != nullptr) { + flight_context.middleware_.push_back(instance); + flight_context.middleware_map_.insert({factory.first, instance}); + instance->SendingHeaders(&outgoing_headers); + } + } + + return ::grpc::Status::OK; + } + + ::grpc::Status Handshake( + ServerContext* context, + ::grpc::ServerReaderWriter* stream) { + GrpcServerCallContext flight_context(context); + GRPC_RETURN_NOT_GRPC_OK( + MakeCallContext(FlightMethod::Handshake, context, flight_context)); + + if (!auth_handler_) { + RETURN_WITH_MIDDLEWARE( + flight_context, + ::grpc::Status( + ::grpc::StatusCode::UNIMPLEMENTED, + "This service does not have an authentication mechanism enabled.")); + } + GrpcServerAuthSender outgoing{stream}; + GrpcServerAuthReader incoming{stream}; + RETURN_WITH_MIDDLEWARE(flight_context, + auth_handler_->Authenticate(&outgoing, &incoming)); + } + + ::grpc::Status ListFlights(ServerContext* context, const pb::Criteria* request, + ServerWriter* writer) { + GrpcServerCallContext flight_context(context); + GRPC_RETURN_NOT_GRPC_OK( + CheckAuth(FlightMethod::ListFlights, context, flight_context)); + + // Retrieve the listing from the implementation + std::unique_ptr listing; + + Criteria criteria; + if (request) { + SERVICE_RETURN_NOT_OK(flight_context, internal::FromProto(*request, &criteria)); + } + SERVICE_RETURN_NOT_OK( + flight_context, impl_->base()->ListFlights(flight_context, &criteria, &listing)); + if (!listing) { + // Treat null listing as no flights available + RETURN_WITH_MIDDLEWARE(flight_context, ::grpc::Status::OK); + } + RETURN_WITH_MIDDLEWARE(flight_context, + WriteStream(listing.get(), writer)); + } + + ::grpc::Status GetFlightInfo(ServerContext* context, + const pb::FlightDescriptor* request, + pb::FlightInfo* response) { + GrpcServerCallContext flight_context(context); + GRPC_RETURN_NOT_GRPC_OK( + CheckAuth(FlightMethod::GetFlightInfo, context, flight_context)); + + CHECK_ARG_NOT_NULL(flight_context, request, "FlightDescriptor cannot be null"); + + FlightDescriptor descr; + SERVICE_RETURN_NOT_OK(flight_context, internal::FromProto(*request, &descr)); + + std::unique_ptr info; + SERVICE_RETURN_NOT_OK(flight_context, + impl_->base()->GetFlightInfo(flight_context, descr, &info)); + + if (!info) { + // Treat null listing as no flights available + RETURN_WITH_MIDDLEWARE(flight_context, ::grpc::Status(::grpc::StatusCode::NOT_FOUND, + "Flight not found")); + } + + SERVICE_RETURN_NOT_OK(flight_context, internal::ToProto(*info, response)); + RETURN_WITH_MIDDLEWARE(flight_context, ::grpc::Status::OK); + } + + ::grpc::Status GetSchema(ServerContext* context, const pb::FlightDescriptor* request, + pb::SchemaResult* response) { + GrpcServerCallContext flight_context(context); + GRPC_RETURN_NOT_GRPC_OK(CheckAuth(FlightMethod::GetSchema, context, flight_context)); + + CHECK_ARG_NOT_NULL(flight_context, request, "FlightDescriptor cannot be null"); + + FlightDescriptor descr; + SERVICE_RETURN_NOT_OK(flight_context, internal::FromProto(*request, &descr)); + + std::unique_ptr result; + SERVICE_RETURN_NOT_OK(flight_context, + impl_->base()->GetSchema(flight_context, descr, &result)); + + if (!result) { + // Treat null listing as no flights available + RETURN_WITH_MIDDLEWARE(flight_context, ::grpc::Status(::grpc::StatusCode::NOT_FOUND, + "Flight not found")); + } + + SERVICE_RETURN_NOT_OK(flight_context, internal::ToProto(*result, response)); + RETURN_WITH_MIDDLEWARE(flight_context, ::grpc::Status::OK); + } + + ::grpc::Status DoGet(ServerContext* context, const pb::Ticket* request, + ServerWriter* writer) { + GrpcServerCallContext flight_context(context); + GRPC_RETURN_NOT_GRPC_OK(CheckAuth(FlightMethod::DoGet, context, flight_context)); + + CHECK_ARG_NOT_NULL(flight_context, request, "ticket cannot be null"); + + Ticket ticket; + SERVICE_RETURN_NOT_OK(flight_context, internal::FromProto(*request, &ticket)); + + GetDataStream stream(writer); + RETURN_WITH_MIDDLEWARE(flight_context, + impl_->DoGet(flight_context, std::move(ticket), &stream)); + } + + ::grpc::Status DoPut( + ServerContext* context, + ::grpc::ServerReaderWriter* reader) { + GrpcServerCallContext flight_context(context); + GRPC_RETURN_NOT_GRPC_OK(CheckAuth(FlightMethod::DoPut, context, flight_context)); + + PutDataStream stream(reader); + RETURN_WITH_MIDDLEWARE(flight_context, impl_->DoPut(flight_context, &stream)); + } + + ::grpc::Status DoExchange( + ServerContext* context, + ::grpc::ServerReaderWriter* stream) { + GrpcServerCallContext flight_context(context); + GRPC_RETURN_NOT_GRPC_OK(CheckAuth(FlightMethod::DoExchange, context, flight_context)); + + ExchangeDataStream data_stream(stream); + RETURN_WITH_MIDDLEWARE(flight_context, + impl_->DoExchange(flight_context, &data_stream)); + } + + ::grpc::Status ListActions(ServerContext* context, const pb::Empty* request, + ServerWriter* writer) { + GrpcServerCallContext flight_context(context); + GRPC_RETURN_NOT_GRPC_OK( + CheckAuth(FlightMethod::ListActions, context, flight_context)); + // Retrieve the listing from the implementation + std::vector types; + SERVICE_RETURN_NOT_OK(flight_context, + impl_->base()->ListActions(flight_context, &types)); + RETURN_WITH_MIDDLEWARE(flight_context, WriteStream(types, writer)); + } + + ::grpc::Status DoAction(ServerContext* context, const pb::Action* request, + ServerWriter* writer) { + GrpcServerCallContext flight_context(context); + GRPC_RETURN_NOT_GRPC_OK(CheckAuth(FlightMethod::DoAction, context, flight_context)); + CHECK_ARG_NOT_NULL(flight_context, request, "Action cannot be null"); + Action action; + SERVICE_RETURN_NOT_OK(flight_context, internal::FromProto(*request, &action)); + + std::unique_ptr results; + SERVICE_RETURN_NOT_OK(flight_context, + impl_->base()->DoAction(flight_context, action, &results)); + + if (!results) { + RETURN_WITH_MIDDLEWARE(flight_context, ::grpc::Status::CANCELLED); + } + + while (true) { + std::unique_ptr result; + SERVICE_RETURN_NOT_OK(flight_context, results->Next(&result)); + if (!result) { + // No more results + break; + } + pb::Result pb_result; + SERVICE_RETURN_NOT_OK(flight_context, internal::ToProto(*result, &pb_result)); + if (!writer->Write(pb_result)) { + // Stream may be closed + break; + } + } + RETURN_WITH_MIDDLEWARE(flight_context, ::grpc::Status::OK); + } + + private: + std::shared_ptr auth_handler_; + std::vector>> + middleware_; + internal::ServerTransport* impl_; +}; + +// The ServerTransport implementation for gRPC. Manages the gRPC server itself. +class GrpcServerTransport : public internal::ServerTransport { + public: + using internal::ServerTransport::ServerTransport; + + static arrow::Result> Make( + FlightServerBase* base, std::shared_ptr memory_manager) { + return std::unique_ptr( + new GrpcServerTransport(base, std::move(memory_manager))); + } + + Status Init(const FlightServerOptions& options, + const arrow::internal::Uri& uri) override { + grpc_service_.reset( + new GrpcServiceHandler(options.auth_handler, options.middleware, this)); + + ::grpc::ServerBuilder builder; + // Allow uploading messages of any length + builder.SetMaxReceiveMessageSize(-1); + + const std::string scheme = uri.scheme(); + int port = 0; + if (scheme == kSchemeGrpc || scheme == kSchemeGrpcTcp || scheme == kSchemeGrpcTls) { + std::stringstream address; + address << arrow::internal::UriEncodeHost(uri.host()) << ':' << uri.port_text(); + + std::shared_ptr<::grpc::ServerCredentials> creds; + if (scheme == kSchemeGrpcTls) { + ::grpc::SslServerCredentialsOptions ssl_options; + for (const auto& pair : options.tls_certificates) { + ssl_options.pem_key_cert_pairs.push_back({pair.pem_key, pair.pem_cert}); + } + if (options.verify_client) { + ssl_options.client_certificate_request = + GRPC_SSL_REQUEST_AND_REQUIRE_CLIENT_CERTIFICATE_AND_VERIFY; + } + if (!options.root_certificates.empty()) { + ssl_options.pem_root_certs = options.root_certificates; + } + creds = ::grpc::SslServerCredentials(ssl_options); + } else { + creds = ::grpc::InsecureServerCredentials(); + } + + builder.AddListeningPort(address.str(), creds, &port); + } else if (scheme == kSchemeGrpcUnix) { + std::stringstream address; + address << "unix:" << uri.path(); + builder.AddListeningPort(address.str(), ::grpc::InsecureServerCredentials()); + location_ = options.location; + } else { + return Status::NotImplemented("Scheme is not supported: " + scheme); + } + + builder.RegisterService(grpc_service_.get()); + + // Disable SO_REUSEPORT - it makes debugging/testing a pain as + // leftover processes can handle requests on accident + builder.AddChannelArgument(GRPC_ARG_ALLOW_REUSEPORT, 0); + + if (options.builder_hook) { + options.builder_hook(&builder); + } + + grpc_server_ = builder.BuildAndStart(); + if (!grpc_server_) { + return Status::UnknownError("Server did not start properly"); + } + + if (scheme == kSchemeGrpcTls) { + RETURN_NOT_OK(Location::ForGrpcTls(uri.host(), port, &location_)); + } else if (scheme == kSchemeGrpc || scheme == kSchemeGrpcTcp) { + RETURN_NOT_OK(Location::ForGrpcTcp(uri.host(), port, &location_)); + } + return Status::OK(); + } + Status Shutdown() override { + grpc_server_->Shutdown(); + return Status::OK(); + } + Status Shutdown(const std::chrono::system_clock::time_point& deadline) override { + grpc_server_->Shutdown(deadline); + return Status::OK(); + } + Status Wait() override { + grpc_server_->Wait(); + return Status::OK(); + } + Location location() const override { return location_; } + + private: + std::unique_ptr grpc_service_; + std::unique_ptr<::grpc::Server> grpc_server_; + Location location_; +}; + +std::once_flag kGrpcServerTransportInitialized; +} // namespace + +void InitializeFlightGrpcServer() { + std::call_once(kGrpcServerTransportInitialized, []() { + auto* registry = flight::internal::GetDefaultTransportRegistry(); + for (const auto& transport : {"grpc", "grpc+tls", "grpc+tcp", "grpc+unix"}) { + ARROW_CHECK_OK(registry->RegisterServer(transport, GrpcServerTransport::Make)); + } + }); +} + +#undef CHECK_ARG_NOT_NULL +#undef RETURN_WITH_MIDDLEWARE +#undef SERVICE_RETURN_NOT_OK + +} // namespace grpc +} // namespace transport +} // namespace flight +} // namespace arrow diff --git a/cpp/src/arrow/flight/transport/grpc/grpc_server.h b/cpp/src/arrow/flight/transport/grpc/grpc_server.h new file mode 100644 index 00000000000..025ee70b204 --- /dev/null +++ b/cpp/src/arrow/flight/transport/grpc/grpc_server.h @@ -0,0 +1,36 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +// gRPC-based transport for Flight. + +#pragma once + +#include "arrow/flight/visibility.h" + +namespace arrow { +namespace flight { +namespace transport { +namespace grpc { + +/// \brief Register the gRPC transport implementation. Idempotent. +ARROW_FLIGHT_EXPORT +void InitializeFlightGrpcServer(); + +} // namespace grpc +} // namespace transport +} // namespace flight +} // namespace arrow diff --git a/cpp/src/arrow/flight/protocol_internal.cc b/cpp/src/arrow/flight/transport/grpc/protocol_grpc_internal.cc similarity index 89% rename from cpp/src/arrow/flight/protocol_internal.cc rename to cpp/src/arrow/flight/transport/grpc/protocol_grpc_internal.cc index 9f815398e48..63cfccdbb5f 100644 --- a/cpp/src/arrow/flight/protocol_internal.cc +++ b/cpp/src/arrow/flight/transport/grpc/protocol_grpc_internal.cc @@ -14,13 +14,12 @@ // KIND, either express or implied. See the License for the // specific language governing permissions and limitations -#include "arrow/flight/protocol_internal.h" +#include "arrow/flight/transport/grpc/protocol_grpc_internal.h" // NOTE(wesm): Including .cc files in another .cc file would ordinarily be a // no-no. We have customized the serialization path for FlightData, which is // currently only possible through some pre-processor commands that need to be // included before either of these files is compiled. Because we don't want to // edit the generated C++ files, we include them here and do our gRPC -// customizations in protocol-internal.h +// customizations in protocol_grpc_internal.h #include "arrow/flight/Flight.grpc.pb.cc" // NOLINT -#include "arrow/flight/Flight.pb.cc" // NOLINT diff --git a/cpp/src/arrow/flight/transport/grpc/protocol_grpc_internal.h b/cpp/src/arrow/flight/transport/grpc/protocol_grpc_internal.h new file mode 100644 index 00000000000..4445cf789bb --- /dev/null +++ b/cpp/src/arrow/flight/transport/grpc/protocol_grpc_internal.h @@ -0,0 +1,27 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations + +#pragma once + +// This addresses platform-specific defines, e.g. on Windows +#include "arrow/flight/platform.h" // IWYU pragma: keep + +// This header holds the Flight gRPC definitions. + +// Need to include this first to get our gRPC customizations +#include "arrow/flight/transport/grpc/customize_grpc.h" // IWYU pragma: export + +#include "arrow/flight/Flight.grpc.pb.h" // IWYU pragma: export diff --git a/cpp/src/arrow/flight/transport/grpc/serialization_internal.cc b/cpp/src/arrow/flight/transport/grpc/serialization_internal.cc new file mode 100644 index 00000000000..e51da615bbd --- /dev/null +++ b/cpp/src/arrow/flight/transport/grpc/serialization_internal.cc @@ -0,0 +1,481 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#include "arrow/flight/transport/grpc/serialization_internal.h" + +// todo cleanup includes + +#include +#include +#include +#include + +#include "arrow/flight/platform.h" + +#if defined(_MSC_VER) +#pragma warning(push) +#pragma warning(disable : 4267) +#endif + +#include +#include +#include + +#include +#ifdef GRPCPP_PP_INCLUDE +#include +#include +#else +#include +#include +#endif + +#if defined(_MSC_VER) +#pragma warning(pop) +#endif + +#include "arrow/buffer.h" +#include "arrow/device.h" +#include "arrow/flight/serialization_internal.h" +#include "arrow/flight/transport.h" +#include "arrow/flight/transport/grpc/util_internal.h" +#include "arrow/ipc/message.h" +#include "arrow/ipc/writer.h" +#include "arrow/util/bit_util.h" +#include "arrow/util/logging.h" + +namespace arrow { +namespace flight { +namespace transport { +namespace grpc { + +namespace pb = arrow::flight::protocol; + +static constexpr int64_t kInt32Max = std::numeric_limits::max(); +using google::protobuf::internal::WireFormatLite; +using google::protobuf::io::ArrayOutputStream; +using google::protobuf::io::CodedInputStream; +using google::protobuf::io::CodedOutputStream; + +using ::grpc::ByteBuffer; + +bool ReadBytesZeroCopy(const std::shared_ptr& source_data, + CodedInputStream* input, std::shared_ptr* out) { + uint32_t length; + if (!input->ReadVarint32(&length)) { + return false; + } + auto buf = + SliceBuffer(source_data, input->CurrentPosition(), static_cast(length)); + *out = buf; + return input->Skip(static_cast(length)); +} + +// Internal wrapper for gRPC ByteBuffer so its memory can be exposed to Arrow +// consumers with zero-copy +class GrpcBuffer : public MutableBuffer { + public: + GrpcBuffer(grpc_slice slice, bool incref) + : MutableBuffer(GRPC_SLICE_START_PTR(slice), + static_cast(GRPC_SLICE_LENGTH(slice))), + slice_(incref ? grpc_slice_ref(slice) : slice) {} + + ~GrpcBuffer() override { + // Decref slice + grpc_slice_unref(slice_); + } + + static Status Wrap(ByteBuffer* cpp_buf, std::shared_ptr* out) { + // These types are guaranteed by static assertions in gRPC to have the same + // in-memory representation + + auto buffer = *reinterpret_cast(cpp_buf); + + // This part below is based on the Flatbuffers gRPC SerializationTraits in + // flatbuffers/grpc.h + + // Check if this is a single uncompressed slice. + if ((buffer->type == GRPC_BB_RAW) && + (buffer->data.raw.compression == GRPC_COMPRESS_NONE) && + (buffer->data.raw.slice_buffer.count == 1)) { + // If it is, then we can reference the `grpc_slice` directly. + grpc_slice slice = buffer->data.raw.slice_buffer.slices[0]; + + if (slice.refcount) { + // Increment reference count so this memory remains valid + *out = std::make_shared(slice, true); + } else { + // Small slices (less than GRPC_SLICE_INLINED_SIZE bytes) are + // inlined into the structure and must be copied. + const uint8_t length = slice.data.inlined.length; + ARROW_ASSIGN_OR_RAISE(*out, arrow::AllocateBuffer(length)); + std::memcpy((*out)->mutable_data(), slice.data.inlined.bytes, length); + } + } else { + // Otherwise, we need to use `grpc_byte_buffer_reader_readall` to read + // `buffer` into a single contiguous `grpc_slice`. The gRPC reader gives + // us back a new slice with the refcount already incremented. + grpc_byte_buffer_reader reader; + if (!grpc_byte_buffer_reader_init(&reader, buffer)) { + return Status::IOError("Internal gRPC error reading from ByteBuffer"); + } + grpc_slice slice = grpc_byte_buffer_reader_readall(&reader); + grpc_byte_buffer_reader_destroy(&reader); + + // Steal the slice reference + *out = std::make_shared(slice, false); + } + + return Status::OK(); + } + + private: + grpc_slice slice_; +}; + +// Destructor callback for grpc::Slice +static void ReleaseBuffer(void* buf_ptr) { + delete reinterpret_cast*>(buf_ptr); +} + +// Initialize gRPC Slice from arrow Buffer +arrow::Result<::grpc::Slice> SliceFromBuffer(const std::shared_ptr& buf) { + // Allocate persistent shared_ptr to control Buffer lifetime + std::shared_ptr* ptr = nullptr; + if (ARROW_PREDICT_TRUE(buf->is_cpu())) { + ptr = new std::shared_ptr(buf); + } else { + // Non-CPU buffer, must copy to CPU-accessible buffer first + ARROW_ASSIGN_OR_RAISE(auto cpu_buf, + Buffer::ViewOrCopy(buf, default_cpu_memory_manager())); + ptr = new std::shared_ptr(cpu_buf); + } + ::grpc::Slice slice(const_cast((*ptr)->data()), + static_cast((*ptr)->size()), &ReleaseBuffer, ptr); + // Make sure no copy was done (some grpc::Slice() constructors do an implicit memcpy) + DCHECK_EQ(slice.begin(), (*ptr)->data()); + return slice; +} + +static const uint8_t kPaddingBytes[8] = {0, 0, 0, 0, 0, 0, 0, 0}; + +// Update the sizes of our Protobuf fields based on the given IPC payload. +::grpc::Status IpcMessageHeaderSize(const arrow::ipc::IpcPayload& ipc_msg, bool has_body, + size_t* header_size, int32_t* metadata_size) { + DCHECK_LE(ipc_msg.metadata->size(), kInt32Max); + *metadata_size = static_cast(ipc_msg.metadata->size()); + + // 1 byte for metadata tag + *header_size += 1 + WireFormatLite::LengthDelimitedSize(*metadata_size); + + // 2 bytes for body tag + if (has_body) { + // We write the body tag in the header but not the actual body data + *header_size += 2 + WireFormatLite::LengthDelimitedSize(ipc_msg.body_length) - + ipc_msg.body_length; + } + + return ::grpc::Status::OK; +} + +::grpc::Status FlightDataSerialize(const FlightPayload& msg, ByteBuffer* out, + bool* own_buffer) { + // Size of the IPC body (protobuf: data_body) + size_t body_size = 0; + // Size of the Protobuf "header" (everything except for the body) + size_t header_size = 0; + // Size of IPC header metadata (protobuf: data_header) + int32_t metadata_size = 0; + + // Write the descriptor if present + int32_t descriptor_size = 0; + if (msg.descriptor != nullptr) { + DCHECK_LE(msg.descriptor->size(), kInt32Max); + descriptor_size = static_cast(msg.descriptor->size()); + header_size += 1 + WireFormatLite::LengthDelimitedSize(descriptor_size); + } + + // App metadata tag if appropriate + int32_t app_metadata_size = 0; + if (msg.app_metadata && msg.app_metadata->size() > 0) { + DCHECK_LE(msg.app_metadata->size(), kInt32Max); + app_metadata_size = static_cast(msg.app_metadata->size()); + header_size += 1 + WireFormatLite::LengthDelimitedSize(app_metadata_size); + } + + const arrow::ipc::IpcPayload& ipc_msg = msg.ipc_message; + // No data in this payload (metadata-only). + bool has_ipc = ipc_msg.type != ipc::MessageType::NONE; + bool has_body = has_ipc ? ipc::Message::HasBody(ipc_msg.type) : false; + + if (has_ipc) { + DCHECK(has_body || ipc_msg.body_length == 0); + GRPC_RETURN_NOT_GRPC_OK( + IpcMessageHeaderSize(ipc_msg, has_body, &header_size, &metadata_size)); + body_size = static_cast(ipc_msg.body_length); + } + + // TODO(wesm): messages over 2GB unlikely to be yet supported + // Validated in WritePayload since returning error here causes gRPC to fail an assertion + DCHECK_LE(body_size, kInt32Max); + + // Allocate and initialize slices + std::vector<::grpc::Slice> slices; + slices.emplace_back(header_size); + + // Force the header_stream to be destructed, which actually flushes + // the data into the slice. + { + ArrayOutputStream header_writer(const_cast(slices[0].begin()), + static_cast(slices[0].size())); + CodedOutputStream header_stream(&header_writer); + + // Write descriptor + if (msg.descriptor != nullptr) { + WireFormatLite::WriteTag(pb::FlightData::kFlightDescriptorFieldNumber, + WireFormatLite::WIRETYPE_LENGTH_DELIMITED, &header_stream); + header_stream.WriteVarint32(descriptor_size); + header_stream.WriteRawMaybeAliased(msg.descriptor->data(), + static_cast(msg.descriptor->size())); + } + + // Write header + if (has_ipc) { + WireFormatLite::WriteTag(pb::FlightData::kDataHeaderFieldNumber, + WireFormatLite::WIRETYPE_LENGTH_DELIMITED, &header_stream); + header_stream.WriteVarint32(metadata_size); + header_stream.WriteRawMaybeAliased(ipc_msg.metadata->data(), + static_cast(ipc_msg.metadata->size())); + } + + // Write app metadata + if (app_metadata_size > 0) { + WireFormatLite::WriteTag(pb::FlightData::kAppMetadataFieldNumber, + WireFormatLite::WIRETYPE_LENGTH_DELIMITED, &header_stream); + header_stream.WriteVarint32(app_metadata_size); + header_stream.WriteRawMaybeAliased(msg.app_metadata->data(), + static_cast(msg.app_metadata->size())); + } + + if (has_body) { + // Write body tag + WireFormatLite::WriteTag(pb::FlightData::kDataBodyFieldNumber, + WireFormatLite::WIRETYPE_LENGTH_DELIMITED, &header_stream); + header_stream.WriteVarint32(static_cast(body_size)); + + // Enqueue body buffers for writing, without copying + for (const auto& buffer : ipc_msg.body_buffers) { + // Buffer may be null when the row length is zero, or when all + // entries are invalid. + if (!buffer) continue; + + ::grpc::Slice slice; + auto status = SliceFromBuffer(buffer).Value(&slice); + if (ARROW_PREDICT_FALSE(!status.ok())) { + // This will likely lead to abort as gRPC cannot recover from an error here + return ToGrpcStatus(status); + } + slices.push_back(std::move(slice)); + + // Write padding if not multiple of 8 + const auto remainder = static_cast( + bit_util::RoundUpToMultipleOf8(buffer->size()) - buffer->size()); + if (remainder) { + slices.push_back(::grpc::Slice(kPaddingBytes, remainder)); + } + } + } + + DCHECK_EQ(static_cast(header_size), header_stream.ByteCount()); + } + + // Hand off the slices to the returned ByteBuffer + *out = ::grpc::ByteBuffer(slices.data(), slices.size()); + *own_buffer = true; + return ::grpc::Status::OK; +} + +// Read internal::FlightData from grpc::ByteBuffer containing FlightData +// protobuf without copying +::grpc::Status FlightDataDeserialize(ByteBuffer* buffer, + arrow::flight::internal::FlightData* out) { + if (!buffer) { + return ::grpc::Status(::grpc::StatusCode::INTERNAL, "No payload"); + } + + // Reset fields in case the caller reuses a single allocation + out->descriptor = nullptr; + out->app_metadata = nullptr; + out->metadata = nullptr; + out->body = nullptr; + + std::shared_ptr wrapped_buffer; + GRPC_RETURN_NOT_OK(GrpcBuffer::Wrap(buffer, &wrapped_buffer)); + + auto buffer_length = static_cast(wrapped_buffer->size()); + CodedInputStream pb_stream(wrapped_buffer->data(), buffer_length); + + pb_stream.SetTotalBytesLimit(buffer_length); + + // This is the bytes remaining when using CodedInputStream like this + while (pb_stream.BytesUntilTotalBytesLimit()) { + const uint32_t tag = pb_stream.ReadTag(); + const int field_number = WireFormatLite::GetTagFieldNumber(tag); + switch (field_number) { + case pb::FlightData::kFlightDescriptorFieldNumber: { + pb::FlightDescriptor pb_descriptor; + uint32_t length; + if (!pb_stream.ReadVarint32(&length)) { + return ::grpc::Status(::grpc::StatusCode::INTERNAL, + "Unable to parse length of FlightDescriptor"); + } + // Can't use ParseFromCodedStream as this reads the entire + // rest of the stream into the descriptor command field. + std::string buffer; + pb_stream.ReadString(&buffer, length); + if (!pb_descriptor.ParseFromString(buffer)) { + return ::grpc::Status(::grpc::StatusCode::INTERNAL, + "Unable to parse FlightDescriptor"); + } + arrow::flight::FlightDescriptor descriptor; + GRPC_RETURN_NOT_OK( + arrow::flight::internal::FromProto(pb_descriptor, &descriptor)); + out->descriptor.reset(new arrow::flight::FlightDescriptor(descriptor)); + } break; + case pb::FlightData::kDataHeaderFieldNumber: { + if (!ReadBytesZeroCopy(wrapped_buffer, &pb_stream, &out->metadata)) { + return ::grpc::Status(::grpc::StatusCode::INTERNAL, + "Unable to read FlightData metadata"); + } + } break; + case pb::FlightData::kAppMetadataFieldNumber: { + if (!ReadBytesZeroCopy(wrapped_buffer, &pb_stream, &out->app_metadata)) { + return ::grpc::Status(::grpc::StatusCode::INTERNAL, + "Unable to read FlightData application metadata"); + } + } break; + case pb::FlightData::kDataBodyFieldNumber: { + if (!ReadBytesZeroCopy(wrapped_buffer, &pb_stream, &out->body)) { + return ::grpc::Status(::grpc::StatusCode::INTERNAL, + "Unable to read FlightData body"); + } + } break; + default: + DCHECK(false) << "cannot happen"; + } + } + buffer->Clear(); + + // TODO(wesm): Where and when should we verify that the FlightData is not + // malformed? + + // Set the default value for an unspecified FlightData body. The other + // fields can be null if they're unspecified. + if (out->body == nullptr) { + out->body = std::make_shared(nullptr, 0); + } + + return ::grpc::Status::OK; +} + +// The pointer bitcast hack below causes legitimate warnings, silence them. +#ifndef _WIN32 +#pragma GCC diagnostic push +#pragma GCC diagnostic ignored "-Wstrict-aliasing" +#endif + +// Pointer bitcast explanation: grpc::*Writer::Write() and grpc::*Reader::Read() +// both take a T* argument (here pb::FlightData*). But they don't do anything +// with that argument except pass it to SerializationTraits::Serialize() and +// SerializationTraits::Deserialize(). +// +// Since we control SerializationTraits, we can interpret the +// pointer argument whichever way we want, including cast it back to the original type. +// (see customize_grpc.h). + +arrow::Result WritePayload( + const FlightPayload& payload, + ::grpc::ClientReaderWriter* writer) { + RETURN_NOT_OK(payload.Validate()); + // Pretend to be pb::FlightData and intercept in SerializationTraits + return writer->Write(*reinterpret_cast(&payload), + ::grpc::WriteOptions()); +} + +arrow::Result WritePayload( + const FlightPayload& payload, + ::grpc::ClientReaderWriter* writer) { + RETURN_NOT_OK(payload.Validate()); + // Pretend to be pb::FlightData and intercept in SerializationTraits + return writer->Write(*reinterpret_cast(&payload), + ::grpc::WriteOptions()); +} + +arrow::Result WritePayload( + const FlightPayload& payload, + ::grpc::ServerReaderWriter* writer) { + RETURN_NOT_OK(payload.Validate()); + // Pretend to be pb::FlightData and intercept in SerializationTraits + return writer->Write(*reinterpret_cast(&payload), + ::grpc::WriteOptions()); +} + +arrow::Result WritePayload(const FlightPayload& payload, + ::grpc::ServerWriter* writer) { + RETURN_NOT_OK(payload.Validate()); + // Pretend to be pb::FlightData and intercept in SerializationTraits + return writer->Write(*reinterpret_cast(&payload), + ::grpc::WriteOptions()); +} + +bool ReadPayload(::grpc::ClientReader* reader, + flight::internal::FlightData* data) { + // Pretend to be pb::FlightData and intercept in SerializationTraits + return reader->Read(reinterpret_cast(data)); +} + +bool ReadPayload(::grpc::ClientReaderWriter* reader, + flight::internal::FlightData* data) { + // Pretend to be pb::FlightData and intercept in SerializationTraits + return reader->Read(reinterpret_cast(data)); +} + +bool ReadPayload(::grpc::ServerReaderWriter* reader, + flight::internal::FlightData* data) { + // Pretend to be pb::FlightData and intercept in SerializationTraits + return reader->Read(reinterpret_cast(data)); +} + +bool ReadPayload(::grpc::ServerReaderWriter* reader, + flight::internal::FlightData* data) { + // Pretend to be pb::FlightData and intercept in SerializationTraits + return reader->Read(reinterpret_cast(data)); +} + +bool ReadPayload(::grpc::ClientReaderWriter* reader, + pb::PutResult* data) { + return reader->Read(data); +} + +#ifndef _WIN32 +#pragma GCC diagnostic pop +#endif + +} // namespace grpc +} // namespace transport +} // namespace flight +} // namespace arrow diff --git a/cpp/src/arrow/flight/transport/grpc/serialization_internal.h b/cpp/src/arrow/flight/transport/grpc/serialization_internal.h new file mode 100644 index 00000000000..5c347fd4f81 --- /dev/null +++ b/cpp/src/arrow/flight/transport/grpc/serialization_internal.h @@ -0,0 +1,71 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +// (De)serialization utilities that hook into gRPC, efficiently +// handling Arrow-encoded data in a gRPC call. + +#pragma once + +#include + +#include "arrow/flight/protocol_internal.h" +#include "arrow/flight/transport/grpc/protocol_grpc_internal.h" +#include "arrow/flight/type_fwd.h" +#include "arrow/result.h" + +namespace arrow { +namespace flight { +namespace transport { +namespace grpc { + +namespace pb = arrow::flight::protocol; + +/// Write Flight message on gRPC stream with zero-copy optimizations. +// Returns Invalid if the payload is ill-formed +// Returns true if the payload was written, false if it was not +// (likely due to disconnect or end-of-stream, e.g. via an +// asynchronous cancellation) +arrow::Result WritePayload( + const FlightPayload& payload, + ::grpc::ClientReaderWriter* writer); +arrow::Result WritePayload( + const FlightPayload& payload, + ::grpc::ClientReaderWriter* writer); +arrow::Result WritePayload( + const FlightPayload& payload, + ::grpc::ServerReaderWriter* writer); +arrow::Result WritePayload(const FlightPayload& payload, + ::grpc::ServerWriter* writer); + +/// Read Flight message from gRPC stream with zero-copy optimizations. +/// True is returned on success, false if stream ended. +bool ReadPayload(::grpc::ClientReader* reader, + flight::internal::FlightData* data); +bool ReadPayload(::grpc::ClientReaderWriter* reader, + flight::internal::FlightData* data); +bool ReadPayload(::grpc::ServerReaderWriter* reader, + flight::internal::FlightData* data); +bool ReadPayload(::grpc::ServerReaderWriter* reader, + flight::internal::FlightData* data); +// Overload to make genericity easier in DoPutPayloadWriter +bool ReadPayload(::grpc::ClientReaderWriter* reader, + pb::PutResult* data); + +} // namespace grpc +} // namespace transport +} // namespace flight +} // namespace arrow diff --git a/cpp/src/arrow/flight/transport/grpc/util_internal.cc b/cpp/src/arrow/flight/transport/grpc/util_internal.cc new file mode 100644 index 00000000000..5268df160e9 --- /dev/null +++ b/cpp/src/arrow/flight/transport/grpc/util_internal.cc @@ -0,0 +1,283 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#include "arrow/flight/transport/grpc/util_internal.h" + +#include +#include +#include +#include +#include + +#ifdef GRPCPP_PP_INCLUDE +#include +#else +#include +#endif + +#include "arrow/flight/types.h" +#include "arrow/status.h" + +namespace arrow { +namespace flight { +namespace transport { +namespace grpc { + +const char* kGrpcAuthHeader = "auth-token-bin"; +const char* kGrpcStatusCodeHeader = "x-arrow-status"; +const char* kGrpcStatusMessageHeader = "x-arrow-status-message-bin"; +const char* kGrpcStatusDetailHeader = "x-arrow-status-detail-bin"; +const char* kBinaryErrorDetailsKey = "grpc-status-details-bin"; + +static Status StatusCodeFromString(const ::grpc::string_ref& code_ref, StatusCode* code) { + // Bounce through std::string to get a proper null-terminated C string + const auto code_int = std::atoi(std::string(code_ref.data(), code_ref.size()).c_str()); + switch (code_int) { + case static_cast(StatusCode::OutOfMemory): + case static_cast(StatusCode::KeyError): + case static_cast(StatusCode::TypeError): + case static_cast(StatusCode::Invalid): + case static_cast(StatusCode::IOError): + case static_cast(StatusCode::CapacityError): + case static_cast(StatusCode::IndexError): + case static_cast(StatusCode::UnknownError): + case static_cast(StatusCode::NotImplemented): + case static_cast(StatusCode::SerializationError): + case static_cast(StatusCode::RError): + case static_cast(StatusCode::CodeGenError): + case static_cast(StatusCode::ExpressionValidationError): + case static_cast(StatusCode::ExecutionError): + case static_cast(StatusCode::AlreadyExists): { + *code = static_cast(code_int); + return Status::OK(); + } + default: + // Code is invalid + return Status::UnknownError("Unknown Arrow status code", code_ref); + } +} + +/// Try to extract a status from gRPC trailers. +/// Return Status::OK if found, an error otherwise. +static Status FromGrpcContext(const ::grpc::ClientContext& ctx, Status* status, + std::shared_ptr flight_status_detail) { + const std::multimap<::grpc::string_ref, ::grpc::string_ref>& trailers = + ctx.GetServerTrailingMetadata(); + const auto code_val = trailers.find(kGrpcStatusCodeHeader); + if (code_val == trailers.end()) { + return Status::IOError("Status code header not found"); + } + + const ::grpc::string_ref code_ref = code_val->second; + StatusCode code = {}; + RETURN_NOT_OK(StatusCodeFromString(code_ref, &code)); + + const auto message_val = trailers.find(kGrpcStatusMessageHeader); + if (message_val == trailers.end()) { + return Status::IOError("Status message header not found"); + } + + const ::grpc::string_ref message_ref = message_val->second; + std::string message = std::string(message_ref.data(), message_ref.size()); + const auto detail_val = trailers.find(kGrpcStatusDetailHeader); + if (detail_val != trailers.end()) { + const ::grpc::string_ref detail_ref = detail_val->second; + message += ". Detail: "; + message += std::string(detail_ref.data(), detail_ref.size()); + } + const auto grpc_detail_val = trailers.find(kBinaryErrorDetailsKey); + if (grpc_detail_val != trailers.end()) { + const ::grpc::string_ref detail_ref = grpc_detail_val->second; + std::string bin_detail = std::string(detail_ref.data(), detail_ref.size()); + if (!flight_status_detail) { + flight_status_detail = + std::make_shared(FlightStatusCode::Internal); + } + flight_status_detail->set_extra_info(bin_detail); + } + *status = Status(code, message, flight_status_detail); + return Status::OK(); +} + +/// Convert a gRPC status to an Arrow status, ignoring any +/// implementation-defined headers that encode further detail. +static Status FromGrpcCode(const ::grpc::Status& grpc_status) { + switch (grpc_status.error_code()) { + case ::grpc::StatusCode::OK: + return Status::OK(); + case ::grpc::StatusCode::CANCELLED: + return Status::IOError("gRPC cancelled call, with message: ", + grpc_status.error_message()) + .WithDetail(std::make_shared(FlightStatusCode::Cancelled)); + case ::grpc::StatusCode::UNKNOWN: { + std::stringstream ss; + ss << "Flight RPC failed with message: " << grpc_status.error_message(); + return Status::UnknownError(ss.str()).WithDetail( + std::make_shared(FlightStatusCode::Failed)); + } + case ::grpc::StatusCode::INVALID_ARGUMENT: + return Status::Invalid("gRPC returned invalid argument error, with message: ", + grpc_status.error_message()); + case ::grpc::StatusCode::DEADLINE_EXCEEDED: + return Status::IOError("gRPC returned deadline exceeded error, with message: ", + grpc_status.error_message()) + .WithDetail(std::make_shared(FlightStatusCode::TimedOut)); + case ::grpc::StatusCode::NOT_FOUND: + return Status::KeyError("gRPC returned not found error, with message: ", + grpc_status.error_message()); + case ::grpc::StatusCode::ALREADY_EXISTS: + return Status::AlreadyExists("gRPC returned already exists error, with message: ", + grpc_status.error_message()); + case ::grpc::StatusCode::PERMISSION_DENIED: + return Status::IOError("gRPC returned permission denied error, with message: ", + grpc_status.error_message()) + .WithDetail( + std::make_shared(FlightStatusCode::Unauthorized)); + case ::grpc::StatusCode::RESOURCE_EXHAUSTED: + return Status::Invalid("gRPC returned resource exhausted error, with message: ", + grpc_status.error_message()); + case ::grpc::StatusCode::FAILED_PRECONDITION: + return Status::Invalid("gRPC returned precondition failed error, with message: ", + grpc_status.error_message()); + case ::grpc::StatusCode::ABORTED: + return Status::IOError("gRPC returned aborted error, with message: ", + grpc_status.error_message()) + .WithDetail(std::make_shared(FlightStatusCode::Internal)); + case ::grpc::StatusCode::OUT_OF_RANGE: + return Status::Invalid("gRPC returned out-of-range error, with message: ", + grpc_status.error_message()); + case ::grpc::StatusCode::UNIMPLEMENTED: + return Status::NotImplemented("gRPC returned unimplemented error, with message: ", + grpc_status.error_message()); + case ::grpc::StatusCode::INTERNAL: + return Status::IOError("gRPC returned internal error, with message: ", + grpc_status.error_message()) + .WithDetail(std::make_shared(FlightStatusCode::Internal)); + case ::grpc::StatusCode::UNAVAILABLE: + return Status::IOError("gRPC returned unavailable error, with message: ", + grpc_status.error_message()) + .WithDetail( + std::make_shared(FlightStatusCode::Unavailable)); + case ::grpc::StatusCode::DATA_LOSS: + return Status::IOError("gRPC returned data loss error, with message: ", + grpc_status.error_message()) + .WithDetail(std::make_shared(FlightStatusCode::Internal)); + case ::grpc::StatusCode::UNAUTHENTICATED: + return Status::IOError("gRPC returned unauthenticated error, with message: ", + grpc_status.error_message()) + .WithDetail( + std::make_shared(FlightStatusCode::Unauthenticated)); + default: + return Status::UnknownError("gRPC failed with error code ", + grpc_status.error_code(), + " and message: ", grpc_status.error_message()); + } +} + +Status FromGrpcStatus(const ::grpc::Status& grpc_status, ::grpc::ClientContext* ctx) { + const Status status = FromGrpcCode(grpc_status); + + if (!status.ok() && ctx) { + Status arrow_status; + + if (!FromGrpcContext(*ctx, &arrow_status, FlightStatusDetail::UnwrapStatus(status)) + .ok()) { + // If we fail to decode a more detailed status from the headers, + // proceed normally + return status; + } + + return arrow_status; + } + return status; +} + +/// Convert an Arrow status to a gRPC status. +static ::grpc::Status ToRawGrpcStatus(const Status& arrow_status) { + if (arrow_status.ok()) { + return ::grpc::Status::OK; + } + + ::grpc::StatusCode grpc_code = ::grpc::StatusCode::UNKNOWN; + std::string message = arrow_status.message(); + if (arrow_status.detail()) { + message += ". Detail: "; + message += arrow_status.detail()->ToString(); + } + + std::shared_ptr flight_status = + FlightStatusDetail::UnwrapStatus(arrow_status); + if (flight_status) { + switch (flight_status->code()) { + case FlightStatusCode::Internal: + grpc_code = ::grpc::StatusCode::INTERNAL; + break; + case FlightStatusCode::TimedOut: + grpc_code = ::grpc::StatusCode::DEADLINE_EXCEEDED; + break; + case FlightStatusCode::Cancelled: + grpc_code = ::grpc::StatusCode::CANCELLED; + break; + case FlightStatusCode::Unauthenticated: + grpc_code = ::grpc::StatusCode::UNAUTHENTICATED; + break; + case FlightStatusCode::Unauthorized: + grpc_code = ::grpc::StatusCode::PERMISSION_DENIED; + break; + case FlightStatusCode::Unavailable: + grpc_code = ::grpc::StatusCode::UNAVAILABLE; + break; + default: + break; + } + } else if (arrow_status.IsNotImplemented()) { + grpc_code = ::grpc::StatusCode::UNIMPLEMENTED; + } else if (arrow_status.IsInvalid()) { + grpc_code = ::grpc::StatusCode::INVALID_ARGUMENT; + } else if (arrow_status.IsKeyError()) { + grpc_code = ::grpc::StatusCode::NOT_FOUND; + } else if (arrow_status.IsAlreadyExists()) { + grpc_code = ::grpc::StatusCode::ALREADY_EXISTS; + } + return ::grpc::Status(grpc_code, message); +} + +/// Convert an Arrow status to a gRPC status, and add extra headers to +/// the response to encode the original Arrow status. +::grpc::Status ToGrpcStatus(const Status& arrow_status, ::grpc::ServerContext* ctx) { + ::grpc::Status status = ToRawGrpcStatus(arrow_status); + if (!status.ok() && ctx) { + const std::string code = std::to_string(static_cast(arrow_status.code())); + ctx->AddTrailingMetadata(kGrpcStatusCodeHeader, code); + ctx->AddTrailingMetadata(kGrpcStatusMessageHeader, arrow_status.message()); + if (arrow_status.detail()) { + const std::string detail_string = arrow_status.detail()->ToString(); + ctx->AddTrailingMetadata(kGrpcStatusDetailHeader, detail_string); + } + auto fsd = FlightStatusDetail::UnwrapStatus(arrow_status); + if (fsd && !fsd->extra_info().empty()) { + ctx->AddTrailingMetadata(kBinaryErrorDetailsKey, fsd->extra_info()); + } + } + + return status; +} + +} // namespace grpc +} // namespace transport +} // namespace flight +} // namespace arrow diff --git a/cpp/src/arrow/flight/transport/grpc/util_internal.h b/cpp/src/arrow/flight/transport/grpc/util_internal.h new file mode 100644 index 00000000000..a267e556544 --- /dev/null +++ b/cpp/src/arrow/flight/transport/grpc/util_internal.h @@ -0,0 +1,88 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#pragma once + +#include "arrow/flight/transport/grpc/protocol_grpc_internal.h" +#include "arrow/flight/visibility.h" +#include "arrow/util/macros.h" + +namespace grpc { + +class Status; + +} // namespace grpc + +namespace arrow { + +class Status; + +namespace flight { + +#define GRPC_RETURN_NOT_OK(expr) \ + do { \ + ::arrow::Status _s = (expr); \ + if (ARROW_PREDICT_FALSE(!_s.ok())) { \ + return ::arrow::flight::transport::grpc::ToGrpcStatus(_s); \ + } \ + } while (0) + +#define GRPC_RETURN_NOT_GRPC_OK(expr) \ + do { \ + ::grpc::Status _s = (expr); \ + if (ARROW_PREDICT_FALSE(!_s.ok())) { \ + return _s; \ + } \ + } while (0) + +namespace transport { +namespace grpc { + +/// The name of the header used to pass authentication tokens. +ARROW_FLIGHT_EXPORT +extern const char* kGrpcAuthHeader; + +/// The name of the header used to pass the exact Arrow status code. +ARROW_FLIGHT_EXPORT +extern const char* kGrpcStatusCodeHeader; + +/// The name of the header used to pass the exact Arrow status message. +ARROW_FLIGHT_EXPORT +extern const char* kGrpcStatusMessageHeader; + +/// The name of the header used to pass the exact Arrow status detail. +ARROW_FLIGHT_EXPORT +extern const char* kGrpcStatusDetailHeader; + +ARROW_FLIGHT_EXPORT +extern const char* kBinaryErrorDetailsKey; + +/// Convert a gRPC status to an Arrow status. Optionally, provide a +/// ClientContext to recover the exact Arrow status if it was passed +/// over the wire. +ARROW_FLIGHT_EXPORT +Status FromGrpcStatus(const ::grpc::Status& grpc_status, + ::grpc::ClientContext* ctx = nullptr); + +ARROW_FLIGHT_EXPORT +::grpc::Status ToGrpcStatus(const Status& arrow_status, + ::grpc::ServerContext* ctx = nullptr); + +} // namespace grpc +} // namespace transport +} // namespace flight +} // namespace arrow diff --git a/cpp/src/arrow/flight/transport_server.cc b/cpp/src/arrow/flight/transport_server.cc new file mode 100644 index 00000000000..daeebcc2405 --- /dev/null +++ b/cpp/src/arrow/flight/transport_server.cc @@ -0,0 +1,325 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#include "arrow/flight/transport_server.h" + +#include + +#include "arrow/buffer.h" +#include "arrow/flight/serialization_internal.h" +#include "arrow/flight/server.h" +#include "arrow/flight/types.h" +#include "arrow/ipc/reader.h" +#include "arrow/result.h" +#include "arrow/status.h" + +namespace arrow { +namespace flight { +namespace internal { + +Status ServerDataStream::WritePutMetadata(const Buffer&) { + return Status::NotImplemented("Writing put metadata for this stream"); +} + +namespace { +class TransportIpcMessageReader : public ipc::MessageReader { + public: + TransportIpcMessageReader( + std::shared_ptr peekable_reader, + std::shared_ptr memory_manager, + std::shared_ptr* app_metadata) + : peekable_reader_(peekable_reader), + memory_manager_(std::move(memory_manager)), + app_metadata_(app_metadata) {} + + ::arrow::Result> ReadNextMessage() override { + if (stream_finished_) return nullptr; + internal::FlightData* data; + peekable_reader_->Next(&data); + if (!data) { + stream_finished_ = true; + return nullptr; + } + if (data->body && + ARROW_PREDICT_FALSE(!data->body->device()->Equals(*memory_manager_->device()))) { + ARROW_ASSIGN_OR_RAISE(data->body, Buffer::ViewOrCopy(data->body, memory_manager_)); + } + *app_metadata_ = std::move(data->app_metadata); + return data->OpenMessage(); + } + + protected: + std::shared_ptr peekable_reader_; + std::shared_ptr memory_manager_; + // A reference to TransportDataStream.app_metadata_. That class + // can't access the app metadata because when it Peek()s the stream, + // it may be looking at a dictionary batch, not the record + // batch. Updating it here ensures the reader is always updated with + // the last metadata message read. + std::shared_ptr* app_metadata_; + bool stream_finished_ = false; +}; + +/// \brief Adapt TransportDataStream to the FlightMessageReader +/// interface for DoPut. +class TransportMessageReader final : public FlightMessageReader { + public: + TransportMessageReader(ServerDataStream* stream, + std::shared_ptr memory_manager) + : peekable_reader_(new internal::PeekableFlightDataReader(stream)), + memory_manager_(std::move(memory_manager)) {} + + Status Init() { + // Peek the first message to get the descriptor. + internal::FlightData* data; + peekable_reader_->Peek(&data); + if (!data) { + return Status::IOError("Stream finished before first message sent"); + } + if (!data->descriptor) { + return Status::IOError("Descriptor missing on first message"); + } + descriptor_ = *data->descriptor; + // If there's a schema (=DoPut), also Open(). + if (data->metadata) { + return EnsureDataStarted(); + } + peekable_reader_->Next(&data); + return Status::OK(); + } + + const FlightDescriptor& descriptor() const override { return descriptor_; } + + arrow::Result> GetSchema() override { + RETURN_NOT_OK(EnsureDataStarted()); + return batch_reader_->schema(); + } + + Status Next(FlightStreamChunk* out) override { + internal::FlightData* data; + peekable_reader_->Peek(&data); + if (!data) { + out->app_metadata = nullptr; + out->data = nullptr; + return Status::OK(); + } + + if (!data->metadata) { + // Metadata-only (data->metadata is the IPC header) + out->app_metadata = data->app_metadata; + out->data = nullptr; + peekable_reader_->Next(&data); + return Status::OK(); + } + + if (!batch_reader_) { + RETURN_NOT_OK(EnsureDataStarted()); + // re-peek here since EnsureDataStarted() advances the stream + return Next(out); + } + RETURN_NOT_OK(batch_reader_->ReadNext(&out->data)); + out->app_metadata = std::move(app_metadata_); + return Status::OK(); + } + + private: + /// Ensure we are set up to read data. + Status EnsureDataStarted() { + if (!batch_reader_) { + // peek() until we find the first data message; discard metadata + if (!peekable_reader_->SkipToData()) { + return Status::IOError("Client never sent a data message"); + } + auto message_reader = + std::unique_ptr(new TransportIpcMessageReader( + peekable_reader_, memory_manager_, &app_metadata_)); + ARROW_ASSIGN_OR_RAISE( + batch_reader_, ipc::RecordBatchStreamReader::Open(std::move(message_reader))); + } + return Status::OK(); + } + + FlightDescriptor descriptor_; + std::shared_ptr peekable_reader_; + std::shared_ptr memory_manager_; + std::shared_ptr batch_reader_; + std::shared_ptr app_metadata_; +}; + +// TODO(ARROW-10787): this should use the same writer/ipc trick as client +class TransportMessageWriter final : public FlightMessageWriter { + public: + explicit TransportMessageWriter(ServerDataStream* stream) + : stream_(stream), ipc_options_(::arrow::ipc::IpcWriteOptions::Defaults()) {} + + Status Begin(const std::shared_ptr& schema, + const ipc::IpcWriteOptions& options) override { + if (started_) { + return Status::Invalid("This writer has already been started."); + } + started_ = true; + ipc_options_ = options; + + RETURN_NOT_OK(mapper_.AddSchemaFields(*schema)); + FlightPayload schema_payload; + RETURN_NOT_OK(ipc::GetSchemaPayload(*schema, ipc_options_, mapper_, + &schema_payload.ipc_message)); + return WritePayload(schema_payload); + } + + Status WriteRecordBatch(const RecordBatch& batch) override { + return WriteWithMetadata(batch, nullptr); + } + + Status WriteMetadata(std::shared_ptr app_metadata) override { + FlightPayload payload{}; + payload.app_metadata = app_metadata; + return WritePayload(payload); + } + + Status WriteWithMetadata(const RecordBatch& batch, + std::shared_ptr app_metadata) override { + RETURN_NOT_OK(CheckStarted()); + RETURN_NOT_OK(EnsureDictionariesWritten(batch)); + FlightPayload payload{}; + if (app_metadata) { + payload.app_metadata = app_metadata; + } + RETURN_NOT_OK(ipc::GetRecordBatchPayload(batch, ipc_options_, &payload.ipc_message)); + RETURN_NOT_OK(WritePayload(payload)); + ++stats_.num_record_batches; + return Status::OK(); + } + + Status Close() override { + // It's fine to Close() without writing data + return Status::OK(); + } + + ipc::WriteStats stats() const override { return stats_; } + + private: + Status WritePayload(const FlightPayload& payload) { + ARROW_ASSIGN_OR_RAISE(auto success, stream_->WriteData(payload)); + if (!success) { + return MakeFlightError(FlightStatusCode::Internal, + "Could not write metadata to stream (client disconnect?)"); + } + ++stats_.num_messages; + return Status::OK(); + } + + Status CheckStarted() { + if (!started_) { + return Status::Invalid("This writer is not started. Call Begin() with a schema"); + } + return Status::OK(); + } + + Status EnsureDictionariesWritten(const RecordBatch& batch) { + if (dictionaries_written_) { + return Status::OK(); + } + dictionaries_written_ = true; + ARROW_ASSIGN_OR_RAISE(const auto dictionaries, + ipc::CollectDictionaries(batch, mapper_)); + for (const auto& pair : dictionaries) { + FlightPayload payload{}; + RETURN_NOT_OK(ipc::GetDictionaryPayload(pair.first, pair.second, ipc_options_, + &payload.ipc_message)); + RETURN_NOT_OK(WritePayload(payload)); + ++stats_.num_dictionary_batches; + } + return Status::OK(); + } + + ServerDataStream* stream_; + ::arrow::ipc::IpcWriteOptions ipc_options_; + ipc::DictionaryFieldMapper mapper_; + ipc::WriteStats stats_; + bool started_ = false; + bool dictionaries_written_ = false; +}; + +/// \brief Adapt TransportDataStream to the FlightMetadataWriter +/// interface for DoPut. +class TransportMetadataWriter final : public FlightMetadataWriter { + public: + explicit TransportMetadataWriter(ServerDataStream* stream) : stream_(stream) {} + + Status WriteMetadata(const Buffer& buffer) override { + return stream_->WritePutMetadata(buffer); + } + + private: + ServerDataStream* stream_; +}; +} // namespace + +Status ServerTransport::DoGet(const ServerCallContext& context, const Ticket& ticket, + ServerDataStream* stream) { + std::unique_ptr data_stream; + RETURN_NOT_OK(base_->DoGet(context, ticket, &data_stream)); + + if (!data_stream) return Status::KeyError("No data in this flight"); + + // Write the schema as the first message in the stream + FlightPayload schema_payload; + RETURN_NOT_OK(data_stream->GetSchemaPayload(&schema_payload)); + ARROW_ASSIGN_OR_RAISE(auto success, stream->WriteData(schema_payload)); + // Connection terminated + if (!success) return Status::OK(); + + // Consume data stream and write out payloads + while (true) { + FlightPayload payload; + RETURN_NOT_OK(data_stream->Next(&payload)); + // End of stream + if (payload.ipc_message.metadata == nullptr) break; + ARROW_ASSIGN_OR_RAISE(auto success, stream->WriteData(payload)); + // Connection terminated + if (!success) return Status::OK(); + } + RETURN_NOT_OK(stream->WritesDone()); + return Status::OK(); +} + +Status ServerTransport::DoPut(const ServerCallContext& context, + ServerDataStream* stream) { + std::unique_ptr reader( + new TransportMessageReader(stream, memory_manager_)); + std::unique_ptr writer(new TransportMetadataWriter(stream)); + RETURN_NOT_OK(reader->Init()); + RETURN_NOT_OK(base_->DoPut(context, std::move(reader), std::move(writer))); + RETURN_NOT_OK(stream->WritesDone()); + return Status::OK(); +} + +Status ServerTransport::DoExchange(const ServerCallContext& context, + ServerDataStream* stream) { + std::unique_ptr reader( + new TransportMessageReader(stream, memory_manager_)); + std::unique_ptr writer(new TransportMessageWriter(stream)); + RETURN_NOT_OK(reader->Init()); + RETURN_NOT_OK(base_->DoExchange(context, std::move(reader), std::move(writer))); + RETURN_NOT_OK(stream->WritesDone()); + return Status::OK(); +} + +} // namespace internal +} // namespace flight +} // namespace arrow diff --git a/cpp/src/arrow/flight/transport_server.h b/cpp/src/arrow/flight/transport_server.h new file mode 100644 index 00000000000..51105a89304 --- /dev/null +++ b/cpp/src/arrow/flight/transport_server.h @@ -0,0 +1,133 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#pragma once + +#include +#include + +#include "arrow/flight/transport.h" +#include "arrow/flight/type_fwd.h" +#include "arrow/flight/visibility.h" +#include "arrow/type_fwd.h" + +namespace arrow { +namespace ipc { +class Message; +} +namespace flight { +namespace internal { + +/// \brief A transport-specific interface for reading/writing Arrow +/// data for a server. +class ARROW_FLIGHT_EXPORT ServerDataStream : public TransportDataStream { + public: + /// \brief Attempt to write a non-data message. + /// + /// Only implemented for DoPut; mutually exclusive with + /// WriteData(const FlightPayload&). + virtual Status WritePutMetadata(const Buffer& payload); +}; + +/// \brief An implementation of a Flight server for a particular +/// transport. +/// +/// This class (the transport implementation) implements the underlying +/// server and handles connections/incoming RPC calls. It should forward RPC +/// calls to the RPC handlers defined on this class, which work in terms of +/// the generic interfaces above. The RPC handlers here then forward calls +/// to the underlying FlightServerBase instance that contains the actual +/// application RPC method handlers. +/// +/// Used by FlightServerBase to manage the server lifecycle. +class ARROW_FLIGHT_EXPORT ServerTransport { + public: + ServerTransport(FlightServerBase* base, std::shared_ptr memory_manager) + : base_(base), memory_manager_(std::move(memory_manager)) {} + virtual ~ServerTransport() = default; + + /// \name Server Lifecycle Methods + /// Transports implement these methods to start/shutdown the underlying + /// server. + /// @{ + /// \brief Initialize the server. + /// + /// This method should launch the server in a background thread, i.e. it + /// should not block. Once this returns, the server should be active. + virtual Status Init(const FlightServerOptions& options, + const arrow::internal::Uri& uri) = 0; + /// \brief Shutdown the server. + /// + /// This should wait for active RPCs to finish. Once this returns, the + /// server is no longer listening. + virtual Status Shutdown() = 0; + /// \brief Shutdown the server with a deadline. + /// + /// This should wait for active RPCs to finish, or for the deadline to + /// expire. Once this returns, the server is no longer listening. + virtual Status Shutdown(const std::chrono::system_clock::time_point& deadline) = 0; + /// \brief Wait for the server to shutdown (but do not shut down the server). + /// + /// Once this returns, the server is no longer listening. + virtual Status Wait() = 0; + /// \brief Get the address the server is listening on, else an empty Location. + virtual Location location() const = 0; + ///@} + + /// \name RPC Handlers + /// Implementations of RPC handlers for Flight methods using the common + /// interfaces here. Transports should call these methods from their + /// server implementation to handle the actual RPC calls. + ///@{ + /// \brief Get the FlightServerBase. + /// + /// Intended as an escape hatch for now since not all methods have been + /// factored into a transport-agnostic interface. + FlightServerBase* base() const { return base_; } + /// \brief Implement DoGet in terms of a transport-level stream. + /// + /// \param[in] context The server context. + /// \param[in] request The request payload. + /// \param[in] stream The transport-specific data stream + /// implementation. Must implement WriteData(const + /// FlightPayload&). + Status DoGet(const ServerCallContext& context, const Ticket& request, + ServerDataStream* stream); + /// \brief Implement DoPut in terms of a transport-level stream. + /// + /// \param[in] context The server context. + /// \param[in] stream The transport-specific data stream + /// implementation. Must implement ReadData(FlightData*) + /// and WritePutMetadata(const Buffer&). + Status DoPut(const ServerCallContext& context, ServerDataStream* stream); + /// \brief Implement DoExchange in terms of a transport-level stream. + /// + /// \param[in] context The server context. + /// \param[in] stream The transport-specific data stream + /// implementation. Must implement ReadData(FlightData*) + /// and WriteData(const FlightPayload&). + Status DoExchange(const ServerCallContext& context, ServerDataStream* stream); + ///@} + + protected: + FlightServerBase* base_; + std::shared_ptr memory_manager_; +}; + +} // namespace internal +} // namespace flight +} // namespace arrow diff --git a/cpp/src/arrow/flight/type_fwd.h b/cpp/src/arrow/flight/type_fwd.h new file mode 100644 index 00000000000..c82c4e6d8f5 --- /dev/null +++ b/cpp/src/arrow/flight/type_fwd.h @@ -0,0 +1,59 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#pragma once + +namespace arrow { +namespace internal { +class Uri; +} +namespace flight { +struct Action; +struct ActionType; +struct BasicAuth; +class ClientAuthHandler; +class ClientMiddleware; +class ClientMiddlewareFactory; +struct Criteria; +class FlightCallOptions; +struct FlightClientOptions; +struct FlightDescriptor; +struct FlightEndpoint; +class FlightInfo; +class FlightListing; +class FlightMetadataReader; +class FlightMetadataWriter; +struct FlightPayload; +class FlightServerBase; +class FlightServerOptions; +class FlightStreamReader; +class FlightStreamWriter; +struct Location; +struct Result; +class ResultStream; +struct SchemaResult; +class ServerCallContext; +class ServerMiddleware; +class ServerMiddlewareFactory; +struct Ticket; +namespace internal { +class ClientTransport; +struct FlightData; +class ServerTransport; +} // namespace internal +} // namespace flight +} // namespace arrow diff --git a/cpp/src/arrow/flight/types.cc b/cpp/src/arrow/flight/types.cc index 6338c36567b..3dc3c1645ef 100644 --- a/cpp/src/arrow/flight/types.cc +++ b/cpp/src/arrow/flight/types.cc @@ -188,7 +188,7 @@ bool Ticket::Equals(const Ticket& other) const { return ticket == other.ticket; arrow::Result Ticket::SerializeToString() const { pb::Ticket pb_ticket; - internal::ToProto(*this, &pb_ticket); + RETURN_NOT_OK(internal::ToProto(*this, &pb_ticket)); std::string out; if (!pb_ticket.SerializeToString(&out)) { diff --git a/cpp/src/arrow/ipc/writer.cc b/cpp/src/arrow/ipc/writer.cc index b33e8946bf3..cf5a08bf3bf 100644 --- a/cpp/src/arrow/ipc/writer.cc +++ b/cpp/src/arrow/ipc/writer.cc @@ -1371,9 +1371,10 @@ namespace internal { Result> OpenRecordBatchWriter( std::unique_ptr sink, const std::shared_ptr& schema, const IpcWriteOptions& options) { - // XXX should we call Start()? - return ::arrow::internal::make_unique( + auto writer = ::arrow::internal::make_unique( std::move(sink), schema, options, /*is_file_format=*/false); + RETURN_NOT_OK(writer->Start()); + return std::move(writer); } Result> MakePayloadStreamWriter( diff --git a/cpp/src/arrow/python/flight.cc b/cpp/src/arrow/python/flight.cc index 1204c64f920..6bc5659c74d 100644 --- a/cpp/src/arrow/python/flight.cc +++ b/cpp/src/arrow/python/flight.cc @@ -18,7 +18,7 @@ #include #include -#include "arrow/flight/internal.h" +#include "arrow/flight/serialization_internal.h" #include "arrow/python/flight.h" #include "arrow/util/io_util.h" #include "arrow/util/logging.h" diff --git a/docs/source/cpp/api/flight.rst b/docs/source/cpp/api/flight.rst index 7cefd66ef84..40bdfe008d8 100644 --- a/docs/source/cpp/api/flight.rst +++ b/docs/source/cpp/api/flight.rst @@ -200,3 +200,9 @@ error codes. .. doxygenfunction:: arrow::flight::MakeFlightError :project: arrow_cpp + +Implementing Custom Transports +============================== + +.. doxygenfile:: arrow/flight/transport_impl.h + :sections: briefdescription detaileddescription