diff --git a/cpp/src/arrow/flight/client.cc b/cpp/src/arrow/flight/client.cc index 1926928c643..f81b6278f11 100644 --- a/cpp/src/arrow/flight/client.cc +++ b/cpp/src/arrow/flight/client.cc @@ -79,7 +79,7 @@ struct ClientRpc { if (auth_handler) { std::string token; RETURN_NOT_OK(auth_handler->GetToken(&token)); - context.AddMetadata(internal::AUTH_HEADER, token); + context.AddMetadata(internal::kGrpcAuthHeader, token); } return Status::OK(); } @@ -129,15 +129,47 @@ class GrpcClientAuthReader : public ClientAuthReader { stream_; }; -class FlightIpcMessageReader : public ipc::MessageReader { +// The next two classes are intertwined. To get the application +// metadata while avoiding reimplementing RecordBatchStreamReader, we +// create an ipc::MessageReader that is tied to the +// MetadataRecordBatchReader. Every time an IPC message is read, it updates +// the application metadata field of the MetadataRecordBatchReader. The +// MetadataRecordBatchReader wraps RecordBatchStreamReader, offering an +// additional method to get both the record batch and application +// metadata. + +class GrpcIpcMessageReader; +class GrpcStreamReader : public FlightStreamReader { public: - FlightIpcMessageReader(std::unique_ptr rpc, - std::unique_ptr> stream) - : rpc_(std::move(rpc)), stream_(std::move(stream)), stream_finished_(false) {} + GrpcStreamReader(); + + static Status Open(std::unique_ptr rpc, + std::unique_ptr> stream, + std::unique_ptr* out); + std::shared_ptr schema() const override; + Status Next(FlightStreamChunk* out) override; + void Cancel() override; + + private: + friend class GrpcIpcMessageReader; + std::unique_ptr batch_reader_; + std::shared_ptr last_app_metadata_; + std::shared_ptr rpc_; +}; + +class GrpcIpcMessageReader : public ipc::MessageReader { + public: + GrpcIpcMessageReader(GrpcStreamReader* reader, std::shared_ptr rpc, + std::unique_ptr> stream) + : flight_reader_(reader), + rpc_(rpc), + stream_(std::move(stream)), + stream_finished_(false) {} Status ReadNextMessage(std::unique_ptr* out) override { if (stream_finished_) { *out = nullptr; + flight_reader_->last_app_metadata_ = nullptr; return Status::OK(); } internal::FlightData data; @@ -145,13 +177,16 @@ class FlightIpcMessageReader : public ipc::MessageReader { // Stream is completed stream_finished_ = true; *out = nullptr; + flight_reader_->last_app_metadata_ = nullptr; return OverrideWithServerError(Status::OK()); } // Validate IPC message auto st = data.OpenMessage(out); if (!st.ok()) { + flight_reader_->last_app_metadata_ = nullptr; return OverrideWithServerError(std::move(st)); } + flight_reader_->last_app_metadata_ = data.app_metadata; return Status::OK(); } @@ -162,23 +197,93 @@ class FlightIpcMessageReader : public ipc::MessageReader { return std::move(st); } + private: + GrpcStreamReader* flight_reader_; // The RPC context lifetime must be coupled to the ClientReader - std::unique_ptr rpc_; + std::shared_ptr rpc_; std::unique_ptr> stream_; bool stream_finished_; }; +GrpcStreamReader::GrpcStreamReader() {} + +Status GrpcStreamReader::Open(std::unique_ptr rpc, + std::unique_ptr> stream, + std::unique_ptr* out) { + *out = std::unique_ptr(new GrpcStreamReader); + out->get()->rpc_ = std::move(rpc); + std::unique_ptr message_reader( + new GrpcIpcMessageReader(out->get(), out->get()->rpc_, std::move(stream))); + return ipc::RecordBatchStreamReader::Open(std::move(message_reader), + &(*out)->batch_reader_); +} + +std::shared_ptr GrpcStreamReader::schema() const { + return batch_reader_->schema(); +} + +Status GrpcStreamReader::Next(FlightStreamChunk* out) { + out->app_metadata = nullptr; + RETURN_NOT_OK(batch_reader_->ReadNext(&out->data)); + out->app_metadata = std::move(last_app_metadata_); + return Status::OK(); +} + +void GrpcStreamReader::Cancel() { rpc_->context.TryCancel(); } + +// Similarly, the next two classes are intertwined. In order to get +// application-specific metadata to the IpcPayloadWriter, +// DoPutPayloadWriter takes a pointer to +// GrpcStreamWriter. GrpcStreamWriter updates a metadata field on +// write; DoPutPayloadWriter reads that metadata field to determine +// what to write. + +class DoPutPayloadWriter; +class GrpcStreamWriter : public FlightStreamWriter { + public: + ~GrpcStreamWriter() = default; + + GrpcStreamWriter() : app_metadata_(nullptr), batch_writer_(nullptr) {} + + static Status Open( + const FlightDescriptor& descriptor, const std::shared_ptr& schema, + std::unique_ptr rpc, std::unique_ptr response, + std::shared_ptr> writer, + std::unique_ptr* out); + + Status WriteRecordBatch(const RecordBatch& batch, bool allow_64bit = false) override { + return WriteWithMetadata(batch, nullptr, allow_64bit); + } + Status WriteWithMetadata(const RecordBatch& batch, std::shared_ptr app_metadata, + bool allow_64bit = false) override { + app_metadata_ = app_metadata; + return batch_writer_->WriteRecordBatch(batch, allow_64bit); + } + void set_memory_pool(MemoryPool* pool) override { + batch_writer_->set_memory_pool(pool); + } + Status Close() override { return batch_writer_->Close(); } + + private: + friend class DoPutPayloadWriter; + std::shared_ptr app_metadata_; + std::unique_ptr batch_writer_; +}; + /// A IpcPayloadWriter implementation that writes to a DoPut stream class DoPutPayloadWriter : public ipc::internal::IpcPayloadWriter { public: - DoPutPayloadWriter(const FlightDescriptor& descriptor, std::unique_ptr rpc, - std::unique_ptr response, - std::unique_ptr> writer) + DoPutPayloadWriter( + const FlightDescriptor& descriptor, std::unique_ptr rpc, + std::unique_ptr response, + std::shared_ptr> writer, + GrpcStreamWriter* stream_writer) : descriptor_(descriptor), rpc_(std::move(rpc)), response_(std::move(response)), writer_(std::move(writer)), - first_payload_(true) {} + first_payload_(true), + stream_writer_(stream_writer) {} ~DoPutPayloadWriter() override = default; @@ -201,6 +306,9 @@ class DoPutPayloadWriter : public ipc::internal::IpcPayloadWriter { } RETURN_NOT_OK(Buffer::FromString(str_descr, &payload.descriptor)); first_payload_ = false; + } else if (ipc_payload.type == ipc::Message::RECORD_BATCH && + stream_writer_->app_metadata_) { + payload.app_metadata = std::move(stream_writer_->app_metadata_); } if (!internal::WritePayload(payload, writer_.get())) { @@ -211,6 +319,10 @@ class DoPutPayloadWriter : public ipc::internal::IpcPayloadWriter { Status Close() override { bool finished_writes = writer_->WritesDone(); + // Drain the read side to avoid hanging + pb::PutResult message; + while (writer_->Read(&message)) { + } RETURN_NOT_OK(internal::FromGrpcStatus(writer_->Finish())); if (!finished_writes) { return Status::UnknownError( @@ -223,9 +335,47 @@ class DoPutPayloadWriter : public ipc::internal::IpcPayloadWriter { // TODO: there isn't a way to access this as a user. const FlightDescriptor descriptor_; std::unique_ptr rpc_; - std::unique_ptr response_; - std::unique_ptr> writer_; + std::unique_ptr response_; + std::shared_ptr> writer_; bool first_payload_; + GrpcStreamWriter* stream_writer_; +}; + +Status GrpcStreamWriter::Open( + const FlightDescriptor& descriptor, const std::shared_ptr& schema, + std::unique_ptr rpc, std::unique_ptr response, + std::shared_ptr> writer, + std::unique_ptr* out) { + std::unique_ptr result(new GrpcStreamWriter); + std::unique_ptr payload_writer(new DoPutPayloadWriter( + descriptor, std::move(rpc), std::move(response), writer, result.get())); + RETURN_NOT_OK(ipc::internal::OpenRecordBatchWriter(std::move(payload_writer), schema, + &result->batch_writer_)); + *out = std::move(result); + return Status::OK(); +} + +FlightMetadataReader::~FlightMetadataReader() = default; + +class GrpcMetadataReader : public FlightMetadataReader { + public: + explicit GrpcMetadataReader( + std::shared_ptr> reader) + : reader_(reader) {} + + Status ReadMetadata(std::shared_ptr* out) override { + pb::PutResult message; + if (reader_->Read(&message)) { + *out = Buffer::FromString(std::move(*message.release_app_metadata())); + } else { + // Stream finished + *out = nullptr; + } + return Status::OK(); + } + + private: + std::shared_ptr> reader_; }; class FlightClient::FlightClientImpl { @@ -367,7 +517,7 @@ class FlightClient::FlightClientImpl { } Status DoGet(const FlightCallOptions& options, const Ticket& ticket, - std::unique_ptr* out) { + std::unique_ptr* out) { pb::Ticket pb_ticket; internal::ToProto(ticket, &pb_ticket); @@ -376,25 +526,25 @@ class FlightClient::FlightClientImpl { std::unique_ptr> stream( stub_->DoGet(&rpc->context, pb_ticket)); - std::unique_ptr message_reader( - new FlightIpcMessageReader(std::move(rpc), std::move(stream))); - return ipc::RecordBatchStreamReader::Open(std::move(message_reader), out); + std::unique_ptr reader; + RETURN_NOT_OK(GrpcStreamReader::Open(std::move(rpc), std::move(stream), &reader)); + *out = std::move(reader); + return Status::OK(); } Status DoPut(const FlightCallOptions& options, const FlightDescriptor& descriptor, const std::shared_ptr& schema, - std::unique_ptr* out) { + std::unique_ptr* out, + std::unique_ptr* reader) { std::unique_ptr rpc(new ClientRpc(options)); RETURN_NOT_OK(rpc->SetToken(auth_handler_.get())); - std::unique_ptr response(new protocol::PutResult); - std::unique_ptr> writer( - stub_->DoPut(&rpc->context, response.get())); - - std::unique_ptr payload_writer( - new DoPutPayloadWriter(descriptor, std::move(rpc), std::move(response), - std::move(writer))); + std::unique_ptr response(new pb::PutResult); + std::shared_ptr> writer( + stub_->DoPut(&rpc->context)); - return ipc::internal::OpenRecordBatchWriter(std::move(payload_writer), schema, out); + *reader = std::unique_ptr(new GrpcMetadataReader(writer)); + return GrpcStreamWriter::Open(descriptor, schema, std::move(rpc), std::move(response), + writer, out); } private: @@ -449,15 +599,16 @@ Status FlightClient::ListFlights(const FlightCallOptions& options, } Status FlightClient::DoGet(const FlightCallOptions& options, const Ticket& ticket, - std::unique_ptr* stream) { + std::unique_ptr* stream) { return impl_->DoGet(options, ticket, stream); } Status FlightClient::DoPut(const FlightCallOptions& options, const FlightDescriptor& descriptor, const std::shared_ptr& schema, - std::unique_ptr* stream) { - return impl_->DoPut(options, descriptor, schema, stream); + std::unique_ptr* stream, + std::unique_ptr* reader) { + return impl_->DoPut(options, descriptor, schema, stream, reader); } } // namespace flight diff --git a/cpp/src/arrow/flight/client.h b/cpp/src/arrow/flight/client.h index b8a5d4f4b91..0fa571d75fa 100644 --- a/cpp/src/arrow/flight/client.h +++ b/cpp/src/arrow/flight/client.h @@ -25,6 +25,7 @@ #include #include +#include "arrow/ipc/reader.h" #include "arrow/ipc/writer.h" #include "arrow/status.h" @@ -35,7 +36,6 @@ namespace arrow { class MemoryPool; class RecordBatch; -class RecordBatchReader; class Schema; namespace flight { @@ -66,6 +66,43 @@ class ARROW_FLIGHT_EXPORT FlightClientOptions { std::string override_hostname; }; +/// \brief A RecordBatchReader exposing Flight metadata and cancel +/// operations. +class ARROW_FLIGHT_EXPORT FlightStreamReader : public MetadataRecordBatchReader { + public: + /// \brief Try to cancel the call. + virtual void Cancel() = 0; +}; + +// Silence warning +// "non dll-interface class RecordBatchReader used as base for dll-interface class" +#ifdef _MSC_VER +#pragma warning(push) +#pragma warning(disable : 4275) +#endif + +/// \brief A RecordBatchWriter that also allows sending +/// application-defined metadata via the Flight protocol. +class ARROW_FLIGHT_EXPORT FlightStreamWriter : public ipc::RecordBatchWriter { + public: + virtual Status WriteWithMetadata(const RecordBatch& batch, + std::shared_ptr app_metadata, + bool allow_64bit = false) = 0; +}; + +#ifdef _MSC_VER +#pragma warning(pop) +#endif + +/// \brief A reader for application-specific metadata sent back to the +/// client during an upload. +class ARROW_FLIGHT_EXPORT FlightMetadataReader { + public: + virtual ~FlightMetadataReader(); + /// \brief Read a message from the server. + virtual Status ReadMetadata(std::shared_ptr* out) = 0; +}; + /// \brief Client class for Arrow Flight RPC services (gRPC-based). /// API experimental for now class ARROW_FLIGHT_EXPORT FlightClient { @@ -151,8 +188,8 @@ class ARROW_FLIGHT_EXPORT FlightClient { /// \param[out] stream the returned RecordBatchReader /// \return Status Status DoGet(const FlightCallOptions& options, const Ticket& ticket, - std::unique_ptr* stream); - Status DoGet(const Ticket& ticket, std::unique_ptr* stream) { + std::unique_ptr* stream); + Status DoGet(const Ticket& ticket, std::unique_ptr* stream) { return DoGet({}, ticket, stream); } @@ -163,13 +200,16 @@ class ARROW_FLIGHT_EXPORT FlightClient { /// \param[in] descriptor the descriptor of the stream /// \param[in] schema the schema for the data to upload /// \param[out] stream a writer to write record batches to + /// \param[out] reader a reader for application metadata from the server /// \return Status Status DoPut(const FlightCallOptions& options, const FlightDescriptor& descriptor, const std::shared_ptr& schema, - std::unique_ptr* stream); + std::unique_ptr* stream, + std::unique_ptr* reader); Status DoPut(const FlightDescriptor& descriptor, const std::shared_ptr& schema, - std::unique_ptr* stream) { - return DoPut({}, descriptor, schema, stream); + std::unique_ptr* stream, + std::unique_ptr* reader) { + return DoPut({}, descriptor, schema, stream, reader); } private: diff --git a/cpp/src/arrow/flight/flight-benchmark.cc b/cpp/src/arrow/flight/flight-benchmark.cc index f2bd356a647..f5dd462ad9a 100644 --- a/cpp/src/arrow/flight/flight-benchmark.cc +++ b/cpp/src/arrow/flight/flight-benchmark.cc @@ -106,10 +106,10 @@ Status RunPerformanceTest(const std::string& hostname, const int port) { perf::Token token; token.ParseFromString(endpoint.ticket.ticket); - std::unique_ptr reader; + std::unique_ptr reader; RETURN_NOT_OK(client->DoGet(endpoint.ticket, &reader)); - std::shared_ptr batch; + FlightStreamChunk batch; // This is hard-coded for right now, 4 columns each with int64 const int bytes_per_record = 32; @@ -120,26 +120,26 @@ Status RunPerformanceTest(const std::string& hostname, const int port) { int64_t num_bytes = 0; int64_t num_records = 0; while (true) { - RETURN_NOT_OK(reader->ReadNext(&batch)); - if (!batch) { + RETURN_NOT_OK(reader->Next(&batch)); + if (!batch.data) { break; } if (verify) { - auto values = - reinterpret_cast(batch->column_data(0)->buffers[1]->data()); + auto values = reinterpret_cast( + batch.data->column_data(0)->buffers[1]->data()); const int64_t start = token.start() + num_records; - for (int64_t i = 0; i < batch->num_rows(); ++i) { + for (int64_t i = 0; i < batch.data->num_rows(); ++i) { if (values[i] != start + i) { return Status::Invalid("verification failure"); } } } - num_records += batch->num_rows(); + num_records += batch.data->num_rows(); // Hard-coded - num_bytes += batch->num_rows() * bytes_per_record; + num_bytes += batch.data->num_rows() * bytes_per_record; } stats.Update(num_records, num_bytes); return Status::OK(); diff --git a/cpp/src/arrow/flight/flight-test.cc b/cpp/src/arrow/flight/flight-test.cc index 3c0b67cd992..c0f0c7f1e86 100644 --- a/cpp/src/arrow/flight/flight-test.cc +++ b/cpp/src/arrow/flight/flight-test.cc @@ -211,19 +211,19 @@ class TestFlightClient : public ::testing::Test { // By convention, fetch the first endpoint Ticket ticket = info->endpoints()[0].ticket; - std::unique_ptr stream; + std::unique_ptr stream; ASSERT_OK(client_->DoGet(ticket, &stream)); - std::shared_ptr chunk; + FlightStreamChunk chunk; for (int i = 0; i < num_batches; ++i) { - ASSERT_OK(stream->ReadNext(&chunk)); - ASSERT_NE(nullptr, chunk); - ASSERT_BATCHES_EQUAL(*expected_batches[i], *chunk); + ASSERT_OK(stream->Next(&chunk)); + ASSERT_NE(nullptr, chunk.data); + ASSERT_BATCHES_EQUAL(*expected_batches[i], *chunk.data); } // Stream exhausted - ASSERT_OK(stream->ReadNext(&chunk)); - ASSERT_EQ(nullptr, chunk); + ASSERT_OK(stream->Next(&chunk)); + ASSERT_EQ(nullptr, chunk.data); } protected: @@ -255,7 +255,8 @@ class TlsTestServer : public FlightServerBase { class DoPutTestServer : public FlightServerBase { public: Status DoPut(const ServerCallContext& context, - std::unique_ptr reader) override { + std::unique_ptr reader, + std::unique_ptr writer) override { descriptor_ = reader->descriptor(); return reader->ReadAll(&batches_); } @@ -267,6 +268,70 @@ class DoPutTestServer : public FlightServerBase { friend class TestDoPut; }; +class MetadataTestServer : public FlightServerBase { + Status DoGet(const ServerCallContext& context, const Ticket& request, + std::unique_ptr* data_stream) override { + BatchVector batches; + RETURN_NOT_OK(ExampleIntBatches(&batches)); + std::shared_ptr batch_reader = + std::make_shared(batches[0]->schema(), batches); + + *data_stream = std::unique_ptr(new NumberingStream( + std::unique_ptr(new RecordBatchStream(batch_reader)))); + return Status::OK(); + } + + Status DoPut(const ServerCallContext& context, + std::unique_ptr reader, + std::unique_ptr writer) override { + FlightStreamChunk chunk; + int counter = 0; + while (true) { + RETURN_NOT_OK(reader->Next(&chunk)); + if (chunk.data == nullptr) break; + if (chunk.app_metadata == nullptr) { + return Status::Invalid("Expected application metadata to be provided"); + } + if (std::to_string(counter) != chunk.app_metadata->ToString()) { + return Status::Invalid("Expected metadata value: " + std::to_string(counter) + + " but got: " + chunk.app_metadata->ToString()); + } + auto metadata = Buffer::FromString(std::to_string(counter)); + RETURN_NOT_OK(writer->WriteMetadata(*metadata)); + counter++; + } + return Status::OK(); + } +}; + +template +class InsecureTestServer : public ::testing::Test { + public: + void SetUp() { + Location location; + ASSERT_OK(Location::ForGrpcTcp("localhost", 30000, &location)); + + std::unique_ptr server(new T); + FlightServerOptions options(location); + ASSERT_OK(server->Init(options)); + + server_.reset(new InProcessTestServer(std::move(server), location)); + ASSERT_OK(server_->Start()); + ASSERT_OK(ConnectClient()); + } + + void TearDown() { server_->Stop(); } + + Status ConnectClient() { return FlightClient::Connect(server_->location(), &client_); } + + protected: + int port_; + std::unique_ptr client_; + std::unique_ptr server_; +}; + +using TestMetadata = InsecureTestServer; + class TestAuthHandler : public ::testing::Test { public: void SetUp() { @@ -323,8 +388,9 @@ class TestDoPut : public ::testing::Test { void CheckDoPut(FlightDescriptor descr, const std::shared_ptr& schema, const BatchVector& batches) { - std::unique_ptr stream; - ASSERT_OK(client_->DoPut(descr, schema, &stream)); + std::unique_ptr stream; + std::unique_ptr reader; + ASSERT_OK(client_->DoPut(descr, schema, &stream, &reader)); for (const auto& batch : batches) { ASSERT_OK(stream->WriteRecordBatch(*batch)); } @@ -485,7 +551,7 @@ TEST_F(TestFlightClient, Issue5095) { // Make sure the server-side error message is reflected to the // client Ticket ticket1{"ARROW-5095-fail"}; - std::unique_ptr stream; + std::unique_ptr stream; Status status = client_->DoGet(ticket1, &stream); ASSERT_RAISES(IOError, status); ASSERT_THAT(status.message(), ::testing::HasSubstr("Server-side error")); @@ -588,13 +654,14 @@ TEST_F(TestAuthHandler, PassAuthenticatedCalls) { status = client_->GetFlightInfo(FlightDescriptor{}, &info); ASSERT_RAISES(NotImplemented, status); - std::unique_ptr stream; + std::unique_ptr stream; status = client_->DoGet(Ticket{}, &stream); ASSERT_RAISES(NotImplemented, status); - std::unique_ptr writer; + std::unique_ptr writer; + std::unique_ptr reader; std::shared_ptr schema = arrow::schema({}); - status = client_->DoPut(FlightDescriptor{}, schema, &writer); + status = client_->DoPut(FlightDescriptor{}, schema, &writer, &reader); ASSERT_OK(status); status = writer->Close(); ASSERT_RAISES(NotImplemented, status); @@ -625,15 +692,16 @@ TEST_F(TestAuthHandler, FailUnauthenticatedCalls) { ASSERT_RAISES(IOError, status); ASSERT_THAT(status.message(), ::testing::HasSubstr("Invalid token")); - std::unique_ptr stream; + std::unique_ptr stream; status = client_->DoGet(Ticket{}, &stream); ASSERT_RAISES(IOError, status); ASSERT_THAT(status.message(), ::testing::HasSubstr("Invalid token")); - std::unique_ptr writer; + std::unique_ptr writer; + std::unique_ptr reader; std::shared_ptr schema( (new arrow::Schema(std::vector>()))); - status = client_->DoPut(FlightDescriptor{}, schema, &writer); + status = client_->DoPut(FlightDescriptor{}, schema, &writer, &reader); ASSERT_OK(status); status = writer->Close(); ASSERT_RAISES(IOError, status); @@ -693,5 +761,72 @@ TEST_F(TestTls, OverrideHostname) { ASSERT_RAISES(IOError, client->DoAction(options, action, &results)); } +TEST_F(TestMetadata, DoGet) { + Ticket ticket{""}; + std::unique_ptr stream; + ASSERT_OK(client_->DoGet(ticket, &stream)); + + BatchVector expected_batches; + ASSERT_OK(ExampleIntBatches(&expected_batches)); + + FlightStreamChunk chunk; + auto num_batches = static_cast(expected_batches.size()); + for (int i = 0; i < num_batches; ++i) { + ASSERT_OK(stream->Next(&chunk)); + ASSERT_NE(nullptr, chunk.data); + ASSERT_NE(nullptr, chunk.app_metadata); + ASSERT_BATCHES_EQUAL(*expected_batches[i], *chunk.data); + ASSERT_EQ(std::to_string(i), chunk.app_metadata->ToString()); + } + ASSERT_OK(stream->Next(&chunk)); + ASSERT_EQ(nullptr, chunk.data); +} + +TEST_F(TestMetadata, DoPut) { + std::unique_ptr writer; + std::unique_ptr reader; + std::shared_ptr schema = ExampleIntSchema(); + ASSERT_OK(client_->DoPut(FlightDescriptor{}, schema, &writer, &reader)); + + BatchVector expected_batches; + ASSERT_OK(ExampleIntBatches(&expected_batches)); + + std::shared_ptr chunk; + std::shared_ptr metadata; + auto num_batches = static_cast(expected_batches.size()); + for (int i = 0; i < num_batches; ++i) { + ASSERT_OK(writer->WriteWithMetadata(*expected_batches[i], + Buffer::FromString(std::to_string(i)))); + } + // This eventually calls grpc::ClientReaderWriter::Finish which can + // hang if there are unread messages. So make sure our wrapper + // around this doesn't hang (because it drains any unread messages) + ASSERT_OK(writer->Close()); +} + +TEST_F(TestMetadata, DoPutReadMetadata) { + std::unique_ptr writer; + std::unique_ptr reader; + std::shared_ptr schema = ExampleIntSchema(); + ASSERT_OK(client_->DoPut(FlightDescriptor{}, schema, &writer, &reader)); + + BatchVector expected_batches; + ASSERT_OK(ExampleIntBatches(&expected_batches)); + + std::shared_ptr chunk; + std::shared_ptr metadata; + auto num_batches = static_cast(expected_batches.size()); + for (int i = 0; i < num_batches; ++i) { + ASSERT_OK(writer->WriteWithMetadata(*expected_batches[i], + Buffer::FromString(std::to_string(i)))); + ASSERT_OK(reader->ReadMetadata(&metadata)); + ASSERT_NE(nullptr, metadata); + ASSERT_EQ(std::to_string(i), metadata->ToString()); + } + // As opposed to DoPutDrainMetadata, now we've read the messages, so + // make sure this still closes as expected. + ASSERT_OK(writer->Close()); +} + } // namespace flight } // namespace arrow diff --git a/cpp/src/arrow/flight/internal.h b/cpp/src/arrow/flight/internal.h index 784e8ebae1c..5283bed2183 100644 --- a/cpp/src/arrow/flight/internal.h +++ b/cpp/src/arrow/flight/internal.h @@ -63,7 +63,8 @@ namespace flight { namespace internal { -static const char* AUTH_HEADER = "auth-token-bin"; +/// The name of the header used to pass authentication tokens. +static const char* kGrpcAuthHeader = "auth-token-bin"; ARROW_FLIGHT_EXPORT Status SchemaToString(const Schema& schema, std::string* out); diff --git a/cpp/src/arrow/flight/serialization-internal.cc b/cpp/src/arrow/flight/serialization-internal.cc index d78bac83870..0ff5abad692 100644 --- a/cpp/src/arrow/flight/serialization-internal.cc +++ b/cpp/src/arrow/flight/serialization-internal.cc @@ -163,6 +163,14 @@ grpc::Status FlightDataSerialize(const FlightPayload& msg, ByteBuffer* out, // 1 byte for metadata tag header_size += 1 + WireFormatLite::LengthDelimitedSize(metadata_size); + // App metadata tag if appropriate + int32_t app_metadata_size = 0; + if (msg.app_metadata && msg.app_metadata->size() > 0) { + DCHECK_LT(msg.app_metadata->size(), kInt32Max); + app_metadata_size = static_cast(msg.app_metadata->size()); + header_size += 1 + WireFormatLite::LengthDelimitedSize(app_metadata_size); + } + for (const auto& buffer : ipc_msg.body_buffers) { // Buffer may be null when the row length is zero, or when all // entries are invalid. @@ -214,6 +222,15 @@ grpc::Status FlightDataSerialize(const FlightPayload& msg, ByteBuffer* out, header_stream.WriteRawMaybeAliased(ipc_msg.metadata->data(), static_cast(ipc_msg.metadata->size())); + // Write app metadata + if (app_metadata_size > 0) { + WireFormatLite::WriteTag(pb::FlightData::kAppMetadataFieldNumber, + WireFormatLite::WIRETYPE_LENGTH_DELIMITED, &header_stream); + header_stream.WriteVarint32(app_metadata_size); + header_stream.WriteRawMaybeAliased(msg.app_metadata->data(), + static_cast(msg.app_metadata->size())); + } + if (has_body) { // Write body tag WireFormatLite::WriteTag(pb::FlightData::kDataBodyFieldNumber, @@ -292,6 +309,12 @@ grpc::Status FlightDataDeserialize(ByteBuffer* buffer, FlightData* out) { "Unable to read FlightData metadata"); } } break; + case pb::FlightData::kAppMetadataFieldNumber: { + if (!ReadBytesZeroCopy(wrapped_buffer, &pb_stream, &out->app_metadata)) { + return grpc::Status(grpc::StatusCode::INTERNAL, + "Unable to read FlightData application metadata"); + } + } break; case pb::FlightData::kDataBodyFieldNumber: { if (!ReadBytesZeroCopy(wrapped_buffer, &pb_stream, &out->body)) { return grpc::Status(grpc::StatusCode::INTERNAL, @@ -330,7 +353,7 @@ Status FlightData::OpenMessage(std::unique_ptr* message) { // (see customize_protobuf.h). bool WritePayload(const FlightPayload& payload, - grpc::ClientWriter* writer) { + grpc::ClientReaderWriter* writer) { // Pretend to be pb::FlightData and intercept in SerializationTraits return writer->Write(*reinterpret_cast(&payload), grpc::WriteOptions()); @@ -348,7 +371,8 @@ bool ReadPayload(grpc::ClientReader* reader, FlightData* data) { return reader->Read(reinterpret_cast(data)); } -bool ReadPayload(grpc::ServerReader* reader, FlightData* data) { +bool ReadPayload(grpc::ServerReaderWriter* reader, + FlightData* data) { // Pretend to be pb::FlightData and intercept in SerializationTraits return reader->Read(reinterpret_cast(data)); } diff --git a/cpp/src/arrow/flight/serialization-internal.h b/cpp/src/arrow/flight/serialization-internal.h index aa47af6ae35..cfb8b8ab048 100644 --- a/cpp/src/arrow/flight/serialization-internal.h +++ b/cpp/src/arrow/flight/serialization-internal.h @@ -43,6 +43,9 @@ struct FlightData { /// Non-length-prefixed Message header as described in format/Message.fbs std::shared_ptr metadata; + /// Application-defined metadata + std::shared_ptr app_metadata; + /// Message body std::shared_ptr body; @@ -53,14 +56,15 @@ struct FlightData { /// Write Flight message on gRPC stream with zero-copy optimizations. /// True is returned on success, false if some error occurred (connection closed?). bool WritePayload(const FlightPayload& payload, - grpc::ClientWriter* writer); + grpc::ClientReaderWriter* writer); bool WritePayload(const FlightPayload& payload, grpc::ServerWriter* writer); /// Read Flight message from gRPC stream with zero-copy optimizations. /// True is returned on success, false if stream ended. bool ReadPayload(grpc::ClientReader* reader, FlightData* data); -bool ReadPayload(grpc::ServerReader* reader, FlightData* data); +bool ReadPayload(grpc::ServerReaderWriter* reader, + FlightData* data); } // namespace internal } // namespace flight diff --git a/cpp/src/arrow/flight/server.cc b/cpp/src/arrow/flight/server.cc index 6f3c466c4ad..d059a8be923 100644 --- a/cpp/src/arrow/flight/server.cc +++ b/cpp/src/arrow/flight/server.cc @@ -72,12 +72,15 @@ namespace { // A MessageReader implementation that reads from a gRPC ServerReader class FlightIpcMessageReader : public ipc::MessageReader { public: - explicit FlightIpcMessageReader(grpc::ServerReader* reader) - : reader_(reader) {} + explicit FlightIpcMessageReader( + grpc::ServerReaderWriter* reader, + std::shared_ptr* last_metadata) + : reader_(reader), app_metadata_(last_metadata) {} Status ReadNextMessage(std::unique_ptr* out) override { if (stream_finished_) { *out = nullptr; + *app_metadata_ = nullptr; return Status::OK(); } internal::FlightData data; @@ -89,6 +92,7 @@ class FlightIpcMessageReader : public ipc::MessageReader { "Client provided malformed message or did not provide message"); } *out = nullptr; + *app_metadata_ = nullptr; return Status::OK(); } @@ -100,25 +104,29 @@ class FlightIpcMessageReader : public ipc::MessageReader { first_message_ = false; } - return data.OpenMessage(out); + RETURN_NOT_OK(data.OpenMessage(out)); + *app_metadata_ = std::move(data.app_metadata); + return Status::OK(); } const FlightDescriptor& descriptor() const { return descriptor_; } protected: - grpc::ServerReader* reader_; + grpc::ServerReaderWriter* reader_; bool stream_finished_ = false; bool first_message_ = true; FlightDescriptor descriptor_; + std::shared_ptr* app_metadata_; }; class FlightMessageReaderImpl : public FlightMessageReader { public: - explicit FlightMessageReaderImpl(grpc::ServerReader* reader) + explicit FlightMessageReaderImpl( + grpc::ServerReaderWriter* reader) : reader_(reader) {} Status Init() { - message_reader_ = new FlightIpcMessageReader(reader_); + message_reader_ = new FlightIpcMessageReader(reader_, &last_metadata_); return ipc::RecordBatchStreamReader::Open( std::unique_ptr(message_reader_), &batch_reader_); } @@ -129,18 +137,41 @@ class FlightMessageReaderImpl : public FlightMessageReader { std::shared_ptr schema() const override { return batch_reader_->schema(); } - Status ReadNext(std::shared_ptr* out) override { - return batch_reader_->ReadNext(out); + Status Next(FlightStreamChunk* out) override { + out->app_metadata = nullptr; + RETURN_NOT_OK(batch_reader_->ReadNext(&out->data)); + out->app_metadata = std::move(last_metadata_); + return Status::OK(); } private: std::shared_ptr schema_; std::unique_ptr dictionary_memo_; - grpc::ServerReader* reader_; + grpc::ServerReaderWriter* reader_; FlightIpcMessageReader* message_reader_; + std::shared_ptr last_metadata_; std::shared_ptr batch_reader_; }; +class GrpcMetadataWriter : public FlightMetadataWriter { + public: + explicit GrpcMetadataWriter( + grpc::ServerReaderWriter* writer) + : writer_(writer) {} + + Status WriteMetadata(const Buffer& buffer) override { + pb::PutResult message{}; + message.set_app_metadata(buffer.data(), buffer.size()); + if (writer_->Write(message)) { + return Status::OK(); + } + return Status::IOError("Unknown error writing metadata."); + } + + private: + grpc::ServerReaderWriter* writer_; +}; + class GrpcServerAuthReader : public ServerAuthReader { public: explicit GrpcServerAuthReader( @@ -153,7 +184,7 @@ class GrpcServerAuthReader : public ServerAuthReader { *token = std::move(*request.release_payload()); return Status::OK(); } - return Status::UnknownError("Could not read client handshake request."); + return Status::IOError("Stream is closed."); } private: @@ -246,7 +277,7 @@ class FlightServiceImpl : public FlightService::Service { } const auto client_metadata = context->client_metadata(); - const auto auth_header = client_metadata.find(internal::AUTH_HEADER); + const auto auth_header = client_metadata.find(internal::kGrpcAuthHeader); std::string token; if (auth_header == client_metadata.end()) { token = ""; @@ -349,16 +380,18 @@ class FlightServiceImpl : public FlightService::Service { return grpc::Status::OK; } - grpc::Status DoPut(ServerContext* context, grpc::ServerReader* reader, - pb::PutResult* response) { + grpc::Status DoPut(ServerContext* context, + grpc::ServerReaderWriter* reader) { GrpcServerCallContext flight_context; GRPC_RETURN_NOT_GRPC_OK(CheckAuth(context, flight_context)); auto message_reader = std::unique_ptr(new FlightMessageReaderImpl(reader)); GRPC_RETURN_NOT_OK(message_reader->Init()); - return internal::ToGrpcStatus( - server_->DoPut(flight_context, std::move(message_reader))); + auto metadata_writer = + std::unique_ptr(new GrpcMetadataWriter(reader)); + return internal::ToGrpcStatus(server_->DoPut( + flight_context, std::move(message_reader), std::move(metadata_writer))); } grpc::Status ListActions(ServerContext* context, const pb::Empty* request, @@ -410,6 +443,8 @@ class FlightServiceImpl : public FlightService::Service { } // namespace +FlightMetadataWriter::~FlightMetadataWriter() = default; + // // gRPC server lifecycle // @@ -572,7 +607,8 @@ Status FlightServerBase::DoGet(const ServerCallContext& context, const Ticket& r } Status FlightServerBase::DoPut(const ServerCallContext& context, - std::unique_ptr reader) { + std::unique_ptr reader, + std::unique_ptr writer) { return Status::NotImplemented("NYI"); } diff --git a/cpp/src/arrow/flight/server.h b/cpp/src/arrow/flight/server.h index c1bcb5c0a3d..25656e625e1 100644 --- a/cpp/src/arrow/flight/server.h +++ b/cpp/src/arrow/flight/server.h @@ -74,23 +74,22 @@ class ARROW_FLIGHT_EXPORT RecordBatchStream : public FlightDataStream { std::unique_ptr impl_; }; -// Silence warning -// "non dll-interface class RecordBatchReader used as base for dll-interface class" -#ifdef _MSC_VER -#pragma warning(push) -#pragma warning(disable : 4275) -#endif - -/// \brief A reader for IPC payloads uploaded by a client -class ARROW_FLIGHT_EXPORT FlightMessageReader : public RecordBatchReader { +/// \brief A reader for IPC payloads uploaded by a client. Also allows +/// reading application-defined metadata via the Flight protocol. +class ARROW_FLIGHT_EXPORT FlightMessageReader : public MetadataRecordBatchReader { public: /// \brief Get the descriptor for this upload. virtual const FlightDescriptor& descriptor() const = 0; }; -#ifdef _MSC_VER -#pragma warning(pop) -#endif +/// \brief A writer for application-specific metadata sent back to the +/// client during an upload. +class ARROW_FLIGHT_EXPORT FlightMetadataWriter { + public: + virtual ~FlightMetadataWriter(); + /// \brief Send a message to the client. + virtual Status WriteMetadata(const Buffer& app_metadata) = 0; +}; /// \brief Call state/contextual data. class ARROW_FLIGHT_EXPORT ServerCallContext { @@ -178,9 +177,11 @@ class ARROW_FLIGHT_EXPORT FlightServerBase { /// \brief Process a stream of IPC payloads sent from a client /// \param[in] context The call context. /// \param[in] reader a sequence of uploaded record batches + /// \param[in] writer send metadata back to the client /// \return Status virtual Status DoPut(const ServerCallContext& context, - std::unique_ptr reader); + std::unique_ptr reader, + std::unique_ptr writer); /// \brief Execute an action, return stream of zero or more results /// \param[in] context The call context. diff --git a/cpp/src/arrow/flight/test-integration-client.cc b/cpp/src/arrow/flight/test-integration-client.cc index abaa3bc4221..b02595eaec9 100644 --- a/cpp/src/arrow/flight/test-integration-client.cc +++ b/cpp/src/arrow/flight/test-integration-client.cc @@ -44,17 +44,26 @@ DEFINE_string(host, "localhost", "Server port to connect to"); DEFINE_int32(port, 31337, "Server port to connect to"); DEFINE_string(path, "", "Resource path to request"); -/// \brief Helper to read a RecordBatchReader into a Table. -arrow::Status ReadToTable(std::unique_ptr& reader, +/// \brief Helper to read a MetadataRecordBatchReader into a Table. +arrow::Status ReadToTable(arrow::flight::MetadataRecordBatchReader& reader, std::shared_ptr* retrieved_data) { + // For integration testing, we expect the server numbers the + // batches, to test the application metadata part of the spec. std::vector> retrieved_chunks; - std::shared_ptr chunk; + arrow::flight::FlightStreamChunk chunk; + int counter = 0; while (true) { - RETURN_NOT_OK(reader->ReadNext(&chunk)); - if (chunk == nullptr) break; - retrieved_chunks.push_back(chunk); + RETURN_NOT_OK(reader.Next(&chunk)); + if (!chunk.data) break; + retrieved_chunks.push_back(chunk.data); + if (std::to_string(counter) != chunk.app_metadata->ToString()) { + return arrow::Status::Invalid( + "Expected metadata value: " + std::to_string(counter) + + " but got: " + chunk.app_metadata->ToString()); + } + counter++; } - return arrow::Table::FromRecordBatches(reader->schema(), retrieved_chunks, + return arrow::Table::FromRecordBatches(reader.schema(), retrieved_chunks, retrieved_data); } @@ -71,14 +80,27 @@ arrow::Status ReadToTable(std::unique_ptr& reader, - arrow::ipc::RecordBatchWriter& writer) { +/// \brief Upload the contents of a RecordBatchReader to a Flight +/// server, validating the application metadata on the side. +arrow::Status UploadReaderToFlight(arrow::RecordBatchReader* reader, + arrow::flight::FlightStreamWriter& writer, + arrow::flight::FlightMetadataReader& metadata_reader) { + int counter = 0; while (true) { std::shared_ptr chunk; RETURN_NOT_OK(reader->ReadNext(&chunk)); if (chunk == nullptr) break; - RETURN_NOT_OK(writer.WriteRecordBatch(*chunk)); + std::shared_ptr metadata = + arrow::Buffer::FromString(std::to_string(counter)); + RETURN_NOT_OK(writer.WriteWithMetadata(*chunk, metadata)); + // Wait for the server to ack the result + std::shared_ptr ack_metadata; + RETURN_NOT_OK(metadata_reader.ReadMetadata(&ack_metadata)); + if (!ack_metadata->Equals(*metadata)) { + return arrow::Status::Invalid("Expected metadata value: " + metadata->ToString() + + " but got: " + ack_metadata->ToString()); + } + counter++; } return writer.Close(); } @@ -91,10 +113,10 @@ arrow::Status ConsumeFlightLocation(const arrow::flight::Location& location, std::unique_ptr read_client; RETURN_NOT_OK(arrow::flight::FlightClient::Connect(location, &read_client)); - std::unique_ptr stream; + std::unique_ptr stream; RETURN_NOT_OK(read_client->DoGet(ticket, &stream)); - return ReadToTable(stream, retrieved_data); + return ReadToTable(*stream, retrieved_data); } int main(int argc, char** argv) { @@ -120,11 +142,12 @@ int main(int argc, char** argv) { std::shared_ptr original_data; ABORT_NOT_OK(ReadToTable(reader, &original_data)); - std::unique_ptr write_stream; - ABORT_NOT_OK(client->DoPut(descr, reader->schema(), &write_stream)); + std::unique_ptr write_stream; + std::unique_ptr metadata_reader; + ABORT_NOT_OK(client->DoPut(descr, reader->schema(), &write_stream, &metadata_reader)); std::unique_ptr table_reader( new arrow::TableBatchReader(*original_data)); - ABORT_NOT_OK(CopyReaderToWriter(table_reader, *write_stream)); + ABORT_NOT_OK(UploadReaderToFlight(table_reader.get(), *write_stream, *metadata_reader)); // 2. Get the ticket for the data. std::unique_ptr info; diff --git a/cpp/src/arrow/flight/test-integration-server.cc b/cpp/src/arrow/flight/test-integration-server.cc index c5bb180663b..fe6b53dbbcd 100644 --- a/cpp/src/arrow/flight/test-integration-server.cc +++ b/cpp/src/arrow/flight/test-integration-server.cc @@ -79,14 +79,16 @@ class FlightIntegrationTestServer : public FlightServerBase { } auto flight = data->second; - *data_stream = std::unique_ptr(new RecordBatchStream( - std::shared_ptr(new TableBatchReader(*flight)))); + *data_stream = std::unique_ptr( + new NumberingStream(std::unique_ptr(new RecordBatchStream( + std::shared_ptr(new TableBatchReader(*flight)))))); return Status::OK(); } Status DoPut(const ServerCallContext& context, - std::unique_ptr reader) override { + std::unique_ptr reader, + std::unique_ptr writer) override { const FlightDescriptor& descriptor = reader->descriptor(); if (descriptor.type != FlightDescriptor::DescriptorType::PATH) { @@ -98,11 +100,14 @@ class FlightIntegrationTestServer : public FlightServerBase { std::string key = descriptor.path[0]; std::vector> retrieved_chunks; - std::shared_ptr chunk; + arrow::flight::FlightStreamChunk chunk; while (true) { - RETURN_NOT_OK(reader->ReadNext(&chunk)); - if (chunk == nullptr) break; - retrieved_chunks.push_back(chunk); + RETURN_NOT_OK(reader->Next(&chunk)); + if (chunk.data == nullptr) break; + retrieved_chunks.push_back(chunk.data); + if (chunk.app_metadata) { + RETURN_NOT_OK(writer->WriteMetadata(*chunk.app_metadata)); + } } std::shared_ptr retrieved_data; RETURN_NOT_OK(arrow::Table::FromRecordBatches(reader->schema(), retrieved_chunks, diff --git a/cpp/src/arrow/flight/test-util.cc b/cpp/src/arrow/flight/test-util.cc index 7dd78fdd6eb..4408801a97e 100644 --- a/cpp/src/arrow/flight/test-util.cc +++ b/cpp/src/arrow/flight/test-util.cc @@ -260,6 +260,24 @@ Status MakeFlightInfo(const Schema& schema, const FlightDescriptor& descriptor, return internal::SchemaToString(schema, &out->schema); } +NumberingStream::NumberingStream(std::unique_ptr stream) + : counter_(0), stream_(std::move(stream)) {} + +std::shared_ptr NumberingStream::schema() { return stream_->schema(); } + +Status NumberingStream::GetSchemaPayload(FlightPayload* payload) { + return stream_->GetSchemaPayload(payload); +} + +Status NumberingStream::Next(FlightPayload* payload) { + RETURN_NOT_OK(stream_->Next(payload)); + if (payload && payload->ipc_message.type == ipc::Message::RECORD_BATCH) { + payload->app_metadata = Buffer::FromString(std::to_string(counter_)); + counter_++; + } + return Status::OK(); +} + std::shared_ptr ExampleIntSchema() { auto f0 = field("f0", int32()); auto f1 = field("f1", int32()); diff --git a/cpp/src/arrow/flight/test-util.h b/cpp/src/arrow/flight/test-util.h index 5b02630b432..7fb0b605e7a 100644 --- a/cpp/src/arrow/flight/test-util.h +++ b/cpp/src/arrow/flight/test-util.h @@ -25,6 +25,7 @@ #include "arrow/status.h" #include "arrow/flight/client_auth.h" +#include "arrow/flight/server.h" #include "arrow/flight/server_auth.h" #include "arrow/flight/types.h" #include "arrow/flight/visibility.h" @@ -127,6 +128,23 @@ class ARROW_FLIGHT_EXPORT BatchIterator : public RecordBatchReader { #pragma warning(pop) #endif +// ---------------------------------------------------------------------- +// A FlightDataStream that numbers the record batches +/// \brief A basic implementation of FlightDataStream that will provide +/// a sequence of FlightData messages to be written to a gRPC stream +class ARROW_FLIGHT_EXPORT NumberingStream : public FlightDataStream { + public: + explicit NumberingStream(std::unique_ptr stream); + + std::shared_ptr schema() override; + Status GetSchemaPayload(FlightPayload* payload) override; + Status Next(FlightPayload* payload) override; + + private: + int counter_; + std::shared_ptr stream_; +}; + // ---------------------------------------------------------------------- // Example data for test-server and unit tests diff --git a/cpp/src/arrow/flight/types.cc b/cpp/src/arrow/flight/types.cc index d982efce5ce..c82e6813648 100644 --- a/cpp/src/arrow/flight/types.cc +++ b/cpp/src/arrow/flight/types.cc @@ -25,6 +25,7 @@ #include "arrow/ipc/dictionary.h" #include "arrow/ipc/reader.h" #include "arrow/status.h" +#include "arrow/table.h" #include "arrow/util/uri.h" namespace arrow { @@ -122,6 +123,24 @@ bool Location::Equals(const Location& other) const { return ToString() == other.ToString(); } +Status MetadataRecordBatchReader::ReadAll( + std::vector>* batches) { + FlightStreamChunk chunk; + + while (true) { + RETURN_NOT_OK(Next(&chunk)); + if (!chunk.data) break; + batches->emplace_back(std::move(chunk.data)); + } + return Status::OK(); +} + +Status MetadataRecordBatchReader::ReadAll(std::shared_ptr* table) { + std::vector> batches; + RETURN_NOT_OK(ReadAll(&batches)); + return Table::FromRecordBatches(schema(), batches, table); +} + SimpleFlightListing::SimpleFlightListing(const std::vector& flights) : position_(0), flights_(flights) {} diff --git a/cpp/src/arrow/flight/types.h b/cpp/src/arrow/flight/types.h index e5f7bcdd550..abf894c88c8 100644 --- a/cpp/src/arrow/flight/types.h +++ b/cpp/src/arrow/flight/types.h @@ -32,8 +32,10 @@ namespace arrow { class Buffer; +class RecordBatch; class Schema; class Status; +class Table; namespace ipc { @@ -205,6 +207,7 @@ struct ARROW_FLIGHT_EXPORT FlightEndpoint { /// This structure corresponds to FlightData in the protocol. struct ARROW_FLIGHT_EXPORT FlightPayload { std::shared_ptr descriptor; + std::shared_ptr app_metadata; ipc::internal::IpcPayload ipc_message; }; @@ -278,6 +281,30 @@ class ARROW_FLIGHT_EXPORT ResultStream { virtual Status Next(std::unique_ptr* info) = 0; }; +/// \brief A holder for a RecordBatch with associated Flight metadata. +struct ARROW_FLIGHT_EXPORT FlightStreamChunk { + public: + std::shared_ptr data; + std::shared_ptr app_metadata; +}; + +/// \brief An interface to read Flight data with metadata. +class ARROW_FLIGHT_EXPORT MetadataRecordBatchReader { + public: + virtual ~MetadataRecordBatchReader() = default; + + /// \brief Get the schema for this stream. + virtual std::shared_ptr schema() const = 0; + /// \brief Get the next message from Flight. If the stream is + /// finished, then the members of \a FlightStreamChunk will be + /// nullptr. + virtual Status Next(FlightStreamChunk* next) = 0; + /// \brief Consume entire stream as a vector of record batches + virtual Status ReadAll(std::vector>* batches); + /// \brief Consume entire stream as a Table + virtual Status ReadAll(std::shared_ptr
* table); +}; + // \brief Create a FlightListing from a vector of FlightInfo objects. This can // be iterated once, then it is consumed class ARROW_FLIGHT_EXPORT SimpleFlightListing : public FlightListing { diff --git a/cpp/src/arrow/ipc/reader.cc b/cpp/src/arrow/ipc/reader.cc index 717f29aaf40..992c4fc138f 100644 --- a/cpp/src/arrow/ipc/reader.cc +++ b/cpp/src/arrow/ipc/reader.cc @@ -455,7 +455,7 @@ class RecordBatchStreamReader::RecordBatchStreamReaderImpl { RETURN_NOT_OK( ReadMessageAndValidate(message_reader_.get(), /*allow_null=*/false, &message)); - CHECK_MESSAGE_TYPE(message->type(), Message::SCHEMA); + CHECK_MESSAGE_TYPE(Message::SCHEMA, message->type()); CHECK_HAS_NO_BODY(*message); if (message->header() == nullptr) { return Status::IOError("Header-pointer of flatbuffer-encoded Message is null."); diff --git a/cpp/src/arrow/python/flight.cc b/cpp/src/arrow/python/flight.cc index ee19fb9d1e8..da5026a7f9c 100644 --- a/cpp/src/arrow/python/flight.cc +++ b/cpp/src/arrow/python/flight.cc @@ -108,10 +108,12 @@ Status PyFlightServer::DoGet(const arrow::flight::ServerCallContext& context, }); } -Status PyFlightServer::DoPut(const arrow::flight::ServerCallContext& context, - std::unique_ptr reader) { +Status PyFlightServer::DoPut( + const arrow::flight::ServerCallContext& context, + std::unique_ptr reader, + std::unique_ptr writer) { return SafeCallIntoPython([&] { - vtable_.do_put(server_.obj(), context, std::move(reader)); + vtable_.do_put(server_.obj(), context, std::move(reader), std::move(writer)); return CheckPyError(); }); } diff --git a/cpp/src/arrow/python/flight.h b/cpp/src/arrow/python/flight.h index aecb97a10f5..5aea7e85259 100644 --- a/cpp/src/arrow/python/flight.h +++ b/cpp/src/arrow/python/flight.h @@ -50,7 +50,8 @@ class ARROW_PYTHON_EXPORT PyFlightServerVtable { std::unique_ptr*)> do_get; std::function)> + std::unique_ptr, + std::unique_ptr)> do_put; std::function* stream) override; Status DoPut(const arrow::flight::ServerCallContext& context, - std::unique_ptr reader) override; + std::unique_ptr reader, + std::unique_ptr writer) override; Status DoAction(const arrow::flight::ServerCallContext& context, const arrow::flight::Action& action, std::unique_ptr* result) override; diff --git a/docs/source/conf.py b/docs/source/conf.py index d525fa94313..e605125a852 100644 --- a/docs/source/conf.py +++ b/docs/source/conf.py @@ -416,7 +416,16 @@ from unittest import mock pyarrow.cuda = sys.modules['pyarrow.cuda'] = mock.Mock() +try: + import pyarrow.flight + flight_enabled = True +except ImportError: + flight_enabled = False + pyarrow.flight = sys.modules['pyarrow.flight'] = mock.Mock() + + def setup(app): # Use a config value to indicate whether CUDA API docs can be generated. # This will also rebuild appropriately when the value changes. app.add_config_value('cuda_enabled', cuda_enabled, 'env') + app.add_config_value('flight_enabled', flight_enabled, 'env') diff --git a/docs/source/cpp/api.rst b/docs/source/cpp/api.rst index 522609e85aa..1c113b7de68 100644 --- a/docs/source/cpp/api.rst +++ b/docs/source/cpp/api.rst @@ -30,3 +30,4 @@ API Reference api/table api/utilities api/cuda + api/flight diff --git a/docs/source/cpp/api/flight.rst b/docs/source/cpp/api/flight.rst new file mode 100644 index 00000000000..4e56a7690ac --- /dev/null +++ b/docs/source/cpp/api/flight.rst @@ -0,0 +1,126 @@ +.. Licensed to the Apache Software Foundation (ASF) under one +.. or more contributor license agreements. See the NOTICE file +.. distributed with this work for additional information +.. regarding copyright ownership. The ASF licenses this file +.. to you under the Apache License, Version 2.0 (the +.. "License"); you may not use this file except in compliance +.. with the License. You may obtain a copy of the License at + +.. http://www.apache.org/licenses/LICENSE-2.0 + +.. Unless required by applicable law or agreed to in writing, +.. software distributed under the License is distributed on an +.. "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +.. KIND, either express or implied. See the License for the +.. specific language governing permissions and limitations +.. under the License. + +================ +Arrow Flight RPC +================ + +.. warning:: Flight is currently unstable. APIs are subject to change, + though we don't expect drastic changes. + +.. warning:: Flight is currently only available when built from source + appropriately. + +Common Types +============ + +.. doxygenstruct:: arrow::flight::Action + :project: arrow_cpp + :members: + +.. doxygenstruct:: arrow::flight::ActionType + :project: arrow_cpp + :members: + +.. doxygenstruct:: arrow::flight::Criteria + :project: arrow_cpp + :members: + +.. doxygenstruct:: arrow::flight::FlightDescriptor + :project: arrow_cpp + :members: + +.. doxygenstruct:: arrow::flight::FlightEndpoint + :project: arrow_cpp + :members: + +.. doxygenclass:: arrow::flight::FlightInfo + :project: arrow_cpp + :members: + +.. doxygenstruct:: arrow::flight::FlightPayload + :project: arrow_cpp + :members: + +.. doxygenclass:: arrow::flight::FlightListing + :project: arrow_cpp + :members: + +.. doxygenstruct:: arrow::flight::Location + :project: arrow_cpp + :members: + +.. doxygenstruct:: arrow::flight::PutResult + :project: arrow_cpp + :members: + +.. doxygenstruct:: arrow::flight::Result + :project: arrow_cpp + :members: + +.. doxygenclass:: arrow::flight::ResultStream + :project: arrow_cpp + :members: + +.. doxygenstruct:: arrow::flight::Ticket + :project: arrow_cpp + :members: + +Clients +======= + +.. doxygenclass:: arrow::flight::FlightClient + :project: arrow_cpp + :members: + +.. doxygenclass:: arrow::flight::FlightCallOptions + :project: arrow_cpp + :members: + +.. doxygenclass:: arrow::flight::ClientAuthHandler + :project: arrow_cpp + :members: + +.. doxygentypedef:: arrow::flight::TimeoutDuration + :project: arrow_cpp + +Servers +======= + +.. doxygenclass:: arrow::flight::FlightServerBase + :project: arrow_cpp + :members: + +.. doxygenclass:: arrow::flight::FlightDataStream + :project: arrow_cpp + :members: + +.. doxygenclass:: arrow::flight::FlightMessageReader + :project: arrow_cpp + :members: + +.. doxygenclass:: arrow::flight::RecordBatchStream + :project: arrow_cpp + :members: + +.. doxygenclass:: arrow::flight::ServerAuthHandler + :project: arrow_cpp + :members: + +.. doxygenclass:: arrow::flight::ServerCallContext + :project: arrow_cpp + :members: diff --git a/docs/source/format/Flight.rst b/docs/source/format/Flight.rst new file mode 100644 index 00000000000..b3476eadf33 --- /dev/null +++ b/docs/source/format/Flight.rst @@ -0,0 +1,106 @@ +.. Licensed to the Apache Software Foundation (ASF) under one +.. or more contributor license agreements. See the NOTICE file +.. distributed with this work for additional information +.. regarding copyright ownership. The ASF licenses this file +.. to you under the Apache License, Version 2.0 (the +.. "License"); you may not use this file except in compliance +.. with the License. You may obtain a copy of the License at + +.. http://www.apache.org/licenses/LICENSE-2.0 + +.. Unless required by applicable law or agreed to in writing, +.. software distributed under the License is distributed on an +.. "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +.. KIND, either express or implied. See the License for the +.. specific language governing permissions and limitations +.. under the License. + +Arrow Flight RPC +================ + +Arrow Flight is a RPC framework for high-performance data services +based on Arrow data, and is built on top of gRPC_ and the :doc:`IPC +format `. + +Flight is organized around streams of Arrow record batches, being +either downloaded from or uploaded to another service. A set of +metadata methods offers discovery and introspection of streams, as +well as the ability to implement application-specific methods. + +Methods and message wire formats are defined by Protobuf, enabling +interoperability with clients that may support gRPC and Arrow +separately, but not Flight. However, Flight implementations include +further optimizations to avoid overhead in usage of Protobuf (mostly +around avoiding excessive memory copies). + +.. _gRPC: https://grpc.io/ + +RPC Methods +----------- + +Flight defines a set of RPC methods for uploading/downloading data, +retrieving metadata about a data stream, listing available data +streams, and for implementing application-specific RPC methods. A +Flight service implements some subset of these methods, while a Flight +client can call any of these methods. Thus, one Flight client can +connect to any Flight service and perform basic operations. + +Data streams are identified by descriptors, which are either a path or +an arbitrary binary command. A client that wishes to download the data +would: + +#. Construct or acquire a ``FlightDescriptor`` for the data set they + are interested in. A client may know what descriptor they want + already, or they may use methods like ``ListFlights`` to discover + them. +#. Call ``GetFlightInfo(FlightDescriptor)`` to get a ``FlightInfo`` + message containing details on where the data is located (as well as + other metadata, like the schema and possibly an estimate of the + dataset size). + + Flight does not require that data live on the same server as + metadata: this call may list other servers to connect to. The + ``FlightInfo`` message includes a ``Ticket``, an opaque binary + token that the server uses to identify the exact data set being + requested. +#. Connect to other servers (if needed). +#. Call ``DoGet(Ticket)`` to get back a stream of Arrow record + batches. + +To upload data, a client would: + +#. Construct or acquire a ``FlightDescriptor``, as before. +#. Call ``DoPut(FlightData)`` and upload a stream of Arrow record + batches. They would also include the ``FlightDescriptor`` with the + first message. + +See `Protocol Buffer Definitions`_ for full details on the methods and +messages involved. + +Authentication +~~~~~~~~~~~~~~ + +Flight supports application-implemented authentication +methods. Authentication, if enabled, has two phases: at connection +time, the client and server can exchange any number of messages. Then, +the client can provide a token alongside each call, and the server can +validate that token. + +Applications may use any part of this; for instance, they may ignore +the initial handshake and send an externally acquired token on each +call, or they may establish trust during the handshake and not +validate a token for each call. (Note that the latter is not secure if +you choose to deploy a layer 7 load balancer, as is common with gRPC.) + +External Resources +------------------ + +- https://arrow.apache.org/blog/2018/10/09/0.11.0-release/ +- https://www.slideshare.net/JacquesNadeau5/apache-arrow-flight-overview + +Protocol Buffer Definitions +--------------------------- + +.. literalinclude:: ../../../format/Flight.proto + :language: protobuf + :linenos: diff --git a/docs/source/index.rst b/docs/source/index.rst index 3b639b4d2ba..6fb16f25180 100644 --- a/docs/source/index.rst +++ b/docs/source/index.rst @@ -43,6 +43,7 @@ such topics as: format/Layout format/Metadata format/IPC + format/Flight .. _toc.usage: diff --git a/docs/source/python/api.rst b/docs/source/python/api.rst index b06509f7a5b..b1dccd4aad6 100644 --- a/docs/source/python/api.rst +++ b/docs/source/python/api.rst @@ -30,6 +30,7 @@ API Reference api/files api/tables api/ipc + api/flight api/formats api/plasma api/cuda diff --git a/docs/source/python/api/flight.rst b/docs/source/python/api/flight.rst new file mode 100644 index 00000000000..4fa137439ba --- /dev/null +++ b/docs/source/python/api/flight.rst @@ -0,0 +1,82 @@ +.. Licensed to the Apache Software Foundation (ASF) under one +.. or more contributor license agreements. See the NOTICE file +.. distributed with this work for additional information +.. regarding copyright ownership. The ASF licenses this file +.. to you under the Apache License, Version 2.0 (the +.. "License"); you may not use this file except in compliance +.. with the License. You may obtain a copy of the License at + +.. http://www.apache.org/licenses/LICENSE-2.0 + +.. Unless required by applicable law or agreed to in writing, +.. software distributed under the License is distributed on an +.. "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +.. KIND, either express or implied. See the License for the +.. specific language governing permissions and limitations +.. under the License. + +.. currentmodule:: pyarrow.flight + +Arrow Flight +============ + +.. ifconfig:: not flight_enabled + + .. error:: + This documentation was built without Flight enabled. The Flight + API docs are not available. + +.. NOTE We still generate those API docs (with empty docstrings) +.. when Flight is disabled and `pyarrow.flight` mocked (see conf.py). +.. Otherwise we'd get autodoc warnings, see https://github.com/sphinx-doc/sphinx/issues/4770 + +.. warning:: Flight is currently unstable. APIs are subject to change, + though we don't expect drastic changes. + +.. warning:: Flight is currently not distributed as part of wheels or + in Conda - it is only available when built from source + appropriately. + +Common Types +------------ + +.. autosummary:: + :toctree: ../generated/ + + Action + ActionType + DescriptorType + FlightDescriptor + FlightEndpoint + FlightInfo + Location + Ticket + Result + +Flight Client +------------- + +.. autosummary:: + :toctree: ../generated/ + + FlightCallOptions + FlightClient + +Flight Server +------------- + +.. autosummary:: + :toctree: ../generated/ + + FlightServerBase + GeneratorStream + RecordBatchStream + +Authentication +-------------- + +.. autosummary:: + :toctree: ../generated/ + + ClientAuthHandler + ServerAuthHandler diff --git a/format/Flight.proto b/format/Flight.proto index 7f0488b86c3..0c8f28e5315 100644 --- a/format/Flight.proto +++ b/format/Flight.proto @@ -77,7 +77,7 @@ service FlightService { * number. In the latter, the service might implement a 'seal' action that * can be applied to a descriptor once all streams are uploaded. */ - rpc DoPut(stream FlightData) returns (PutResult) {} + rpc DoPut(stream FlightData) returns (stream PutResult) {} /* * Flight services can support an arbitrary number of simple actions in @@ -285,6 +285,11 @@ message FlightData { */ bytes data_header = 2; + /* + * Application-defined metadata. + */ + bytes app_metadata = 3; + /* * The actual batch of Arrow data. Preferably handled with minimal-copies * coming last in the definition to help with sidecar patterns (it is @@ -295,7 +300,8 @@ message FlightData { } /** - * The response message (currently empty) associated with the submission of a - * DoPut. + * The response message associated with the submission of a DoPut. */ -message PutResult {} +message PutResult { + bytes app_metadata = 1; +} diff --git a/integration/integration_test.py b/integration/integration_test.py index a4763c98c8c..aca05747c72 100644 --- a/integration/integration_test.py +++ b/integration/integration_test.py @@ -1152,7 +1152,7 @@ def _temp_path(): generate_interval_case(), generate_map_case(), generate_nested_case(), - generate_dictionary_case().skip_category(SKIP_FLIGHT), + generate_dictionary_case(), generate_nested_dictionary_case().skip_category(SKIP_ARROW) .skip_category(SKIP_FLIGHT), ] diff --git a/java/flight/src/main/java/org/apache/arrow/flight/ArrowMessage.java b/java/flight/src/main/java/org/apache/arrow/flight/ArrowMessage.java index 550f5c113cf..787906950d5 100644 --- a/java/flight/src/main/java/org/apache/arrow/flight/ArrowMessage.java +++ b/java/flight/src/main/java/org/apache/arrow/flight/ArrowMessage.java @@ -35,6 +35,7 @@ import org.apache.arrow.flight.impl.Flight.FlightDescriptor; import org.apache.arrow.memory.BufferAllocator; import org.apache.arrow.util.AutoCloseables; +import org.apache.arrow.vector.ipc.message.ArrowDictionaryBatch; import org.apache.arrow.vector.ipc.message.ArrowRecordBatch; import org.apache.arrow.vector.ipc.message.MessageSerializer; import org.apache.arrow.vector.types.pojo.Schema; @@ -52,7 +53,6 @@ import io.grpc.Drainable; import io.grpc.MethodDescriptor; import io.grpc.MethodDescriptor.Marshaller; -import io.grpc.internal.ReadableBuffer; import io.grpc.protobuf.ProtoUtils; import io.netty.buffer.ArrowBuf; @@ -74,9 +74,16 @@ class ArrowMessage implements AutoCloseable { (FlightData.DATA_BODY_FIELD_NUMBER << 3) | WireFormat.WIRETYPE_LENGTH_DELIMITED; private static final int HEADER_TAG = (FlightData.DATA_HEADER_FIELD_NUMBER << 3) | WireFormat.WIRETYPE_LENGTH_DELIMITED; + private static final int APP_METADATA_TAG = + (FlightData.APP_METADATA_FIELD_NUMBER << 3) | WireFormat.WIRETYPE_LENGTH_DELIMITED; private static Marshaller NO_BODY_MARSHALLER = ProtoUtils.marshaller(FlightData.getDefaultInstance()); + /** Get the application-specific metadata in this message. The ArrowMessage retains ownership of the buffer. */ + public ArrowBuf getApplicationMetadata() { + return appMetadata; + } + /** Types of messages that can be sent. */ public enum HeaderType { NONE, @@ -114,6 +121,7 @@ public static HeaderType getHeader(byte b) { private final FlightDescriptor descriptor; private final Message message; + private final ArrowBuf appMetadata; private final List bufs; public ArrowMessage(FlightDescriptor descriptor, Schema schema) { @@ -124,9 +132,15 @@ public ArrowMessage(FlightDescriptor descriptor, Schema schema) { message = Message.getRootAsMessage(serializedMessage); bufs = ImmutableList.of(); this.descriptor = descriptor; + this.appMetadata = null; } - public ArrowMessage(ArrowRecordBatch batch) { + /** + * Create an ArrowMessage from a record batch and app metadata. + * @param batch The record batch. + * @param appMetadata The app metadata. May be null. Takes ownership of the buffer otherwise. + */ + public ArrowMessage(ArrowRecordBatch batch, ArrowBuf appMetadata) { FlatBufferBuilder builder = new FlatBufferBuilder(); int batchOffset = batch.writeTo(builder); ByteBuffer serializedMessage = MessageSerializer.serializeMessage(builder, MessageHeader.RecordBatch, batchOffset, @@ -135,11 +149,28 @@ public ArrowMessage(ArrowRecordBatch batch) { this.message = Message.getRootAsMessage(serializedMessage); this.bufs = ImmutableList.copyOf(batch.getBuffers()); this.descriptor = null; + this.appMetadata = appMetadata; } - private ArrowMessage(FlightDescriptor descriptor, Message message, ArrowBuf buf) { + public ArrowMessage(ArrowDictionaryBatch batch) { + FlatBufferBuilder builder = new FlatBufferBuilder(); + int batchOffset = batch.writeTo(builder); + ByteBuffer serializedMessage = MessageSerializer + .serializeMessage(builder, MessageHeader.DictionaryBatch, batchOffset, + batch.computeBodyLength()); + serializedMessage = serializedMessage.slice(); + this.message = Message.getRootAsMessage(serializedMessage); + // asInputStream will free the buffers implicitly, so increment the reference count + batch.getDictionary().getBuffers().forEach(buf -> buf.getReferenceManager().retain()); + this.bufs = ImmutableList.copyOf(batch.getDictionary().getBuffers()); + this.descriptor = null; + this.appMetadata = null; + } + + private ArrowMessage(FlightDescriptor descriptor, Message message, ArrowBuf appMetadata, ArrowBuf buf) { this.message = message; this.descriptor = descriptor; + this.appMetadata = appMetadata; this.bufs = buf == null ? ImmutableList.of() : ImmutableList.of(buf); } @@ -169,10 +200,18 @@ public ArrowRecordBatch asRecordBatch() throws IOException { RecordBatch recordBatch = new RecordBatch(); message.header(recordBatch); ArrowBuf underlying = bufs.get(0); + underlying.getReferenceManager().retain(); ArrowRecordBatch batch = MessageSerializer.deserializeRecordBatch(recordBatch, underlying); return batch; } + public ArrowDictionaryBatch asDictionaryBatch() throws IOException { + Preconditions.checkArgument(bufs.size() == 1, "A batch can only be consumed if it contains a single ArrowBuf."); + Preconditions.checkArgument(getMessageType() == HeaderType.DICTIONARY_BATCH); + ArrowBuf underlying = bufs.get(0); + return MessageSerializer.deserializeDictionaryBatch(message, underlying); + } + public Iterable getBufs() { return Iterables.unmodifiableIterable(bufs); } @@ -183,6 +222,7 @@ private static ArrowMessage frame(BufferAllocator allocator, final InputStream s FlightDescriptor descriptor = null; Message header = null; ArrowBuf body = null; + ArrowBuf appMetadata = null; while (stream.available() > 0) { int tag = readRawVarint32(stream); switch (tag) { @@ -201,6 +241,12 @@ private static ArrowMessage frame(BufferAllocator allocator, final InputStream s header = Message.getRootAsMessage(ByteBuffer.wrap(bytes)); break; } + case APP_METADATA_TAG: { + int size = readRawVarint32(stream); + appMetadata = allocator.buffer(size); + GetReadableBuffer.readIntoBuffer(stream, appMetadata, size, FAST_PATH); + break; + } case BODY_TAG: if (body != null) { // only read last body. @@ -209,15 +255,7 @@ private static ArrowMessage frame(BufferAllocator allocator, final InputStream s } int size = readRawVarint32(stream); body = allocator.buffer(size); - ReadableBuffer readableBuffer = FAST_PATH ? GetReadableBuffer.getReadableBuffer(stream) : null; - if (readableBuffer != null) { - readableBuffer.readBytes(body.nioBuffer(0, size)); - } else { - byte[] heapBytes = new byte[size]; - ByteStreams.readFully(stream, heapBytes); - body.writeBytes(heapBytes); - } - body.writerIndex(size); + GetReadableBuffer.readIntoBuffer(stream, body, size, FAST_PATH); break; default: @@ -225,7 +263,7 @@ private static ArrowMessage frame(BufferAllocator allocator, final InputStream s } } - return new ArrowMessage(descriptor, header, body); + return new ArrowMessage(descriptor, header, appMetadata, body); } catch (Exception ioe) { throw new RuntimeException(ioe); } @@ -246,7 +284,6 @@ private InputStream asInputStream(BufferAllocator allocator) { final ByteString bytes = ByteString.copyFrom(message.getByteBuffer(), message.getByteBuffer().remaining()); - if (getMessageType() == HeaderType.SCHEMA) { final FlightData.Builder builder = FlightData.newBuilder() @@ -260,15 +297,23 @@ private InputStream asInputStream(BufferAllocator allocator) { return NO_BODY_MARSHALLER.stream(builder.build()); } - Preconditions.checkArgument(getMessageType() == HeaderType.RECORD_BATCH); + Preconditions.checkArgument(getMessageType() == HeaderType.RECORD_BATCH || + getMessageType() == HeaderType.DICTIONARY_BATCH); Preconditions.checkArgument(!bufs.isEmpty()); Preconditions.checkArgument(descriptor == null, "Descriptor should only be included in the schema message."); ByteArrayOutputStream baos = new ByteArrayOutputStream(); CodedOutputStream cos = CodedOutputStream.newInstance(baos); cos.writeBytes(FlightData.DATA_HEADER_FIELD_NUMBER, bytes); - cos.writeTag(FlightData.DATA_BODY_FIELD_NUMBER, WireFormat.WIRETYPE_LENGTH_DELIMITED); + if (appMetadata != null && appMetadata.capacity() > 0) { + // Must call slice() as CodedOutputStream#writeByteBuffer writes -capacity- bytes, not -limit- bytes + cos.writeByteBuffer(FlightData.APP_METADATA_FIELD_NUMBER, appMetadata.asNettyBuffer().nioBuffer().slice()); + // This is weird, but implicitly, writing an ArrowMessage frees any references it has + appMetadata.getReferenceManager().release(); + } + + cos.writeTag(FlightData.DATA_BODY_FIELD_NUMBER, WireFormat.WIRETYPE_LENGTH_DELIMITED); int size = 0; List allBufs = new ArrayList<>(); for (ArrowBuf b : bufs) { @@ -290,6 +335,7 @@ private InputStream asInputStream(BufferAllocator allocator) { initialBuf.writeBytes(baos.toByteArray()); final CompositeByteBuf bb = new CompositeByteBuf(allocator.getAsByteBufAllocator(), true, bufs.size() + 1, ImmutableList.builder().add(initialBuf.asNettyBuffer()).addAll(allBufs).build()); + // Implicitly, transfer ownership of our buffers to the input stream (which will decrement the refcount when done) final ByteBufInputStream is = new DrainableByteBufInputStream(bb); return is; } catch (Exception ex) { @@ -319,7 +365,7 @@ public int drainTo(OutputStream target) throws IOException { } @Override - public void close() throws IOException { + public void close() { buf.release(); } @@ -354,5 +400,8 @@ public ArrowMessage parse(InputStream stream) { @Override public void close() throws Exception { AutoCloseables.close(bufs); + if (appMetadata != null) { + appMetadata.close(); + } } } diff --git a/java/flight/src/main/java/org/apache/arrow/flight/AsyncPutListener.java b/java/flight/src/main/java/org/apache/arrow/flight/AsyncPutListener.java new file mode 100644 index 00000000000..c8214e31953 --- /dev/null +++ b/java/flight/src/main/java/org/apache/arrow/flight/AsyncPutListener.java @@ -0,0 +1,63 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.arrow.flight; + +import java.util.concurrent.CompletableFuture; +import java.util.concurrent.ExecutionException; + +/** + * A handler for server-sent application metadata messages during a Flight DoPut operation. + * + *

