From 03320f31028135b3365283e465803d25c3a2e50a Mon Sep 17 00:00:00 2001
From: David Li
Date: Fri, 18 Feb 2022 14:20:45 -0500
Subject: [PATCH 01/30] ARROW-15282: [C++][FlightRPC] Split data methods from
the underlying transport
---
cpp/src/arrow/flight/CMakeLists.txt | 2 +
cpp/src/arrow/flight/client.cc | 867 ++++++++----------
cpp/src/arrow/flight/client.h | 9 +-
cpp/src/arrow/flight/serialization_internal.h | 27 +-
cpp/src/arrow/flight/server.cc | 556 ++++-------
cpp/src/arrow/flight/server.h | 4 +
cpp/src/arrow/flight/sql/client.cc | 2 +-
cpp/src/arrow/flight/test_util.h | 29 +-
cpp/src/arrow/flight/transport_impl.cc | 163 ++++
cpp/src/arrow/flight/transport_impl.h | 221 +++++
cpp/src/arrow/flight/transport_server_impl.cc | 327 +++++++
cpp/src/arrow/flight/type_fwd.h | 47 +
cpp/src/arrow/ipc/writer.cc | 5 +-
13 files changed, 1389 insertions(+), 870 deletions(-)
create mode 100644 cpp/src/arrow/flight/transport_impl.cc
create mode 100644 cpp/src/arrow/flight/transport_impl.h
create mode 100644 cpp/src/arrow/flight/transport_server_impl.cc
create mode 100644 cpp/src/arrow/flight/type_fwd.h
diff --git a/cpp/src/arrow/flight/CMakeLists.txt b/cpp/src/arrow/flight/CMakeLists.txt
index cae36562be3..14b3c4efe0b 100644
--- a/cpp/src/arrow/flight/CMakeLists.txt
+++ b/cpp/src/arrow/flight/CMakeLists.txt
@@ -175,6 +175,8 @@ set(ARROW_FLIGHT_SRCS
serialization_internal.cc
server.cc
server_auth.cc
+ transport_impl.cc
+ transport_server_impl.cc
types.cc)
add_arrow_lib(arrow_flight
diff --git a/cpp/src/arrow/flight/client.cc b/cpp/src/arrow/flight/client.cc
index 39d6cebf41f..e18f5f37f6f 100644
--- a/cpp/src/arrow/flight/client.cc
+++ b/cpp/src/arrow/flight/client.cc
@@ -57,6 +57,7 @@
#include "arrow/flight/middleware.h"
#include "arrow/flight/middleware_internal.h"
#include "arrow/flight/serialization_internal.h"
+#include "arrow/flight/transport_impl.h"
#include "arrow/flight/types.h"
namespace arrow {
@@ -127,137 +128,6 @@ struct ClientRpc {
}
};
-/// Helper that manages Finish() of a gRPC stream.
-///
-/// When we encounter an error (e.g. could not decode an IPC message),
-/// we want to provide both the client-side error context and any
-/// available server-side context. This helper helps wrap up that
-/// logic.
-///
-/// This class protects the stream with a flag (so that Finish is
-/// idempotent), and drains the read side (so that Finish won't hang).
-///
-/// The template lets us abstract between DoGet/DoExchange and DoPut,
-/// which respectively read internal::FlightData and pb::PutResult.
-template
-class FinishableStream {
- public:
- FinishableStream(std::shared_ptr rpc, std::shared_ptr stream)
- : rpc_(rpc), stream_(stream), finished_(false), server_status_() {}
- virtual ~FinishableStream() = default;
-
- /// \brief Get the underlying stream.
- std::shared_ptr stream() const { return stream_; }
-
- /// \brief Finish the call, adding server context to the given status.
- virtual Status Finish(Status st) {
- if (finished_) {
- return MergeStatus(std::move(st));
- }
-
- // Drain the read side, as otherwise gRPC Finish() will hang. We
- // only call Finish() when the client closes the writer or the
- // reader finishes, so it's OK to assume the client no longer
- // wants to read and drain the read side. (If the client wants to
- // indicate that it is done writing, but not done reading, it
- // should use DoneWriting.
- ReadT message;
- while (internal::ReadPayload(stream_.get(), &message)) {
- // Drain the read side to avoid gRPC hanging in Finish()
- }
-
- server_status_ = internal::FromGrpcStatus(stream_->Finish(), &rpc_->context);
- finished_ = true;
-
- return MergeStatus(std::move(st));
- }
-
- private:
- Status MergeStatus(Status&& st) {
- if (server_status_.ok()) {
- return std::move(st);
- }
- return Status::FromDetailAndArgs(
- server_status_.code(), server_status_.detail(), server_status_.message(),
- ". Client context: ", st.ToString(),
- ". gRPC client debug context: ", rpc_->context.debug_error_string());
- }
-
- std::shared_ptr rpc_;
- std::shared_ptr stream_;
- bool finished_;
- Status server_status_;
-};
-
-/// Helper that manages \a Finish() of a read-write gRPC stream.
-///
-/// This also calls \a WritesDone() and protects itself with a mutex
-/// to enable sharing between the reader and writer.
-template
-class FinishableWritableStream : public FinishableStream {
- public:
- FinishableWritableStream(std::shared_ptr rpc,
- std::shared_ptr read_mutex,
- std::shared_ptr stream)
- : FinishableStream(rpc, stream),
- finish_mutex_(),
- read_mutex_(read_mutex),
- done_writing_(false) {}
- virtual ~FinishableWritableStream() = default;
-
- /// \brief Indicate to gRPC that the write half of the stream is done.
- Status DoneWriting() {
- // This is only used by the writer side of a stream, so it need
- // not be protected with a lock.
- if (done_writing_) {
- return Status::OK();
- }
- done_writing_ = true;
- if (!this->stream()->WritesDone()) {
- // Error happened, try to close the stream to get more detailed info
- return Finish(MakeFlightError(FlightStatusCode::Internal,
- "Could not flush pending record batches"));
- }
- return Status::OK();
- }
-
- Status Finish(Status st) override {
- // This may be used concurrently by reader/writer side of a
- // stream, so it needs to be protected.
- std::lock_guard guard(finish_mutex_);
-
- // Now that we're shared between a reader and writer, we need to
- // protect ourselves from being called while there's an
- // outstanding read.
- std::unique_lock read_guard(*read_mutex_, std::try_to_lock);
- if (!read_guard.owns_lock()) {
- return MakeFlightError(
- FlightStatusCode::Internal,
- "Cannot close stream with pending read operation. Client context: " +
- st.ToString());
- }
-
- // Try to flush pending writes. Don't use our WritesDone() to
- // avoid recursion.
- bool finished_writes = done_writing_ || this->stream()->WritesDone();
- done_writing_ = true;
-
- st = FinishableStream::Finish(std::move(st));
-
- if (!finished_writes) {
- return Status::FromDetailAndArgs(
- st.code(), st.detail(), st.message(),
- ". Additionally, could not finish writing record batches before closing");
- }
- return st;
- }
-
- private:
- std::mutex finish_mutex_;
- std::shared_ptr read_mutex_;
- bool done_writing_;
-};
-
class GrpcAddCallHeaders : public AddCallHeaders {
public:
explicit GrpcAddCallHeaders(std::multimap* metadata)
@@ -428,22 +298,18 @@ class GrpcClientAuthReader : public ClientAuthReader {
stream_;
};
-// An ipc::MessageReader that adapts any readable gRPC stream
-// returning FlightData.
-template
-class GrpcIpcMessageReader : public ipc::MessageReader {
+/// \brief An ipc::MessageReader adapting the Flight ClientDataStream interface.
+///
+/// In order to support app_metadata and reuse the existing IPC
+/// infrastructure, this takes a pointer to a buffer (provided by the
+/// FlightStreamReader implementation) and upon reading a message,
+/// updates that buffer with the one read from the server.
+class IpcMessageReader : public ipc::MessageReader {
public:
- GrpcIpcMessageReader(
- std::shared_ptr rpc, std::shared_ptr memory_manager,
- std::shared_ptr read_mutex,
- std::shared_ptr> stream,
- std::shared_ptr>>
- peekable_reader,
- std::shared_ptr* app_metadata)
- : rpc_(rpc),
- memory_manager_(std::move(memory_manager)),
- read_mutex_(read_mutex),
- stream_(std::move(stream)),
+ IpcMessageReader(std::shared_ptr stream,
+ std::shared_ptr peekable_reader,
+ std::shared_ptr* app_metadata)
+ : stream_(std::move(stream)),
peekable_reader_(peekable_reader),
app_metadata_(app_metadata),
stream_finished_(false) {}
@@ -453,20 +319,11 @@ class GrpcIpcMessageReader : public ipc::MessageReader {
return nullptr;
}
internal::FlightData* data;
- {
- auto guard = read_mutex_ ? std::unique_lock(*read_mutex_)
- : std::unique_lock();
- peekable_reader_->Next(&data);
- }
+ peekable_reader_->Next(&data);
if (!data) {
stream_finished_ = true;
return stream_->Finish(Status::OK());
}
-
- if (ARROW_PREDICT_FALSE(!memory_manager_->is_cpu() && data->body)) {
- ARROW_ASSIGN_OR_RAISE(data->body, Buffer::ViewOrCopy(data->body, memory_manager_));
- }
-
// Validate IPC message
auto result = data->OpenMessage();
if (!result.ok()) {
@@ -477,15 +334,9 @@ class GrpcIpcMessageReader : public ipc::MessageReader {
}
private:
- // The RPC context lifetime must be coupled to the ClientReader
- std::shared_ptr rpc_;
+ std::shared_ptr stream_;
+ std::shared_ptr peekable_reader_;
std::shared_ptr memory_manager_;
- // Guard reads with a mutex to prevent concurrent reads if the write
- // side calls Finish(). Nullable as DoGet doesn't need this.
- std::shared_ptr read_mutex_;
- std::shared_ptr> stream_;
- std::shared_ptr>>
- peekable_reader_;
// A reference to GrpcStreamReader.app_metadata_. That class
// can't access the app metadata because when it Peek()s the stream,
// it may be looking at a dictionary batch, not the record
@@ -495,34 +346,21 @@ class GrpcIpcMessageReader : public ipc::MessageReader {
bool stream_finished_;
};
-/// The implementation of the public-facing API for reading from a
-/// FlightData stream
-template
-class GrpcStreamReader : public FlightStreamReader {
+/// \brief A reader for any ClientDataStream.
+class ClientStreamReader : public FlightStreamReader {
public:
- GrpcStreamReader(std::shared_ptr rpc,
- std::shared_ptr memory_manager,
- std::shared_ptr read_mutex,
- const ipc::IpcReadOptions& options, StopToken stop_token,
- std::shared_ptr> stream)
- : rpc_(rpc),
- memory_manager_(memory_manager ? std::move(memory_manager)
- : CPUDevice::Instance()->default_memory_manager()),
- read_mutex_(read_mutex),
+ ClientStreamReader(std::shared_ptr stream,
+ const ipc::IpcReadOptions& options, StopToken stop_token)
+ : stream_(std::move(stream)),
options_(options),
stop_token_(std::move(stop_token)),
- stream_(stream),
- peekable_reader_(new internal::PeekableFlightDataReader>(
- stream->stream())),
+ peekable_reader_(new internal::PeekableFlightDataReader(stream_.get())),
app_metadata_(nullptr) {}
Status EnsureDataStarted() {
if (!batch_reader_) {
bool skipped_to_data = false;
- {
- auto guard = TakeGuard();
- skipped_to_data = peekable_reader_->SkipToData();
- }
+ skipped_to_data = peekable_reader_->SkipToData();
// peek() until we find the first data message; discard metadata
if (!skipped_to_data) {
return OverrideWithServerError(MakeFlightError(
@@ -530,8 +368,7 @@ class GrpcStreamReader : public FlightStreamReader {
}
auto message_reader = std::unique_ptr(
- new GrpcIpcMessageReader(rpc_, memory_manager_, read_mutex_, stream_,
- peekable_reader_, &app_metadata_));
+ new IpcMessageReader(stream_, peekable_reader_, &app_metadata_));
auto result =
ipc::RecordBatchStreamReader::Open(std::move(message_reader), options_);
RETURN_NOT_OK(OverrideWithServerError(std::move(result).Value(&batch_reader_)));
@@ -544,10 +381,7 @@ class GrpcStreamReader : public FlightStreamReader {
}
Status Next(FlightStreamChunk* out) override {
internal::FlightData* data;
- {
- auto guard = TakeGuard();
- peekable_reader_->Peek(&data);
- }
+ peekable_reader_->Peek(&data);
if (!data) {
out->app_metadata = nullptr;
out->data = nullptr;
@@ -558,10 +392,7 @@ class GrpcStreamReader : public FlightStreamReader {
// Metadata-only (data->metadata is the IPC header)
out->app_metadata = data->app_metadata;
out->data = nullptr;
- {
- auto guard = TakeGuard();
- peekable_reader_->Next(&data);
- }
+ peekable_reader_->Next(&data);
return Status::OK();
}
@@ -596,14 +427,9 @@ class GrpcStreamReader : public FlightStreamReader {
return ReadAll(table, stop_token_);
}
using FlightStreamReader::ReadAll;
- void Cancel() override { rpc_->context.TryCancel(); }
+ void Cancel() override { stream_->TryCancel(); }
private:
- std::unique_lock TakeGuard() {
- return read_mutex_ ? std::unique_lock(*read_mutex_)
- : std::unique_lock();
- }
-
Status OverrideWithServerError(Status&& st) {
if (st.ok()) {
return std::move(st);
@@ -611,183 +437,247 @@ class GrpcStreamReader : public FlightStreamReader {
return stream_->Finish(std::move(st));
}
- friend class GrpcIpcMessageReader;
- std::shared_ptr rpc_;
- std::shared_ptr memory_manager_;
- // Guard reads with a lock to prevent Finish()/Close() from being
- // called on the writer while the reader has a pending
- // read. Nullable, as DoGet() doesn't need this.
- std::shared_ptr read_mutex_;
+ std::shared_ptr stream_;
ipc::IpcReadOptions options_;
StopToken stop_token_;
- std::shared_ptr> stream_;
- std::shared_ptr>>
- peekable_reader_;
+ std::shared_ptr peekable_reader_;
std::shared_ptr batch_reader_;
std::shared_ptr app_metadata_;
};
-// The next two classes implement writing to a FlightData stream.
-// Similarly to the read side, we want to reuse the implementation of
-// RecordBatchWriter. As a result, these two classes are intertwined
-// in order to pass application metadata "through" RecordBatchWriter.
-// In order to get application-specific metadata to the
-// IpcPayloadWriter, DoPutPayloadWriter takes a pointer to
-// GrpcStreamWriter. GrpcStreamWriter updates a metadata field on
-// write; DoPutPayloadWriter reads that metadata field to determine
-// what to write.
-
-template
-class DoPutPayloadWriter;
-
-template
-class GrpcStreamWriter : public FlightStreamWriter {
+FlightMetadataReader::~FlightMetadataReader() = default;
+
+/// \brief The base of the ClientDataStream implementation for gRPC.
+template
+class FinishableDataStream : public internal::ClientDataStream {
public:
- ~GrpcStreamWriter() override = default;
+ FinishableDataStream(std::shared_ptr rpc, std::shared_ptr stream,
+ std::shared_ptr memory_manager)
+ : rpc_(std::move(rpc)),
+ stream_(std::move(stream)),
+ memory_manager_(memory_manager ? std::move(memory_manager)
+ : CPUDevice::Instance()->default_memory_manager()),
+ finished_(false) {}
- using GrpcStream = grpc::ClientReaderWriter;
+ Status Finish() override {
+ if (finished_) {
+ return server_status_;
+ }
- explicit GrpcStreamWriter(
- const FlightDescriptor& descriptor, std::shared_ptr rpc,
- int64_t write_size_limit_bytes, const ipc::IpcWriteOptions& options,
- std::shared_ptr> writer)
- : app_metadata_(nullptr),
- batch_writer_(nullptr),
- writer_(std::move(writer)),
- rpc_(std::move(rpc)),
- write_size_limit_bytes_(write_size_limit_bytes),
- options_(options),
- descriptor_(descriptor),
- writer_closed_(false) {}
+ // Drain the read side, as otherwise gRPC Finish() will hang. We
+ // only call Finish() when the client closes the writer or the
+ // reader finishes, so it's OK to assume the client no longer
+ // wants to read and drain the read side. (If the client wants to
+ // indicate that it is done writing, but not done reading, it
+ // should use DoneWriting.
+ ReadPayload message;
+ while (internal::ReadPayload(stream_.get(), &message)) {
+ // Drain the read side to avoid gRPC hanging in Finish()
+ }
- static Status Open(
- const FlightDescriptor& descriptor, std::shared_ptr schema,
- const ipc::IpcWriteOptions& options, std::shared_ptr rpc,
- int64_t write_size_limit_bytes,
- std::shared_ptr> writer,
- std::unique_ptr* out);
+ server_status_ = internal::FromGrpcStatus(stream_->Finish(), &rpc_->context);
+ if (!server_status_.ok()) {
+ server_status_ = Status::FromDetailAndArgs(
+ server_status_.code(), server_status_.detail(), server_status_.message(),
+ ". gRPC client debug context: ", rpc_->context.debug_error_string());
+ }
+ finished_ = true;
- Status CheckStarted() {
- if (!batch_writer_) {
- return Status::Invalid("Writer not initialized. Call Begin() with a schema.");
+ return server_status_;
+ }
+ void TryCancel() override { rpc_->context.TryCancel(); }
+
+ std::shared_ptr rpc_;
+ std::shared_ptr stream_;
+ std::shared_ptr memory_manager_;
+ bool finished_;
+ Status server_status_;
+};
+
+/// \brief A ClientDataStream implementation for gRPC that manages a
+/// mutex to protect from concurrent reads/writes, and drains the
+/// read side on finish.
+template
+class WritableDataStream : public FinishableDataStream {
+ public:
+ using Base = FinishableDataStream;
+ WritableDataStream(std::shared_ptr rpc, std::shared_ptr stream,
+ std::shared_ptr memory_manager)
+ : Base(std::move(rpc), std::move(stream), std::move(memory_manager)),
+ read_mutex_(),
+ finish_mutex_(),
+ done_writing_(false) {}
+
+ Status WritesDone() override {
+ // This is only used by the writer side of a stream, so it need
+ // not be protected with a lock.
+ if (done_writing_) {
+ return Status::OK();
+ }
+ done_writing_ = true;
+ if (!stream_->WritesDone()) {
+ // Error happened, try to close the stream to get more detailed info
+ return internal::ClientDataStream::Finish(MakeFlightError(
+ FlightStatusCode::Internal, "Could not flush pending record batches"));
}
return Status::OK();
}
- Status Begin(const std::shared_ptr& schema,
- const ipc::IpcWriteOptions& options) override {
- if (batch_writer_) {
- return Status::Invalid("This writer has already been started.");
+ Status Finish() override {
+ // This may be used concurrently by reader/writer side of a
+ // stream, so it needs to be protected.
+ std::lock_guard guard(finish_mutex_);
+
+ // Now that we're shared between a reader and writer, we need to
+ // protect ourselves from being called while there's an
+ // outstanding read.
+ std::unique_lock read_guard(read_mutex_, std::try_to_lock);
+ if (!read_guard.owns_lock()) {
+ return MakeFlightError(FlightStatusCode::Internal,
+ "Cannot close stream with pending read operation.");
}
- std::unique_ptr payload_writer(
- new DoPutPayloadWriter(
- descriptor_, std::move(rpc_), write_size_limit_bytes_, writer_, this));
- // XXX: this does not actually write the message to the stream.
- // See Close().
- ARROW_ASSIGN_OR_RAISE(batch_writer_, ipc::internal::OpenRecordBatchWriter(
- std::move(payload_writer), schema, options));
- return Status::OK();
+
+ // Try to flush pending writes. Don't use our WritesDone() to
+ // avoid recursion.
+ bool finished_writes = done_writing_ || stream_->WritesDone();
+ done_writing_ = true;
+
+ Status st = Base::Finish();
+ if (!finished_writes) {
+ return Status::FromDetailAndArgs(
+ st.code(), st.detail(), st.message(),
+ ". Additionally, could not finish writing record batches before closing");
+ }
+ return st;
}
- Status Begin(const std::shared_ptr& schema) override {
- return Begin(schema, options_);
+ using Base::stream_;
+ std::mutex read_mutex_;
+ std::mutex finish_mutex_;
+ bool done_writing_;
+};
+
+class GrpcClientGetStream
+ : public FinishableDataStream,
+ internal::FlightData> {
+ public:
+ using FinishableDataStream::FinishableDataStream;
+
+ bool ReadData(internal::FlightData* data) override {
+ bool success = internal::ReadPayload(stream_.get(), data);
+ if (ARROW_PREDICT_FALSE(!success)) return false;
+ if (data->body &&
+ ARROW_PREDICT_FALSE(!data->body->device()->Equals(*memory_manager_->device()))) {
+ auto status = Buffer::ViewOrCopy(data->body, memory_manager_).Value(&data->body);
+ if (!status.ok()) {
+ server_status_ = std::move(status);
+ return false;
+ }
+ }
+ return true;
}
+ Status WritesDone() override { return Status::NotImplemented("NYI"); }
+};
- Status WriteRecordBatch(const RecordBatch& batch) override {
- RETURN_NOT_OK(CheckStarted());
- return WriteWithMetadata(batch, nullptr);
+class GrpcClientPutStream
+ : public WritableDataStream,
+ pb::PutResult> {
+ public:
+ using Stream = grpc::ClientReaderWriter;
+ GrpcClientPutStream(std::shared_ptr rpc, std::shared_ptr stream,
+ std::shared_ptr memory_manager)
+ : WritableDataStream(std::move(rpc), std::move(stream), std::move(memory_manager)) {
}
- Status WriteMetadata(std::shared_ptr app_metadata) override {
- FlightPayload payload{};
- payload.app_metadata = app_metadata;
- auto status = internal::WritePayload(payload, writer_->stream().get());
+ bool ReadPutMetadata(std::shared_ptr* out) override {
+ std::lock_guard guard(read_mutex_);
+ pb::PutResult message;
+ if (stream_->Read(&message)) {
+ *out = Buffer::FromString(std::move(*message.mutable_app_metadata()));
+ } else {
+ // Stream finished
+ *out = nullptr;
+ }
+ return true;
+ }
+ Status WriteData(const FlightPayload& payload) override {
+ auto status = internal::WritePayload(payload, this->stream_.get());
if (status.IsIOError()) {
- return writer_->Finish(MakeFlightError(FlightStatusCode::Internal,
- "Could not write metadata to stream"));
+ return internal::ClientDataStream::Finish(MakeFlightError(
+ FlightStatusCode::Internal, "Could not write record batch to stream"));
}
return status;
}
+};
- Status WriteWithMetadata(const RecordBatch& batch,
- std::shared_ptr app_metadata) override {
- RETURN_NOT_OK(CheckStarted());
- app_metadata_ = app_metadata;
- return batch_writer_->WriteRecordBatch(batch);
+class GrpcClientExchangeStream
+ : public WritableDataStream,
+ internal::FlightData> {
+ public:
+ using Stream = grpc::ClientReaderWriter;
+ GrpcClientExchangeStream(std::shared_ptr rpc, std::shared_ptr stream,
+ std::shared_ptr memory_manager)
+ : WritableDataStream(std::move(rpc), std::move(stream), std::move(memory_manager)) {
}
- Status DoneWriting() override {
- // Do not CheckStarted - DoneWriting applies to data and metadata
- if (batch_writer_) {
- // Close the writer if we have one; this will force it to flush any
- // remaining data, before we close the write side of the stream.
- writer_closed_ = true;
- Status st = batch_writer_->Close();
- if (!st.ok()) {
- return writer_->Finish(std::move(st));
+ bool ReadData(internal::FlightData* data) override {
+ std::lock_guard guard(read_mutex_);
+ bool success = internal::ReadPayload(stream_.get(), data);
+ if (ARROW_PREDICT_FALSE(!success)) return false;
+ if (data->body &&
+ ARROW_PREDICT_FALSE(!data->body->device()->Equals(*memory_manager_->device()))) {
+ auto status = Buffer::ViewOrCopy(data->body, memory_manager_).Value(&data->body);
+ if (!status.ok()) {
+ server_status_ = std::move(status);
+ return false;
}
}
- return writer_->DoneWriting();
+ return true;
}
-
- Status Close() override {
- // Do not CheckStarted - Close applies to data and metadata
- if (batch_writer_ && !writer_closed_) {
- // This is important! Close() calls
- // IpcPayloadWriter::CheckStarted() which will force the initial
- // schema message to be written to the stream. This is required
- // to unstick the server, else the client and the server end up
- // waiting for each other. This happens if the client never
- // wrote anything before calling Close().
- writer_closed_ = true;
- return writer_->Finish(batch_writer_->Close());
+ Status WriteData(const FlightPayload& payload) override {
+ auto status = internal::WritePayload(payload, this->stream_.get());
+ if (status.IsIOError()) {
+ return internal::ClientDataStream::Finish(MakeFlightError(
+ FlightStatusCode::Internal, "Could not write record batch to stream"));
}
- return writer_->Finish(Status::OK());
+ return status;
}
+};
- ipc::WriteStats stats() const override {
- ARROW_CHECK_NE(batch_writer_, nullptr);
- return batch_writer_->stats();
+class ClientMetadataReader : public FlightMetadataReader {
+ public:
+ explicit ClientMetadataReader(std::shared_ptr stream)
+ : stream_(std::move(stream)) {}
+
+ Status ReadMetadata(std::shared_ptr* out) override {
+ if (!stream_->ReadPutMetadata(out)) {
+ return stream_->Finish(Status::OK());
+ }
+ return Status::OK();
}
private:
- friend class DoPutPayloadWriter;
- std::shared_ptr app_metadata_;
- std::unique_ptr batch_writer_;
- std::shared_ptr> writer_;
-
- // Fields used to lazy-initialize the IpcPayloadWriter. They're
- // invalid once Begin() is called.
- std::shared_ptr rpc_;
- int64_t write_size_limit_bytes_;
- ipc::IpcWriteOptions options_;
- FlightDescriptor descriptor_;
- bool writer_closed_;
+ std::shared_ptr stream_;
};
-/// A IpcPayloadWriter implementation that writes to a gRPC stream of
-/// FlightData messages.
-template
-class DoPutPayloadWriter : public ipc::internal::IpcPayloadWriter {
+/// \brief An IpcPayloadWriter for any ClientDataStream.
+///
+/// To support app_metadata and reuse the existing IPC infrastructure,
+/// this takes a pointer to a buffer to be combined with the IPC
+/// payload when writing a Flight payload.
+class ClientPutPayloadWriter : public ipc::internal::IpcPayloadWriter {
public:
- using GrpcStream = grpc::ClientReaderWriter;
-
- DoPutPayloadWriter(
- const FlightDescriptor& descriptor, std::shared_ptr rpc,
- int64_t write_size_limit_bytes,
- std::shared_ptr> writer,
- GrpcStreamWriter* stream_writer)
- : descriptor_(descriptor),
- rpc_(rpc),
+ explicit ClientPutPayloadWriter(std::shared_ptr stream,
+ FlightDescriptor descriptor,
+ int64_t write_size_limit_bytes,
+ std::shared_ptr* app_metadata)
+ : descriptor_(std::move(descriptor)),
write_size_limit_bytes_(write_size_limit_bytes),
- writer_(std::move(writer)),
- first_payload_(true),
- stream_writer_(stream_writer) {}
-
- ~DoPutPayloadWriter() override = default;
+ stream_(std::move(stream)),
+ app_metadata_(app_metadata),
+ first_payload_(true) {}
Status Start() override { return Status::OK(); }
-
Status WritePayload(const ipc::IpcPayload& ipc_payload) override {
FlightPayload payload;
payload.ipc_message = ipc_payload;
@@ -800,9 +690,8 @@ class DoPutPayloadWriter : public ipc::internal::IpcPayloadWriter {
// Write the descriptor to begin with
RETURN_NOT_OK(internal::ToPayload(descriptor_, &payload.descriptor));
first_payload_ = false;
- } else if (ipc_payload.type == ipc::MessageType::RECORD_BATCH &&
- stream_writer_->app_metadata_) {
- payload.app_metadata = std::move(stream_writer_->app_metadata_);
+ } else if (ipc_payload.type == ipc::MessageType::RECORD_BATCH && *app_metadata_) {
+ payload.app_metadata = std::move(*app_metadata_);
}
if (write_size_limit_bytes_ > 0) {
@@ -821,85 +710,135 @@ class DoPutPayloadWriter : public ipc::internal::IpcPayloadWriter {
std::make_shared(write_size_limit_bytes_, size));
}
}
-
- auto status = internal::WritePayload(payload, writer_->stream().get());
- if (status.IsIOError()) {
- return writer_->Finish(MakeFlightError(FlightStatusCode::Internal,
- "Could not write record batch to stream"));
- }
- return status;
+ return stream_->WriteData(payload);
}
-
Status Close() override {
- // Closing is handled one layer up in GrpcStreamWriter::Close
+ // Closing is handled one layer up in ClientStreamWriter::Close
return Status::OK();
}
- protected:
+ private:
const FlightDescriptor descriptor_;
- std::shared_ptr rpc_;
- int64_t write_size_limit_bytes_;
- std::shared_ptr> writer_;
+ const int64_t write_size_limit_bytes_;
+ std::shared_ptr stream_;
+ std::shared_ptr* app_metadata_;
bool first_payload_;
- GrpcStreamWriter* stream_writer_;
};
-template
-Status GrpcStreamWriter::Open(
- const FlightDescriptor& descriptor,
- std::shared_ptr schema, // this schema is nullable
- const ipc::IpcWriteOptions& options, std::shared_ptr rpc,
- int64_t write_size_limit_bytes,
- std::shared_ptr> writer,
- std::unique_ptr* out) {
- std::unique_ptr> instance(
- new GrpcStreamWriter(
- descriptor, std::move(rpc), write_size_limit_bytes, options, writer));
- if (schema) {
- // The schema was provided (DoPut). Eagerly write the schema and
- // descriptor together as the first message.
- RETURN_NOT_OK(instance->Begin(schema, options));
- } else {
- // The schema was not provided (DoExchange). Eagerly write just
- // the descriptor as the first message. Note that if the client
- // calls Begin() to send data, we'll send a redundant descriptor.
- FlightPayload payload{};
- RETURN_NOT_OK(internal::ToPayload(descriptor, &payload.descriptor));
- auto status = internal::WritePayload(payload, instance->writer_->stream().get());
+class ClientStreamWriter : public FlightStreamWriter {
+ public:
+ explicit ClientStreamWriter(std::shared_ptr stream,
+ const ipc::IpcWriteOptions& options,
+ int64_t write_size_limit_bytes, FlightDescriptor descriptor)
+ : stream_(std::move(stream)),
+ batch_writer_(nullptr),
+ app_metadata_(nullptr),
+ writer_closed_(false),
+ write_options_(options),
+ write_size_limit_bytes_(write_size_limit_bytes),
+ descriptor_(std::move(descriptor)) {}
+
+ Status Begin(const std::shared_ptr& schema,
+ const ipc::IpcWriteOptions& options) override {
+ if (batch_writer_) {
+ return Status::Invalid("This writer has already been started.");
+ }
+ std::unique_ptr payload_writer(
+ new ClientPutPayloadWriter(stream_, std::move(descriptor_),
+ write_size_limit_bytes_, &app_metadata_));
+ // XXX: this does not actually write the message to the stream.
+ // See Close().
+ ARROW_ASSIGN_OR_RAISE(batch_writer_, ipc::internal::OpenRecordBatchWriter(
+ std::move(payload_writer), schema, options));
+ return Status::OK();
+ }
+
+ Status Begin(const std::shared_ptr& schema) override {
+ return Begin(schema, write_options_);
+ }
+
+ // Overload used by FlightClient::DoExchange
+ Status Begin() {
+ FlightPayload payload;
+ RETURN_NOT_OK(internal::ToPayload(descriptor_, &payload.descriptor));
+ RETURN_NOT_OK(stream_->WriteData(payload));
+ return Status::OK();
+ }
+
+ Status WriteRecordBatch(const RecordBatch& batch) override {
+ RETURN_NOT_OK(CheckStarted());
+ return WriteWithMetadata(batch, nullptr);
+ }
+
+ Status WriteMetadata(std::shared_ptr app_metadata) override {
+ FlightPayload payload;
+ payload.app_metadata = app_metadata;
+ auto status = stream_->WriteData(payload);
if (status.IsIOError()) {
- return writer->Finish(MakeFlightError(FlightStatusCode::Internal,
- "Could not write descriptor to stream"));
+ return stream_->Finish(MakeFlightError(FlightStatusCode::Internal,
+ "Could not write metadata to stream"));
}
- RETURN_NOT_OK(status);
+ return status;
}
- *out = std::move(instance);
- return Status::OK();
-}
-FlightMetadataReader::~FlightMetadataReader() = default;
+ Status WriteWithMetadata(const RecordBatch& batch,
+ std::shared_ptr app_metadata) override {
+ RETURN_NOT_OK(CheckStarted());
+ app_metadata_ = app_metadata;
+ return batch_writer_->WriteRecordBatch(batch);
+ }
-class GrpcMetadataReader : public FlightMetadataReader {
- public:
- explicit GrpcMetadataReader(
- std::shared_ptr> reader,
- std::shared_ptr read_mutex)
- : reader_(reader), read_mutex_(read_mutex) {}
+ Status DoneWriting() override {
+ // Do not CheckStarted - DoneWriting applies to data and metadata
+ if (batch_writer_) {
+ // Close the writer if we have one; this will force it to flush any
+ // remaining data, before we close the write side of the stream.
+ writer_closed_ = true;
+ Status st = batch_writer_->Close();
+ if (!st.ok()) {
+ return stream_->Finish(std::move(st));
+ }
+ }
+ return stream_->WritesDone();
+ }
- Status ReadMetadata(std::shared_ptr* out) override {
- std::lock_guard guard(*read_mutex_);
- pb::PutResult message;
- if (reader_->Read(&message)) {
- *out = Buffer::FromString(std::move(*message.mutable_app_metadata()));
- } else {
- // Stream finished
- *out = nullptr;
+ Status Close() override {
+ // Do not CheckStarted - Close applies to data and metadata
+ if (batch_writer_ && !writer_closed_) {
+ // This is important! Close() calls
+ // IpcPayloadWriter::CheckStarted() which will force the initial
+ // schema message to be written to the stream. This is required
+ // to unstick the server, else the client and the server end up
+ // waiting for each other. This happens if the client never
+ // wrote anything before calling Close().
+ writer_closed_ = true;
+ return stream_->Finish(batch_writer_->Close());
}
- return Status::OK();
+ return stream_->Finish(Status::OK());
+ }
+
+ ipc::WriteStats stats() const override {
+ ARROW_CHECK_NE(batch_writer_, nullptr);
+ return batch_writer_->stats();
}
private:
- std::shared_ptr> reader_;
- std::shared_ptr read_mutex_;
+ Status CheckStarted() {
+ if (!batch_writer_) {
+ return Status::Invalid("Writer not initialized. Call Begin() with a schema.");
+ }
+ return Status::OK();
+ }
+
+ std::shared_ptr stream_;
+ std::unique_ptr batch_writer_;
+ std::shared_ptr app_metadata_;
+ bool writer_closed_;
+
+ // Temporary state to construct the IPC payload writer
+ ipc::IpcWriteOptions write_options_;
+ int64_t write_size_limit_bytes_;
+ FlightDescriptor descriptor_;
};
namespace {
@@ -927,16 +866,16 @@ constexpr char kDummyRootCert[] =
"-----END CERTIFICATE-----\n";
#endif
} // namespace
-class FlightClient::FlightClientImpl {
+class GrpcClientImpl : public internal::ClientTransportImpl {
public:
- Status Connect(const Location& location, const FlightClientOptions& options) {
+ Status Init(const FlightClientOptions& options, const Location& location,
+ const arrow::internal::Uri& uri) override {
const std::string& scheme = location.scheme();
std::stringstream grpc_uri;
std::shared_ptr creds;
if (scheme == kSchemeGrpc || scheme == kSchemeGrpcTcp || scheme == kSchemeGrpcTls) {
- grpc_uri << arrow::internal::UriEncodeHost(location.uri_->host()) << ':'
- << location.uri_->port_text();
+ grpc_uri << arrow::internal::UriEncodeHost(uri.host()) << ':' << uri.port_text();
if (scheme == kSchemeGrpcTls) {
if (options.disable_server_verification) {
@@ -1031,10 +970,11 @@ class FlightClient::FlightClientImpl {
creds = grpc::InsecureChannelCredentials();
}
} else if (scheme == kSchemeGrpcUnix) {
- grpc_uri << "unix://" << location.uri_->path();
+ grpc_uri << "unix://" << uri.path();
creds = grpc::InsecureChannelCredentials();
} else {
- return Status::NotImplemented("Flight scheme " + scheme + " is not supported.");
+ return Status::NotImplemented("Flight scheme ", scheme,
+ " is not supported by the gRPC transport");
}
grpc::ChannelArguments args;
@@ -1075,13 +1015,13 @@ class FlightClient::FlightClientImpl {
stub_ = pb::FlightService::NewStub(
grpc::experimental::CreateCustomChannelWithInterceptors(
grpc_uri.str(), creds, args, std::move(interceptors)));
-
- write_size_limit_bytes_ = options.write_size_limit_bytes;
return Status::OK();
}
+ Status Close() override { return Status::OK(); }
+
Status Authenticate(const FlightCallOptions& options,
- std::unique_ptr auth_handler) {
+ std::unique_ptr auth_handler) override {
auth_handler_ = std::move(auth_handler);
ClientRpc rpc(options);
std::shared_ptr>
@@ -1101,7 +1041,7 @@ class FlightClient::FlightClientImpl {
arrow::Result> AuthenticateBasicToken(
const FlightCallOptions& options, const std::string& username,
- const std::string& password) {
+ const std::string& password) override {
// Add basic auth headers to outgoing headers.
ClientRpc rpc(options);
internal::AddBasicAuthHeaders(&rpc.context, username, password);
@@ -1124,7 +1064,7 @@ class FlightClient::FlightClientImpl {
}
Status ListFlights(const FlightCallOptions& options, const Criteria& criteria,
- std::unique_ptr* listing) {
+ std::unique_ptr* listing) override {
pb::Criteria pb_criteria;
RETURN_NOT_OK(internal::ToProto(criteria, &pb_criteria));
@@ -1148,7 +1088,7 @@ class FlightClient::FlightClientImpl {
}
Status DoAction(const FlightCallOptions& options, const Action& action,
- std::unique_ptr* results) {
+ std::unique_ptr* results) override {
pb::Action pb_action;
RETURN_NOT_OK(internal::ToProto(action, &pb_action));
@@ -1173,7 +1113,8 @@ class FlightClient::FlightClientImpl {
return internal::FromGrpcStatus(stream->Finish(), &rpc.context);
}
- Status ListActions(const FlightCallOptions& options, std::vector* types) {
+ Status ListActions(const FlightCallOptions& options,
+ std::vector* types) override {
pb::Empty empty;
ClientRpc rpc(options);
@@ -1194,7 +1135,7 @@ class FlightClient::FlightClientImpl {
Status GetFlightInfo(const FlightCallOptions& options,
const FlightDescriptor& descriptor,
- std::unique_ptr* info) {
+ std::unique_ptr* info) override {
pb::FlightDescriptor pb_descriptor;
pb::FlightInfo pb_response;
@@ -1213,7 +1154,7 @@ class FlightClient::FlightClientImpl {
}
Status GetSchema(const FlightCallOptions& options, const FlightDescriptor& descriptor,
- std::unique_ptr* schema_result) {
+ std::unique_ptr* schema_result) override {
pb::FlightDescriptor pb_descriptor;
pb::SchemaResult pb_response;
@@ -1232,8 +1173,7 @@ class FlightClient::FlightClientImpl {
}
Status DoGet(const FlightCallOptions& options, const Ticket& ticket,
- std::unique_ptr* out) {
- using StreamReader = GrpcStreamReader>;
+ std::unique_ptr* out) override {
pb::Ticket pb_ticket;
internal::ToProto(ticket, &pb_ticket);
@@ -1241,62 +1181,33 @@ class FlightClient::FlightClientImpl {
RETURN_NOT_OK(rpc->SetToken(auth_handler_.get()));
std::shared_ptr> stream =
stub_->DoGet(&rpc->context, pb_ticket);
- auto finishable_stream = std::make_shared<
- FinishableStream, internal::FlightData>>(
- rpc, stream);
- *out = std::unique_ptr(
- new StreamReader(rpc, options.memory_manager, nullptr, options.read_options,
- options.stop_token, finishable_stream));
- // Eagerly read the schema
- return static_cast(out->get())->EnsureDataStarted();
+ *out = std::unique_ptr(new GrpcClientGetStream(
+ std::move(rpc), std::move(stream), options.memory_manager));
+ return Status::OK();
}
- Status DoPut(const FlightCallOptions& options, const FlightDescriptor& descriptor,
- const std::shared_ptr& schema,
- std::unique_ptr* out,
- std::unique_ptr* reader) {
+ Status DoPut(const FlightCallOptions& options,
+ std::unique_ptr* out) override {
using GrpcStream = grpc::ClientReaderWriter;
- using StreamWriter = GrpcStreamWriter;
auto rpc = std::make_shared(options);
RETURN_NOT_OK(rpc->SetToken(auth_handler_.get()));
std::shared_ptr stream = stub_->DoPut(&rpc->context);
- // The writer drains the reader on close to avoid hanging inside
- // gRPC. Concurrent reads are unsafe, so a mutex protects this operation.
- std::shared_ptr read_mutex = std::make_shared();
- auto finishable_stream =
- std::make_shared>(
- rpc, read_mutex, stream);
- *reader =
- std::unique_ptr(new GrpcMetadataReader(stream, read_mutex));
- return StreamWriter::Open(descriptor, schema, options.write_options, rpc,
- write_size_limit_bytes_, finishable_stream, out);
+ *out = std::unique_ptr(new GrpcClientPutStream(
+ std::move(rpc), std::move(stream), options.memory_manager));
+ return Status::OK();
}
- Status DoExchange(const FlightCallOptions& options, const FlightDescriptor& descriptor,
- std::unique_ptr* writer,
- std::unique_ptr* reader) {
+ Status DoExchange(const FlightCallOptions& options,
+ std::unique_ptr* out) override {
using GrpcStream = grpc::ClientReaderWriter;
- using StreamReader = GrpcStreamReader;
- using StreamWriter = GrpcStreamWriter;
auto rpc = std::make_shared(options);
RETURN_NOT_OK(rpc->SetToken(auth_handler_.get()));
- std::shared_ptr> stream =
- stub_->DoExchange(&rpc->context);
- // The writer drains the reader on close to avoid hanging inside
- // gRPC. Concurrent reads are unsafe, so a mutex protects this operation.
- std::shared_ptr read_mutex = std::make_shared();
- auto finishable_stream =
- std::make_shared>(
- rpc, read_mutex, stream);
- *reader = std::unique_ptr(
- new StreamReader(rpc, options.memory_manager, read_mutex, options.read_options,
- options.stop_token, finishable_stream));
- // Do not eagerly read the schema. There may be metadata messages
- // before any data is sent, or data may not be sent at all.
- return StreamWriter::Open(descriptor, nullptr, options.write_options, rpc,
- write_size_limit_bytes_, finishable_stream, writer);
+ std::shared_ptr stream = stub_->DoExchange(&rpc->context);
+ *out = std::unique_ptr(new GrpcClientExchangeStream(
+ std::move(rpc), std::move(stream), options.memory_manager));
+ return Status::OK();
}
private:
@@ -1312,10 +1223,9 @@ class FlightClient::FlightClientImpl {
GRPC_NAMESPACE_FOR_TLS_CREDENTIALS_OPTIONS::TlsServerAuthorizationCheckConfig>
noop_auth_check_;
#endif
- int64_t write_size_limit_bytes_;
};
-FlightClient::FlightClient() { impl_.reset(new FlightClientImpl); }
+FlightClient::FlightClient() {}
FlightClient::~FlightClient() {
auto st = Close();
@@ -1333,7 +1243,16 @@ Status FlightClient::Connect(const Location& location,
Status FlightClient::Connect(const Location& location, const FlightClientOptions& options,
std::unique_ptr* client) {
client->reset(new FlightClient);
- return (*client)->impl_->Connect(location, options);
+ (*client)->write_size_limit_bytes_ = options.write_size_limit_bytes;
+ const auto scheme = location.scheme();
+ if (util::string_view(scheme).starts_with("grpc")) {
+ (*client)->impl_.reset(new GrpcClientImpl);
+ } else {
+ ARROW_ASSIGN_OR_RAISE(
+ (*client)->impl_,
+ internal::GetDefaultTransportImplRegistry()->MakeClientImpl(scheme));
+ }
+ return (*client)->impl_->Init(options, location, *location.uri_);
}
Status FlightClient::Authenticate(const FlightCallOptions& options,
@@ -1390,7 +1309,12 @@ Status FlightClient::ListFlights(const FlightCallOptions& options,
Status FlightClient::DoGet(const FlightCallOptions& options, const Ticket& ticket,
std::unique_ptr* stream) {
RETURN_NOT_OK(CheckOpen());
- return impl_->DoGet(options, ticket, stream);
+ std::unique_ptr remote_stream;
+ RETURN_NOT_OK(impl_->DoGet(options, ticket, &remote_stream));
+ *stream = std::unique_ptr(new ClientStreamReader(
+ std::move(remote_stream), options.read_options, options.stop_token));
+ // Eagerly read the schema
+ return static_cast(stream->get())->EnsureDataStarted();
}
Status FlightClient::DoPut(const FlightCallOptions& options,
@@ -1399,7 +1323,16 @@ Status FlightClient::DoPut(const FlightCallOptions& options,
std::unique_ptr* stream,
std::unique_ptr* reader) {
RETURN_NOT_OK(CheckOpen());
- return impl_->DoPut(options, descriptor, schema, stream, reader);
+ std::unique_ptr remote_stream;
+ RETURN_NOT_OK(impl_->DoPut(options, &remote_stream));
+ std::shared_ptr shared_stream = std::move(remote_stream);
+ *reader =
+ std::unique_ptr(new ClientMetadataReader(shared_stream));
+ *stream = std::unique_ptr(
+ new ClientStreamWriter(std::move(shared_stream), options.write_options,
+ write_size_limit_bytes_, descriptor));
+ RETURN_NOT_OK((*stream)->Begin(schema, options.write_options));
+ return Status::OK();
}
Status FlightClient::DoExchange(const FlightCallOptions& options,
@@ -1407,7 +1340,17 @@ Status FlightClient::DoExchange(const FlightCallOptions& options,
std::unique_ptr* writer,
std::unique_ptr* reader) {
RETURN_NOT_OK(CheckOpen());
- return impl_->DoExchange(options, descriptor, writer, reader);
+ std::unique_ptr remote_stream;
+ RETURN_NOT_OK(impl_->DoExchange(options, &remote_stream));
+ std::shared_ptr shared_stream = std::move(remote_stream);
+ *reader = std::unique_ptr(
+ new ClientStreamReader(shared_stream, options.read_options, options.stop_token));
+ auto stream_writer = std::unique_ptr(
+ new ClientStreamWriter(std::move(shared_stream), options.write_options,
+ write_size_limit_bytes_, descriptor));
+ RETURN_NOT_OK(stream_writer->Begin());
+ *writer = std::move(stream_writer);
+ return Status::OK();
}
Status FlightClient::Close() {
diff --git a/cpp/src/arrow/flight/client.h b/cpp/src/arrow/flight/client.h
index 15c67051931..43a1e9be91a 100644
--- a/cpp/src/arrow/flight/client.h
+++ b/cpp/src/arrow/flight/client.h
@@ -176,6 +176,11 @@ class ARROW_FLIGHT_EXPORT FlightMetadataReader {
virtual Status ReadMetadata(std::shared_ptr* out) = 0;
};
+// Forward declaration
+namespace internal {
+class ClientTransportImpl;
+}
+
/// \brief Client class for Arrow Flight RPC services (gRPC-based).
/// API experimental for now
class ARROW_FLIGHT_EXPORT FlightClient {
@@ -336,8 +341,8 @@ class ARROW_FLIGHT_EXPORT FlightClient {
private:
FlightClient();
Status CheckOpen() const;
- class FlightClientImpl;
- std::unique_ptr impl_;
+ std::unique_ptr impl_;
+ int64_t write_size_limit_bytes_;
};
} // namespace flight
diff --git a/cpp/src/arrow/flight/serialization_internal.h b/cpp/src/arrow/flight/serialization_internal.h
index 5f7d0cc487c..11f5e16f4b1 100644
--- a/cpp/src/arrow/flight/serialization_internal.h
+++ b/cpp/src/arrow/flight/serialization_internal.h
@@ -23,6 +23,7 @@
#include
#include "arrow/flight/internal.h"
+#include "arrow/flight/transport_impl.h"
#include "arrow/flight/types.h"
#include "arrow/ipc/message.h"
#include "arrow/result.h"
@@ -34,25 +35,6 @@ class Buffer;
namespace flight {
namespace internal {
-/// Internal, not user-visible type used for memory-efficient reads from gRPC
-/// stream
-struct FlightData {
- /// Used only for puts, may be null
- std::unique_ptr descriptor;
-
- /// Non-length-prefixed Message header as described in format/Message.fbs
- std::shared_ptr metadata;
-
- /// Application-defined metadata
- std::shared_ptr app_metadata;
-
- /// Message body
- std::shared_ptr body;
-
- /// Open IPC message from the metadata and body
- ::arrow::Result> OpenMessage();
-};
-
/// Write Flight message on gRPC stream with zero-copy optimizations.
// Returns Invalid if the payload is ill-formed
// Returns IOError if gRPC did not write the message (note this is not
@@ -87,10 +69,9 @@ bool ReadPayload(grpc::ClientReaderWriter* reader
// The Flight reader can then peek at the message to determine whether
// it has application metadata or not, and pass the message to
// RecordBatchStreamReader as appropriate.
-template
class PeekableFlightDataReader {
public:
- explicit PeekableFlightDataReader(ReaderPtr stream)
+ explicit PeekableFlightDataReader(TransportDataStream* stream)
: stream_(stream), peek_(), finished_(false), valid_(false) {}
void Peek(internal::FlightData** out) {
@@ -132,7 +113,7 @@ class PeekableFlightDataReader {
return valid_;
}
- if (!internal::ReadPayload(&*stream_, &peek_)) {
+ if (!stream_->ReadData(&peek_)) {
finished_ = true;
valid_ = false;
} else {
@@ -141,7 +122,7 @@ class PeekableFlightDataReader {
return valid_;
}
- ReaderPtr stream_;
+ internal::TransportDataStream* stream_;
internal::FlightData peek_;
bool finished_;
bool valid_;
diff --git a/cpp/src/arrow/flight/server.cc b/cpp/src/arrow/flight/server.cc
index 1c613b3c7c0..011cba093d6 100644
--- a/cpp/src/arrow/flight/server.cc
+++ b/cpp/src/arrow/flight/server.cc
@@ -61,6 +61,7 @@
#include "arrow/flight/serialization_internal.h"
#include "arrow/flight/server_auth.h"
#include "arrow/flight/server_middleware.h"
+#include "arrow/flight/transport_impl.h"
#include "arrow/flight/types.h"
using FlightService = arrow::flight::protocol::FlightService;
@@ -97,163 +98,9 @@ namespace pb = arrow::flight::protocol;
} \
} while (false)
-namespace {
-
-// A MessageReader implementation that reads from a gRPC ServerReader.
-// Templated to be generic over DoPut/DoExchange.
-template
-class FlightIpcMessageReader : public ipc::MessageReader {
- public:
- explicit FlightIpcMessageReader(
- std::shared_ptr> peekable_reader,
- std::shared_ptr memory_manager,
- std::shared_ptr* app_metadata)
- : peekable_reader_(peekable_reader),
- memory_manager_(std::move(memory_manager)),
- app_metadata_(app_metadata) {}
-
- ::arrow::Result> ReadNextMessage() override {
- if (stream_finished_) {
- return nullptr;
- }
- internal::FlightData* data;
- peekable_reader_->Next(&data);
- if (!data) {
- stream_finished_ = true;
- if (first_message_) {
- return Status::Invalid(
- "Client provided malformed message or did not provide message");
- }
- return nullptr;
- }
- if (ARROW_PREDICT_FALSE(!memory_manager_->is_cpu())) {
- ARROW_ASSIGN_OR_RAISE(data->body, Buffer::ViewOrCopy(data->body, memory_manager_));
- }
- *app_metadata_ = std::move(data->app_metadata);
- return data->OpenMessage();
- }
-
- protected:
- std::shared_ptr> peekable_reader_;
- std::shared_ptr memory_manager_;
- // A reference to FlightMessageReaderImpl.app_metadata_. That class
- // can't access the app metadata because when it Peek()s the stream,
- // it may be looking at a dictionary batch, not the record
- // batch. Updating it here ensures the reader is always updated with
- // the last metadata message read.
- std::shared_ptr* app_metadata_;
- bool first_message_ = true;
- bool stream_finished_ = false;
-};
-
-template
-class FlightMessageReaderImpl : public FlightMessageReader {
- public:
- using GrpcStream = grpc::ServerReaderWriter;
-
- explicit FlightMessageReaderImpl(GrpcStream* reader,
- std::shared_ptr memory_manager)
- : reader_(reader),
- memory_manager_(std::move(memory_manager)),
- peekable_reader_(new internal::PeekableFlightDataReader(reader)) {}
-
- Status Init() {
- // Peek the first message to get the descriptor.
- internal::FlightData* data;
- peekable_reader_->Peek(&data);
- if (!data) {
- return Status::IOError("Stream finished before first message sent");
- }
- if (!data->descriptor) {
- return Status::IOError("Descriptor missing on first message");
- }
- descriptor_ = *data->descriptor.get(); // Copy
- // If there's a schema (=DoPut), also Open().
- if (data->metadata) {
- return EnsureDataStarted();
- }
- peekable_reader_->Next(&data);
- return Status::OK();
- }
-
- const FlightDescriptor& descriptor() const override { return descriptor_; }
-
- arrow::Result> GetSchema() override {
- RETURN_NOT_OK(EnsureDataStarted());
- return batch_reader_->schema();
- }
-
- Status Next(FlightStreamChunk* out) override {
- internal::FlightData* data;
- peekable_reader_->Peek(&data);
- if (!data) {
- out->app_metadata = nullptr;
- out->data = nullptr;
- return Status::OK();
- }
-
- if (!data->metadata) {
- // Metadata-only (data->metadata is the IPC header)
- out->app_metadata = data->app_metadata;
- out->data = nullptr;
- peekable_reader_->Next(&data);
- return Status::OK();
- }
-
- if (!batch_reader_) {
- RETURN_NOT_OK(EnsureDataStarted());
- // re-peek here since EnsureDataStarted() advances the stream
- return Next(out);
- }
- RETURN_NOT_OK(batch_reader_->ReadNext(&out->data));
- out->app_metadata = std::move(app_metadata_);
- return Status::OK();
- }
-
- private:
- /// Ensure we are set up to read data.
- Status EnsureDataStarted() {
- if (!batch_reader_) {
- // peek() until we find the first data message; discard metadata
- if (!peekable_reader_->SkipToData()) {
- return Status::IOError("Client never sent a data message");
- }
- auto message_reader =
- std::unique_ptr(new FlightIpcMessageReader(
- peekable_reader_, memory_manager_, &app_metadata_));
- ARROW_ASSIGN_OR_RAISE(
- batch_reader_, ipc::RecordBatchStreamReader::Open(std::move(message_reader)));
- }
- return Status::OK();
- }
-
- FlightDescriptor descriptor_;
- GrpcStream* reader_;
- std::shared_ptr memory_manager_;
- std::shared_ptr> peekable_reader_;
- std::shared_ptr batch_reader_;
- std::shared_ptr app_metadata_;
-};
-
-class GrpcMetadataWriter : public FlightMetadataWriter {
- public:
- explicit GrpcMetadataWriter(
- grpc::ServerReaderWriter* writer)
- : writer_(writer) {}
-
- Status WriteMetadata(const Buffer& buffer) override {
- pb::PutResult message{};
- message.set_app_metadata(buffer.data(), buffer.size());
- if (writer_->Write(message)) {
- return Status::OK();
- }
- return Status::IOError("Unknown error writing metadata.");
- }
-
- private:
- grpc::ServerReaderWriter* writer_;
-};
+FlightMetadataWriter::~FlightMetadataWriter() = default;
+namespace {
class GrpcServerAuthReader : public ServerAuthReader {
public:
explicit GrpcServerAuthReader(
@@ -292,100 +139,7 @@ class GrpcServerAuthSender : public ServerAuthSender {
grpc::ServerReaderWriter* stream_;
};
-/// The implementation of the write side of a bidirectional FlightData
-/// stream for DoExchange.
-class DoExchangeMessageWriter : public FlightMessageWriter {
- public:
- explicit DoExchangeMessageWriter(
- grpc::ServerReaderWriter* stream)
- : stream_(stream), ipc_options_(::arrow::ipc::IpcWriteOptions::Defaults()) {}
-
- Status Begin(const std::shared_ptr& schema,
- const ipc::IpcWriteOptions& options) override {
- if (started_) {
- return Status::Invalid("This writer has already been started.");
- }
- started_ = true;
- ipc_options_ = options;
-
- RETURN_NOT_OK(mapper_.AddSchemaFields(*schema));
- FlightPayload schema_payload;
- RETURN_NOT_OK(ipc::GetSchemaPayload(*schema, ipc_options_, mapper_,
- &schema_payload.ipc_message));
- return WritePayload(schema_payload);
- }
-
- Status WriteRecordBatch(const RecordBatch& batch) override {
- return WriteWithMetadata(batch, nullptr);
- }
-
- Status WriteMetadata(std::shared_ptr app_metadata) override {
- FlightPayload payload{};
- payload.app_metadata = app_metadata;
- return WritePayload(payload);
- }
-
- Status WriteWithMetadata(const RecordBatch& batch,
- std::shared_ptr app_metadata) override {
- RETURN_NOT_OK(CheckStarted());
- RETURN_NOT_OK(EnsureDictionariesWritten(batch));
- FlightPayload payload{};
- if (app_metadata) {
- payload.app_metadata = app_metadata;
- }
- RETURN_NOT_OK(ipc::GetRecordBatchPayload(batch, ipc_options_, &payload.ipc_message));
- RETURN_NOT_OK(WritePayload(payload));
- ++stats_.num_record_batches;
- return Status::OK();
- }
-
- Status Close() override {
- // It's fine to Close() without writing data
- return Status::OK();
- }
-
- ipc::WriteStats stats() const override { return stats_; }
-
- private:
- Status WritePayload(const FlightPayload& payload) {
- RETURN_NOT_OK(internal::WritePayload(payload, stream_));
- ++stats_.num_messages;
- return Status::OK();
- }
-
- Status CheckStarted() {
- if (!started_) {
- return Status::Invalid("This writer is not started. Call Begin() with a schema");
- }
- return Status::OK();
- }
-
- Status EnsureDictionariesWritten(const RecordBatch& batch) {
- if (dictionaries_written_) {
- return Status::OK();
- }
- dictionaries_written_ = true;
- ARROW_ASSIGN_OR_RAISE(const auto dictionaries,
- ipc::CollectDictionaries(batch, mapper_));
- for (const auto& pair : dictionaries) {
- FlightPayload payload{};
- RETURN_NOT_OK(ipc::GetDictionaryPayload(pair.first, pair.second, ipc_options_,
- &payload.ipc_message));
- RETURN_NOT_OK(WritePayload(payload));
- ++stats_.num_dictionary_batches;
- }
- return Status::OK();
- }
-
- grpc::ServerReaderWriter* stream_;
- ::arrow::ipc::IpcWriteOptions ipc_options_;
- ipc::DictionaryFieldMapper mapper_;
- ipc::WriteStats stats_;
- bool started_ = false;
- bool dictionaries_written_ = false;
-};
-
-class FlightServiceImpl;
+class FlightGrpcServiceImpl;
class GrpcServerCallContext : public ServerCallContext {
explicit GrpcServerCallContext(grpc::ServerContext* context)
: context_(context), peer_(context_->peer()) {}
@@ -421,7 +175,7 @@ class GrpcServerCallContext : public ServerCallContext {
}
private:
- friend class FlightServiceImpl;
+ friend class FlightGrpcServiceImpl;
ServerContext* context_;
std::string peer_;
std::string peer_identity_;
@@ -442,20 +196,69 @@ class GrpcAddCallHeaders : public AddCallHeaders {
grpc::ServerContext* context_;
};
+class GetDataStream : public internal::TransportDataStream {
+ public:
+ explicit GetDataStream(ServerWriter* writer) : writer_(writer) {}
+
+ Status WriteData(const FlightPayload& payload) override {
+ return internal::WritePayload(payload, writer_);
+ }
+
+ private:
+ ServerWriter