From 86f4789ab26d9048a8be6263354745a37bd9131d Mon Sep 17 00:00:00 2001
From: David Li
Date: Tue, 9 Apr 2019 14:08:17 -0400
Subject: [PATCH 01/18] Add application metadata field to FlightData message
---
format/Flight.proto | 14 ++++++++++----
1 file changed, 10 insertions(+), 4 deletions(-)
diff --git a/format/Flight.proto b/format/Flight.proto
index 7f0488b86c3..f82a7e52450 100644
--- a/format/Flight.proto
+++ b/format/Flight.proto
@@ -77,7 +77,7 @@ service FlightService {
* number. In the latter, the service might implement a 'seal' action that
* can be applied to a descriptor once all streams are uploaded.
*/
- rpc DoPut(stream FlightData) returns (PutResult) {}
+ rpc DoPut(stream FlightData) returns (stream PutResult) {}
/*
* Flight services can support an arbitrary number of simple actions in
@@ -285,6 +285,11 @@ message FlightData {
*/
bytes data_header = 2;
+ /*
+ * Application-defined metadata.
+ */
+ bytes data_app_metadata = 3;
+
/*
* The actual batch of Arrow data. Preferably handled with minimal-copies
* coming last in the definition to help with sidecar patterns (it is
@@ -295,7 +300,8 @@ message FlightData {
}
/**
- * The response message (currently empty) associated with the submission of a
- * DoPut.
+ * The response message associated with the submission of a DoPut.
*/
-message PutResult {}
+message PutResult {
+ bytes data_app_metadata = 1;
+}
From a8ac27fb3b9319854acca7084c56da341d38aac7 Mon Sep 17 00:00:00 2001
From: David Li
Date: Tue, 9 Apr 2019 17:29:56 -0400
Subject: [PATCH 02/18] Implement application metadata in Flight
---
cpp/src/arrow/flight/client.cc | 206 +++++++++++--
cpp/src/arrow/flight/client.h | 33 +-
cpp/src/arrow/flight/flight-benchmark.cc | 2 +-
cpp/src/arrow/flight/flight-test.cc | 159 +++++++++-
cpp/src/arrow/flight/internal.h | 3 +-
.../arrow/flight/serialization-internal.cc | 28 +-
cpp/src/arrow/flight/serialization-internal.h | 8 +-
cpp/src/arrow/flight/server.cc | 75 ++++-
cpp/src/arrow/flight/server.h | 18 +-
.../arrow/flight/test-integration-client.cc | 50 ++-
.../arrow/flight/test-integration-server.cc | 12 +-
cpp/src/arrow/flight/test-util.cc | 16 +
cpp/src/arrow/flight/test-util.h | 18 ++
cpp/src/arrow/flight/types.h | 12 +
cpp/src/arrow/ipc/reader.cc | 2 +-
cpp/src/arrow/python/flight.cc | 8 +-
cpp/src/arrow/python/flight.h | 6 +-
format/Flight.proto | 4 +-
.../org/apache/arrow/flight/ArrowMessage.java | 31 +-
.../arrow/flight/FlightBindingService.java | 11 +-
.../org/apache/arrow/flight/FlightClient.java | 70 +++--
.../apache/arrow/flight/FlightProducer.java | 32 +-
.../org/apache/arrow/flight/FlightServer.java | 25 +-
.../apache/arrow/flight/FlightService.java | 16 +-
.../org/apache/arrow/flight/FlightStream.java | 14 +-
.../arrow/flight/NoOpFlightProducer.java | 8 +-
.../org/apache/arrow/flight/PutResult.java | 62 ++++
.../arrow/flight/example/InMemoryStore.java | 9 +-
.../apache/arrow/flight/example/Stream.java | 5 +-
.../integration/IntegrationTestClient.java | 29 +-
.../apache/arrow/flight/FlightTestUtil.java | 5 +-
.../arrow/flight/TestApplicationMetadata.java | 167 ++++++++++
.../apache/arrow/flight/TestBackPressure.java | 9 +-
.../arrow/flight/TestBasicOperation.java | 26 +-
.../apache/arrow/flight/TestCallOptions.java | 7 +-
.../apache/arrow/flight/TestLargeMessage.java | 39 ++-
.../apache/arrow/flight/auth/TestAuth.java | 7 +-
.../flight/example/TestExampleServer.java | 17 +-
.../flight/perf/PerformanceTestServer.java | 26 +-
.../apache/arrow/flight/perf/TestPerf.java | 3 +-
python/pyarrow/_flight.pyx | 284 ++++++++++++++++--
python/pyarrow/includes/libarrow_flight.pxd | 34 ++-
python/pyarrow/tests/test_flight.py | 117 +++++++-
43 files changed, 1449 insertions(+), 264 deletions(-)
create mode 100644 java/flight/src/main/java/org/apache/arrow/flight/PutResult.java
create mode 100644 java/flight/src/test/java/org/apache/arrow/flight/TestApplicationMetadata.java
diff --git a/cpp/src/arrow/flight/client.cc b/cpp/src/arrow/flight/client.cc
index 1926928c643..654e1e7c487 100644
--- a/cpp/src/arrow/flight/client.cc
+++ b/cpp/src/arrow/flight/client.cc
@@ -79,7 +79,7 @@ struct ClientRpc {
if (auth_handler) {
std::string token;
RETURN_NOT_OK(auth_handler->GetToken(&token));
- context.AddMetadata(internal::AUTH_HEADER, token);
+ context.AddMetadata(internal::kGrpcAuthHeader, token);
}
return Status::OK();
}
@@ -129,15 +129,47 @@ class GrpcClientAuthReader : public ClientAuthReader {
stream_;
};
-class FlightIpcMessageReader : public ipc::MessageReader {
+// The next two classes are intertwined. To get the application
+// metadata while avoiding reimplementing RecordBatchStreamReader, we
+// create an ipc::MessageReader that is tied to the
+// MetadataRecordBatchReader. Every time an IPC message is read, it updates
+// the application metadata field of the MetadataRecordBatchReader. The
+// MetadataRecordBatchReader wraps RecordBatchStreamReader, offering an
+// additional method to get both the record batch and application
+// metadata.
+
+class GrpcIpcMessageReader;
+class GrpcStreamReader : public MetadataRecordBatchReader {
public:
- FlightIpcMessageReader(std::unique_ptr rpc,
- std::unique_ptr> stream)
- : rpc_(std::move(rpc)), stream_(std::move(stream)), stream_finished_(false) {}
+ GrpcStreamReader();
+
+ static Status Open(std::unique_ptr rpc,
+ std::unique_ptr> stream,
+ std::unique_ptr* out);
+ std::shared_ptr schema() const override;
+ Status ReadNext(std::shared_ptr* out) override;
+ Status ReadWithMetadata(std::shared_ptr* out,
+ std::shared_ptr* app_metadata) override;
+
+ private:
+ friend class GrpcIpcMessageReader;
+ std::unique_ptr batch_reader_;
+ std::shared_ptr last_app_metadata_;
+};
+
+class GrpcIpcMessageReader : public ipc::MessageReader {
+ public:
+ GrpcIpcMessageReader(GrpcStreamReader* reader, std::unique_ptr rpc,
+ std::unique_ptr> stream)
+ : flight_reader_(reader),
+ rpc_(std::move(rpc)),
+ stream_(std::move(stream)),
+ stream_finished_(false) {}
Status ReadNextMessage(std::unique_ptr* out) override {
if (stream_finished_) {
*out = nullptr;
+ flight_reader_->last_app_metadata_ = nullptr;
return Status::OK();
}
internal::FlightData data;
@@ -145,13 +177,16 @@ class FlightIpcMessageReader : public ipc::MessageReader {
// Stream is completed
stream_finished_ = true;
*out = nullptr;
+ flight_reader_->last_app_metadata_ = nullptr;
return OverrideWithServerError(Status::OK());
}
// Validate IPC message
auto st = data.OpenMessage(out);
if (!st.ok()) {
+ flight_reader_->last_app_metadata_ = nullptr;
return OverrideWithServerError(std::move(st));
}
+ flight_reader_->last_app_metadata_ = data.app_metadata;
return Status::OK();
}
@@ -162,23 +197,95 @@ class FlightIpcMessageReader : public ipc::MessageReader {
return std::move(st);
}
+ private:
+ GrpcStreamReader* flight_reader_;
// The RPC context lifetime must be coupled to the ClientReader
std::unique_ptr rpc_;
std::unique_ptr> stream_;
bool stream_finished_;
};
+GrpcStreamReader::GrpcStreamReader() {}
+
+Status GrpcStreamReader::Open(std::unique_ptr rpc,
+ std::unique_ptr> stream,
+ std::unique_ptr* out) {
+ *out = std::unique_ptr(new GrpcStreamReader);
+ std::unique_ptr message_reader(
+ new GrpcIpcMessageReader(out->get(), std::move(rpc), std::move(stream)));
+ return ipc::RecordBatchStreamReader::Open(std::move(message_reader),
+ &(*out)->batch_reader_);
+}
+
+std::shared_ptr GrpcStreamReader::schema() const {
+ return batch_reader_->schema();
+}
+
+Status GrpcStreamReader::ReadNext(std::shared_ptr* out) {
+ std::shared_ptr app_metadata;
+ return ReadWithMetadata(out, &app_metadata);
+}
+
+Status GrpcStreamReader::ReadWithMetadata(std::shared_ptr* out,
+ std::shared_ptr* app_metadata) {
+ RETURN_NOT_OK(batch_reader_->ReadNext(out));
+ *app_metadata = std::move(last_app_metadata_);
+ return Status::OK();
+}
+
+// Similarly, the next two classes are intertwined. In order to get
+// application-specific metadata to the IpcPayloadWriter,
+// DoPutPayloadWriter takes a pointer to
+// GrpcStreamWriter. GrpcStreamWriter updates a metadata field on
+// write; DoPutPayloadWriter reads that metadata field to determine
+// what to write.
+
+class DoPutPayloadWriter;
+class GrpcStreamWriter : public FlightStreamWriter {
+ public:
+ ~GrpcStreamWriter() = default;
+
+ GrpcStreamWriter() : app_metadata_(nullptr), batch_writer_(nullptr) {}
+
+ static Status Open(
+ const FlightDescriptor& descriptor, const std::shared_ptr& schema,
+ std::unique_ptr rpc, std::unique_ptr response,
+ std::shared_ptr> writer,
+ std::unique_ptr* out);
+
+ Status WriteRecordBatch(const RecordBatch& batch, bool allow_64bit = false) override {
+ return WriteWithMetadata(batch, nullptr, allow_64bit);
+ }
+ Status WriteWithMetadata(const RecordBatch& batch, std::shared_ptr app_metadata,
+ bool allow_64bit = false) override {
+ app_metadata_ = app_metadata;
+ return batch_writer_->WriteRecordBatch(batch, allow_64bit);
+ }
+ void set_memory_pool(MemoryPool* pool) override {
+ batch_writer_->set_memory_pool(pool);
+ }
+ Status Close() override { return batch_writer_->Close(); }
+
+ private:
+ friend class DoPutPayloadWriter;
+ std::shared_ptr app_metadata_;
+ std::unique_ptr batch_writer_;
+};
+
/// A IpcPayloadWriter implementation that writes to a DoPut stream
class DoPutPayloadWriter : public ipc::internal::IpcPayloadWriter {
public:
- DoPutPayloadWriter(const FlightDescriptor& descriptor, std::unique_ptr rpc,
- std::unique_ptr response,
- std::unique_ptr> writer)
+ DoPutPayloadWriter(
+ const FlightDescriptor& descriptor, std::unique_ptr rpc,
+ std::unique_ptr response,
+ std::shared_ptr> writer,
+ GrpcStreamWriter* stream_writer)
: descriptor_(descriptor),
rpc_(std::move(rpc)),
response_(std::move(response)),
writer_(std::move(writer)),
- first_payload_(true) {}
+ first_payload_(true),
+ stream_writer_(stream_writer) {}
~DoPutPayloadWriter() override = default;
@@ -201,6 +308,8 @@ class DoPutPayloadWriter : public ipc::internal::IpcPayloadWriter {
}
RETURN_NOT_OK(Buffer::FromString(str_descr, &payload.descriptor));
first_payload_ = false;
+ } else if (stream_writer_->app_metadata_) {
+ payload.app_metadata = std::move(stream_writer_->app_metadata_);
}
if (!internal::WritePayload(payload, writer_.get())) {
@@ -211,6 +320,10 @@ class DoPutPayloadWriter : public ipc::internal::IpcPayloadWriter {
Status Close() override {
bool finished_writes = writer_->WritesDone();
+ // Drain the read side to avoid hanging
+ pb::PutResult message;
+ while (writer_->Read(&message)) {
+ }
RETURN_NOT_OK(internal::FromGrpcStatus(writer_->Finish()));
if (!finished_writes) {
return Status::UnknownError(
@@ -223,9 +336,47 @@ class DoPutPayloadWriter : public ipc::internal::IpcPayloadWriter {
// TODO: there isn't a way to access this as a user.
const FlightDescriptor descriptor_;
std::unique_ptr rpc_;
- std::unique_ptr response_;
- std::unique_ptr> writer_;
+ std::unique_ptr response_;
+ std::shared_ptr> writer_;
bool first_payload_;
+ GrpcStreamWriter* stream_writer_;
+};
+
+Status GrpcStreamWriter::Open(
+ const FlightDescriptor& descriptor, const std::shared_ptr& schema,
+ std::unique_ptr rpc, std::unique_ptr response,
+ std::shared_ptr> writer,
+ std::unique_ptr* out) {
+ std::unique_ptr result(new GrpcStreamWriter);
+ std::unique_ptr payload_writer(new DoPutPayloadWriter(
+ descriptor, std::move(rpc), std::move(response), writer, result.get()));
+ RETURN_NOT_OK(ipc::internal::OpenRecordBatchWriter(std::move(payload_writer), schema,
+ &result->batch_writer_));
+ *out = std::move(result);
+ return Status::OK();
+}
+
+FlightMetadataReader::~FlightMetadataReader() = default;
+
+class GrpcMetadataReader : public FlightMetadataReader {
+ public:
+ explicit GrpcMetadataReader(
+ std::shared_ptr> reader)
+ : reader_(reader) {}
+
+ Status ReadMetadata(std::shared_ptr* out) override {
+ pb::PutResult message;
+ if (reader_->Read(&message)) {
+ *out = Buffer::FromString(std::move(*message.release_app_metadata()));
+ } else {
+ // Stream finished
+ *out = nullptr;
+ }
+ return Status::OK();
+ }
+
+ private:
+ std::shared_ptr> reader_;
};
class FlightClient::FlightClientImpl {
@@ -367,7 +518,7 @@ class FlightClient::FlightClientImpl {
}
Status DoGet(const FlightCallOptions& options, const Ticket& ticket,
- std::unique_ptr* out) {
+ std::unique_ptr* out) {
pb::Ticket pb_ticket;
internal::ToProto(ticket, &pb_ticket);
@@ -376,25 +527,25 @@ class FlightClient::FlightClientImpl {
std::unique_ptr> stream(
stub_->DoGet(&rpc->context, pb_ticket));
- std::unique_ptr message_reader(
- new FlightIpcMessageReader(std::move(rpc), std::move(stream)));
- return ipc::RecordBatchStreamReader::Open(std::move(message_reader), out);
+ std::unique_ptr reader;
+ RETURN_NOT_OK(GrpcStreamReader::Open(std::move(rpc), std::move(stream), &reader));
+ *out = std::move(reader);
+ return Status::OK();
}
Status DoPut(const FlightCallOptions& options, const FlightDescriptor& descriptor,
const std::shared_ptr& schema,
- std::unique_ptr* out) {
+ std::unique_ptr* out,
+ std::unique_ptr* reader) {
std::unique_ptr rpc(new ClientRpc(options));
RETURN_NOT_OK(rpc->SetToken(auth_handler_.get()));
- std::unique_ptr response(new protocol::PutResult);
- std::unique_ptr> writer(
- stub_->DoPut(&rpc->context, response.get()));
-
- std::unique_ptr payload_writer(
- new DoPutPayloadWriter(descriptor, std::move(rpc), std::move(response),
- std::move(writer)));
+ std::unique_ptr response(new pb::PutResult);
+ std::shared_ptr> writer(
+ stub_->DoPut(&rpc->context));
- return ipc::internal::OpenRecordBatchWriter(std::move(payload_writer), schema, out);
+ *reader = std::unique_ptr(new GrpcMetadataReader(writer));
+ return GrpcStreamWriter::Open(descriptor, schema, std::move(rpc), std::move(response),
+ writer, out);
}
private:
@@ -449,15 +600,16 @@ Status FlightClient::ListFlights(const FlightCallOptions& options,
}
Status FlightClient::DoGet(const FlightCallOptions& options, const Ticket& ticket,
- std::unique_ptr* stream) {
+ std::unique_ptr* stream) {
return impl_->DoGet(options, ticket, stream);
}
Status FlightClient::DoPut(const FlightCallOptions& options,
const FlightDescriptor& descriptor,
const std::shared_ptr& schema,
- std::unique_ptr* stream) {
- return impl_->DoPut(options, descriptor, schema, stream);
+ std::unique_ptr* stream,
+ std::unique_ptr* reader) {
+ return impl_->DoPut(options, descriptor, schema, stream, reader);
}
} // namespace flight
diff --git a/cpp/src/arrow/flight/client.h b/cpp/src/arrow/flight/client.h
index b8a5d4f4b91..5cd7ce0f349 100644
--- a/cpp/src/arrow/flight/client.h
+++ b/cpp/src/arrow/flight/client.h
@@ -25,6 +25,7 @@
#include
#include
+#include "arrow/ipc/reader.h"
#include "arrow/ipc/writer.h"
#include "arrow/status.h"
@@ -35,7 +36,6 @@ namespace arrow {
class MemoryPool;
class RecordBatch;
-class RecordBatchReader;
class Schema;
namespace flight {
@@ -66,6 +66,24 @@ class ARROW_FLIGHT_EXPORT FlightClientOptions {
std::string override_hostname;
};
+/// \brief A RecordBatchWriter that also allows sending
+/// application-defined metadata via the Flight protocol.
+class ARROW_EXPORT FlightStreamWriter : public ipc::RecordBatchWriter {
+ public:
+ virtual Status WriteWithMetadata(const RecordBatch& batch,
+ std::shared_ptr app_metadata,
+ bool allow_64bit = false) = 0;
+};
+
+/// \brief A reader for application-specific metadata sent back to the
+/// client during an upload.
+class ARROW_EXPORT FlightMetadataReader {
+ public:
+ virtual ~FlightMetadataReader();
+ /// \brief Read a message from the server.
+ virtual Status ReadMetadata(std::shared_ptr* out) = 0;
+};
+
/// \brief Client class for Arrow Flight RPC services (gRPC-based).
/// API experimental for now
class ARROW_FLIGHT_EXPORT FlightClient {
@@ -151,8 +169,8 @@ class ARROW_FLIGHT_EXPORT FlightClient {
/// \param[out] stream the returned RecordBatchReader
/// \return Status
Status DoGet(const FlightCallOptions& options, const Ticket& ticket,
- std::unique_ptr* stream);
- Status DoGet(const Ticket& ticket, std::unique_ptr* stream) {
+ std::unique_ptr* stream);
+ Status DoGet(const Ticket& ticket, std::unique_ptr* stream) {
return DoGet({}, ticket, stream);
}
@@ -163,13 +181,16 @@ class ARROW_FLIGHT_EXPORT FlightClient {
/// \param[in] descriptor the descriptor of the stream
/// \param[in] schema the schema for the data to upload
/// \param[out] stream a writer to write record batches to
+ /// \param[out] reader a reader for application metadata from the server
/// \return Status
Status DoPut(const FlightCallOptions& options, const FlightDescriptor& descriptor,
const std::shared_ptr& schema,
- std::unique_ptr* stream);
+ std::unique_ptr* stream,
+ std::unique_ptr* reader);
Status DoPut(const FlightDescriptor& descriptor, const std::shared_ptr& schema,
- std::unique_ptr* stream) {
- return DoPut({}, descriptor, schema, stream);
+ std::unique_ptr* stream,
+ std::unique_ptr* reader) {
+ return DoPut({}, descriptor, schema, stream, reader);
}
private:
diff --git a/cpp/src/arrow/flight/flight-benchmark.cc b/cpp/src/arrow/flight/flight-benchmark.cc
index f2bd356a647..32c584a609a 100644
--- a/cpp/src/arrow/flight/flight-benchmark.cc
+++ b/cpp/src/arrow/flight/flight-benchmark.cc
@@ -106,7 +106,7 @@ Status RunPerformanceTest(const std::string& hostname, const int port) {
perf::Token token;
token.ParseFromString(endpoint.ticket.ticket);
- std::unique_ptr reader;
+ std::unique_ptr reader;
RETURN_NOT_OK(client->DoGet(endpoint.ticket, &reader));
std::shared_ptr batch;
diff --git a/cpp/src/arrow/flight/flight-test.cc b/cpp/src/arrow/flight/flight-test.cc
index 3c0b67cd992..b22422f109c 100644
--- a/cpp/src/arrow/flight/flight-test.cc
+++ b/cpp/src/arrow/flight/flight-test.cc
@@ -211,7 +211,7 @@ class TestFlightClient : public ::testing::Test {
// By convention, fetch the first endpoint
Ticket ticket = info->endpoints()[0].ticket;
- std::unique_ptr stream;
+ std::unique_ptr stream;
ASSERT_OK(client_->DoGet(ticket, &stream));
std::shared_ptr chunk;
@@ -255,7 +255,8 @@ class TlsTestServer : public FlightServerBase {
class DoPutTestServer : public FlightServerBase {
public:
Status DoPut(const ServerCallContext& context,
- std::unique_ptr reader) override {
+ std::unique_ptr reader,
+ std::unique_ptr writer) override {
descriptor_ = reader->descriptor();
return reader->ReadAll(&batches_);
}
@@ -267,6 +268,71 @@ class DoPutTestServer : public FlightServerBase {
friend class TestDoPut;
};
+class MetadataTestServer : public FlightServerBase {
+ Status DoGet(const ServerCallContext& context, const Ticket& request,
+ std::unique_ptr* data_stream) override {
+ BatchVector batches;
+ RETURN_NOT_OK(ExampleIntBatches(&batches));
+ std::shared_ptr batch_reader =
+ std::make_shared(batches[0]->schema(), batches);
+
+ *data_stream = std::unique_ptr(new NumberingStream(
+ std::unique_ptr(new RecordBatchStream(batch_reader))));
+ return Status::OK();
+ }
+
+ Status DoPut(const ServerCallContext& context,
+ std::unique_ptr reader,
+ std::unique_ptr writer) override {
+ std::shared_ptr chunk;
+ std::shared_ptr app_metadata;
+ int counter = 0;
+ while (true) {
+ RETURN_NOT_OK(reader->ReadWithMetadata(&chunk, &app_metadata));
+ if (chunk == nullptr) break;
+ if (app_metadata == nullptr) {
+ return Status::Invalid("Expected application metadata to be provided");
+ }
+ if (std::to_string(counter) != app_metadata->ToString()) {
+ return Status::Invalid("Expected metadata value: " + std::to_string(counter) +
+ " but got: " + app_metadata->ToString());
+ }
+ auto metadata = Buffer::FromString(std::to_string(counter));
+ RETURN_NOT_OK(writer->WriteMetadata(*metadata));
+ counter++;
+ }
+ return Status::OK();
+ }
+};
+
+template
+class InsecureTestServer : public ::testing::Test {
+ public:
+ void SetUp() {
+ Location location;
+ ASSERT_OK(Location::ForGrpcTcp("localhost", 30000, &location));
+
+ std::unique_ptr server(new T);
+ FlightServerOptions options(location);
+ ASSERT_OK(server->Init(options));
+
+ server_.reset(new InProcessTestServer(std::move(server), location));
+ ASSERT_OK(server_->Start());
+ ASSERT_OK(ConnectClient());
+ }
+
+ void TearDown() { server_->Stop(); }
+
+ Status ConnectClient() { return FlightClient::Connect(server_->location(), &client_); }
+
+ protected:
+ int port_;
+ std::unique_ptr client_;
+ std::unique_ptr server_;
+};
+
+using TestMetadata = InsecureTestServer;
+
class TestAuthHandler : public ::testing::Test {
public:
void SetUp() {
@@ -323,8 +389,9 @@ class TestDoPut : public ::testing::Test {
void CheckDoPut(FlightDescriptor descr, const std::shared_ptr& schema,
const BatchVector& batches) {
- std::unique_ptr stream;
- ASSERT_OK(client_->DoPut(descr, schema, &stream));
+ std::unique_ptr stream;
+ std::unique_ptr reader;
+ ASSERT_OK(client_->DoPut(descr, schema, &stream, &reader));
for (const auto& batch : batches) {
ASSERT_OK(stream->WriteRecordBatch(*batch));
}
@@ -485,7 +552,7 @@ TEST_F(TestFlightClient, Issue5095) {
// Make sure the server-side error message is reflected to the
// client
Ticket ticket1{"ARROW-5095-fail"};
- std::unique_ptr stream;
+ std::unique_ptr stream;
Status status = client_->DoGet(ticket1, &stream);
ASSERT_RAISES(IOError, status);
ASSERT_THAT(status.message(), ::testing::HasSubstr("Server-side error"));
@@ -588,13 +655,14 @@ TEST_F(TestAuthHandler, PassAuthenticatedCalls) {
status = client_->GetFlightInfo(FlightDescriptor{}, &info);
ASSERT_RAISES(NotImplemented, status);
- std::unique_ptr stream;
+ std::unique_ptr stream;
status = client_->DoGet(Ticket{}, &stream);
ASSERT_RAISES(NotImplemented, status);
- std::unique_ptr writer;
+ std::unique_ptr writer;
+ std::unique_ptr reader;
std::shared_ptr schema = arrow::schema({});
- status = client_->DoPut(FlightDescriptor{}, schema, &writer);
+ status = client_->DoPut(FlightDescriptor{}, schema, &writer, &reader);
ASSERT_OK(status);
status = writer->Close();
ASSERT_RAISES(NotImplemented, status);
@@ -625,15 +693,16 @@ TEST_F(TestAuthHandler, FailUnauthenticatedCalls) {
ASSERT_RAISES(IOError, status);
ASSERT_THAT(status.message(), ::testing::HasSubstr("Invalid token"));
- std::unique_ptr stream;
+ std::unique_ptr stream;
status = client_->DoGet(Ticket{}, &stream);
ASSERT_RAISES(IOError, status);
ASSERT_THAT(status.message(), ::testing::HasSubstr("Invalid token"));
- std::unique_ptr writer;
+ std::unique_ptr writer;
+ std::unique_ptr reader;
std::shared_ptr schema(
(new arrow::Schema(std::vector>())));
- status = client_->DoPut(FlightDescriptor{}, schema, &writer);
+ status = client_->DoPut(FlightDescriptor{}, schema, &writer, &reader);
ASSERT_OK(status);
status = writer->Close();
ASSERT_RAISES(IOError, status);
@@ -693,5 +762,73 @@ TEST_F(TestTls, OverrideHostname) {
ASSERT_RAISES(IOError, client->DoAction(options, action, &results));
}
+TEST_F(TestMetadata, DoGet) {
+ Ticket ticket{""};
+ std::unique_ptr stream;
+ ASSERT_OK(client_->DoGet(ticket, &stream));
+
+ BatchVector expected_batches;
+ ASSERT_OK(ExampleIntBatches(&expected_batches));
+
+ std::shared_ptr chunk;
+ std::shared_ptr metadata;
+ auto num_batches = static_cast(expected_batches.size());
+ for (int i = 0; i < num_batches; ++i) {
+ ASSERT_OK(stream->ReadWithMetadata(&chunk, &metadata));
+ ASSERT_NE(nullptr, chunk);
+ ASSERT_NE(nullptr, metadata);
+ ASSERT_BATCHES_EQUAL(*expected_batches[i], *chunk);
+ ASSERT_EQ(std::to_string(i), metadata->ToString());
+ }
+ ASSERT_OK(stream->ReadNext(&chunk));
+ ASSERT_EQ(nullptr, chunk);
+}
+
+TEST_F(TestMetadata, DoPut) {
+ std::unique_ptr writer;
+ std::unique_ptr reader;
+ std::shared_ptr schema = ExampleIntSchema();
+ ASSERT_OK(client_->DoPut(FlightDescriptor{}, schema, &writer, &reader));
+
+ BatchVector expected_batches;
+ ASSERT_OK(ExampleIntBatches(&expected_batches));
+
+ std::shared_ptr chunk;
+ std::shared_ptr metadata;
+ auto num_batches = static_cast(expected_batches.size());
+ for (int i = 0; i < num_batches; ++i) {
+ ASSERT_OK(writer->WriteWithMetadata(*expected_batches[i],
+ Buffer::FromString(std::to_string(i))));
+ }
+ // This eventually calls grpc::ClientReaderWriter::Finish which can
+ // hang if there are unread messages. So make sure our wrapper
+ // around this doesn't hang (because it drains any unread messages)
+ ASSERT_OK(writer->Close());
+}
+
+TEST_F(TestMetadata, DoPutReadMetadata) {
+ std::unique_ptr writer;
+ std::unique_ptr reader;
+ std::shared_ptr schema = ExampleIntSchema();
+ ASSERT_OK(client_->DoPut(FlightDescriptor{}, schema, &writer, &reader));
+
+ BatchVector expected_batches;
+ ASSERT_OK(ExampleIntBatches(&expected_batches));
+
+ std::shared_ptr chunk;
+ std::shared_ptr metadata;
+ auto num_batches = static_cast(expected_batches.size());
+ for (int i = 0; i < num_batches; ++i) {
+ ASSERT_OK(writer->WriteWithMetadata(*expected_batches[i],
+ Buffer::FromString(std::to_string(i))));
+ ASSERT_OK(reader->ReadMetadata(&metadata));
+ ASSERT_NE(nullptr, metadata);
+ ASSERT_EQ(std::to_string(i), metadata->ToString());
+ }
+ // As opposed to DoPutDrainMetadata, now we've read the messages, so
+ // make sure this still closes as expected.
+ ASSERT_OK(writer->Close());
+}
+
} // namespace flight
} // namespace arrow
diff --git a/cpp/src/arrow/flight/internal.h b/cpp/src/arrow/flight/internal.h
index 784e8ebae1c..5283bed2183 100644
--- a/cpp/src/arrow/flight/internal.h
+++ b/cpp/src/arrow/flight/internal.h
@@ -63,7 +63,8 @@ namespace flight {
namespace internal {
-static const char* AUTH_HEADER = "auth-token-bin";
+/// The name of the header used to pass authentication tokens.
+static const char* kGrpcAuthHeader = "auth-token-bin";
ARROW_FLIGHT_EXPORT
Status SchemaToString(const Schema& schema, std::string* out);
diff --git a/cpp/src/arrow/flight/serialization-internal.cc b/cpp/src/arrow/flight/serialization-internal.cc
index d78bac83870..0ff5abad692 100644
--- a/cpp/src/arrow/flight/serialization-internal.cc
+++ b/cpp/src/arrow/flight/serialization-internal.cc
@@ -163,6 +163,14 @@ grpc::Status FlightDataSerialize(const FlightPayload& msg, ByteBuffer* out,
// 1 byte for metadata tag
header_size += 1 + WireFormatLite::LengthDelimitedSize(metadata_size);
+ // App metadata tag if appropriate
+ int32_t app_metadata_size = 0;
+ if (msg.app_metadata && msg.app_metadata->size() > 0) {
+ DCHECK_LT(msg.app_metadata->size(), kInt32Max);
+ app_metadata_size = static_cast(msg.app_metadata->size());
+ header_size += 1 + WireFormatLite::LengthDelimitedSize(app_metadata_size);
+ }
+
for (const auto& buffer : ipc_msg.body_buffers) {
// Buffer may be null when the row length is zero, or when all
// entries are invalid.
@@ -214,6 +222,15 @@ grpc::Status FlightDataSerialize(const FlightPayload& msg, ByteBuffer* out,
header_stream.WriteRawMaybeAliased(ipc_msg.metadata->data(),
static_cast(ipc_msg.metadata->size()));
+ // Write app metadata
+ if (app_metadata_size > 0) {
+ WireFormatLite::WriteTag(pb::FlightData::kAppMetadataFieldNumber,
+ WireFormatLite::WIRETYPE_LENGTH_DELIMITED, &header_stream);
+ header_stream.WriteVarint32(app_metadata_size);
+ header_stream.WriteRawMaybeAliased(msg.app_metadata->data(),
+ static_cast(msg.app_metadata->size()));
+ }
+
if (has_body) {
// Write body tag
WireFormatLite::WriteTag(pb::FlightData::kDataBodyFieldNumber,
@@ -292,6 +309,12 @@ grpc::Status FlightDataDeserialize(ByteBuffer* buffer, FlightData* out) {
"Unable to read FlightData metadata");
}
} break;
+ case pb::FlightData::kAppMetadataFieldNumber: {
+ if (!ReadBytesZeroCopy(wrapped_buffer, &pb_stream, &out->app_metadata)) {
+ return grpc::Status(grpc::StatusCode::INTERNAL,
+ "Unable to read FlightData application metadata");
+ }
+ } break;
case pb::FlightData::kDataBodyFieldNumber: {
if (!ReadBytesZeroCopy(wrapped_buffer, &pb_stream, &out->body)) {
return grpc::Status(grpc::StatusCode::INTERNAL,
@@ -330,7 +353,7 @@ Status FlightData::OpenMessage(std::unique_ptr* message) {
// (see customize_protobuf.h).
bool WritePayload(const FlightPayload& payload,
- grpc::ClientWriter* writer) {
+ grpc::ClientReaderWriter* writer) {
// Pretend to be pb::FlightData and intercept in SerializationTraits
return writer->Write(*reinterpret_cast(&payload),
grpc::WriteOptions());
@@ -348,7 +371,8 @@ bool ReadPayload(grpc::ClientReader* reader, FlightData* data) {
return reader->Read(reinterpret_cast(data));
}
-bool ReadPayload(grpc::ServerReader* reader, FlightData* data) {
+bool ReadPayload(grpc::ServerReaderWriter* reader,
+ FlightData* data) {
// Pretend to be pb::FlightData and intercept in SerializationTraits
return reader->Read(reinterpret_cast(data));
}
diff --git a/cpp/src/arrow/flight/serialization-internal.h b/cpp/src/arrow/flight/serialization-internal.h
index aa47af6ae35..cfb8b8ab048 100644
--- a/cpp/src/arrow/flight/serialization-internal.h
+++ b/cpp/src/arrow/flight/serialization-internal.h
@@ -43,6 +43,9 @@ struct FlightData {
/// Non-length-prefixed Message header as described in format/Message.fbs
std::shared_ptr metadata;
+ /// Application-defined metadata
+ std::shared_ptr app_metadata;
+
/// Message body
std::shared_ptr body;
@@ -53,14 +56,15 @@ struct FlightData {
/// Write Flight message on gRPC stream with zero-copy optimizations.
/// True is returned on success, false if some error occurred (connection closed?).
bool WritePayload(const FlightPayload& payload,
- grpc::ClientWriter* writer);
+ grpc::ClientReaderWriter* writer);
bool WritePayload(const FlightPayload& payload,
grpc::ServerWriter* writer);
/// Read Flight message from gRPC stream with zero-copy optimizations.
/// True is returned on success, false if stream ended.
bool ReadPayload(grpc::ClientReader* reader, FlightData* data);
-bool ReadPayload(grpc::ServerReader* reader, FlightData* data);
+bool ReadPayload(grpc::ServerReaderWriter* reader,
+ FlightData* data);
} // namespace internal
} // namespace flight
diff --git a/cpp/src/arrow/flight/server.cc b/cpp/src/arrow/flight/server.cc
index 6f3c466c4ad..3bc1a6aacb5 100644
--- a/cpp/src/arrow/flight/server.cc
+++ b/cpp/src/arrow/flight/server.cc
@@ -72,12 +72,15 @@ namespace {
// A MessageReader implementation that reads from a gRPC ServerReader
class FlightIpcMessageReader : public ipc::MessageReader {
public:
- explicit FlightIpcMessageReader(grpc::ServerReader* reader)
- : reader_(reader) {}
+ explicit FlightIpcMessageReader(
+ grpc::ServerReaderWriter* reader,
+ std::shared_ptr* last_metadata)
+ : reader_(reader), app_metadata_(last_metadata) {}
Status ReadNextMessage(std::unique_ptr* out) override {
if (stream_finished_) {
*out = nullptr;
+ *app_metadata_ = nullptr;
return Status::OK();
}
internal::FlightData data;
@@ -89,6 +92,7 @@ class FlightIpcMessageReader : public ipc::MessageReader {
"Client provided malformed message or did not provide message");
}
*out = nullptr;
+ *app_metadata_ = nullptr;
return Status::OK();
}
@@ -100,25 +104,29 @@ class FlightIpcMessageReader : public ipc::MessageReader {
first_message_ = false;
}
- return data.OpenMessage(out);
+ RETURN_NOT_OK(data.OpenMessage(out));
+ *app_metadata_ = std::move(data.app_metadata);
+ return Status::OK();
}
const FlightDescriptor& descriptor() const { return descriptor_; }
protected:
- grpc::ServerReader* reader_;
+ grpc::ServerReaderWriter* reader_;
bool stream_finished_ = false;
bool first_message_ = true;
FlightDescriptor descriptor_;
+ std::shared_ptr* app_metadata_;
};
class FlightMessageReaderImpl : public FlightMessageReader {
public:
- explicit FlightMessageReaderImpl(grpc::ServerReader* reader)
+ explicit FlightMessageReaderImpl(
+ grpc::ServerReaderWriter* reader)
: reader_(reader) {}
Status Init() {
- message_reader_ = new FlightIpcMessageReader(reader_);
+ message_reader_ = new FlightIpcMessageReader(reader_, &last_metadata_);
return ipc::RecordBatchStreamReader::Open(
std::unique_ptr(message_reader_), &batch_reader_);
}
@@ -130,17 +138,49 @@ class FlightMessageReaderImpl : public FlightMessageReader {
std::shared_ptr schema() const override { return batch_reader_->schema(); }
Status ReadNext(std::shared_ptr* out) override {
- return batch_reader_->ReadNext(out);
+ return ReadWithMetadata(out, nullptr);
+ }
+
+ Status ReadWithMetadata(std::shared_ptr* out,
+ std::shared_ptr* app_metadata) override {
+ if (app_metadata) {
+ *app_metadata = nullptr;
+ }
+ RETURN_NOT_OK(batch_reader_->ReadNext(out));
+ if (app_metadata) {
+ *app_metadata = std::move(last_metadata_);
+ }
+ return Status::OK();
}
private:
std::shared_ptr schema_;
std::unique_ptr dictionary_memo_;
- grpc::ServerReader* reader_;
+ grpc::ServerReaderWriter* reader_;
FlightIpcMessageReader* message_reader_;
+ std::shared_ptr last_metadata_;
std::shared_ptr batch_reader_;
};
+class GrpcMetadataWriter : public FlightMetadataWriter {
+ public:
+ explicit GrpcMetadataWriter(
+ grpc::ServerReaderWriter* writer)
+ : writer_(writer) {}
+
+ Status WriteMetadata(const Buffer& buffer) override {
+ pb::PutResult message{};
+ message.set_app_metadata(buffer.data(), buffer.size());
+ if (writer_->Write(message)) {
+ return Status::OK();
+ }
+ return Status::IOError("Unknown error writing metadata.");
+ }
+
+ private:
+ grpc::ServerReaderWriter* writer_;
+};
+
class GrpcServerAuthReader : public ServerAuthReader {
public:
explicit GrpcServerAuthReader(
@@ -153,7 +193,7 @@ class GrpcServerAuthReader : public ServerAuthReader {
*token = std::move(*request.release_payload());
return Status::OK();
}
- return Status::UnknownError("Could not read client handshake request.");
+ return Status::IOError("Stream is closed.");
}
private:
@@ -246,7 +286,7 @@ class FlightServiceImpl : public FlightService::Service {
}
const auto client_metadata = context->client_metadata();
- const auto auth_header = client_metadata.find(internal::AUTH_HEADER);
+ const auto auth_header = client_metadata.find(internal::kGrpcAuthHeader);
std::string token;
if (auth_header == client_metadata.end()) {
token = "";
@@ -349,16 +389,18 @@ class FlightServiceImpl : public FlightService::Service {
return grpc::Status::OK;
}
- grpc::Status DoPut(ServerContext* context, grpc::ServerReader* reader,
- pb::PutResult* response) {
+ grpc::Status DoPut(ServerContext* context,
+ grpc::ServerReaderWriter* reader) {
GrpcServerCallContext flight_context;
GRPC_RETURN_NOT_GRPC_OK(CheckAuth(context, flight_context));
auto message_reader =
std::unique_ptr(new FlightMessageReaderImpl(reader));
GRPC_RETURN_NOT_OK(message_reader->Init());
- return internal::ToGrpcStatus(
- server_->DoPut(flight_context, std::move(message_reader)));
+ auto metadata_writer =
+ std::unique_ptr(new GrpcMetadataWriter(reader));
+ return internal::ToGrpcStatus(server_->DoPut(
+ flight_context, std::move(message_reader), std::move(metadata_writer)));
}
grpc::Status ListActions(ServerContext* context, const pb::Empty* request,
@@ -410,6 +452,8 @@ class FlightServiceImpl : public FlightService::Service {
} // namespace
+FlightMetadataWriter::~FlightMetadataWriter() = default;
+
//
// gRPC server lifecycle
//
@@ -572,7 +616,8 @@ Status FlightServerBase::DoGet(const ServerCallContext& context, const Ticket& r
}
Status FlightServerBase::DoPut(const ServerCallContext& context,
- std::unique_ptr reader) {
+ std::unique_ptr reader,
+ std::unique_ptr writer) {
return Status::NotImplemented("NYI");
}
diff --git a/cpp/src/arrow/flight/server.h b/cpp/src/arrow/flight/server.h
index c1bcb5c0a3d..8792ec78e5e 100644
--- a/cpp/src/arrow/flight/server.h
+++ b/cpp/src/arrow/flight/server.h
@@ -81,8 +81,9 @@ class ARROW_FLIGHT_EXPORT RecordBatchStream : public FlightDataStream {
#pragma warning(disable : 4275)
#endif
-/// \brief A reader for IPC payloads uploaded by a client
-class ARROW_FLIGHT_EXPORT FlightMessageReader : public RecordBatchReader {
+/// \brief A reader for IPC payloads uploaded by a client. Also allows
+/// reading application-defined metadata via the Flight protocol.
+class ARROW_FLIGHT_EXPORT FlightMessageReader : public MetadataRecordBatchReader {
public:
/// \brief Get the descriptor for this upload.
virtual const FlightDescriptor& descriptor() const = 0;
@@ -92,6 +93,15 @@ class ARROW_FLIGHT_EXPORT FlightMessageReader : public RecordBatchReader {
#pragma warning(pop)
#endif
+/// \brief A writer for application-specific metadata sent back to the
+/// client during an upload.
+class ARROW_FLIGHT_EXPORT FlightMetadataWriter {
+ public:
+ virtual ~FlightMetadataWriter();
+ /// \brief Send a message to the client.
+ virtual Status WriteMetadata(const Buffer& app_metadata) = 0;
+};
+
/// \brief Call state/contextual data.
class ARROW_FLIGHT_EXPORT ServerCallContext {
public:
@@ -178,9 +188,11 @@ class ARROW_FLIGHT_EXPORT FlightServerBase {
/// \brief Process a stream of IPC payloads sent from a client
/// \param[in] context The call context.
/// \param[in] reader a sequence of uploaded record batches
+ /// \param[in] writer send metadata back to the client
/// \return Status
virtual Status DoPut(const ServerCallContext& context,
- std::unique_ptr reader);
+ std::unique_ptr reader,
+ std::unique_ptr writer);
/// \brief Execute an action, return stream of zero or more results
/// \param[in] context The call context.
diff --git a/cpp/src/arrow/flight/test-integration-client.cc b/cpp/src/arrow/flight/test-integration-client.cc
index abaa3bc4221..ec2b79c3737 100644
--- a/cpp/src/arrow/flight/test-integration-client.cc
+++ b/cpp/src/arrow/flight/test-integration-client.cc
@@ -44,17 +44,27 @@ DEFINE_string(host, "localhost", "Server port to connect to");
DEFINE_int32(port, 31337, "Server port to connect to");
DEFINE_string(path, "", "Resource path to request");
-/// \brief Helper to read a RecordBatchReader into a Table.
-arrow::Status ReadToTable(std::unique_ptr& reader,
+/// \brief Helper to read a MetadataRecordBatchReader into a Table.
+arrow::Status ReadToTable(arrow::flight::MetadataRecordBatchReader& reader,
std::shared_ptr* retrieved_data) {
+ // For integration testing, we expect the server numbers the
+ // batches, to test the application metadata part of the spec.
std::vector> retrieved_chunks;
std::shared_ptr chunk;
+ std::shared_ptr metadata_chunk;
+ int counter = 0;
while (true) {
- RETURN_NOT_OK(reader->ReadNext(&chunk));
+ RETURN_NOT_OK(reader.ReadWithMetadata(&chunk, &metadata_chunk));
if (chunk == nullptr) break;
retrieved_chunks.push_back(chunk);
+ if (std::to_string(counter) != metadata_chunk->ToString()) {
+ return arrow::Status::Invalid(
+ "Expected metadata value: " + std::to_string(counter) +
+ " but got: " + metadata_chunk->ToString());
+ }
+ counter++;
}
- return arrow::Table::FromRecordBatches(reader->schema(), retrieved_chunks,
+ return arrow::Table::FromRecordBatches(reader.schema(), retrieved_chunks,
retrieved_data);
}
@@ -71,14 +81,27 @@ arrow::Status ReadToTable(std::unique_ptr& reader,
- arrow::ipc::RecordBatchWriter& writer) {
+/// \brief Upload the contents of a RecordBatchReader to a Flight
+/// server, validating the application metadata on the side.
+arrow::Status UploadReaderToFlight(arrow::RecordBatchReader* reader,
+ arrow::flight::FlightStreamWriter& writer,
+ arrow::flight::FlightMetadataReader& metadata_reader) {
+ int counter = 0;
while (true) {
std::shared_ptr chunk;
RETURN_NOT_OK(reader->ReadNext(&chunk));
if (chunk == nullptr) break;
- RETURN_NOT_OK(writer.WriteRecordBatch(*chunk));
+ std::shared_ptr metadata =
+ arrow::Buffer::FromString(std::to_string(counter));
+ RETURN_NOT_OK(writer.WriteWithMetadata(*chunk, metadata));
+ // Wait for the server to ack the result
+ std::shared_ptr ack_metadata;
+ RETURN_NOT_OK(metadata_reader.ReadMetadata(&ack_metadata));
+ if (!ack_metadata->Equals(*metadata)) {
+ return arrow::Status::Invalid("Expected metadata value: " + metadata->ToString() +
+ " but got: " + ack_metadata->ToString());
+ }
+ counter++;
}
return writer.Close();
}
@@ -91,10 +114,10 @@ arrow::Status ConsumeFlightLocation(const arrow::flight::Location& location,
std::unique_ptr read_client;
RETURN_NOT_OK(arrow::flight::FlightClient::Connect(location, &read_client));
- std::unique_ptr stream;
+ std::unique_ptr stream;
RETURN_NOT_OK(read_client->DoGet(ticket, &stream));
- return ReadToTable(stream, retrieved_data);
+ return ReadToTable(*stream, retrieved_data);
}
int main(int argc, char** argv) {
@@ -120,11 +143,12 @@ int main(int argc, char** argv) {
std::shared_ptr original_data;
ABORT_NOT_OK(ReadToTable(reader, &original_data));
- std::unique_ptr write_stream;
- ABORT_NOT_OK(client->DoPut(descr, reader->schema(), &write_stream));
+ std::unique_ptr write_stream;
+ std::unique_ptr metadata_reader;
+ ABORT_NOT_OK(client->DoPut(descr, reader->schema(), &write_stream, &metadata_reader));
std::unique_ptr table_reader(
new arrow::TableBatchReader(*original_data));
- ABORT_NOT_OK(CopyReaderToWriter(table_reader, *write_stream));
+ ABORT_NOT_OK(UploadReaderToFlight(table_reader.get(), *write_stream, *metadata_reader));
// 2. Get the ticket for the data.
std::unique_ptr info;
diff --git a/cpp/src/arrow/flight/test-integration-server.cc b/cpp/src/arrow/flight/test-integration-server.cc
index c5bb180663b..6d04588c190 100644
--- a/cpp/src/arrow/flight/test-integration-server.cc
+++ b/cpp/src/arrow/flight/test-integration-server.cc
@@ -79,14 +79,16 @@ class FlightIntegrationTestServer : public FlightServerBase {
}
auto flight = data->second;
- *data_stream = std::unique_ptr(new RecordBatchStream(
- std::shared_ptr(new TableBatchReader(*flight))));
+ *data_stream = std::unique_ptr(
+ new NumberingStream(std::unique_ptr(new RecordBatchStream(
+ std::shared_ptr(new TableBatchReader(*flight))))));
return Status::OK();
}
Status DoPut(const ServerCallContext& context,
- std::unique_ptr reader) override {
+ std::unique_ptr reader,
+ std::unique_ptr writer) override {
const FlightDescriptor& descriptor = reader->descriptor();
if (descriptor.type != FlightDescriptor::DescriptorType::PATH) {
@@ -99,10 +101,12 @@ class FlightIntegrationTestServer : public FlightServerBase {
std::vector> retrieved_chunks;
std::shared_ptr chunk;
+ std::shared_ptr metadata;
while (true) {
- RETURN_NOT_OK(reader->ReadNext(&chunk));
+ RETURN_NOT_OK(reader->ReadWithMetadata(&chunk, &metadata));
if (chunk == nullptr) break;
retrieved_chunks.push_back(chunk);
+ RETURN_NOT_OK(writer->WriteMetadata(*metadata));
}
std::shared_ptr retrieved_data;
RETURN_NOT_OK(arrow::Table::FromRecordBatches(reader->schema(), retrieved_chunks,
diff --git a/cpp/src/arrow/flight/test-util.cc b/cpp/src/arrow/flight/test-util.cc
index 7dd78fdd6eb..c84b2eda81c 100644
--- a/cpp/src/arrow/flight/test-util.cc
+++ b/cpp/src/arrow/flight/test-util.cc
@@ -260,6 +260,22 @@ Status MakeFlightInfo(const Schema& schema, const FlightDescriptor& descriptor,
return internal::SchemaToString(schema, &out->schema);
}
+NumberingStream::NumberingStream(std::unique_ptr stream)
+ : counter_(0), stream_(std::move(stream)) {}
+
+std::shared_ptr NumberingStream::schema() { return stream_->schema(); }
+
+Status NumberingStream::GetSchemaPayload(FlightPayload* payload) {
+ return stream_->GetSchemaPayload(payload);
+}
+
+Status NumberingStream::Next(FlightPayload* payload) {
+ RETURN_NOT_OK(stream_->Next(payload));
+ payload->app_metadata = Buffer::FromString(std::to_string(counter_));
+ counter_++;
+ return Status::OK();
+}
+
std::shared_ptr ExampleIntSchema() {
auto f0 = field("f0", int32());
auto f1 = field("f1", int32());
diff --git a/cpp/src/arrow/flight/test-util.h b/cpp/src/arrow/flight/test-util.h
index 5b02630b432..54e9b03bce1 100644
--- a/cpp/src/arrow/flight/test-util.h
+++ b/cpp/src/arrow/flight/test-util.h
@@ -25,6 +25,7 @@
#include "arrow/status.h"
#include "arrow/flight/client_auth.h"
+#include "arrow/flight/server.h"
#include "arrow/flight/server_auth.h"
#include "arrow/flight/types.h"
#include "arrow/flight/visibility.h"
@@ -127,6 +128,23 @@ class ARROW_FLIGHT_EXPORT BatchIterator : public RecordBatchReader {
#pragma warning(pop)
#endif
+// ----------------------------------------------------------------------
+// A FlightDataStream that numbers the record batches
+/// \brief A basic implementation of FlightDataStream that will provide
+/// a sequence of FlightData messages to be written to a gRPC stream
+class ARROW_EXPORT NumberingStream : public FlightDataStream {
+ public:
+ explicit NumberingStream(std::unique_ptr stream);
+
+ std::shared_ptr schema() override;
+ Status GetSchemaPayload(FlightPayload* payload) override;
+ Status Next(FlightPayload* payload) override;
+
+ private:
+ int counter_;
+ std::shared_ptr stream_;
+};
+
// ----------------------------------------------------------------------
// Example data for test-server and unit tests
diff --git a/cpp/src/arrow/flight/types.h b/cpp/src/arrow/flight/types.h
index e5f7bcdd550..d885e0d778f 100644
--- a/cpp/src/arrow/flight/types.h
+++ b/cpp/src/arrow/flight/types.h
@@ -27,6 +27,7 @@
#include
#include "arrow/flight/visibility.h"
+#include "arrow/ipc/reader.h"
#include "arrow/ipc/writer.h"
namespace arrow {
@@ -205,6 +206,7 @@ struct ARROW_FLIGHT_EXPORT FlightEndpoint {
/// This structure corresponds to FlightData in the protocol.
struct ARROW_FLIGHT_EXPORT FlightPayload {
std::shared_ptr descriptor;
+ std::shared_ptr app_metadata;
ipc::internal::IpcPayload ipc_message;
};
@@ -278,6 +280,16 @@ class ARROW_FLIGHT_EXPORT ResultStream {
virtual Status Next(std::unique_ptr* info) = 0;
};
+/// \brief A RecordBatchReader that also exposes application-defined
+/// metadata from the Flight protocol.
+class ARROW_EXPORT MetadataRecordBatchReader : public RecordBatchReader {
+ public:
+ virtual ~MetadataRecordBatchReader() = default;
+
+ virtual Status ReadWithMetadata(std::shared_ptr* out,
+ std::shared_ptr* app_metadata) = 0;
+};
+
// \brief Create a FlightListing from a vector of FlightInfo objects. This can
// be iterated once, then it is consumed
class ARROW_FLIGHT_EXPORT SimpleFlightListing : public FlightListing {
diff --git a/cpp/src/arrow/ipc/reader.cc b/cpp/src/arrow/ipc/reader.cc
index 717f29aaf40..992c4fc138f 100644
--- a/cpp/src/arrow/ipc/reader.cc
+++ b/cpp/src/arrow/ipc/reader.cc
@@ -455,7 +455,7 @@ class RecordBatchStreamReader::RecordBatchStreamReaderImpl {
RETURN_NOT_OK(
ReadMessageAndValidate(message_reader_.get(), /*allow_null=*/false, &message));
- CHECK_MESSAGE_TYPE(message->type(), Message::SCHEMA);
+ CHECK_MESSAGE_TYPE(Message::SCHEMA, message->type());
CHECK_HAS_NO_BODY(*message);
if (message->header() == nullptr) {
return Status::IOError("Header-pointer of flatbuffer-encoded Message is null.");
diff --git a/cpp/src/arrow/python/flight.cc b/cpp/src/arrow/python/flight.cc
index ee19fb9d1e8..da5026a7f9c 100644
--- a/cpp/src/arrow/python/flight.cc
+++ b/cpp/src/arrow/python/flight.cc
@@ -108,10 +108,12 @@ Status PyFlightServer::DoGet(const arrow::flight::ServerCallContext& context,
});
}
-Status PyFlightServer::DoPut(const arrow::flight::ServerCallContext& context,
- std::unique_ptr reader) {
+Status PyFlightServer::DoPut(
+ const arrow::flight::ServerCallContext& context,
+ std::unique_ptr reader,
+ std::unique_ptr writer) {
return SafeCallIntoPython([&] {
- vtable_.do_put(server_.obj(), context, std::move(reader));
+ vtable_.do_put(server_.obj(), context, std::move(reader), std::move(writer));
return CheckPyError();
});
}
diff --git a/cpp/src/arrow/python/flight.h b/cpp/src/arrow/python/flight.h
index aecb97a10f5..5aea7e85259 100644
--- a/cpp/src/arrow/python/flight.h
+++ b/cpp/src/arrow/python/flight.h
@@ -50,7 +50,8 @@ class ARROW_PYTHON_EXPORT PyFlightServerVtable {
std::unique_ptr*)>
do_get;
std::function)>
+ std::unique_ptr