To handle messages, create an instance of this class overriding {@link #onNext(PutResult)}. The other methods + * should not be overridden. + */ +public class AsyncPutListener implements FlightClient.PutListener { + + private CompletableFuture completed; + + public AsyncPutListener() { + completed = new CompletableFuture<>(); + } + + /** + * Wait for the stream to finish on the server side. You must call this to be notified of any errors that may have + * happened during the upload. + */ + @Override + public final void getResult() { + try { + completed.get(); + } catch (InterruptedException | ExecutionException e) { + throw new RuntimeException(e); + } + } + + @Override + public void onNext(PutResult val) { + } + + @Override + public final void onError(Throwable t) { + completed.completeExceptionally(t); + } + + @Override + public final void onCompleted() { + completed.complete(null); + } +} diff --git a/java/flight/src/main/java/org/apache/arrow/flight/DictionaryUtils.java b/java/flight/src/main/java/org/apache/arrow/flight/DictionaryUtils.java new file mode 100644 index 00000000000..6409b6ac63f --- /dev/null +++ b/java/flight/src/main/java/org/apache/arrow/flight/DictionaryUtils.java @@ -0,0 +1,77 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.arrow.flight; + +import java.util.ArrayList; +import java.util.Collections; +import java.util.HashSet; +import java.util.List; +import java.util.Set; +import java.util.function.Consumer; + +import org.apache.arrow.vector.FieldVector; +import org.apache.arrow.vector.VectorSchemaRoot; +import org.apache.arrow.vector.VectorUnloader; +import org.apache.arrow.vector.dictionary.Dictionary; +import org.apache.arrow.vector.dictionary.DictionaryProvider; +import org.apache.arrow.vector.ipc.message.ArrowDictionaryBatch; +import org.apache.arrow.vector.types.pojo.Field; +import org.apache.arrow.vector.types.pojo.Schema; +import org.apache.arrow.vector.util.DictionaryUtility; + +/** + * Utilities to work with dictionaries in Flight. + */ +final class DictionaryUtils { + + private DictionaryUtils() { + throw new UnsupportedOperationException("Do not instantiate this class."); + } + + /** + * Generate all the necessary Flight messages to send a schema and associated dictionaries. + */ + static Schema generateSchemaMessages(final Schema originalSchema, final FlightDescriptor descriptor, + final DictionaryProvider provider, final Consumer messageCallback) { + final List fields = new ArrayList<>(originalSchema.getFields().size()); + final Set dictionaryIds = new HashSet<>(); + for (final Field field : originalSchema.getFields()) { + fields.add(DictionaryUtility.toMessageFormat(field, provider, dictionaryIds)); + } + final Schema schema = new Schema(fields, originalSchema.getCustomMetadata()); + // Send the schema message + messageCallback.accept(new ArrowMessage(descriptor == null ? null : descriptor.toProtocol(), schema)); + // Create and write dictionary batches + for (Long id : dictionaryIds) { + final Dictionary dictionary = provider.lookup(id); + final FieldVector vector = dictionary.getVector(); + final int count = vector.getValueCount(); + // Do NOT close this root, as it does not actually own the vector. + final VectorSchemaRoot dictRoot = new VectorSchemaRoot( + Collections.singletonList(vector.getField()), + Collections.singletonList(vector), + count); + final VectorUnloader unloader = new VectorUnloader(dictRoot); + try (final ArrowDictionaryBatch dictionaryBatch = new ArrowDictionaryBatch( + id, unloader.getRecordBatch())) { + messageCallback.accept(new ArrowMessage(dictionaryBatch)); + } + } + return schema; + } +} diff --git a/java/flight/src/main/java/org/apache/arrow/flight/FlightBindingService.java b/java/flight/src/main/java/org/apache/arrow/flight/FlightBindingService.java index d352b2bbc6d..13a28f9fd9f 100644 --- a/java/flight/src/main/java/org/apache/arrow/flight/FlightBindingService.java +++ b/java/flight/src/main/java/org/apache/arrow/flight/FlightBindingService.java @@ -28,12 +28,11 @@ import io.grpc.BindableService; import io.grpc.MethodDescriptor; +import io.grpc.MethodDescriptor.MethodType; import io.grpc.ServerMethodDefinition; import io.grpc.ServerServiceDefinition; import io.grpc.protobuf.ProtoUtils; import io.grpc.stub.ServerCalls; -import io.grpc.stub.ServerCalls.ClientStreamingMethod; -import io.grpc.stub.ServerCalls.ServerStreamingMethod; import io.grpc.stub.StreamObserver; /** @@ -66,7 +65,7 @@ public static MethodDescriptor getDoGetDescriptor(B public static MethodDescriptor getDoPutDescriptor(BufferAllocator allocator) { return MethodDescriptor.newBuilder() - .setType(io.grpc.MethodDescriptor.MethodType.CLIENT_STREAMING) + .setType(MethodType.BIDI_STREAMING) .setFullMethodName(DO_PUT) .setSampledToLocalTracing(false) .setRequestMarshaller(ArrowMessage.createMarshaller(allocator)) @@ -84,7 +83,7 @@ public ServerServiceDefinition bindService() { ServerServiceDefinition.Builder serviceBuilder = ServerServiceDefinition.builder(FlightConstants.SERVICE); serviceBuilder.addMethod(doGetDescriptor, ServerCalls.asyncServerStreamingCall(new DoGetMethod(delegate))); - serviceBuilder.addMethod(doPutDescriptor, ServerCalls.asyncClientStreamingCall(new DoPutMethod(delegate))); + serviceBuilder.addMethod(doPutDescriptor, ServerCalls.asyncBidiStreamingCall(new DoPutMethod(delegate))); // copy over not-overridden methods. for (ServerMethodDefinition definition : baseDefinition.getMethods()) { @@ -98,7 +97,7 @@ public ServerServiceDefinition bindService() { return serviceBuilder.build(); } - private class DoGetMethod implements ServerStreamingMethod { + private class DoGetMethod implements ServerCalls.ServerStreamingMethod { private final FlightService delegate; @@ -112,7 +111,7 @@ public void invoke(Flight.Ticket request, StreamObserver responseO } } - private class DoPutMethod implements ClientStreamingMethod { + private class DoPutMethod implements ServerCalls.BidiStreamingMethod { private final FlightService delegate; public DoPutMethod(FlightService delegate) { diff --git a/java/flight/src/main/java/org/apache/arrow/flight/FlightClient.java b/java/flight/src/main/java/org/apache/arrow/flight/FlightClient.java index 37e4514db7d..9ac3686fea3 100644 --- a/java/flight/src/main/java/org/apache/arrow/flight/FlightClient.java +++ b/java/flight/src/main/java/org/apache/arrow/flight/FlightClient.java @@ -17,9 +17,6 @@ package org.apache.arrow.flight; -import static io.grpc.stub.ClientCalls.asyncClientStreamingCall; -import static io.grpc.stub.ClientCalls.asyncServerStreamingCall; - import java.io.InputStream; import java.net.URISyntaxException; import java.util.Iterator; @@ -28,27 +25,26 @@ import javax.net.ssl.SSLException; +import org.apache.arrow.flight.FlightProducer.StreamListener; import org.apache.arrow.flight.auth.BasicClientAuthHandler; import org.apache.arrow.flight.auth.ClientAuthHandler; import org.apache.arrow.flight.auth.ClientAuthInterceptor; import org.apache.arrow.flight.auth.ClientAuthWrapper; import org.apache.arrow.flight.impl.Flight; import org.apache.arrow.flight.impl.Flight.Empty; -import org.apache.arrow.flight.impl.Flight.PutResult; import org.apache.arrow.flight.impl.FlightServiceGrpc; import org.apache.arrow.flight.impl.FlightServiceGrpc.FlightServiceBlockingStub; import org.apache.arrow.flight.impl.FlightServiceGrpc.FlightServiceStub; import org.apache.arrow.memory.BufferAllocator; import org.apache.arrow.vector.VectorSchemaRoot; import org.apache.arrow.vector.VectorUnloader; +import org.apache.arrow.vector.dictionary.DictionaryProvider; +import org.apache.arrow.vector.dictionary.DictionaryProvider.MapDictionaryProvider; import org.apache.arrow.vector.ipc.message.ArrowRecordBatch; import com.google.common.base.Preconditions; -import com.google.common.base.Throwables; import com.google.common.collect.ImmutableList; import com.google.common.collect.Iterators; -import com.google.common.util.concurrent.ListenableFuture; -import com.google.common.util.concurrent.SettableFuture; import io.grpc.ClientCall; import io.grpc.ManagedChannel; @@ -56,9 +52,11 @@ import io.grpc.netty.GrpcSslContexts; import io.grpc.netty.NettyChannelBuilder; import io.grpc.stub.ClientCallStreamObserver; +import io.grpc.stub.ClientCalls; import io.grpc.stub.ClientResponseObserver; import io.grpc.stub.StreamObserver; +import io.netty.buffer.ArrowBuf; import io.netty.channel.EventLoopGroup; import io.netty.channel.ServerChannel; import io.netty.handler.ssl.SslContextBuilder; @@ -78,6 +76,9 @@ public class FlightClient implements AutoCloseable { private final MethodDescriptor doGetDescriptor; private final MethodDescriptor doPutDescriptor; + /** + * Create a Flight client from an allocator and a gRPC channel. + */ private FlightClient(BufferAllocator incomingAllocator, ManagedChannel channel) { this.allocator = incomingAllocator.newChildAllocator("flight-client", 0, Long.MAX_VALUE); this.channel = channel; @@ -156,27 +157,42 @@ public void authenticate(ClientAuthHandler handler, CallOption... options) { /** * Create or append a descriptor with another stream. - * @param descriptor FlightDescriptor - * @param root VectorSchemaRoot + * + * @param descriptor FlightDescriptor the descriptor for the data + * @param root VectorSchemaRoot the root containing data + * @param metadataListener A handler for metadata messages from the server. This will be passed buffers that will be + * freed after {@link StreamListener#onNext(Object)} is called! * @param options RPC-layer hints for this call. - * @return ClientStreamListener + * @return ClientStreamListener an interface to control uploading data */ - public ClientStreamListener startPut( - FlightDescriptor descriptor, VectorSchemaRoot root, CallOption... options) { + public ClientStreamListener startPut(FlightDescriptor descriptor, VectorSchemaRoot root, + PutListener metadataListener, CallOption... options) { + return startPut(descriptor, root, new MapDictionaryProvider(), metadataListener, options); + } + + /** + * Create or append a descriptor with another stream. + * @param descriptor FlightDescriptor the descriptor for the data + * @param root VectorSchemaRoot the root containing data + * @param metadataListener A handler for metadata messages from the server. + * @param options RPC-layer hints for this call. + * @return ClientStreamListener an interface to control uploading data + */ + public ClientStreamListener startPut(FlightDescriptor descriptor, VectorSchemaRoot root, DictionaryProvider provider, + PutListener metadataListener, CallOption... options) { Preconditions.checkNotNull(descriptor); Preconditions.checkNotNull(root); - SetStreamObserver resultObserver = new SetStreamObserver<>(); + SetStreamObserver resultObserver = new SetStreamObserver(allocator, metadataListener); final io.grpc.CallOptions callOptions = CallOptions.wrapStub(asyncStub, options).getCallOptions(); ClientCallStreamObserver observer = (ClientCallStreamObserver) - asyncClientStreamingCall( - authInterceptor.interceptCall(doPutDescriptor, callOptions, channel), resultObserver); + ClientCalls.asyncBidiStreamingCall( + authInterceptor.interceptCall(doPutDescriptor, callOptions, channel), resultObserver); // send the schema to start. - ArrowMessage message = new ArrowMessage(descriptor.toProtocol(), root.getSchema()); - observer.onNext(message); + DictionaryUtils.generateSchemaMessages(root.getSchema(), descriptor, provider, observer::onNext); return new PutObserver(new VectorUnloader( root, true /* include # of nulls in vectors */, true /* must align buffers to be C++-compatible */), - observer, resultObserver.getFuture()); + observer, metadataListener); } /** @@ -202,7 +218,7 @@ public FlightInfo getInfo(FlightDescriptor descriptor, CallOption... options) { public FlightStream getStream(Ticket ticket, CallOption... options) { final io.grpc.CallOptions callOptions = CallOptions.wrapStub(asyncStub, options).getCallOptions(); ClientCall call = - authInterceptor.interceptCall(doGetDescriptor, callOptions, channel); + authInterceptor.interceptCall(doGetDescriptor, callOptions, channel); FlightStream stream = new FlightStream( allocator, PENDING_REQUESTS, @@ -235,54 +251,64 @@ public void onCompleted() { }; - asyncServerStreamingCall(call, ticket.toProtocol(), clientResponseObserver); + ClientCalls.asyncServerStreamingCall(call, ticket.toProtocol(), clientResponseObserver); return stream; } - private static class SetStreamObserver implements StreamObserver { - private final SettableFuture result = SettableFuture.create(); - private volatile T resultLocal; + private static class SetStreamObserver implements StreamObserver { + private final BufferAllocator allocator; + private final StreamListener listener; + + SetStreamObserver(BufferAllocator allocator, StreamListener listener) { + super(); + this.allocator = allocator; + this.listener = listener == null ? NoOpStreamListener.getInstance() : listener; + } @Override - public void onNext(T value) { - resultLocal = value; + public void onNext(Flight.PutResult value) { + try (final PutResult message = PutResult.fromProtocol(allocator, value)) { + listener.onNext(message); + } } @Override public void onError(Throwable t) { - result.setException(t); + listener.onError(t); } @Override public void onCompleted() { - result.set(Preconditions.checkNotNull(resultLocal)); - } - - public ListenableFuture getFuture() { - return result; + listener.onCompleted(); } } private static class PutObserver implements ClientStreamListener { + private final ClientCallStreamObserver observer; private final VectorUnloader unloader; - private final ListenableFuture futureResult; + private final PutListener listener; public PutObserver(VectorUnloader unloader, ClientCallStreamObserver observer, - ListenableFuture futureResult) { + PutListener listener) { this.observer = observer; this.unloader = unloader; - this.futureResult = futureResult; + this.listener = listener; } @Override public void putNext() { + putNext(null); + } + + @Override + public void putNext(ArrowBuf appMetadata) { ArrowRecordBatch batch = unloader.getRecordBatch(); - // Check the futureResult in case server sent an exception - while (!observer.isReady() && !futureResult.isDone()) { + while (!observer.isReady()) { /* busy wait */ } - observer.onNext(new ArrowMessage(batch)); + // Takes ownership of appMetadata + observer.onNext(new ArrowMessage(batch, appMetadata)); } @Override @@ -296,12 +322,8 @@ public void completed() { } @Override - public PutResult getResult() { - try { - return futureResult.get(); - } catch (Exception ex) { - throw Throwables.propagate(ex); - } + public void getResult() { + listener.getResult(); } } @@ -310,17 +332,59 @@ public PutResult getResult() { */ public interface ClientStreamListener { + /** + * Send the current data in the corresponding {@link VectorSchemaRoot} to the server. + */ void putNext(); + /** + * Send the current data in the corresponding {@link VectorSchemaRoot} to the server, along with + * application-specific metadata. This takes ownership of the buffer. + */ + void putNext(ArrowBuf appMetadata); + + /** + * Indicate an error to the server. Terminates the stream; do not call {@link #completed()}. + */ void error(Throwable ex); + /** Indicate the stream is finished on the client side. */ void completed(); - PutResult getResult(); - + /** + * Wait for the stream to finish on the server side. You must call this to be notified of any errors that may have + * happened during the upload. + */ + void getResult(); } + /** + * A handler for server-sent application metadata messages during a Flight DoPut operation. + * + *

Generally, instead of implementing this yourself, you should use {@link AsyncPutListener} or {@link + * SyncPutListener}. + */ + public interface PutListener extends StreamListener { + /** + * Wait for the stream to finish on the server side. You must call this to be notified of any errors that may have + * happened during the upload. + */ + void getResult(); + + /** + * Called when a message from the server is received. + * + * @param val The application metadata. This buffer will be reclaimed once onNext returns; you must retain a + * reference to use it outside this method. + */ + @Override + void onNext(PutResult val); + } + + /** + * Shut down this client. + */ public void close() throws InterruptedException { channel.shutdown().awaitTermination(5, TimeUnit.SECONDS); allocator.close(); diff --git a/java/flight/src/main/java/org/apache/arrow/flight/FlightProducer.java b/java/flight/src/main/java/org/apache/arrow/flight/FlightProducer.java index 9c3cc2b59c7..fdb5e9f586c 100644 --- a/java/flight/src/main/java/org/apache/arrow/flight/FlightProducer.java +++ b/java/flight/src/main/java/org/apache/arrow/flight/FlightProducer.java @@ -17,49 +17,118 @@ package org.apache.arrow.flight; -import java.util.concurrent.Callable; - -import org.apache.arrow.flight.impl.Flight.PutResult; import org.apache.arrow.vector.VectorSchemaRoot; +import org.apache.arrow.vector.dictionary.DictionaryProvider; + +import io.netty.buffer.ArrowBuf; /** * API to Implement an Arrow Flight producer. */ public interface FlightProducer { - void getStream(CallContext context, Ticket ticket, - ServerStreamListener listener); + /** + * Return data for a stream. + * + * @param context Per-call context. + * @param ticket The application-defined ticket identifying this stream. + * @param listener An interface for sending data back to the client. + */ + void getStream(CallContext context, Ticket ticket, ServerStreamListener listener); + /** + * List available data streams on this service. + * + * @param context Per-call context. + * @param criteria Application-defined criteria for filtering streams. + * @param listener An interface for sending data back to the client. + */ void listFlights(CallContext context, Criteria criteria, StreamListener listener); - FlightInfo getFlightInfo(CallContext context, - FlightDescriptor descriptor); + /** + * Get information about a particular data stream. + * + * @param context Per-call context. + * @param descriptor The descriptor identifying the data stream. + * @return Metadata about the stream. + */ + FlightInfo getFlightInfo(CallContext context, FlightDescriptor descriptor); - Callable acceptPut(CallContext context, - FlightStream flightStream); + /** + * Accept uploaded data for a particular stream. + * + * @param context Per-call context. + * @param flightStream The data stream being uploaded. + */ + Runnable acceptPut(CallContext context, + FlightStream flightStream, StreamListener ackStream); + /** + * Generic handler for application-defined RPCs. + * + * @param context Per-call context. + * @param action Client-supplied parameters. + * @param listener A stream of responses. + */ void doAction(CallContext context, Action action, StreamListener listener); - void listActions(CallContext context, - StreamListener listener); + /** + * List available application-defined RPCs. + * @param context Per-call context. + * @param listener An interface for sending data back to the client. + */ + void listActions(CallContext context, StreamListener listener); /** - * Listener for creating a stream on the server side. + * An interface for sending Arrow data back to a client. */ interface ServerStreamListener { + /** + * Check whether the call has been cancelled. If so, stop sending data. + */ boolean isCancelled(); + /** + * A hint indicating whether the client is ready to receive data without excessive buffering. + */ boolean isReady(); + /** + * Start sending data, using the schema of the given {@link VectorSchemaRoot}. + * + *

This method must be called before all others. + */ void start(VectorSchemaRoot root); + /** + * Start sending data, using the schema of the given {@link VectorSchemaRoot}. + * + *

This method must be called before all others. + */ + void start(VectorSchemaRoot root, DictionaryProvider dictionaries); + + /** + * Send the current contents of the associated {@link VectorSchemaRoot}. + */ void putNext(); + /** + * Send the current contents of the associated {@link VectorSchemaRoot} alongside application-defined metadata. + * @param metadata The metadata to send. Ownership of the buffer is transferred to the Flight implementation. + */ + void putNext(ArrowBuf metadata); + + /** + * Indicate an error to the client. Terminates the stream; do not call {@link #completed()} afterwards. + */ void error(Throwable ex); + /** + * Indicate that transmission is finished. + */ void completed(); } @@ -71,10 +140,21 @@ interface ServerStreamListener { */ interface StreamListener { + /** + * Send the next value to the client. + */ void onNext(T val); + /** + * Indicate an error to the client. + * + *

Terminates the stream; do not call {@link #onCompleted()}. + */ void onError(Throwable t); + /** + * Indicate that the transmission is finished. + */ void onCompleted(); } diff --git a/java/flight/src/main/java/org/apache/arrow/flight/FlightServer.java b/java/flight/src/main/java/org/apache/arrow/flight/FlightServer.java index cd59a75cbbb..3f02dd588e1 100644 --- a/java/flight/src/main/java/org/apache/arrow/flight/FlightServer.java +++ b/java/flight/src/main/java/org/apache/arrow/flight/FlightServer.java @@ -21,6 +21,8 @@ import java.io.FileInputStream; import java.io.IOException; import java.io.InputStream; +import java.net.URI; +import java.net.URISyntaxException; import java.util.HashMap; import java.util.Map; import java.util.concurrent.Executor; @@ -47,13 +49,15 @@ public class FlightServer implements AutoCloseable { private static final org.slf4j.Logger logger = org.slf4j.LoggerFactory.getLogger(FlightServer.class); + private final Location location; private final Server server; /** The maximum size of an individual gRPC message. This effectively disables the limit. */ static final int MAX_GRPC_MESSAGE_SIZE = Integer.MAX_VALUE; /** Create a new instance from a gRPC server. For internal use only. */ - private FlightServer(Server server) { + private FlightServer(Location location, Server server) { + this.location = location; this.server = server; } @@ -63,10 +67,27 @@ public FlightServer start() throws IOException { return this; } + /** Get the port the server is running on (if applicable). */ public int getPort() { return server.getPort(); } + /** Get the location for this server. */ + public Location getLocation() { + if (location.getUri().getPort() == 0) { + // If the server was bound to port 0, replace the port in the location with the real port. + final URI uri = location.getUri(); + try { + return new Location(new URI(uri.getScheme(), uri.getUserInfo(), uri.getHost(), getPort(), + uri.getPath(), uri.getQuery(), uri.getFragment())); + } catch (URISyntaxException e) { + // We don't expect this to happen + throw new RuntimeException(e); + } + } + return location; + } + /** Block until the server shuts down. */ public void awaitTermination() throws InterruptedException { server.awaitTermination(); @@ -211,7 +232,7 @@ public FlightServer build() { return null; }); - return new FlightServer(builder.build()); + return new FlightServer(location, builder.build()); } /** diff --git a/java/flight/src/main/java/org/apache/arrow/flight/FlightService.java b/java/flight/src/main/java/org/apache/arrow/flight/FlightService.java index b5c22efb224..ee45cef24d3 100644 --- a/java/flight/src/main/java/org/apache/arrow/flight/FlightService.java +++ b/java/flight/src/main/java/org/apache/arrow/flight/FlightService.java @@ -30,18 +30,21 @@ import org.apache.arrow.flight.impl.Flight.Empty; import org.apache.arrow.flight.impl.Flight.HandshakeRequest; import org.apache.arrow.flight.impl.Flight.HandshakeResponse; -import org.apache.arrow.flight.impl.Flight.PutResult; import org.apache.arrow.flight.impl.FlightServiceGrpc.FlightServiceImplBase; import org.apache.arrow.memory.BufferAllocator; import org.apache.arrow.vector.VectorSchemaRoot; import org.apache.arrow.vector.VectorUnloader; +import org.apache.arrow.vector.dictionary.DictionaryProvider; +import org.apache.arrow.vector.dictionary.DictionaryProvider.MapDictionaryProvider; import org.slf4j.Logger; import org.slf4j.LoggerFactory; import com.google.common.base.Preconditions; +import io.grpc.Status; import io.grpc.stub.ServerCallStreamObserver; import io.grpc.stub.StreamObserver; +import io.netty.buffer.ArrowBuf; /** * GRPC service implementation for a flight server. @@ -138,15 +141,25 @@ public boolean isCancelled() { @Override public void start(VectorSchemaRoot root) { - responseObserver.onNext(new ArrowMessage(null, root.getSchema())); - // [ARROW-4213] We must align buffers to be compatible with other languages. + start(root, new MapDictionaryProvider()); + } + + @Override + public void start(VectorSchemaRoot root, DictionaryProvider provider) { unloader = new VectorUnloader(root, true, true); + + DictionaryUtils.generateSchemaMessages(root.getSchema(), null, provider, responseObserver::onNext); } @Override public void putNext() { + putNext(null); + } + + @Override + public void putNext(ArrowBuf metadata) { Preconditions.checkNotNull(unloader); - responseObserver.onNext(new ArrowMessage(unloader.getRecordBatch())); + responseObserver.onNext(new ArrowMessage(unloader.getRecordBatch(), metadata)); } @Override @@ -161,18 +174,33 @@ public void completed() { } - public StreamObserver doPutCustom(final StreamObserver responseObserverSimple) { - ServerCallStreamObserver responseObserver = (ServerCallStreamObserver) responseObserverSimple; + public StreamObserver doPutCustom(final StreamObserver responseObserverSimple) { + ServerCallStreamObserver responseObserver = + (ServerCallStreamObserver) responseObserverSimple; responseObserver.disableAutoInboundFlowControl(); responseObserver.request(1); - FlightStream fs = new FlightStream(allocator, PENDING_REQUESTS, null, (count) -> responseObserver.request(count)); + // Set a default metadata listener that does nothing. Service implementations should call + // FlightStream#setMetadataListener before returning a Runnable if they want to receive metadata. + FlightStream fs = new FlightStream(allocator, PENDING_REQUESTS, (String message, Throwable cause) -> { + responseObserver.onError(Status.CANCELLED.withCause(cause).withDescription(message).asException()); + }, responseObserver::request); executors.submit(() -> { try { - responseObserver.onNext(producer.acceptPut(makeContext(responseObserver), fs).call()); + producer.acceptPut(makeContext(responseObserver), fs, + StreamPipe.wrap(responseObserver, PutResult::toProtocol)).run(); responseObserver.onCompleted(); } catch (Exception ex) { responseObserver.onError(ex); + // The client may have terminated, so the exception here is effectively swallowed. + // Log the error as well so -something- makes it to the developer. + logger.error("Exception handling DoPut", ex); + } + try { + fs.close(); + } catch (Exception e) { + logger.error("Exception closing Flight stream", e); + throw new RuntimeException(e); } }); diff --git a/java/flight/src/main/java/org/apache/arrow/flight/FlightStream.java b/java/flight/src/main/java/org/apache/arrow/flight/FlightStream.java index 79685c49f79..010ff330a2c 100644 --- a/java/flight/src/main/java/org/apache/arrow/flight/FlightStream.java +++ b/java/flight/src/main/java/org/apache/arrow/flight/FlightStream.java @@ -17,17 +17,28 @@ package org.apache.arrow.flight; +import java.util.ArrayList; +import java.util.Collections; +import java.util.HashMap; import java.util.List; +import java.util.Map; import java.util.concurrent.ExecutionException; import java.util.concurrent.LinkedBlockingQueue; import java.util.stream.Collectors; +import org.apache.arrow.flight.ArrowMessage.HeaderType; import org.apache.arrow.memory.BufferAllocator; import org.apache.arrow.util.AutoCloseables; +import org.apache.arrow.vector.FieldVector; import org.apache.arrow.vector.VectorLoader; import org.apache.arrow.vector.VectorSchemaRoot; +import org.apache.arrow.vector.dictionary.Dictionary; +import org.apache.arrow.vector.dictionary.DictionaryProvider; +import org.apache.arrow.vector.ipc.message.ArrowDictionaryBatch; import org.apache.arrow.vector.ipc.message.ArrowRecordBatch; +import org.apache.arrow.vector.types.pojo.Field; import org.apache.arrow.vector.types.pojo.Schema; +import org.apache.arrow.vector.util.DictionaryUtility; import com.google.common.base.Throwables; import com.google.common.collect.ImmutableList; @@ -35,11 +46,12 @@ import com.google.common.util.concurrent.SettableFuture; import io.grpc.stub.StreamObserver; +import io.netty.buffer.ArrowBuf; /** * An adaptor between protobuf streams and flight data streams. */ -public class FlightStream { +public class FlightStream implements AutoCloseable { private final Object DONE = new Object(); @@ -56,10 +68,12 @@ public class FlightStream { private volatile int pending = 1; private boolean completed = false; private volatile VectorSchemaRoot fulfilledRoot; + private DictionaryProvider.MapDictionaryProvider dictionaries; private volatile VectorLoader loader; private volatile Throwable ex; private volatile FlightDescriptor descriptor; private volatile Schema schema; + private volatile ArrowBuf applicationMetadata = null; /** * Constructs a new instance. @@ -74,12 +88,17 @@ public FlightStream(BufferAllocator allocator, int pendingTarget, Cancellable ca this.pendingTarget = pendingTarget; this.cancellable = cancellable; this.requestor = requestor; + this.dictionaries = new DictionaryProvider.MapDictionaryProvider(); } public Schema getSchema() { return schema; } + public DictionaryProvider getDictionaryProvider() { + return dictionaries; + } + public FlightDescriptor getDescriptor() { return descriptor; } @@ -98,7 +117,10 @@ public void close() throws Exception { .map(t -> ((AutoCloseable) t)) .collect(Collectors.toList()); - AutoCloseables.close(Iterables.concat(closeables, ImmutableList.of(root.get()))); + // Must check for null since ImmutableList doesn't accept nulls + AutoCloseables.close(Iterables.concat(closeables, + applicationMetadata != null ? ImmutableList.of(root.get(), applicationMetadata) + : ImmutableList.of(root.get()))); } /** @@ -131,15 +153,41 @@ public boolean next() { throw new Exception(ex); } } else { - ArrowMessage msg = ((ArrowMessage) data); - try (ArrowRecordBatch arb = msg.asRecordBatch()) { - loader.load(arb); + try (ArrowMessage msg = ((ArrowMessage) data)) { + if (msg.getMessageType() == HeaderType.RECORD_BATCH) { + try (ArrowRecordBatch arb = msg.asRecordBatch()) { + loader.load(arb); + } + if (this.applicationMetadata != null) { + this.applicationMetadata.close(); + } + this.applicationMetadata = msg.getApplicationMetadata(); + if (this.applicationMetadata != null) { + this.applicationMetadata.getReferenceManager().retain(); + } + } else if (msg.getMessageType() == HeaderType.DICTIONARY_BATCH) { + try (ArrowDictionaryBatch arb = msg.asDictionaryBatch()) { + final long id = arb.getDictionaryId(); + final Dictionary dictionary = dictionaries.lookup(id); + if (dictionary == null) { + throw new IllegalArgumentException("Dictionary not defined in schema: ID " + id); + } + + final FieldVector vector = dictionary.getVector(); + final VectorSchemaRoot dictionaryRoot = new VectorSchemaRoot(Collections.singletonList(vector.getField()), + Collections.singletonList(vector), 0); + final VectorLoader dictionaryLoader = new VectorLoader(dictionaryRoot); + dictionaryLoader.load(arb.getDictionary()); + } + return next(); + } else { + throw new UnsupportedOperationException("Message type is unsupported: " + msg.getMessageType()); + } + return true; } - return true; } - } catch (Exception e) { - throw Throwables.propagate(e); + throw new RuntimeException(e); } } @@ -152,6 +200,17 @@ public VectorSchemaRoot getRoot() { } } + /** + * Get the most recent metadata sent from the server. This may be cleared by calls to {@link #next()} if the server + * sends a message without metadata. This does NOT take ownership of the buffer - call retain() to create a reference + * if you need the buffer after a call to {@link #next()}. + * + * @return the application metadata. May be null. + */ + public ArrowBuf getLatestMetadata() { + return applicationMetadata; + } + private synchronized void requestOutstanding() { if (pending < pendingTarget) { requestor.request(pendingTarget - pending); @@ -169,23 +228,36 @@ public Observer() { public void onNext(ArrowMessage msg) { requestOutstanding(); switch (msg.getMessageType()) { - case SCHEMA: + case SCHEMA: { schema = msg.asSchema(); + final List fields = new ArrayList<>(); + final Map dictionaryMap = new HashMap<>(); + for (final Field originalField : schema.getFields()) { + final Field updatedField = DictionaryUtility.toMemoryFormat(originalField, allocator, dictionaryMap); + fields.add(updatedField); + } + for (final Map.Entry entry : dictionaryMap.entrySet()) { + dictionaries.put(entry.getValue()); + } + schema = new Schema(fields, schema.getCustomMetadata()); fulfilledRoot = VectorSchemaRoot.create(schema, allocator); loader = new VectorLoader(fulfilledRoot); descriptor = msg.getDescriptor() != null ? new FlightDescriptor(msg.getDescriptor()) : null; root.set(fulfilledRoot); break; + } case RECORD_BATCH: queue.add(msg); break; - case NONE: case DICTIONARY_BATCH: + queue.add(msg); + break; + case NONE: case TENSOR: default: queue.add(DONE_EX); - ex = new UnsupportedOperationException("Unable to handle message of type: " + msg); + ex = new UnsupportedOperationException("Unable to handle message of type: " + msg.getMessageType()); } diff --git a/java/flight/src/main/java/org/apache/arrow/flight/NoOpFlightProducer.java b/java/flight/src/main/java/org/apache/arrow/flight/NoOpFlightProducer.java index 0e6e373e654..eca32e1c679 100644 --- a/java/flight/src/main/java/org/apache/arrow/flight/NoOpFlightProducer.java +++ b/java/flight/src/main/java/org/apache/arrow/flight/NoOpFlightProducer.java @@ -17,10 +17,6 @@ package org.apache.arrow.flight; -import java.util.concurrent.Callable; - -import org.apache.arrow.flight.impl.Flight.PutResult; - /** * A {@link FlightProducer} that throws on all operations. */ @@ -45,8 +41,8 @@ public FlightInfo getFlightInfo(CallContext context, } @Override - public Callable acceptPut(CallContext context, - FlightStream flightStream) { + public Runnable acceptPut(CallContext context, + FlightStream flightStream, StreamListener ackStream) { throw new UnsupportedOperationException("NYI"); } diff --git a/java/flight/src/main/java/org/apache/arrow/flight/GenericOperation.java b/java/flight/src/main/java/org/apache/arrow/flight/NoOpStreamListener.java similarity index 54% rename from java/flight/src/main/java/org/apache/arrow/flight/GenericOperation.java rename to java/flight/src/main/java/org/apache/arrow/flight/NoOpStreamListener.java index 03a1e92af12..e06af1a1026 100644 --- a/java/flight/src/main/java/org/apache/arrow/flight/GenericOperation.java +++ b/java/flight/src/main/java/org/apache/arrow/flight/NoOpStreamListener.java @@ -17,26 +17,33 @@ package org.apache.arrow.flight; +import org.apache.arrow.flight.FlightProducer.StreamListener; + /** - * Unused?. + * A {@link StreamListener} that does nothing for all callbacks. + * @param The type of the callback object. */ -class GenericOperation { - - private final String type; - private final byte[] body; +public class NoOpStreamListener implements StreamListener { + private static NoOpStreamListener INSTANCE = new NoOpStreamListener(); - public GenericOperation(String type, byte[] body) { - super(); - this.type = type; - this.body = body == null ? new byte[0] : body; + /** Ignores the value received. */ + @Override + public void onNext(T val) { } - public String getType() { - return type; + /** Ignores the error received. */ + @Override + public void onError(Throwable t) { } - public byte[] getBody() { - return body; + /** Ignores the stream completion event. */ + @Override + public void onCompleted() { } + @SuppressWarnings("unchecked") + public static StreamListener getInstance() { + // Safe because we never use T + return (StreamListener) INSTANCE; + } } diff --git a/java/flight/src/main/java/org/apache/arrow/flight/PutResult.java b/java/flight/src/main/java/org/apache/arrow/flight/PutResult.java new file mode 100644 index 00000000000..11848690d07 --- /dev/null +++ b/java/flight/src/main/java/org/apache/arrow/flight/PutResult.java @@ -0,0 +1,97 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.arrow.flight; + +import org.apache.arrow.flight.impl.Flight; +import org.apache.arrow.memory.BufferAllocator; +import org.apache.arrow.memory.ReferenceManager; + +import com.google.protobuf.ByteString; + +import io.netty.buffer.ArrowBuf; + +/** + * A message from the server during a DoPut operation. + * + *

This object owns an {@link ArrowBuf} and should be closed when you are done with it. + */ +public class PutResult implements AutoCloseable { + + private ArrowBuf applicationMetadata; + + private PutResult(ArrowBuf metadata) { + applicationMetadata = metadata; + } + + /** + * Create a PutResult with application-specific metadata. + * + *

This method assumes ownership of the {@link ArrowBuf}. + */ + public static PutResult metadata(ArrowBuf metadata) { + if (metadata == null) { + return empty(); + } + return new PutResult(metadata); + } + + /** Create an empty PutResult. */ + public static PutResult empty() { + return new PutResult(null); + } + + /** + * Get the metadata in this message. May be null. + * + *

Ownership of the {@link ArrowBuf} is retained by this object. Call {@link ReferenceManager#retain()} to preserve + * a reference. + */ + public ArrowBuf getApplicationMetadata() { + return applicationMetadata; + } + + Flight.PutResult toProtocol() { + if (applicationMetadata == null) { + return Flight.PutResult.getDefaultInstance(); + } + return Flight.PutResult.newBuilder().setAppMetadata(ByteString.copyFrom(applicationMetadata.nioBuffer())).build(); + } + + /** + * Construct a PutResult from a Protobuf message. + * + * @param allocator The allocator to use for allocating application metadata memory. The result object owns the + * allocated buffer, if any. + * @param message The gRPC/Protobuf message. + */ + static PutResult fromProtocol(BufferAllocator allocator, Flight.PutResult message) { + final ArrowBuf buf = allocator.buffer(message.getAppMetadata().size()); + message.getAppMetadata().asReadOnlyByteBufferList().forEach(bb -> { + buf.setBytes(buf.writerIndex(), bb); + buf.writerIndex(buf.writerIndex() + bb.limit()); + }); + return new PutResult(buf); + } + + @Override + public void close() { + if (applicationMetadata != null) { + applicationMetadata.close(); + } + } +} diff --git a/java/flight/src/main/java/org/apache/arrow/flight/SyncPutListener.java b/java/flight/src/main/java/org/apache/arrow/flight/SyncPutListener.java new file mode 100644 index 00000000000..f1246a1d079 --- /dev/null +++ b/java/flight/src/main/java/org/apache/arrow/flight/SyncPutListener.java @@ -0,0 +1,114 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.arrow.flight; + +import java.util.concurrent.CompletableFuture; +import java.util.concurrent.ExecutionException; +import java.util.concurrent.LinkedBlockingQueue; +import java.util.concurrent.TimeUnit; + +import io.netty.buffer.ArrowBuf; + +/** + * A listener for server-sent application metadata messages during a Flight DoPut. This class wraps the messages in a + * synchronous interface. + */ +public final class SyncPutListener implements FlightClient.PutListener, AutoCloseable { + + private final LinkedBlockingQueue queue; + private final CompletableFuture completed; + private static final Object DONE = new Object(); + private static final Object DONE_WITH_EXCEPTION = new Object(); + + public SyncPutListener() { + queue = new LinkedBlockingQueue<>(); + completed = new CompletableFuture<>(); + } + + private PutResult unwrap(Object queueItem) throws InterruptedException, ExecutionException { + if (queueItem == DONE) { + queue.put(queueItem); + return null; + } else if (queueItem == DONE_WITH_EXCEPTION) { + queue.put(queueItem); + completed.get(); + } + return (PutResult) queueItem; + } + + /** + * Get the next message from the server, blocking until it is available. + * + * @return The next message, or null if the server is done sending messages. The caller assumes ownership of the + * metadata and must remember to close it. + * @throws InterruptedException if interrupted while waiting. + * @throws ExecutionException if the server sent an error, or if there was an internal error. + */ + public PutResult read() throws InterruptedException, ExecutionException { + return unwrap(queue.take()); + } + + /** + * Get the next message from the server, blocking for the specified amount of time until it is available. + * + * @return The next message, or null if the server is done sending messages or no message arrived before the timeout. + * The caller assumes ownership of the metadata and must remember to close it. + * @throws InterruptedException if interrupted while waiting. + * @throws ExecutionException if the server sent an error, or if there was an internal error. + */ + public PutResult poll(long timeout, TimeUnit unit) throws InterruptedException, ExecutionException { + return unwrap(queue.poll(timeout, unit)); + } + + @Override + public void getResult() { + try { + completed.get(); + } catch (InterruptedException | ExecutionException e) { + throw new RuntimeException(e); + } + } + + @Override + public void onNext(PutResult val) { + final ArrowBuf metadata = val.getApplicationMetadata(); + metadata.getReferenceManager().retain(); + queue.add(PutResult.metadata(metadata)); + } + + @Override + public void onError(Throwable t) { + completed.completeExceptionally(t); + queue.add(DONE_WITH_EXCEPTION); + } + + @Override + public void onCompleted() { + completed.complete(null); + queue.add(DONE); + } + + @Override + public void close() { + queue.forEach(o -> { + if (o instanceof PutResult) { + ((PutResult) o).close(); + } + }); + } +} diff --git a/java/flight/src/main/java/org/apache/arrow/flight/example/FlightHolder.java b/java/flight/src/main/java/org/apache/arrow/flight/example/FlightHolder.java index 91ed04e7ffa..cf3eb154ed7 100644 --- a/java/flight/src/main/java/org/apache/arrow/flight/example/FlightHolder.java +++ b/java/flight/src/main/java/org/apache/arrow/flight/example/FlightHolder.java @@ -28,6 +28,7 @@ import org.apache.arrow.flight.Location; import org.apache.arrow.memory.BufferAllocator; import org.apache.arrow.util.AutoCloseables; +import org.apache.arrow.vector.dictionary.DictionaryProvider; import org.apache.arrow.vector.types.pojo.Schema; import com.google.common.base.Preconditions; @@ -43,19 +44,22 @@ public class FlightHolder implements AutoCloseable { private final FlightDescriptor descriptor; private final Schema schema; private final List streams = new CopyOnWriteArrayList<>(); + private final DictionaryProvider dictionaryProvider; /** * Creates a new instance. - * - * @param allocator The allocator to use for allocating buffers to store data. + * @param allocator The allocator to use for allocating buffers to store data. * @param descriptor The descriptor for the streams. * @param schema The schema for the stream. + * @param dictionaryProvider The dictionary provider for the stream. */ - public FlightHolder(BufferAllocator allocator, FlightDescriptor descriptor, Schema schema) { + public FlightHolder(BufferAllocator allocator, FlightDescriptor descriptor, Schema schema, + DictionaryProvider dictionaryProvider) { Preconditions.checkArgument(!descriptor.isCommand()); this.allocator = allocator.newChildAllocator(descriptor.toString(), 0, Long.MAX_VALUE); this.descriptor = descriptor; this.schema = schema; + this.dictionaryProvider = dictionaryProvider; } /** @@ -72,8 +76,8 @@ public Stream getStream(ExampleTicket ticket) { * Adds a new streams which clients can populate via the returned object. */ public Stream.StreamCreator addStream(Schema schema) { - Preconditions.checkArgument(schema.equals(schema), "Stream schema inconsistent with existing schema."); - return new Stream.StreamCreator(schema, allocator, t -> { + Preconditions.checkArgument(this.schema.equals(schema), "Stream schema inconsistent with existing schema."); + return new Stream.StreamCreator(schema, dictionaryProvider, allocator, t -> { synchronized (streams) { streams.add(t); } diff --git a/java/flight/src/main/java/org/apache/arrow/flight/example/InMemoryStore.java b/java/flight/src/main/java/org/apache/arrow/flight/example/InMemoryStore.java index 452faa17cdd..59324b30397 100644 --- a/java/flight/src/main/java/org/apache/arrow/flight/example/InMemoryStore.java +++ b/java/flight/src/main/java/org/apache/arrow/flight/example/InMemoryStore.java @@ -17,7 +17,6 @@ package org.apache.arrow.flight.example; -import java.util.concurrent.Callable; import java.util.concurrent.ConcurrentHashMap; import java.util.concurrent.ConcurrentMap; @@ -29,15 +28,14 @@ import org.apache.arrow.flight.FlightProducer; import org.apache.arrow.flight.FlightStream; import org.apache.arrow.flight.Location; +import org.apache.arrow.flight.PutResult; import org.apache.arrow.flight.Result; import org.apache.arrow.flight.Ticket; import org.apache.arrow.flight.example.Stream.StreamCreator; -import org.apache.arrow.flight.impl.Flight.PutResult; import org.apache.arrow.memory.BufferAllocator; import org.apache.arrow.util.AutoCloseables; import org.apache.arrow.vector.VectorSchemaRoot; import org.apache.arrow.vector.VectorUnloader; -import org.apache.arrow.vector.types.pojo.Schema; /** * A FlightProducer that hosts an in memory store of Arrow buffers. @@ -80,17 +78,6 @@ public Stream getStream(Ticket t) { return h.getStream(example); } - /** - * Create a new {@link Stream} with the given schema and descriptor. - */ - public StreamCreator putStream(final FlightDescriptor descriptor, final Schema schema) { - final FlightHolder h = holders.computeIfAbsent( - descriptor, - t -> new FlightHolder(allocator, t, schema)); - - return h.addStream(schema); - } - @Override public void listFlights(CallContext context, Criteria criteria, StreamListener listener) { @@ -116,25 +103,25 @@ public FlightInfo getFlightInfo(CallContext context, } @Override - public Callable acceptPut(CallContext context, - final FlightStream flightStream) { + public Runnable acceptPut(CallContext context, + final FlightStream flightStream, final StreamListener ackStream) { return () -> { StreamCreator creator = null; boolean success = false; try (VectorSchemaRoot root = flightStream.getRoot()) { final FlightHolder h = holders.computeIfAbsent( flightStream.getDescriptor(), - t -> new FlightHolder(allocator, t, flightStream.getSchema())); + t -> new FlightHolder(allocator, t, flightStream.getSchema(), flightStream.getDictionaryProvider())); creator = h.addStream(flightStream.getSchema()); VectorUnloader unloader = new VectorUnloader(root); while (flightStream.next()) { + ackStream.onNext(PutResult.metadata(flightStream.getLatestMetadata())); creator.add(unloader.getRecordBatch()); } creator.complete(); success = true; - return PutResult.getDefaultInstance(); } finally { if (!success) { creator.drop(); diff --git a/java/flight/src/main/java/org/apache/arrow/flight/example/Stream.java b/java/flight/src/main/java/org/apache/arrow/flight/example/Stream.java index f36b38cee09..2d42ed25524 100644 --- a/java/flight/src/main/java/org/apache/arrow/flight/example/Stream.java +++ b/java/flight/src/main/java/org/apache/arrow/flight/example/Stream.java @@ -17,6 +17,7 @@ package org.apache.arrow.flight.example; +import java.nio.charset.StandardCharsets; import java.util.ArrayList; import java.util.Iterator; import java.util.List; @@ -28,18 +29,22 @@ import org.apache.arrow.util.AutoCloseables; import org.apache.arrow.vector.VectorLoader; import org.apache.arrow.vector.VectorSchemaRoot; +import org.apache.arrow.vector.dictionary.DictionaryProvider; import org.apache.arrow.vector.ipc.message.ArrowRecordBatch; import org.apache.arrow.vector.types.pojo.Schema; import com.google.common.base.Throwables; import com.google.common.collect.ImmutableList; +import io.netty.buffer.ArrowBuf; + /** * A collection of Arrow record batches. */ public class Stream implements AutoCloseable, Iterable { private final String uuid = UUID.randomUUID().toString(); + private final DictionaryProvider dictionaryProvider; private final List batches; private final Schema schema; private final long recordCount; @@ -53,9 +58,11 @@ public class Stream implements AutoCloseable, Iterable { */ public Stream( final Schema schema, + final DictionaryProvider dictionaryProvider, List batches, long recordCount) { this.schema = schema; + this.dictionaryProvider = dictionaryProvider; this.batches = ImmutableList.copyOf(batches); this.recordCount = recordCount; } @@ -82,11 +89,17 @@ public String getUuid() { */ public void sendTo(BufferAllocator allocator, ServerStreamListener listener) { try (VectorSchemaRoot root = VectorSchemaRoot.create(schema, allocator)) { - listener.start(root); + listener.start(root, dictionaryProvider); final VectorLoader loader = new VectorLoader(root); + int counter = 0; for (ArrowRecordBatch batch : batches) { + final byte[] rawMetadata = Integer.toString(counter).getBytes(StandardCharsets.UTF_8); + final ArrowBuf metadata = allocator.buffer(rawMetadata.length); + metadata.writeBytes(rawMetadata); loader.load(batch); - listener.putNext(); + // Transfers ownership of the buffer - do not free buffer ourselves + listener.putNext(metadata); + counter++; } listener.completed(); } catch (Exception ex) { @@ -118,18 +131,22 @@ public static class StreamCreator { private final List batches = new ArrayList<>(); private final Consumer committer; private long recordCount = 0; + private DictionaryProvider dictionaryProvider; /** * Creates a new instance. * * @param schema The schema for batches in the stream. + * @param dictionaryProvider The dictionary provider for the stream. * @param allocator The allocator used to copy data permanently into the stream. * @param committer A callback for when the the stream is ready to be finalized (no more batches). */ - public StreamCreator(Schema schema, BufferAllocator allocator, Consumer committer) { + public StreamCreator(Schema schema, DictionaryProvider dictionaryProvider, + BufferAllocator allocator, Consumer committer) { this.allocator = allocator; this.committer = committer; this.schema = schema; + this.dictionaryProvider = dictionaryProvider; } /** @@ -152,7 +169,7 @@ public void add(ArrowRecordBatch batch) { * Complete building the stream (no more batches can be added). */ public void complete() { - Stream stream = new Stream(schema, batches, recordCount); + Stream stream = new Stream(schema, dictionaryProvider, batches, recordCount); committer.accept(stream); } diff --git a/java/flight/src/main/java/org/apache/arrow/flight/example/integration/IntegrationTestClient.java b/java/flight/src/main/java/org/apache/arrow/flight/example/integration/IntegrationTestClient.java index ccafde08271..477dfdba4bd 100644 --- a/java/flight/src/main/java/org/apache/arrow/flight/example/integration/IntegrationTestClient.java +++ b/java/flight/src/main/java/org/apache/arrow/flight/example/integration/IntegrationTestClient.java @@ -19,15 +19,18 @@ import java.io.File; import java.io.IOException; +import java.nio.charset.StandardCharsets; import java.util.Collections; import java.util.List; +import org.apache.arrow.flight.AsyncPutListener; import org.apache.arrow.flight.FlightClient; import org.apache.arrow.flight.FlightDescriptor; import org.apache.arrow.flight.FlightEndpoint; import org.apache.arrow.flight.FlightInfo; import org.apache.arrow.flight.FlightStream; import org.apache.arrow.flight.Location; +import org.apache.arrow.flight.PutResult; import org.apache.arrow.memory.BufferAllocator; import org.apache.arrow.memory.RootAllocator; import org.apache.arrow.vector.VectorLoader; @@ -41,6 +44,8 @@ import org.apache.commons.cli.Options; import org.apache.commons.cli.ParseException; +import io.netty.buffer.ArrowBuf; + /** * An Example Flight Server that provides access to the InMemoryStore. */ @@ -89,15 +94,36 @@ private void run(String[] args) throws ParseException, IOException { FlightDescriptor descriptor = FlightDescriptor.path(inputPath); VectorSchemaRoot jsonRoot; try (JsonFileReader reader = new JsonFileReader(new File(inputPath), allocator); - VectorSchemaRoot root = VectorSchemaRoot.create(reader.start(), allocator)) { + VectorSchemaRoot root = VectorSchemaRoot.create(reader.start(), allocator)) { jsonRoot = VectorSchemaRoot.create(root.getSchema(), allocator); VectorUnloader unloader = new VectorUnloader(root); VectorLoader jsonLoader = new VectorLoader(jsonRoot); - FlightClient.ClientStreamListener stream = client.startPut(descriptor, root); + FlightClient.ClientStreamListener stream = client.startPut(descriptor, root, reader, + new AsyncPutListener() { + int counter = 0; + + @Override + public void onNext(PutResult val) { + final byte[] metadataRaw = new byte[val.getApplicationMetadata().readableBytes()]; + val.getApplicationMetadata().readBytes(metadataRaw); + final String metadata = new String(metadataRaw, StandardCharsets.UTF_8); + if (!Integer.toString(counter).equals(metadata)) { + throw new RuntimeException( + String.format("Invalid ACK from server. Expected '%d' but got '%s'.", counter, metadata)); + } + counter++; + } + }); + int counter = 0; while (reader.read(root)) { - stream.putNext(); + final byte[] rawMetadata = Integer.toString(counter).getBytes(StandardCharsets.UTF_8); + final ArrowBuf metadata = allocator.buffer(rawMetadata.length); + metadata.writeBytes(rawMetadata); + // Transfers ownership of the buffer, so do not release it ourselves + stream.putNext(metadata); jsonLoader.load(unloader.getRecordBatch()); root.clear(); + counter++; } stream.completed(); // Need to call this, or exceptions from the server get swallowed diff --git a/java/flight/src/main/java/org/apache/arrow/flight/grpc/GetReadableBuffer.java b/java/flight/src/main/java/org/apache/arrow/flight/grpc/GetReadableBuffer.java index 9591cf57f22..b584d961e71 100644 --- a/java/flight/src/main/java/org/apache/arrow/flight/grpc/GetReadableBuffer.java +++ b/java/flight/src/main/java/org/apache/arrow/flight/grpc/GetReadableBuffer.java @@ -17,12 +17,15 @@ package org.apache.arrow.flight.grpc; +import java.io.IOException; import java.io.InputStream; import java.lang.reflect.Field; import com.google.common.base.Throwables; +import com.google.common.io.ByteStreams; import io.grpc.internal.ReadableBuffer; +import io.netty.buffer.ArrowBuf; /** * Enable access to ReadableBuffer directly to copy data from an BufferInputStream into a target @@ -72,4 +75,24 @@ public static ReadableBuffer getReadableBuffer(InputStream is) { } } + /** + * Helper method to read a gRPC-provided InputStream into an ArrowBuf. + * @param stream The stream to read from. Should be an instance of {@link #BUFFER_INPUT_STREAM}. + * @param buf The buffer to read into. + * @param size The number of bytes to read. + * @param fastPath Whether to enable the fast path (i.e. detect whether the stream is a {@link #BUFFER_INPUT_STREAM}). + * @throws IOException if there is an error reading form the stream + */ + public static void readIntoBuffer(final InputStream stream, final ArrowBuf buf, final int size, + final boolean fastPath) throws IOException { + ReadableBuffer readableBuffer = fastPath ? getReadableBuffer(stream) : null; + if (readableBuffer != null) { + readableBuffer.readBytes(buf.nioBuffer(0, size)); + } else { + byte[] heapBytes = new byte[size]; + ByteStreams.readFully(stream, heapBytes); + buf.writeBytes(heapBytes); + } + buf.writerIndex(size); + } } diff --git a/java/flight/src/test/java/org/apache/arrow/flight/FlightTestUtil.java b/java/flight/src/test/java/org/apache/arrow/flight/FlightTestUtil.java index 3cb09ef5cd9..a10d490555d 100644 --- a/java/flight/src/test/java/org/apache/arrow/flight/FlightTestUtil.java +++ b/java/flight/src/test/java/org/apache/arrow/flight/FlightTestUtil.java @@ -43,14 +43,15 @@ public class FlightTestUtil { * Returns a a FlightServer (actually anything that is startable) * that has been started bound to a random port. */ - public static T getStartedServer(Function newServerFromPort) throws IOException { + public static T getStartedServer(Function newServerFromLocation) throws IOException { IOException lastThrown = null; T server = null; for (int x = 0; x < 3; x++) { final int port = 49152 + RANDOM.nextInt(5000); + final Location location = Location.forGrpcInsecure(LOCALHOST, port); lastThrown = null; try { - server = newServerFromPort.apply(port); + server = newServerFromLocation.apply(location); try { server.getClass().getMethod("start").invoke(server); } catch (NoSuchMethodException | IllegalAccessException e) { diff --git a/java/flight/src/test/java/org/apache/arrow/flight/TestApplicationMetadata.java b/java/flight/src/test/java/org/apache/arrow/flight/TestApplicationMetadata.java new file mode 100644 index 00000000000..ad2c58f3b78 --- /dev/null +++ b/java/flight/src/test/java/org/apache/arrow/flight/TestApplicationMetadata.java @@ -0,0 +1,245 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.arrow.flight; + +import java.util.Collections; + +import java.util.concurrent.ExecutionException; +import java.util.concurrent.TimeUnit; +import java.util.function.BiConsumer; + +import org.apache.arrow.flight.FlightClient.PutListener; +import org.apache.arrow.memory.BufferAllocator; +import org.apache.arrow.memory.RootAllocator; +import org.apache.arrow.vector.IntVector; +import org.apache.arrow.vector.VectorSchemaRoot; +import org.apache.arrow.vector.types.pojo.ArrowType; +import org.apache.arrow.vector.types.pojo.Field; +import org.apache.arrow.vector.types.pojo.Schema; + +import org.junit.Assert; +import org.junit.Ignore; +import org.junit.Test; + +import io.grpc.Status; +import io.netty.buffer.ArrowBuf; + +/** + * Tests for application-specific metadata support in Flight. + */ +public class TestApplicationMetadata { + + /** + * Ensure that a client can read the metadata sent from the server. + */ + @Test + // This test is consistently flaky on CI, unfortunately. + @Ignore + public void retrieveMetadata() { + test((allocator, client) -> { + try (final FlightStream stream = client.getStream(new Ticket(new byte[0]))) { + byte i = 0; + while (stream.next()) { + final IntVector vector = (IntVector) stream.getRoot().getVector("a"); + Assert.assertEquals(1, vector.getValueCount()); + Assert.assertEquals(10, vector.get(0)); + Assert.assertEquals(i, stream.getLatestMetadata().getByte(0)); + i++; + } + } catch (Exception e) { + throw new RuntimeException(e); + } + }); + } + + /** + * Ensure that a client can send metadata to the server. + */ + @Test + @Ignore + public void uploadMetadataAsync() { + final Schema schema = new Schema(Collections.singletonList(Field.nullable("a", new ArrowType.Int(32, true)))); + test((allocator, client) -> { + try (final VectorSchemaRoot root = VectorSchemaRoot.create(schema, allocator)) { + final FlightDescriptor descriptor = FlightDescriptor.path("test"); + + final PutListener listener = new AsyncPutListener() { + int counter = 0; + + @Override + public void onNext(PutResult val) { + Assert.assertNotNull(val); + Assert.assertEquals(counter, val.getApplicationMetadata().getByte(0)); + counter++; + } + }; + final FlightClient.ClientStreamListener writer = client.startPut(descriptor, root, listener); + + root.allocateNew(); + for (byte i = 0; i < 10; i++) { + final IntVector vector = (IntVector) root.getVector("a"); + final ArrowBuf metadata = allocator.buffer(1); + metadata.writeByte(i); + vector.set(0, 10); + vector.setValueCount(1); + root.setRowCount(1); + writer.putNext(metadata); + } + writer.completed(); + // Must attempt to retrieve the result to get any server-side errors. + writer.getResult(); + } catch (Exception e) { + throw new RuntimeException(e); + } + }); + } + + /** + * Ensure that a client can send metadata to the server. Uses the synchronous API. + */ + @Test + @Ignore + public void uploadMetadataSync() { + final Schema schema = new Schema(Collections.singletonList(Field.nullable("a", new ArrowType.Int(32, true)))); + test((allocator, client) -> { + try (final VectorSchemaRoot root = VectorSchemaRoot.create(schema, allocator); + final SyncPutListener listener = new SyncPutListener()) { + final FlightDescriptor descriptor = FlightDescriptor.path("test"); + final FlightClient.ClientStreamListener writer = client.startPut(descriptor, root, listener); + + root.allocateNew(); + for (byte i = 0; i < 10; i++) { + final IntVector vector = (IntVector) root.getVector("a"); + final ArrowBuf metadata = allocator.buffer(1); + metadata.writeByte(i); + vector.set(0, 10); + vector.setValueCount(1); + root.setRowCount(1); + writer.putNext(metadata); + try (final PutResult message = listener.poll(5000, TimeUnit.SECONDS)) { + Assert.assertNotNull(message); + Assert.assertEquals(i, message.getApplicationMetadata().getByte(0)); + } catch (InterruptedException | ExecutionException e) { + throw new RuntimeException(e); + } + } + writer.completed(); + // Must attempt to retrieve the result to get any server-side errors. + writer.getResult(); + } + }); + } + + /** + * Make sure that a {@link SyncPutListener} properly reclaims memory if ignored. + */ + @Test + @Ignore + public void syncMemoryReclaimed() { + final Schema schema = new Schema(Collections.singletonList(Field.nullable("a", new ArrowType.Int(32, true)))); + test((allocator, client) -> { + try (final VectorSchemaRoot root = VectorSchemaRoot.create(schema, allocator); + final SyncPutListener listener = new SyncPutListener()) { + final FlightDescriptor descriptor = FlightDescriptor.path("test"); + final FlightClient.ClientStreamListener writer = client.startPut(descriptor, root, listener); + + root.allocateNew(); + for (byte i = 0; i < 10; i++) { + final IntVector vector = (IntVector) root.getVector("a"); + final ArrowBuf metadata = allocator.buffer(1); + metadata.writeByte(i); + vector.set(0, 10); + vector.setValueCount(1); + root.setRowCount(1); + writer.putNext(metadata); + } + writer.completed(); + // Must attempt to retrieve the result to get any server-side errors. + writer.getResult(); + } + }); + } + + private void test(BiConsumer fun) { + try (final BufferAllocator allocator = new RootAllocator(Long.MAX_VALUE); + final FlightServer s = + FlightTestUtil.getStartedServer( + (location) -> FlightServer.builder(allocator, location, new MetadataFlightProducer(allocator)).build()); + final FlightClient client = FlightClient.builder(allocator, s.getLocation()).build()) { + fun.accept(allocator, client); + } catch (Exception e) { + throw new RuntimeException(e); + } + } + + /** + * A FlightProducer that always produces a fixed data stream with metadata on the side. + */ + private static class MetadataFlightProducer extends NoOpFlightProducer { + + private final BufferAllocator allocator; + + public MetadataFlightProducer(BufferAllocator allocator) { + this.allocator = allocator; + } + + @Override + public void getStream(CallContext context, Ticket ticket, ServerStreamListener listener) { + final Schema schema = new Schema(Collections.singletonList(Field.nullable("a", new ArrowType.Int(32, true)))); + try (VectorSchemaRoot root = VectorSchemaRoot.create(schema, allocator)) { + root.allocateNew(); + listener.start(root); + for (byte i = 0; i < 10; i++) { + final IntVector vector = (IntVector) root.getVector("a"); + vector.set(0, 10); + vector.setValueCount(1); + root.setRowCount(1); + final ArrowBuf metadata = allocator.buffer(1); + metadata.writeByte(i); + listener.putNext(metadata); + } + listener.completed(); + } + } + + @Override + public Runnable acceptPut(CallContext context, FlightStream stream, StreamListener ackStream) { + return () -> { + try { + byte current = 0; + while (stream.next()) { + final ArrowBuf metadata = stream.getLatestMetadata(); + if (current != metadata.getByte(0)) { + ackStream.onError(Status.INVALID_ARGUMENT.withDescription(String + .format("Metadata does not match expected value; got %d but expected %d.", metadata.getByte(0), + current)).asRuntimeException()); + return; + } + ackStream.onNext(PutResult.metadata(metadata)); + current++; + } + if (current != 10) { + throw new IllegalArgumentException("Wrong number of messages sent."); + } + } catch (Exception e) { + throw new RuntimeException(e); + } + }; + } + } +} diff --git a/java/flight/src/test/java/org/apache/arrow/flight/TestBackPressure.java b/java/flight/src/test/java/org/apache/arrow/flight/TestBackPressure.java index 1b40e7ee426..d0e26e13d7f 100644 --- a/java/flight/src/test/java/org/apache/arrow/flight/TestBackPressure.java +++ b/java/flight/src/test/java/org/apache/arrow/flight/TestBackPressure.java @@ -46,29 +46,27 @@ public void ensureIndependentSteams() throws Exception { try ( final BufferAllocator a = new RootAllocator(Long.MAX_VALUE); final PerformanceTestServer server = FlightTestUtil.getStartedServer( - (port) -> (new PerformanceTestServer(a, Location.forGrpcInsecure(FlightTestUtil.LOCALHOST, port)))); + (location) -> (new PerformanceTestServer(a, location))); final FlightClient client = FlightClient.builder(a, server.getLocation()).build() ) { - FlightStream fs1 = client.getStream(client.getInfo( + try (FlightStream fs1 = client.getStream(client.getInfo( TestPerf.getPerfFlightDescriptor(110L * BATCH_SIZE, BATCH_SIZE, 1)) - .getEndpoints().get(0).getTicket()); - consume(fs1, 10); + .getEndpoints().get(0).getTicket())) { + consume(fs1, 10); - // stop consuming fs1 but make sure we can consume a large amount of fs2. - FlightStream fs2 = client.getStream(client.getInfo( - TestPerf.getPerfFlightDescriptor(200L * BATCH_SIZE, BATCH_SIZE, 1)) - .getEndpoints().get(0).getTicket()); - consume(fs2, 100); + // stop consuming fs1 but make sure we can consume a large amount of fs2. + try (FlightStream fs2 = client.getStream(client.getInfo( + TestPerf.getPerfFlightDescriptor(200L * BATCH_SIZE, BATCH_SIZE, 1)) + .getEndpoints().get(0).getTicket())) { + consume(fs2, 100); - consume(fs1, 100); - consume(fs2, 100); - - consume(fs1); - consume(fs2); - - fs1.close(); - fs2.close(); + consume(fs1, 100); + consume(fs2, 100); + consume(fs1); + consume(fs2); + } + } } } @@ -92,27 +90,28 @@ public void getStream(CallContext context, Ticket ticket, ServerStreamListener listener) { int batches = 0; final Schema pojoSchema = new Schema(ImmutableList.of(Field.nullable("a", MinorType.BIGINT.getType()))); - VectorSchemaRoot root = VectorSchemaRoot.create(pojoSchema, allocator); - listener.start(root); - while (true) { - while (!listener.isReady()) { - try { - Thread.sleep(1); - sleepTime.addAndGet(1L); - } catch (InterruptedException ignore) { + try (VectorSchemaRoot root = VectorSchemaRoot.create(pojoSchema, allocator)) { + listener.start(root); + while (true) { + while (!listener.isReady()) { + try { + Thread.sleep(1); + sleepTime.addAndGet(1L); + } catch (InterruptedException ignore) { + } } - } - if (batches > 100) { - root.clear(); - listener.completed(); - return; - } + if (batches > 100) { + root.clear(); + listener.completed(); + return; + } - root.allocateNew(); - root.setRowCount(4095); - listener.putNext(); - batches++; + root.allocateNew(); + root.setRowCount(4095); + listener.putNext(); + batches++; + } } } }; @@ -121,16 +120,15 @@ public void getStream(CallContext context, Ticket ticket, try ( BufferAllocator serverAllocator = allocator.newChildAllocator("server", 0, Long.MAX_VALUE); FlightServer server = - FlightTestUtil.getStartedServer( - (port) -> FlightServer.builder(serverAllocator, Location.forGrpcInsecure("localhost", port), producer) - .build()); + FlightTestUtil.getStartedServer((location) -> FlightServer.builder(serverAllocator, location, producer) + .build()); BufferAllocator clientAllocator = allocator.newChildAllocator("client", 0, Long.MAX_VALUE); FlightClient client = FlightClient - .builder(clientAllocator, Location.forGrpcInsecure(FlightTestUtil.LOCALHOST, server.getPort())) - .build() + .builder(clientAllocator, server.getLocation()) + .build(); + FlightStream stream = client.getStream(new Ticket(new byte[1])) ) { - FlightStream stream = client.getStream(new Ticket(new byte[1])); VectorSchemaRoot root = stream.getRoot(); root.clear(); Thread.sleep(wait); diff --git a/java/flight/src/test/java/org/apache/arrow/flight/TestBasicOperation.java b/java/flight/src/test/java/org/apache/arrow/flight/TestBasicOperation.java index f8413b0dbdb..abc5a2c321d 100644 --- a/java/flight/src/test/java/org/apache/arrow/flight/TestBasicOperation.java +++ b/java/flight/src/test/java/org/apache/arrow/flight/TestBasicOperation.java @@ -19,14 +19,12 @@ import java.net.URISyntaxException; import java.util.Iterator; -import java.util.concurrent.Callable; import java.util.function.BiConsumer; import java.util.function.Consumer; import org.apache.arrow.flight.FlightClient.ClientStreamListener; import org.apache.arrow.flight.impl.Flight; import org.apache.arrow.flight.impl.Flight.FlightDescriptor.DescriptorType; -import org.apache.arrow.flight.impl.Flight.PutResult; import org.apache.arrow.memory.BufferAllocator; import org.apache.arrow.memory.RootAllocator; import org.apache.arrow.vector.IntVector; @@ -98,7 +96,8 @@ public void putStream() throws Exception { IntVector iv = new IntVector("c1", a); VectorSchemaRoot root = VectorSchemaRoot.of(iv); - ClientStreamListener listener = c.startPut(FlightDescriptor.path("hello"), root); + ClientStreamListener listener = c + .startPut(FlightDescriptor.path("hello"), root, new AsyncPutListener()); //batch 1 root.allocateNew(); @@ -155,12 +154,11 @@ private void test(BiConsumer consumer) throws Exc Producer producer = new Producer(a); FlightServer s = FlightTestUtil.getStartedServer( - (port) -> FlightServer.builder(a, Location.forGrpcInsecure("localhost", port), producer).build() + (location) -> FlightServer.builder(a, location, producer).build() )) { try ( - FlightClient c = FlightClient.builder(a, Location.forGrpcInsecure(FlightTestUtil.LOCALHOST, s.getPort())) - .build() + FlightClient c = FlightClient.builder(a, s.getLocation()).build() ) { try (BufferAllocator testAllocator = a.newChildAllocator("testcase", 0, Long.MAX_VALUE)) { consumer.accept(c, testAllocator); @@ -199,14 +197,12 @@ public void listFlights(CallContext context, Criteria criteria, } @Override - public Callable acceptPut(CallContext context, - FlightStream flightStream) { + public Runnable acceptPut(CallContext context, FlightStream flightStream, StreamListener ackStream) { return () -> { try (VectorSchemaRoot root = flightStream.getRoot()) { while (flightStream.next()) { } - return PutResult.getDefaultInstance(); } }; } diff --git a/java/flight/src/test/java/org/apache/arrow/flight/TestCallOptions.java b/java/flight/src/test/java/org/apache/arrow/flight/TestCallOptions.java index 71d99862fda..3acb9473006 100644 --- a/java/flight/src/test/java/org/apache/arrow/flight/TestCallOptions.java +++ b/java/flight/src/test/java/org/apache/arrow/flight/TestCallOptions.java @@ -69,11 +69,8 @@ void test(Consumer testFn) { BufferAllocator a = new RootAllocator(Long.MAX_VALUE); Producer producer = new Producer(a); FlightServer s = - FlightTestUtil.getStartedServer( - (port) -> FlightServer.builder(a, Location.forGrpcInsecure(FlightTestUtil.LOCALHOST, port), producer) - .build()); - FlightClient client = FlightClient.builder(a, Location.forGrpcInsecure(FlightTestUtil.LOCALHOST, s.getPort())) - .build()) { + FlightTestUtil.getStartedServer((location) -> FlightServer.builder(a, location, producer).build()); + FlightClient client = FlightClient.builder(a, s.getLocation()).build()) { testFn.accept(client); } catch (InterruptedException | IOException e) { throw new RuntimeException(e); diff --git a/java/flight/src/test/java/org/apache/arrow/flight/TestLargeMessage.java b/java/flight/src/test/java/org/apache/arrow/flight/TestLargeMessage.java index 99135484623..629b6f5ebd8 100644 --- a/java/flight/src/test/java/org/apache/arrow/flight/TestLargeMessage.java +++ b/java/flight/src/test/java/org/apache/arrow/flight/TestLargeMessage.java @@ -19,10 +19,8 @@ import java.util.Arrays; import java.util.List; -import java.util.concurrent.Callable; import java.util.stream.Stream; -import org.apache.arrow.flight.impl.Flight; import org.apache.arrow.memory.BufferAllocator; import org.apache.arrow.memory.RootAllocator; import org.apache.arrow.vector.IntVector; @@ -43,13 +41,11 @@ public void getLargeMessage() throws Exception { try (final BufferAllocator a = new RootAllocator(Long.MAX_VALUE); final Producer producer = new Producer(a); final FlightServer s = - FlightTestUtil.getStartedServer( - (port) -> FlightServer.builder(a, Location.forGrpcInsecure("localhost", port), producer).build())) { + FlightTestUtil.getStartedServer((location) -> FlightServer.builder(a, location, producer).build())) { - try (FlightClient client = FlightClient - .builder(a, Location.forGrpcInsecure(FlightTestUtil.LOCALHOST, s.getPort())).build()) { - FlightStream stream = client.getStream(new Ticket(new byte[]{})); - try (VectorSchemaRoot root = stream.getRoot()) { + try (FlightClient client = FlightClient.builder(a, s.getLocation()).build()) { + try (FlightStream stream = client.getStream(new Ticket(new byte[]{})); + VectorSchemaRoot root = stream.getRoot()) { while (stream.next()) { for (final Field field : root.getSchema().getFields()) { int value = 0; @@ -61,7 +57,6 @@ public void getLargeMessage() throws Exception { } } } - stream.close(); } } } @@ -74,18 +69,17 @@ public void putLargeMessage() throws Exception { try (final BufferAllocator a = new RootAllocator(Long.MAX_VALUE); final Producer producer = new Producer(a); final FlightServer s = - FlightTestUtil.getStartedServer( - (port) -> FlightServer.builder(a, Location.forGrpcInsecure("localhost", port), producer).build() + FlightTestUtil.getStartedServer((location) -> FlightServer.builder(a, location, producer).build() )) { - try (FlightClient client = FlightClient - .builder(a, Location.forGrpcInsecure(FlightTestUtil.LOCALHOST, s.getPort())).build(); + try (FlightClient client = FlightClient.builder(a, s.getLocation()).build(); BufferAllocator testAllocator = a.newChildAllocator("testcase", 0, Long.MAX_VALUE); VectorSchemaRoot root = generateData(testAllocator)) { - final FlightClient.ClientStreamListener listener = client.startPut(FlightDescriptor.path("hello"), root); + final FlightClient.ClientStreamListener listener = client.startPut(FlightDescriptor.path("hello"), root, + new AsyncPutListener()); listener.putNext(); listener.completed(); - Assert.assertEquals(listener.getResult(), Flight.PutResult.getDefaultInstance()); + listener.getResult(); } } } @@ -141,14 +135,12 @@ public FlightInfo getFlightInfo(CallContext context, } @Override - public Callable acceptPut(CallContext context, - FlightStream flightStream) { + public Runnable acceptPut(CallContext context, FlightStream flightStream, StreamListener ackStream) { return () -> { try (VectorSchemaRoot root = flightStream.getRoot()) { while (flightStream.next()) { ; } - return Flight.PutResult.getDefaultInstance(); } }; } diff --git a/java/flight/src/test/java/org/apache/arrow/flight/TestTls.java b/java/flight/src/test/java/org/apache/arrow/flight/TestTls.java index c22304d5647..b9d4dea5572 100644 --- a/java/flight/src/test/java/org/apache/arrow/flight/TestTls.java +++ b/java/flight/src/test/java/org/apache/arrow/flight/TestTls.java @@ -96,9 +96,9 @@ void test(Consumer testFn) { Producer producer = new Producer(); FlightServer s = FlightTestUtil.getStartedServer( - (port) -> { + (location) -> { try { - return FlightServer.builder(a, Location.forGrpcTls(FlightTestUtil.LOCALHOST, port), producer) + return FlightServer.builder(a, location, producer) .useTls(certKey.cert, certKey.key) .build(); } catch (IOException e) { diff --git a/java/flight/src/test/java/org/apache/arrow/flight/auth/TestAuth.java b/java/flight/src/test/java/org/apache/arrow/flight/auth/TestAuth.java index 39b2924c620..54bbadb0369 100644 --- a/java/flight/src/test/java/org/apache/arrow/flight/auth/TestAuth.java +++ b/java/flight/src/test/java/org/apache/arrow/flight/auth/TestAuth.java @@ -29,7 +29,6 @@ import org.apache.arrow.flight.FlightServer; import org.apache.arrow.flight.FlightStream; import org.apache.arrow.flight.FlightTestUtil; -import org.apache.arrow.flight.Location; import org.apache.arrow.flight.NoOpFlightProducer; import org.apache.arrow.flight.Ticket; import org.apache.arrow.memory.BufferAllocator; @@ -119,9 +118,9 @@ public byte[] getToken(String username, String password) { } }; - server = FlightTestUtil.getStartedServer((port) -> FlightServer.builder( + server = FlightTestUtil.getStartedServer((location) -> FlightServer.builder( allocator, - Location.forGrpcInsecure("localhost", port), + location, new NoOpFlightProducer() { @Override public void listFlights(CallContext context, Criteria criteria, @@ -150,8 +149,7 @@ public void getStream(CallContext context, Ticket ticket, ServerStreamListener l listener.completed(); } }).authHandler(new BasicServerAuthHandler(validator)).build()); - client = FlightClient.builder(allocator, Location.forGrpcInsecure(FlightTestUtil.LOCALHOST, server.getPort())) - .build(); + client = FlightClient.builder(allocator, server.getLocation()).build(); } @After diff --git a/java/flight/src/test/java/org/apache/arrow/flight/example/TestExampleServer.java b/java/flight/src/test/java/org/apache/arrow/flight/example/TestExampleServer.java index 097c92cfe19..fb157f45ed1 100644 --- a/java/flight/src/test/java/org/apache/arrow/flight/example/TestExampleServer.java +++ b/java/flight/src/test/java/org/apache/arrow/flight/example/TestExampleServer.java @@ -19,6 +19,7 @@ import java.io.IOException; +import org.apache.arrow.flight.AsyncPutListener; import org.apache.arrow.flight.FlightClient; import org.apache.arrow.flight.FlightClient.ClientStreamListener; import org.apache.arrow.flight.FlightDescriptor; @@ -33,12 +34,12 @@ import org.apache.arrow.vector.VectorSchemaRoot; import org.junit.After; import org.junit.Before; +import org.junit.Ignore; import org.junit.Test; /** * Ensure that example server supports get and put. */ -@org.junit.Ignore public class TestExampleServer { private BufferAllocator allocator; @@ -68,6 +69,7 @@ public void after() throws Exception { } @Test + @Ignore public void putStream() { BufferAllocator a = caseAllocator; final int size = 10; @@ -75,7 +77,8 @@ public void putStream() { IntVector iv = new IntVector("c1", a); VectorSchemaRoot root = VectorSchemaRoot.of(iv); - ClientStreamListener listener = client.startPut(FlightDescriptor.path("hello"), root); + ClientStreamListener listener = client.startPut(FlightDescriptor.path("hello"), root, + new AsyncPutListener()); //batch 1 root.allocateNew(); @@ -102,10 +105,13 @@ public void putStream() { listener.getResult(); FlightInfo info = client.getInfo(FlightDescriptor.path("hello")); - FlightStream stream = client.getStream(info.getEndpoints().get(0).getTicket()); - VectorSchemaRoot newRoot = stream.getRoot(); - while (stream.next()) { - newRoot.clear(); + try (final FlightStream stream = client.getStream(info.getEndpoints().get(0).getTicket())) { + VectorSchemaRoot newRoot = stream.getRoot(); + while (stream.next()) { + newRoot.clear(); + } + } catch (Exception e) { + throw new RuntimeException(e); } } } diff --git a/java/flight/src/test/java/org/apache/arrow/flight/perf/PerformanceTestServer.java b/java/flight/src/test/java/org/apache/arrow/flight/perf/PerformanceTestServer.java index d8d6e671d56..72099b987c4 100644 --- a/java/flight/src/test/java/org/apache/arrow/flight/perf/PerformanceTestServer.java +++ b/java/flight/src/test/java/org/apache/arrow/flight/perf/PerformanceTestServer.java @@ -21,21 +21,14 @@ import java.nio.ByteBuffer; import java.util.ArrayList; import java.util.List; -import java.util.concurrent.Callable; -import org.apache.arrow.flight.Action; -import org.apache.arrow.flight.ActionType; -import org.apache.arrow.flight.Criteria; import org.apache.arrow.flight.FlightDescriptor; import org.apache.arrow.flight.FlightEndpoint; import org.apache.arrow.flight.FlightInfo; -import org.apache.arrow.flight.FlightProducer; import org.apache.arrow.flight.FlightServer; -import org.apache.arrow.flight.FlightStream; import org.apache.arrow.flight.Location; -import org.apache.arrow.flight.Result; +import org.apache.arrow.flight.NoOpFlightProducer; import org.apache.arrow.flight.Ticket; -import org.apache.arrow.flight.impl.Flight.PutResult; import org.apache.arrow.flight.perf.impl.PerfOuterClass.Perf; import org.apache.arrow.flight.perf.impl.PerfOuterClass.Token; import org.apache.arrow.memory.BufferAllocator; @@ -79,7 +72,7 @@ public void close() throws Exception { AutoCloseables.close(flightServer, allocator); } - private final class PerfProducer implements FlightProducer { + private final class PerfProducer extends NoOpFlightProducer { @Override public void getStream(CallContext context, Ticket ticket, @@ -145,11 +138,6 @@ public void getStream(CallContext context, Ticket ticket, } } - @Override - public void listFlights(CallContext context, Criteria criteria, - StreamListener listener) { - } - @Override public FlightInfo getFlightInfo(CallContext context, FlightDescriptor descriptor) { @@ -181,24 +169,6 @@ public FlightInfo getFlightInfo(CallContext context, throw new RuntimeException(e); } } - - @Override - public Callable acceptPut(CallContext context, - FlightStream flightStream) { - return null; - } - - @Override - public void doAction(CallContext context, Action action, - StreamListener listener) { - listener.onCompleted(); - } - - @Override - public void listActions(CallContext context, - StreamListener listener) { - } - } } diff --git a/java/flight/src/test/java/org/apache/arrow/flight/perf/TestPerf.java b/java/flight/src/test/java/org/apache/arrow/flight/perf/TestPerf.java index a9b9d60b1b9..c23c793612d 100644 --- a/java/flight/src/test/java/org/apache/arrow/flight/perf/TestPerf.java +++ b/java/flight/src/test/java/org/apache/arrow/flight/perf/TestPerf.java @@ -28,7 +28,6 @@ import org.apache.arrow.flight.FlightInfo; import org.apache.arrow.flight.FlightStream; import org.apache.arrow.flight.FlightTestUtil; -import org.apache.arrow.flight.Location; import org.apache.arrow.flight.Ticket; import org.apache.arrow.flight.perf.impl.PerfOuterClass.Perf; import org.apache.arrow.memory.BufferAllocator; @@ -81,8 +80,7 @@ public void throughput() throws Exception { try ( final BufferAllocator a = new RootAllocator(Long.MAX_VALUE); final PerformanceTestServer server = - FlightTestUtil.getStartedServer((port) -> new PerformanceTestServer(a, - Location.forGrpcInsecure(FlightTestUtil.LOCALHOST, port))); + FlightTestUtil.getStartedServer((location) -> new PerformanceTestServer(a, location)); final FlightClient client = FlightClient.builder(a, server.getLocation()).build(); ) { final FlightInfo info = client.getInfo(getPerfFlightDescriptor(50_000_000L, 4095, 2)); @@ -93,11 +91,13 @@ public void throughput() throws Exception { .map(t -> pool.submit(t)) .collect(Collectors.toList()); - Futures.whenAllSucceed(results); - Result r = new Result(); - for (ListenableFuture f : results) { - r.add(f.get()); - } + final Result r = Futures.whenAllSucceed(results).call(() -> { + Result res = new Result(); + for (ListenableFuture f : results) { + res.add(f.get()); + } + return res; + }).get(); double seconds = r.nanos * 1.0d / 1000 / 1000 / 1000; System.out.println(String.format( @@ -127,28 +127,29 @@ public Consumer(FlightClient client, Ticket ticket) { public Result call() throws Exception { final Result r = new Result(); Stopwatch watch = Stopwatch.createStarted(); - FlightStream stream = client.getStream(ticket); - final VectorSchemaRoot root = stream.getRoot(); - try { - BigIntVector a = (BigIntVector) root.getVector("a"); - while (stream.next()) { - int rows = root.getRowCount(); - long aSum = r.aSum; - for (int i = 0; i < rows; i++) { - if (VALIDATE) { - aSum += a.get(i); + try (final FlightStream stream = client.getStream(ticket)) { + final VectorSchemaRoot root = stream.getRoot(); + try { + BigIntVector a = (BigIntVector) root.getVector("a"); + while (stream.next()) { + int rows = root.getRowCount(); + long aSum = r.aSum; + for (int i = 0; i < rows; i++) { + if (VALIDATE) { + aSum += a.get(i); + } } + r.bytes += rows * 32; + r.rows += rows; + r.aSum = aSum; + r.batches++; } - r.bytes += rows * 32; - r.rows += rows; - r.aSum = aSum; - r.batches++; - } - r.nanos = watch.elapsed(TimeUnit.NANOSECONDS); - return r; - } finally { - root.clear(); + r.nanos = watch.elapsed(TimeUnit.NANOSECONDS); + return r; + } finally { + root.clear(); + } } } diff --git a/java/pom.xml b/java/pom.xml index 540b41b2d54..916b7f1e3e2 100644 --- a/java/pom.xml +++ b/java/pom.xml @@ -33,7 +33,7 @@ 5.4.0 1.7.25 20.0 - 4.1.22.Final + 4.1.27.Final 2.9.8 2.7.1 1.9.0 diff --git a/python/examples/flight/server.py b/python/examples/flight/server.py index 72ed590e71c..3b699723df4 100644 --- a/python/examples/flight/server.py +++ b/python/examples/flight/server.py @@ -77,7 +77,7 @@ def do_get(self, context, ticket): return None return pyarrow.flight.RecordBatchStream(self.flights[key]) - def list_actions(self): + def list_actions(self, context): return [ ("clear", "Clear the stored flights."), ("shutdown", "Shut down this server."), diff --git a/python/pyarrow/_flight.pyx b/python/pyarrow/_flight.pyx index 7ca83a94994..7fc4ed48c99 100644 --- a/python/pyarrow/_flight.pyx +++ b/python/pyarrow/_flight.pyx @@ -44,6 +44,15 @@ cdef class FlightCallOptions: CFlightCallOptions options def __init__(self, timeout=None): + """Create call options. + + Parameters + ---------- + timeout : float or None + A timeout for the call, in seconds. None means that the + timeout defaults to an implementation-specific value. + + """ if timeout is not None: self.options.timeout = CTimeoutDuration(timeout) @@ -70,14 +79,24 @@ cdef class Action: CAction action def __init__(self, action_type, buf): + """Create an action from a type and a buffer. + + Parameters + ---------- + action_type : bytes or str + buf : Buffer or bytes-like object + """ self.action.type = tobytes(action_type) self.action.body = pyarrow_unwrap_buffer(as_buffer(buf)) @property def type(self): + """The action type.""" return frombytes(self.action.type) + @property def body(self): + """The action body (arguments for the action).""" return pyarrow_wrap_buffer(self.action.body) @staticmethod @@ -92,10 +111,16 @@ _ActionType = collections.namedtuple('_ActionType', ['type', 'description']) class ActionType(_ActionType): - """A type of action executable on a Flight service.""" + """A type of action that is executable on a Flight service.""" def make_action(self, buf): - """Create an Action with this type.""" + """Create an Action with this type. + + Parameters + ---------- + buf : obj + An Arrow buffer or Python bytes or bytes-like object. + """ return Action(self.type, buf) @@ -105,6 +130,12 @@ cdef class Result: unique_ptr[CResult] result def __init__(self, buf): + """Create a new result. + + Parameters + ---------- + buf : Buffer or bytes-like object + """ self.result.reset(new CResult()) self.result.get().body = pyarrow_unwrap_buffer(as_buffer(buf)) @@ -115,6 +146,23 @@ cdef class Result: class DescriptorType(enum.Enum): + """ + The type of a FlightDescriptor. + + Attributes + ---------- + + UNKNOWN + An unknown descriptor type. + + PATH + A Flight stream represented by a path. + + CMD + A Flight stream represented by an application-defined command. + + """ + UNKNOWN = 0 PATH = 1 CMD = 2 @@ -151,6 +199,7 @@ cdef class FlightDescriptor: @property def descriptor_type(self): + """Get the type of this descriptor.""" if self.descriptor.type == CDescriptorTypeUnknown: return DescriptorType.UNKNOWN elif self.descriptor.type == CDescriptorTypePath: @@ -309,6 +358,7 @@ cdef class FlightEndpoint: @property def ticket(self): + """Get the ticket in this endpoint.""" return Ticket(self.endpoint.ticket.ticket) @property @@ -400,12 +450,149 @@ cdef class FlightInfo: return result -cdef class FlightRecordBatchReader(_CRecordBatchReader, _ReadPandasOption): +cdef class FlightStreamChunk: + """A RecordBatch with application metadata on the side.""" + cdef: + CFlightStreamChunk chunk + + @property + def data(self): + if self.chunk.data == NULL: + return None + return pyarrow_wrap_batch(self.chunk.data) + + @property + def app_metadata(self): + if self.chunk.app_metadata == NULL: + return None + return pyarrow_wrap_buffer(self.chunk.app_metadata) + + def __iter__(self): + return iter((self.data, self.app_metadata)) + + +cdef class _MetadataRecordBatchReader: + """A reader for Flight streams.""" + + # Needs to be separate class so the "real" class can subclass the + # pure-Python mixin class + cdef dict __dict__ + cdef shared_ptr[CMetadataRecordBatchReader] reader + + cdef readonly: + Schema schema + + +cdef class MetadataRecordBatchReader(_MetadataRecordBatchReader, + _ReadPandasOption): + """A reader for Flight streams.""" + + def __iter__(self): + while True: + yield self.read_chunk() + + def read_all(self): + """Read the entire contents of the stream as a Table.""" + cdef: + shared_ptr[CTable] c_table + with nogil: + check_status(self.reader.get().ReadAll(&c_table)) + return pyarrow_wrap_table(c_table) + + def read_chunk(self): + """Read the next RecordBatch along with any metadata. + + Returns + ------- + data : RecordBatch + The next RecordBatch in the stream. + app_metadata : Buffer or None + Application-specific metadata for the batch as defined by + Flight. + + Raises + ------ + StopIteration + when the stream is finished + """ + cdef: + FlightStreamChunk chunk = FlightStreamChunk() + + with nogil: + check_status(self.reader.get().Next(&chunk.chunk)) + + if chunk.chunk.data == NULL: + raise StopIteration + + return chunk + + +cdef class FlightStreamReader(MetadataRecordBatchReader): + """A reader that can also be canceled.""" + + def cancel(self): + """Cancel the read operation.""" + with nogil: + ( self.reader.get()).Cancel() -cdef class FlightRecordBatchWriter(_CRecordBatchWriter): - pass +cdef class FlightStreamWriter(_CRecordBatchWriter): + """A RecordBatchWriter that also allows writing application metadata.""" + + def write_with_metadata(self, RecordBatch batch, buf): + """Write a RecordBatch along with Flight metadata. + + Parameters + ---------- + batch : RecordBatch + The next RecordBatch in the stream. + buf : Buffer + Application-specific metadata for the batch as defined by + Flight. + """ + cdef shared_ptr[CBuffer] c_buf = pyarrow_unwrap_buffer(as_buffer(buf)) + with nogil: + check_status( + ( self.writer.get()) + .WriteWithMetadata(deref(batch.batch), + c_buf, + 1)) + + +cdef class FlightMetadataReader: + """A reader for Flight metadata messages sent during a DoPut.""" + + cdef: + unique_ptr[CFlightMetadataReader] reader + + def read(self): + """Read the next metadata message.""" + cdef shared_ptr[CBuffer] buf + with nogil: + check_status(self.reader.get().ReadMetadata(&buf)) + if buf == NULL: + return None + return pyarrow_wrap_buffer(buf) + + +cdef class FlightMetadataWriter: + """A sender for Flight metadata messages during a DoPut.""" + + cdef: + unique_ptr[CFlightMetadataWriter] writer + + def write(self, message): + """Write the next metadata message. + + Parameters + ---------- + message : Buffer + """ + cdef shared_ptr[CBuffer] buf = \ + pyarrow_unwrap_buffer(as_buffer(message)) + with nogil: + check_status(self.writer.get().WriteMetadata(deref(buf))) cdef class FlightClient: @@ -451,7 +638,15 @@ cdef class FlightClient: return result def authenticate(self, auth_handler, options: FlightCallOptions = None): - """Authenticate to the server.""" + """Authenticate to the server. + + Parameters + ---------- + auth_handler : ClientAuthHandler + The authentication mechanism to use. + options : FlightCallOptions + Options for this call. + """ cdef: unique_ptr[CClientAuthHandler] handler CFlightCallOptions* c_options = FlightCallOptions.unwrap(options) @@ -539,34 +734,53 @@ cdef class FlightClient: return result def do_get(self, ticket: Ticket, options: FlightCallOptions = None): - """Request the data for a flight.""" + """Request the data for a flight. + + Returns + ------- + reader : FlightStreamReader + """ cdef: - unique_ptr[CRecordBatchReader] reader + unique_ptr[CFlightStreamReader] reader CFlightCallOptions* c_options = FlightCallOptions.unwrap(options) with nogil: - check_status(self.client.get().DoGet( - deref(c_options), ticket.ticket, &reader)) - result = FlightRecordBatchReader() + check_status( + self.client.get().DoGet( + deref(c_options), ticket.ticket, &reader)) + result = FlightStreamReader() result.reader.reset(reader.release()) + result.schema = pyarrow_wrap_schema(result.reader.get().schema()) return result def do_put(self, descriptor: FlightDescriptor, schema: Schema, options: FlightCallOptions = None): - """Upload data to a flight.""" + """Upload data to a flight. + + Returns + ------- + writer : FlightStreamWriter + reader : FlightMetadataReader + """ cdef: shared_ptr[CSchema] c_schema = pyarrow_unwrap_schema(schema) - unique_ptr[CRecordBatchWriter] writer + unique_ptr[CFlightStreamWriter] writer + unique_ptr[CFlightMetadataReader] metadata_reader CFlightCallOptions* c_options = FlightCallOptions.unwrap(options) CFlightDescriptor c_descriptor = \ FlightDescriptor.unwrap(descriptor) + FlightMetadataReader reader = FlightMetadataReader() with nogil: check_status(self.client.get().DoPut( - deref(c_options), c_descriptor, c_schema, &writer)) - result = FlightRecordBatchWriter() + deref(c_options), + c_descriptor, + c_schema, + &writer, + &reader.reader)) + result = FlightStreamWriter() result.writer.reset(writer.release()) - return result + return result, reader cdef class FlightDataStream: @@ -809,11 +1023,22 @@ cdef void _data_stream_next(void* self, CFlightPayload* payload) except *: payload.ipc_message.metadata.reset( nullptr) return + if isinstance(result, (list, tuple)): + result, metadata = result + else: + result, metadata = result, None + if isinstance(result, (Table, _CRecordBatchReader)): + if metadata: + raise ValueError("Can only return metadata alongside a " + "RecordBatch.") result = RecordBatchStream(result) stream_schema = pyarrow_wrap_schema(stream.schema) if isinstance(result, FlightDataStream): + if metadata: + raise ValueError("Can only return metadata alongside a " + "RecordBatch.") data_stream = unique_ptr[CFlightDataStream]( ( result).to_stream()) substream_schema = pyarrow_wrap_schema(data_stream.get().schema()) @@ -838,6 +1063,8 @@ cdef void _data_stream_next(void* self, CFlightPayload* payload) except *: deref(batch.batch), c_default_memory_pool(), &payload.ipc_message)) + if metadata: + payload.app_metadata = pyarrow_unwrap_buffer(as_buffer(metadata)) else: raise TypeError("GeneratorStream must be initialized with " "an iterator of FlightDataStream, Table, " @@ -880,17 +1107,22 @@ cdef void _get_flight_info(void* self, const CServerCallContext& context, cdef void _do_put(void* self, const CServerCallContext& context, - unique_ptr[CFlightMessageReader] reader) except *: + unique_ptr[CFlightMessageReader] reader, + unique_ptr[CFlightMetadataWriter] writer) except *: """Callback for implementing Flight servers in Python.""" cdef: - FlightRecordBatchReader py_reader = FlightRecordBatchReader() + MetadataRecordBatchReader py_reader = MetadataRecordBatchReader() + FlightMetadataWriter py_writer = FlightMetadataWriter() FlightDescriptor descriptor = \ FlightDescriptor.__new__(FlightDescriptor) descriptor.descriptor = reader.get().descriptor() py_reader.reader.reset(reader.release()) + py_reader.schema = pyarrow_wrap_schema( + py_reader.reader.get().schema()) + py_writer.writer.reset(writer.release()) ( self).do_put(ServerCallContext.wrap(context), descriptor, - py_reader) + py_reader, py_writer) cdef void _do_get(void* self, const CServerCallContext& context, @@ -943,7 +1175,7 @@ cdef void _list_actions(void* self, const CServerCallContext& context, cdef: CActionType action_type # Method should return a list of ActionTypes or similar tuple - result = ( self).list_actions() + result = ( self).list_actions(ServerCallContext.wrap(context)) for action in result: action_type.type = tobytes(action[0]) action_type.description = tobytes(action[1]) @@ -990,10 +1222,25 @@ cdef void _get_token(void* self, c_string* token) except *: cdef class ServerAuthHandler: - """Authentication middleware for a server.""" + """Authentication middleware for a server. + + To implement an authentication mechanism, subclass this class and + override its methods. + + """ def authenticate(self, outgoing, incoming): - """Conduct the handshake with the client.""" + """Conduct the handshake with the client. + + May raise an error if the client cannot authenticate. + + Parameters + ---------- + outgoing : ServerAuthSender + A channel to send messages to the client. + incoming : ServerAuthReader + A channel to read messages from the client. + """ raise NotImplementedError def is_valid(self, token): @@ -1003,6 +1250,11 @@ cdef class ServerAuthHandler: name the peer) or raise an exception (if the token is invalid). + Parameters + ---------- + token : bytes + The authentication token from the client. + """ raise NotImplementedError @@ -1017,7 +1269,15 @@ cdef class ClientAuthHandler: """Authentication plugin for a client.""" def authenticate(self, outgoing, incoming): - """Conduct the handshake with the server.""" + """Conduct the handshake with the server. + + Parameters + ---------- + outgoing : ClientAuthSender + A channel to send messages to the server. + incoming : ClientAuthReader + A channel to read messages from the server. + """ raise NotImplementedError def get_token(self): @@ -1032,12 +1292,26 @@ cdef class ClientAuthHandler: cdef class FlightServerBase: - """A Flight service definition.""" + """A Flight service definition. + + Override methods to define your Flight service. + + """ cdef: unique_ptr[PyFlightServer] server def run(self, location, auth_handler=None, tls_certificates=None): + """Start this server. + + Parameters + ---------- + location : Location + auth_handler : ServerAuthHandler + An authentication mechanism to use. May be None. + tls_certificates : list + A list of (certificate, key) pairs. + """ cdef: PyFlightServerVtable vtable = PyFlightServerVtable() PyFlightServer* c_server @@ -1078,7 +1352,8 @@ cdef class FlightServerBase: def get_flight_info(self, context, descriptor): raise NotImplementedError - def do_put(self, context, descriptor, reader): + def do_put(self, context, descriptor, reader, + writer: FlightMetadataWriter): raise NotImplementedError def do_get(self, context, ticket): diff --git a/python/pyarrow/includes/libarrow_flight.pxd b/python/pyarrow/includes/libarrow_flight.pxd index 61e9571995d..49a515352cd 100644 --- a/python/pyarrow/includes/libarrow_flight.pxd +++ b/python/pyarrow/includes/libarrow_flight.pxd @@ -112,22 +112,50 @@ cdef extern from "arrow/flight/api.h" namespace "arrow" nogil: cdef cppclass CSimpleFlightListing" arrow::flight::SimpleFlightListing": CSimpleFlightListing(vector[CFlightInfo]&& info) - cdef cppclass CFlightMessageReader \ - " arrow::flight::FlightMessageReader"(CRecordBatchReader): - CFlightDescriptor& descriptor() - cdef cppclass CFlightPayload" arrow::flight::FlightPayload": shared_ptr[CBuffer] descriptor + shared_ptr[CBuffer] app_metadata CIpcPayload ipc_message cdef cppclass CFlightDataStream" arrow::flight::FlightDataStream": shared_ptr[CSchema] schema() CStatus Next(CFlightPayload*) + cdef cppclass CFlightStreamChunk" arrow::flight::FlightStreamChunk": + CFlightStreamChunk() + shared_ptr[CRecordBatch] data + shared_ptr[CBuffer] app_metadata + + cdef cppclass CMetadataRecordBatchReader \ + " arrow::flight::MetadataRecordBatchReader": + shared_ptr[CSchema] schema() + CStatus Next(CFlightStreamChunk* out) + CStatus ReadAll(shared_ptr[CTable]* table) + + cdef cppclass CFlightStreamReader \ + " arrow::flight::FlightStreamReader"(CMetadataRecordBatchReader): + void Cancel() + + cdef cppclass CFlightMessageReader \ + " arrow::flight::FlightMessageReader"(CMetadataRecordBatchReader): + CFlightDescriptor& descriptor() + + cdef cppclass CFlightStreamWriter \ + " arrow::flight::FlightStreamWriter"(CRecordBatchWriter): + CStatus WriteWithMetadata(const CRecordBatch& batch, + shared_ptr[CBuffer] app_metadata, + c_bool allow_64bit) + cdef cppclass CRecordBatchStream \ " arrow::flight::RecordBatchStream"(CFlightDataStream): CRecordBatchStream(shared_ptr[CRecordBatchReader]& reader) + cdef cppclass CFlightMetadataReader" arrow::flight::FlightMetadataReader": + CStatus ReadMetadata(shared_ptr[CBuffer]* out) + + cdef cppclass CFlightMetadataWriter" arrow::flight::FlightMetadataWriter": + CStatus WriteMetadata(const CBuffer& message) + cdef cppclass CServerAuthReader" arrow::flight::ServerAuthReader": CStatus Read(c_string* token) @@ -193,11 +221,12 @@ cdef extern from "arrow/flight/api.h" namespace "arrow" nogil: unique_ptr[CFlightInfo]* info) CStatus DoGet(CFlightCallOptions& options, CTicket& ticket, - unique_ptr[CRecordBatchReader]* stream) + unique_ptr[CFlightStreamReader]* stream) CStatus DoPut(CFlightCallOptions& options, CFlightDescriptor& descriptor, shared_ptr[CSchema]& schema, - unique_ptr[CRecordBatchWriter]* stream) + unique_ptr[CFlightStreamWriter]* stream, + unique_ptr[CFlightMetadataReader]* reader) # Callbacks for implementing Flight servers @@ -209,7 +238,8 @@ ctypedef void cb_get_flight_info(object, const CServerCallContext&, const CFlightDescriptor&, unique_ptr[CFlightInfo]*) ctypedef void cb_do_put(object, const CServerCallContext&, - unique_ptr[CFlightMessageReader]) + unique_ptr[CFlightMessageReader], + unique_ptr[CFlightMetadataWriter]) ctypedef void cb_do_get(object, const CServerCallContext&, const CTicket&, unique_ptr[CFlightDataStream]*) diff --git a/python/pyarrow/tests/test_flight.py b/python/pyarrow/tests/test_flight.py index 3088a7a86f1..3f83a1c9adf 100644 --- a/python/pyarrow/tests/test_flight.py +++ b/python/pyarrow/tests/test_flight.py @@ -20,13 +20,13 @@ import contextlib import os import socket +import struct import tempfile import threading import time import traceback import pytest - import pyarrow as pa from pyarrow.compat import tobytes @@ -114,17 +114,57 @@ def do_get(self, context, ticket): return flight.RecordBatchStream(table) +class MetadataFlightServer(flight.FlightServerBase): + """A Flight server that numbers incoming/outgoing data.""" + + def do_get(self, context, ticket): + data = [ + pa.array([-10, -5, 0, 5, 10]) + ] + table = pa.Table.from_arrays(data, names=['a']) + return flight.GeneratorStream( + table.schema, + self.number_batches(table)) + + def do_put(self, context, descriptor, reader, writer): + counter = 0 + expected_data = [-10, -5, 0, 5, 10] + while True: + try: + batch, buf = reader.read_chunk() + assert batch.equals(pa.RecordBatch.from_arrays( + [pa.array([expected_data[counter]])], + ['a'] + )) + assert buf is not None + client_counter, = struct.unpack('