diff --git a/ci/travis_script_python.sh b/ci/travis_script_python.sh index 710ebb92646..5c95eb5fefb 100755 --- a/ci/travis_script_python.sh +++ b/ci/travis_script_python.sh @@ -90,6 +90,10 @@ CMAKE_COMMON_FLAGS="-DARROW_EXTRA_ERROR_CONTEXT=ON" PYTHON_CPP_BUILD_TARGETS="arrow_python-all plasma parquet" +if [ "$ARROW_TRAVIS_FLIGHT" == "1" ]; then + CMAKE_COMMON_FLAGS="$CMAKE_COMMON_FLAGS -DARROW_FLIGHT=ON" +fi + if [ "$ARROW_TRAVIS_COVERAGE" == "1" ]; then CMAKE_COMMON_FLAGS="$CMAKE_COMMON_FLAGS -DARROW_GENERATE_COVERAGE=ON" fi diff --git a/cpp/src/arrow/CMakeLists.txt b/cpp/src/arrow/CMakeLists.txt index 6854f14f068..d1f48525233 100644 --- a/cpp/src/arrow/CMakeLists.txt +++ b/cpp/src/arrow/CMakeLists.txt @@ -274,6 +274,7 @@ if(ARROW_BUILD_TESTS OR ARROW_BUILD_BENCHMARKS) # that depend on gtest add_arrow_lib(arrow_testing SOURCES + ipc/test-common.cc testing/gtest_util.cc testing/random.cc OUTPUTS diff --git a/cpp/src/arrow/array-test.cc b/cpp/src/arrow/array-test.cc index ad90e46ced9..6826c7294d8 100644 --- a/cpp/src/arrow/array-test.cc +++ b/cpp/src/arrow/array-test.cc @@ -34,7 +34,6 @@ #include "arrow/buffer-builder.h" #include "arrow/buffer.h" #include "arrow/builder.h" -#include "arrow/ipc/test-common.h" #include "arrow/memory_pool.h" #include "arrow/record_batch.h" #include "arrow/status.h" @@ -1690,45 +1689,6 @@ TEST_F(TestAdaptiveUIntBuilder, TestAppendNulls) { } } -// ---------------------------------------------------------------------- -// Union tests - -TEST(TestUnionArrayAdHoc, TestSliceEquals) { - std::shared_ptr batch; - ASSERT_OK(ipc::MakeUnion(&batch)); - - const int64_t size = batch->num_rows(); - - auto CheckUnion = [&size](std::shared_ptr array) { - std::shared_ptr slice, slice2; - slice = array->Slice(2); - ASSERT_EQ(size - 2, slice->length()); - - slice2 = array->Slice(2); - ASSERT_EQ(size - 2, slice->length()); - - ASSERT_TRUE(slice->Equals(slice2)); - ASSERT_TRUE(array->RangeEquals(2, array->length(), 0, slice)); - - // Chained slices - slice2 = array->Slice(1)->Slice(1); - ASSERT_TRUE(slice->Equals(slice2)); - - slice = array->Slice(1, 5); - slice2 = array->Slice(1, 5); - ASSERT_EQ(5, slice->length()); - - ASSERT_TRUE(slice->Equals(slice2)); - ASSERT_TRUE(array->RangeEquals(1, 6, 0, slice)); - - AssertZeroPadded(*array); - TestInitialized(*array); - }; - - CheckUnion(batch->column(1)); - CheckUnion(batch->column(2)); -} - using DecimalVector = std::vector; class DecimalTest : public ::testing::TestWithParam { diff --git a/cpp/src/arrow/array-union-test.cc b/cpp/src/arrow/array-union-test.cc new file mode 100644 index 00000000000..067d195dedc --- /dev/null +++ b/cpp/src/arrow/array-union-test.cc @@ -0,0 +1,74 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#include +#include +#include +#include + +#include + +#include "arrow/array.h" +#include "arrow/builder.h" +#include "arrow/status.h" +// TODO ipc shouldn't be included here +#include "arrow/ipc/test-common.h" +#include "arrow/testing/gtest_common.h" +#include "arrow/testing/gtest_util.h" +#include "arrow/type.h" +#include "arrow/util/checked_cast.h" + +namespace arrow { + +using internal::checked_cast; + +TEST(TestUnionArrayAdHoc, TestSliceEquals) { + std::shared_ptr batch; + ASSERT_OK(ipc::test::MakeUnion(&batch)); + + auto CheckUnion = [](std::shared_ptr array) { + const int64_t size = array->length(); + std::shared_ptr slice, slice2; + slice = array->Slice(2); + ASSERT_EQ(size - 2, slice->length()); + + slice2 = array->Slice(2); + ASSERT_EQ(size - 2, slice->length()); + + ASSERT_TRUE(slice->Equals(slice2)); + ASSERT_TRUE(array->RangeEquals(2, array->length(), 0, slice)); + + // Chained slices + slice2 = array->Slice(1)->Slice(1); + ASSERT_TRUE(slice->Equals(slice2)); + + slice = array->Slice(1, 5); + slice2 = array->Slice(1, 5); + ASSERT_EQ(5, slice->length()); + + ASSERT_TRUE(slice->Equals(slice2)); + ASSERT_TRUE(array->RangeEquals(1, 6, 0, slice)); + + AssertZeroPadded(*array); + TestInitialized(*array); + }; + + CheckUnion(batch->column(1)); + CheckUnion(batch->column(2)); +} + +} // namespace arrow diff --git a/cpp/src/arrow/flight/client.cc b/cpp/src/arrow/flight/client.cc index 154d34f826b..28b237d4348 100644 --- a/cpp/src/arrow/flight/client.cc +++ b/cpp/src/arrow/flight/client.cc @@ -16,7 +16,6 @@ // under the License. #include "arrow/flight/client.h" -#include "arrow/flight/protocol-internal.h" // IWYU pragma: keep #include #include @@ -30,10 +29,10 @@ #include #endif -#include "arrow/ipc/dictionary.h" -#include "arrow/ipc/metadata-internal.h" +#include "arrow/buffer.h" #include "arrow/ipc/reader.h" #include "arrow/ipc/writer.h" +#include "arrow/memory_pool.h" #include "arrow/record_batch.h" #include "arrow/status.h" #include "arrow/type.h" @@ -58,86 +57,90 @@ struct ClientRpc { /// XXX workaround until we have a handshake in Connect context.set_wait_for_ready(true); } + + Status IOError(const std::string& error_message) { + std::stringstream ss; + ss << error_message << context.debug_error_string(); + return Status::IOError(ss.str()); + } }; -class FlightStreamReader : public RecordBatchReader { +class FlightIpcMessageReader : public ipc::MessageReader { public: - FlightStreamReader(std::unique_ptr rpc, - const std::shared_ptr& schema, - std::unique_ptr> stream) - : rpc_(std::move(rpc)), - stream_finished_(false), - schema_(schema), - stream_(std::move(stream)) {} - - std::shared_ptr schema() const override { return schema_; } - - Status ReadNext(std::shared_ptr* out) override { - internal::FlightData data; + FlightIpcMessageReader(std::unique_ptr rpc, + std::unique_ptr> stream) + : rpc_(std::move(rpc)), stream_(std::move(stream)), stream_finished_(false) {} + Status ReadNextMessage(std::unique_ptr* out) override { if (stream_finished_) { *out = nullptr; return Status::OK(); } - - // Pretend to be pb::FlightData and intercept in SerializationTraits - if (stream_->Read(reinterpret_cast(&data))) { - std::unique_ptr message; - - // Validate IPC message - RETURN_NOT_OK(ipc::Message::Open(data.metadata, data.body, &message)); - if (message->type() == ipc::Message::Type::RECORD_BATCH) { - return ipc::ReadRecordBatch(*message, schema_, out); - } else if (message->type() == ipc::Message::Type::SCHEMA) { - return Status(StatusCode::Invalid, "Flight stream changed schema midway"); - } else { - return Status(StatusCode::Invalid, "Unrecognized message in Flight stream"); - } - } else { + internal::FlightData data; + if (!internal::ReadPayload(stream_.get(), &data)) { // Stream is completed stream_finished_ = true; *out = nullptr; - return internal::FromGrpcStatus(stream_->Finish()); + return OverrideWithServerError(Status::OK()); } + // Validate IPC message + auto st = data.OpenMessage(out); + if (!st.ok()) { + return OverrideWithServerError(std::move(st)); + } + return Status::OK(); + } + + protected: + Status OverrideWithServerError(Status&& st) { + // Get the gRPC status if not OK, to propagate any server error message + RETURN_NOT_OK(internal::FromGrpcStatus(stream_->Finish())); + return st; } - private: // The RPC context lifetime must be coupled to the ClientReader std::unique_ptr rpc_; - - bool stream_finished_; - std::shared_ptr schema_; std::unique_ptr> stream_; + bool stream_finished_; }; -/// \brief A RecordBatchWriter implementation that writes to a Flight -/// DoPut stream. -class FlightPutWriter::FlightPutWriterImpl : public ipc::RecordBatchWriter { +/// A IpcPayloadWriter implementation that writes to a DoPut stream +class DoPutPayloadWriter : public ipc::internal::IpcPayloadWriter { public: - explicit FlightPutWriterImpl(std::unique_ptr rpc, - const FlightDescriptor& descriptor, - const std::shared_ptr& schema, - MemoryPool* pool = default_memory_pool()) - : rpc_(std::move(rpc)), descriptor_(descriptor), schema_(schema), pool_(pool) {} + DoPutPayloadWriter(const FlightDescriptor& descriptor, std::unique_ptr rpc, + std::unique_ptr response, + std::unique_ptr> writer) + : descriptor_(descriptor), + rpc_(std::move(rpc)), + response_(std::move(response)), + writer_(std::move(writer)), + first_payload_(true) {} + + ~DoPutPayloadWriter() override = default; - Status WriteRecordBatch(const RecordBatch& batch, bool allow_64bit = false) override { + Status Start() override { return Status::OK(); } + + Status WritePayload(const ipc::internal::IpcPayload& ipc_payload) override { FlightPayload payload; - RETURN_NOT_OK( - ipc::internal::GetRecordBatchPayload(batch, pool_, &payload.ipc_message)); + payload.ipc_message = ipc_payload; + + if (first_payload_) { + // First Flight message needs to encore the Flight descriptor + DCHECK_EQ(ipc_payload.type, ipc::Message::SCHEMA); + std::string str_descr; + { + pb::FlightDescriptor pb_descr; + RETURN_NOT_OK(internal::ToProto(descriptor_, &pb_descr)); + if (!pb_descr.SerializeToString(&str_descr)) { + return Status::UnknownError("Failed to serialized Flight descriptor"); + } + } + RETURN_NOT_OK(Buffer::FromString(str_descr, &payload.descriptor)); + first_payload_ = false; + } -#ifndef _WIN32 -#pragma GCC diagnostic push -#pragma GCC diagnostic ignored "-Wstrict-aliasing" -#endif - if (!writer_->Write(*reinterpret_cast(&payload), - grpc::WriteOptions())) { -#ifndef _WIN32 -#pragma GCC diagnostic pop -#endif - std::stringstream ss; - ss << "Could not write record batch to stream: " - << rpc_->context.debug_error_string(); - return Status::IOError(ss.str()); + if (!internal::WritePayload(payload, writer_.get())) { + return rpc_->IOError("Could not write record batch to stream: "); } return Status::OK(); } @@ -152,43 +155,15 @@ class FlightPutWriter::FlightPutWriterImpl : public ipc::RecordBatchWriter { return Status::OK(); } - void set_memory_pool(MemoryPool* pool) override { pool_ = pool; } - - private: - /// \brief Set the gRPC writer backing this Flight stream. - /// \param [in] writer the gRPC writer - void set_stream(std::unique_ptr> writer) { - writer_ = std::move(writer); - } - + protected: // TODO: there isn't a way to access this as a user. - protocol::PutResult response; + const FlightDescriptor descriptor_; std::unique_ptr rpc_; - FlightDescriptor descriptor_; - std::shared_ptr schema_; + std::unique_ptr response_; std::unique_ptr> writer_; - MemoryPool* pool_; - - // We need to reference some fields - friend class FlightClient; + bool first_payload_; }; -FlightPutWriter::~FlightPutWriter() {} - -FlightPutWriter::FlightPutWriter(std::unique_ptr impl) { - impl_ = std::move(impl); -} - -Status FlightPutWriter::WriteRecordBatch(const RecordBatch& batch, bool allow_64bit) { - return impl_->WriteRecordBatch(batch, allow_64bit); -} - -Status FlightPutWriter::Close() { return impl_->Close(); } - -void FlightPutWriter::set_memory_pool(MemoryPool* pool) { - return impl_->set_memory_pool(pool); -} - class FlightClient::FlightClientImpl { public: Status Connect(const std::string& host, int port) { @@ -218,13 +193,13 @@ class FlightClient::FlightClientImpl { std::vector flights; pb::FlightGetInfo pb_info; - FlightInfo::Data info_data; while (stream->Read(&pb_info)) { + FlightInfo::Data info_data; RETURN_NOT_OK(internal::FromProto(pb_info, &info_data)); - flights.emplace_back(FlightInfo(std::move(info_data))); + flights.emplace_back(std::move(info_data)); } - listing->reset(new SimpleFlightListing(flights)); + listing->reset(new SimpleFlightListing(std::move(flights))); return internal::FromGrpcStatus(stream->Finish()); } @@ -292,65 +267,23 @@ class FlightClient::FlightClientImpl { std::unique_ptr> stream( stub_->DoGet(&rpc->context, pb_ticket)); - // First message must be the schema - std::shared_ptr schema; - internal::FlightData data; - if (!stream->Read(reinterpret_cast(&data))) { - // Get the gRPC status if not OK, to get any server error - // messages - RETURN_NOT_OK(internal::FromGrpcStatus(stream->Finish())); - return Status(StatusCode::Invalid, "No data in Flight stream"); - } - std::unique_ptr message; - RETURN_NOT_OK(ipc::Message::Open(data.metadata, data.body, &message)); - if (message->type() != ipc::Message::Type::SCHEMA) { - return Status(StatusCode::Invalid, "Flight stream did not start with schema"); - } - RETURN_NOT_OK(ipc::ReadSchema(*message, &schema)); - - *out = std::unique_ptr( - new FlightStreamReader(std::move(rpc), schema, std::move(stream))); - return Status::OK(); + std::unique_ptr message_reader( + new FlightIpcMessageReader(std::move(rpc), std::move(stream))); + return ipc::RecordBatchStreamReader::Open(std::move(message_reader), out); } Status DoPut(const FlightDescriptor& descriptor, const std::shared_ptr& schema, - std::unique_ptr* stream) { + std::unique_ptr* out) { std::unique_ptr rpc(new ClientRpc); - std::unique_ptr out( - new FlightPutWriter::FlightPutWriterImpl(std::move(rpc), descriptor, schema)); - std::unique_ptr> write_stream( - stub_->DoPut(&out->rpc_->context, &out->response)); + std::unique_ptr response(new protocol::PutResult); + std::unique_ptr> writer( + stub_->DoPut(&rpc->context, response.get())); - // First write the descriptor and schema to the stream. - FlightPayload payload; - ipc::DictionaryMemo dictionary_memo; - RETURN_NOT_OK(ipc::internal::GetSchemaPayload(*schema, out->pool_, &dictionary_memo, - &payload.ipc_message)); - pb::FlightDescriptor pb_descr; - RETURN_NOT_OK(internal::ToProto(descriptor, &pb_descr)); - std::string str_descr; - pb_descr.SerializeToString(&str_descr); - RETURN_NOT_OK(Buffer::FromString(str_descr, &payload.descriptor)); - -#ifndef _WIN32 -#pragma GCC diagnostic push -#pragma GCC diagnostic ignored "-Wstrict-aliasing" -#endif - if (!write_stream->Write(*reinterpret_cast(&payload), - grpc::WriteOptions())) { -#ifndef _WIN32 -#pragma GCC diagnostic pop -#endif - std::stringstream ss; - ss << "Could not write descriptor and schema to stream: " - << rpc->context.debug_error_string(); - return Status::IOError(ss.str()); - } + std::unique_ptr payload_writer( + new DoPutPayloadWriter(descriptor, std::move(rpc), std::move(response), + std::move(writer))); - out->set_stream(std::move(write_stream)); - *stream = - std::unique_ptr(new FlightPutWriter(std::move(out))); - return Status::OK(); + return ipc::internal::OpenRecordBatchWriter(std::move(payload_writer), schema, out); } private: diff --git a/cpp/src/arrow/flight/client.h b/cpp/src/arrow/flight/client.h index 6277c151034..3603908ca29 100644 --- a/cpp/src/arrow/flight/client.h +++ b/cpp/src/arrow/flight/client.h @@ -109,22 +109,5 @@ class ARROW_EXPORT FlightClient { std::unique_ptr impl_; }; -/// \brief An interface to upload record batches to a Flight server -class ARROW_EXPORT FlightPutWriter : public ipc::RecordBatchWriter { - public: - ~FlightPutWriter() override; - - Status WriteRecordBatch(const RecordBatch& batch, bool allow_64bit = false) override; - Status Close() override; - void set_memory_pool(MemoryPool* pool) override; - - private: - class FlightPutWriterImpl; - explicit FlightPutWriter(std::unique_ptr impl); - std::unique_ptr impl_; - - friend class FlightClient; -}; - } // namespace flight } // namespace arrow diff --git a/cpp/src/arrow/flight/customize_protobuf.h b/cpp/src/arrow/flight/customize_protobuf.h index 58c168c23b9..1d67480a53e 100644 --- a/cpp/src/arrow/flight/customize_protobuf.h +++ b/cpp/src/arrow/flight/customize_protobuf.h @@ -100,6 +100,8 @@ template class SerializationTraits::value>::type> { public: + // In the functions below, we cast back the Message argument to its real + // type (see ReadPayload() and WritePayload() for the initial cast). static Status Serialize(const grpc::protobuf::Message& msg, ByteBuffer* bb, bool* own_buffer) { return arrow::flight::internal::FlightDataSerialize( diff --git a/cpp/src/arrow/flight/flight-test.cc b/cpp/src/arrow/flight/flight-test.cc index b09914690b7..cf67e29ab6f 100644 --- a/cpp/src/arrow/flight/flight-test.cc +++ b/cpp/src/arrow/flight/flight-test.cc @@ -47,6 +47,90 @@ namespace pb = arrow::flight::protocol; namespace arrow { namespace flight { +void AssertEqual(const ActionType& expected, const ActionType& actual) { + ASSERT_EQ(expected.type, actual.type); + ASSERT_EQ(expected.description, actual.description); +} + +void AssertEqual(const FlightDescriptor& expected, const FlightDescriptor& actual) { + ASSERT_TRUE(expected.Equals(actual)); +} + +void AssertEqual(const Ticket& expected, const Ticket& actual) { + ASSERT_EQ(expected.ticket, actual.ticket); +} + +void AssertEqual(const Location& expected, const Location& actual) { + ASSERT_EQ(expected.host, actual.host); + ASSERT_EQ(expected.port, actual.port); +} + +void AssertEqual(const std::vector& expected, + const std::vector& actual) { + ASSERT_EQ(expected.size(), actual.size()); + for (size_t i = 0; i < expected.size(); ++i) { + AssertEqual(expected[i].ticket, actual[i].ticket); + + ASSERT_EQ(expected[i].locations.size(), actual[i].locations.size()); + for (size_t j = 0; j < expected[i].locations.size(); ++j) { + AssertEqual(expected[i].locations[j], actual[i].locations[j]); + } + } +} + +template +void AssertEqual(const std::vector& expected, const std::vector& actual) { + ASSERT_EQ(expected.size(), actual.size()); + for (size_t i = 0; i < expected.size(); ++i) { + AssertEqual(expected[i], actual[i]); + } +} + +void AssertEqual(const FlightInfo& expected, const FlightInfo& actual) { + std::shared_ptr ex_schema, actual_schema; + ASSERT_OK(expected.GetSchema(&ex_schema)); + ASSERT_OK(actual.GetSchema(&actual_schema)); + + AssertSchemaEqual(*ex_schema, *actual_schema); + ASSERT_EQ(expected.total_records(), actual.total_records()); + ASSERT_EQ(expected.total_bytes(), actual.total_bytes()); + + AssertEqual(expected.descriptor(), actual.descriptor()); + AssertEqual(expected.endpoints(), actual.endpoints()); +} + +TEST(TestFlightDescriptor, Basics) { + auto a = FlightDescriptor::Command("select * from table"); + auto b = FlightDescriptor::Command("select * from table"); + auto c = FlightDescriptor::Command("select foo from table"); + auto d = FlightDescriptor::Path({"foo", "bar"}); + auto e = FlightDescriptor::Path({"foo", "baz"}); + auto f = FlightDescriptor::Path({"foo", "baz"}); + + ASSERT_EQ(a.ToString(), "FlightDescriptor"); + ASSERT_EQ(d.ToString(), "FlightDescriptor"); + ASSERT_TRUE(a.Equals(b)); + ASSERT_FALSE(a.Equals(c)); + ASSERT_FALSE(a.Equals(d)); + ASSERT_FALSE(d.Equals(e)); + ASSERT_TRUE(e.Equals(f)); +} + +TEST(TestFlightDescriptor, ToFromProto) { + FlightDescriptor descr_test; + pb::FlightDescriptor pb_descr; + + FlightDescriptor descr1{FlightDescriptor::PATH, "", {"foo", "bar"}}; + ASSERT_OK(internal::ToProto(descr1, &pb_descr)); + ASSERT_OK(internal::FromProto(pb_descr, &descr_test)); + AssertEqual(descr1, descr_test); + + FlightDescriptor descr2{FlightDescriptor::CMD, "command", {}}; + ASSERT_OK(internal::ToProto(descr2, &pb_descr)); + ASSERT_OK(internal::FromProto(pb_descr, &descr_test)); + AssertEqual(descr2, descr_test); +} + TEST(TestFlight, StartStopTestServer) { TestServer server("flight-test-server", 30000); server.Start(); @@ -85,72 +169,52 @@ class TestFlightClient : public ::testing::Test { Status ConnectClient() { return FlightClient::Connect("localhost", port_, &client_); } + template + void CheckDoGet(const FlightDescriptor& descr, const BatchVector& expected_batches, + EndpointCheckFunc&& check_endpoints) { + auto num_batches = static_cast(expected_batches.size()); + DCHECK_GE(num_batches, 2); + auto expected_schema = expected_batches[0]->schema(); + + std::unique_ptr info; + ASSERT_OK(client_->GetFlightInfo(descr, &info)); + check_endpoints(info->endpoints()); + + std::shared_ptr schema; + ASSERT_OK(info->GetSchema(&schema)); + AssertSchemaEqual(*expected_schema, *schema); + + // By convention, fetch the first endpoint + Ticket ticket = info->endpoints()[0].ticket; + std::unique_ptr stream; + ASSERT_OK(client_->DoGet(ticket, &stream)); + + std::shared_ptr chunk; + for (int i = 0; i < num_batches; ++i) { + ASSERT_OK(stream->ReadNext(&chunk)); + ASSERT_NE(nullptr, chunk); + ASSERT_BATCHES_EQUAL(*expected_batches[i], *chunk); + } + + // Stream exhausted + ASSERT_OK(stream->ReadNext(&chunk)); + ASSERT_EQ(nullptr, chunk); + } + protected: int port_; std::unique_ptr client_; std::unique_ptr server_; }; -// The server implementation is in test-server.cc; to make changes to the -// expected results, make edits there -void AssertEqual(const FlightDescriptor& expected, const FlightDescriptor& actual) {} - -void AssertEqual(const Ticket& expected, const Ticket& actual) { - ASSERT_EQ(expected.ticket, actual.ticket); -} - -void AssertEqual(const Location& expected, const Location& actual) { - ASSERT_EQ(expected.host, actual.host); - ASSERT_EQ(expected.port, actual.port); -} - -void AssertEqual(const std::vector& expected, - const std::vector& actual) { - ASSERT_EQ(expected.size(), actual.size()); - for (size_t i = 0; i < expected.size(); ++i) { - AssertEqual(expected[i].ticket, actual[i].ticket); - - ASSERT_EQ(expected[i].locations.size(), actual[i].locations.size()); - for (size_t j = 0; j < expected[i].locations.size(); ++j) { - AssertEqual(expected[i].locations[j], actual[i].locations[j]); - } - } -} - -void AssertEqual(const FlightInfo& expected, const FlightInfo& actual) { - std::shared_ptr ex_schema, actual_schema; - ASSERT_OK(expected.GetSchema(&ex_schema)); - ASSERT_OK(actual.GetSchema(&actual_schema)); - - AssertSchemaEqual(*ex_schema, *actual_schema); - ASSERT_EQ(expected.total_records(), actual.total_records()); - ASSERT_EQ(expected.total_bytes(), actual.total_bytes()); - - AssertEqual(expected.descriptor(), actual.descriptor()); - AssertEqual(expected.endpoints(), actual.endpoints()); -} - -void AssertEqual(const ActionType& expected, const ActionType& actual) { - ASSERT_EQ(expected.type, actual.type); - ASSERT_EQ(expected.description, actual.description); -} - -template -void AssertEqual(const std::vector& expected, const std::vector& actual) { - ASSERT_EQ(expected.size(), actual.size()); - for (size_t i = 0; i < expected.size(); ++i) { - AssertEqual(expected[i], actual[i]); - } -} - TEST_F(TestFlightClient, ListFlights) { std::unique_ptr listing; ASSERT_OK(client_->ListFlights(&listing)); ASSERT_TRUE(listing != nullptr); std::vector flights = ExampleFlightInfo(); - std::unique_ptr info; + std::unique_ptr info; for (const FlightInfo& flight : flights) { ASSERT_OK(listing->Next(&info)); AssertEqual(flight, *info); @@ -159,66 +223,56 @@ TEST_F(TestFlightClient, ListFlights) { ASSERT_TRUE(info == nullptr); ASSERT_OK(listing->Next(&info)); + ASSERT_TRUE(info == nullptr); } TEST_F(TestFlightClient, GetFlightInfo) { - FlightDescriptor descr{FlightDescriptor::PATH, "", {"foo", "bar"}}; + auto descr = FlightDescriptor::Path({"examples", "ints"}); std::unique_ptr info; - ASSERT_OK(client_->GetFlightInfo(descr, &info)); - ASSERT_TRUE(info != nullptr); + ASSERT_OK(client_->GetFlightInfo(descr, &info)); + ASSERT_NE(info, nullptr); std::vector flights = ExampleFlightInfo(); AssertEqual(flights[0], *info); } -TEST(TestFlightProtocol, FlightDescriptor) { - FlightDescriptor descr_test; - pb::FlightDescriptor pb_descr; - - FlightDescriptor descr1{FlightDescriptor::PATH, "", {"foo", "bar"}}; - ASSERT_OK(internal::ToProto(descr1, &pb_descr)); - ASSERT_OK(internal::FromProto(pb_descr, &descr_test)); - AssertEqual(descr1, descr_test); - - FlightDescriptor descr2{FlightDescriptor::CMD, "command", {}}; - ASSERT_OK(internal::ToProto(descr2, &pb_descr)); - ASSERT_OK(internal::FromProto(pb_descr, &descr_test)); - AssertEqual(descr2, descr_test); -} - -TEST_F(TestFlightClient, DoGet) { - FlightDescriptor descr{FlightDescriptor::PATH, "", {"foo", "bar"}}; +TEST_F(TestFlightClient, GetFlightInfoNotFound) { + auto descr = FlightDescriptor::Path({"examples", "things"}); std::unique_ptr info; - ASSERT_OK(client_->GetFlightInfo(descr, &info)); - - // Two endpoints in the example FlightInfo - ASSERT_EQ(2, info->endpoints().size()); - - Ticket ticket = info->endpoints()[0].ticket; - AssertEqual(Ticket{"ticket-id-1"}, ticket); + // XXX Ideally should be Invalid (or KeyError), but gRPC doesn't support + // multiple error codes. + auto st = client_->GetFlightInfo(descr, &info); + ASSERT_RAISES(IOError, st); + ASSERT_NE(st.message().find("Flight not found"), std::string::npos); +} - std::shared_ptr schema; - ASSERT_OK(info->GetSchema(&schema)); +TEST_F(TestFlightClient, DoGetInts) { + auto descr = FlightDescriptor::Path({"examples", "ints"}); + BatchVector expected_batches; + ASSERT_OK(ExampleIntBatches(&expected_batches)); - auto expected_schema = ExampleSchema1(); - AssertSchemaEqual(*expected_schema, *schema); + auto check_endpoints = [](const std::vector& endpoints) { + // Two endpoints in the example FlightInfo + ASSERT_EQ(2, endpoints.size()); + AssertEqual(Ticket{"ticket-ints-1"}, endpoints[0].ticket); + }; - std::unique_ptr stream; - ASSERT_OK(client_->DoGet(ticket, &stream)); + CheckDoGet(descr, expected_batches, check_endpoints); +} +TEST_F(TestFlightClient, DoGetDicts) { + auto descr = FlightDescriptor::Path({"examples", "dicts"}); BatchVector expected_batches; - const int num_batches = 5; - ASSERT_OK(SimpleIntegerBatches(num_batches, &expected_batches)); - std::shared_ptr chunk; - for (int i = 0; i < num_batches; ++i) { - ASSERT_OK(stream->ReadNext(&chunk)); - ASSERT_BATCHES_EQUAL(*expected_batches[i], *chunk); - } + ASSERT_OK(ExampleDictBatches(&expected_batches)); + + auto check_endpoints = [](const std::vector& endpoints) { + // One endpoint in the example FlightInfo + ASSERT_EQ(1, endpoints.size()); + AssertEqual(Ticket{"ticket-dicts-1"}, endpoints[0].ticket); + }; - // Stream exhausted - ASSERT_OK(stream->ReadNext(&chunk)); - ASSERT_EQ(nullptr, chunk); + CheckDoGet(descr, expected_batches, check_endpoints); } TEST_F(TestFlightClient, ListActions) { diff --git a/cpp/src/arrow/flight/protocol-internal.h b/cpp/src/arrow/flight/protocol-internal.h index 2e8dd32e559..848c1a801bd 100644 --- a/cpp/src/arrow/flight/protocol-internal.h +++ b/cpp/src/arrow/flight/protocol-internal.h @@ -16,6 +16,8 @@ #pragma once +// This header holds the Flight protobuf definitions. + // Need to include this first to get our gRPC customizations #include "arrow/flight/customize_protobuf.h" // IWYU pragma: export diff --git a/cpp/src/arrow/flight/serialization-internal.cc b/cpp/src/arrow/flight/serialization-internal.cc index b7e566cf978..b7a227a8ba7 100644 --- a/cpp/src/arrow/flight/serialization-internal.cc +++ b/cpp/src/arrow/flight/serialization-internal.cc @@ -24,6 +24,7 @@ #include "arrow/util/config.h" +#include #include #include #include @@ -307,6 +308,53 @@ grpc::Status FlightDataDeserialize(ByteBuffer* buffer, FlightData* out) { return grpc::Status::OK; } +Status FlightData::OpenMessage(std::unique_ptr* message) { + return ipc::Message::Open(metadata, body, message); +} + +// The pointer bitcast hack below causes legitimate warnings, silence them. +#ifndef _WIN32 +#pragma GCC diagnostic push +#pragma GCC diagnostic ignored "-Wstrict-aliasing" +#endif + +// Pointer bitcast explanation: grpc::*Writer::Write() and grpc::*Reader::Read() +// both take a T* argument (here pb::FlightData*). But they don't do anything +// with that argument except pass it to SerializationTraits::Serialize() and +// SerializationTraits::Deserialize(). +// +// Since we control SerializationTraits, we can interpret the +// pointer argument whichever way we want, including cast it back to the original type. +// (see customize_protobuf.h). + +bool WritePayload(const FlightPayload& payload, + grpc::ClientWriter* writer) { + // Pretend to be pb::FlightData and intercept in SerializationTraits + return writer->Write(*reinterpret_cast(&payload), + grpc::WriteOptions()); +} + +bool WritePayload(const FlightPayload& payload, + grpc::ServerWriter* writer) { + // Pretend to be pb::FlightData and intercept in SerializationTraits + return writer->Write(*reinterpret_cast(&payload), + grpc::WriteOptions()); +} + +bool ReadPayload(grpc::ClientReader* reader, FlightData* data) { + // Pretend to be pb::FlightData and intercept in SerializationTraits + return reader->Read(reinterpret_cast(data)); +} + +bool ReadPayload(grpc::ServerReader* reader, FlightData* data) { + // Pretend to be pb::FlightData and intercept in SerializationTraits + return reader->Read(reinterpret_cast(data)); +} + +#ifndef _WIN32 +#pragma GCC diagnostic pop +#endif + } // namespace internal } // namespace flight } // namespace arrow diff --git a/cpp/src/arrow/flight/serialization-internal.h b/cpp/src/arrow/flight/serialization-internal.h index 457629062d2..aa47af6ae35 100644 --- a/cpp/src/arrow/flight/serialization-internal.h +++ b/cpp/src/arrow/flight/serialization-internal.h @@ -20,15 +20,12 @@ #pragma once -// Enable gRPC customizations -#include "arrow/flight/protocol-internal.h" // IWYU pragma: keep - #include -#include - #include "arrow/flight/internal.h" #include "arrow/flight/types.h" +#include "arrow/ipc/message.h" +#include "arrow/status.h" namespace arrow { @@ -48,8 +45,23 @@ struct FlightData { /// Message body std::shared_ptr body; + + /// Open IPC message from the metadata and body + Status OpenMessage(std::unique_ptr* message); }; +/// Write Flight message on gRPC stream with zero-copy optimizations. +/// True is returned on success, false if some error occurred (connection closed?). +bool WritePayload(const FlightPayload& payload, + grpc::ClientWriter* writer); +bool WritePayload(const FlightPayload& payload, + grpc::ServerWriter* writer); + +/// Read Flight message from gRPC stream with zero-copy optimizations. +/// True is returned on success, false if stream ended. +bool ReadPayload(grpc::ClientReader* reader, FlightData* data); +bool ReadPayload(grpc::ServerReader* reader, FlightData* data); + } // namespace internal } // namespace flight } // namespace arrow diff --git a/cpp/src/arrow/flight/server.cc b/cpp/src/arrow/flight/server.cc index 5a8dc7e0d2a..29de44a52d9 100644 --- a/cpp/src/arrow/flight/server.cc +++ b/cpp/src/arrow/flight/server.cc @@ -16,7 +16,6 @@ // under the License. #include "arrow/flight/server.h" -#include "arrow/flight/protocol-internal.h" #include #include @@ -32,7 +31,7 @@ #include #endif -#include "arrow/ipc/dictionary.h" +#include "arrow/buffer.h" #include "arrow/ipc/reader.h" #include "arrow/ipc/writer.h" #include "arrow/memory_pool.h" @@ -81,19 +80,11 @@ class FlightMessageReaderImpl : public FlightMessageReader { } internal::FlightData data; - // Pretend to be pb::FlightData and intercept in SerializationTraits -#ifndef _WIN32 -#pragma GCC diagnostic push -#pragma GCC diagnostic ignored "-Wstrict-aliasing" -#endif - if (reader_->Read(reinterpret_cast(&data))) { -#ifndef _WIN32 -#pragma GCC diagnostic pop -#endif + if (internal::ReadPayload(reader_, &data)) { std::unique_ptr message; // Validate IPC message - RETURN_NOT_OK(ipc::Message::Open(data.metadata, data.body, &message)); + RETURN_NOT_OK(data.OpenMessage(&message)); if (message->type() == ipc::Message::Type::RECORD_BATCH) { return ipc::ReadRecordBatch(*message, schema_, out); } else { @@ -126,9 +117,9 @@ class FlightServiceImpl : public FlightService::Service { return grpc::Status(grpc::StatusCode::INTERNAL, "No items to iterate"); } // Write flight info to stream until listing is exhausted - ProtoType pb_value; - std::unique_ptr value; while (true) { + ProtoType pb_value; + std::unique_ptr value; GRPC_RETURN_NOT_OK(iterator->Next(&value)); if (!value) { break; @@ -148,8 +139,8 @@ class FlightServiceImpl : public FlightService::Service { grpc::Status WriteStream(const std::vector& values, ServerWriter* writer) { // Write flight info to stream until listing is exhausted - ProtoType pb_value; for (const UserType& value : values) { + ProtoType pb_value; GRPC_RETURN_NOT_OK(internal::ToProto(value, &pb_value)); // Blocking write if (!writer->Write(pb_value)) { @@ -210,36 +201,34 @@ class FlightServiceImpl : public FlightService::Service { return grpc::Status(grpc::StatusCode::NOT_FOUND, "No data in this flight"); } - // Write the schema as the first message in the stream - FlightPayload schema_payload; + // Write the schema as the first message(s) in the stream + // (several messages may be required if there are dictionaries) MemoryPool* pool = default_memory_pool(); - ipc::DictionaryMemo dictionary_memo; - GRPC_RETURN_NOT_OK(ipc::internal::GetSchemaPayload( - *data_stream->schema(), pool, &dictionary_memo, &schema_payload.ipc_message)); - - // Pretend to be pb::FlightData, we cast back to FlightPayload in - // SerializationTraits -#ifndef _WIN32 -#pragma GCC diagnostic push -#pragma GCC diagnostic ignored "-Wstrict-aliasing" -#endif - writer->Write(*reinterpret_cast(&schema_payload), - grpc::WriteOptions()); + std::vector ipc_payloads; + GRPC_RETURN_NOT_OK( + ipc::internal::GetSchemaPayloads(*data_stream->schema(), pool, &ipc_payloads)); + + for (auto& ipc_payload : ipc_payloads) { + // For DoGet, descriptor doesn't need to be written out + FlightPayload schema_payload; + schema_payload.ipc_message = std::move(ipc_payload); + + if (!internal::WritePayload(schema_payload, writer)) { + // Connection terminated? XXX return error code? + return grpc::Status::OK; + } + } + // Write incoming data as individual messages while (true) { FlightPayload payload; GRPC_RETURN_NOT_OK(data_stream->Next(&payload)); if (payload.ipc_message.metadata == nullptr || - !writer->Write(*reinterpret_cast(&payload), - grpc::WriteOptions())) { + !internal::WritePayload(payload, writer)) // No more messages to write, or connection terminated for some other // reason break; - } } -#ifndef _WIN32 -#pragma GCC diagnostic pop -#endif return grpc::Status::OK; } @@ -247,10 +236,10 @@ class FlightServiceImpl : public FlightService::Service { pb::PutResult* response) { // Get metadata internal::FlightData data; - if (reader->Read(reinterpret_cast(&data))) { + if (internal::ReadPayload(reader, &data)) { // Message only lives as long as data std::unique_ptr message; - GRPC_RETURN_NOT_OK(ipc::Message::Open(data.metadata, data.body, &message)); + GRPC_RETURN_NOT_OK(data.OpenMessage(&message)); if (!message || message->type() != ipc::Message::Type::SCHEMA) { return internal::ToGrpcStatus( diff --git a/cpp/src/arrow/flight/test-server.cc b/cpp/src/arrow/flight/test-server.cc index 316d89fa521..a7049db7256 100644 --- a/cpp/src/arrow/flight/test-server.cc +++ b/cpp/src/arrow/flight/test-server.cc @@ -38,9 +38,14 @@ namespace arrow { namespace flight { Status GetBatchForFlight(const Ticket& ticket, std::shared_ptr* out) { - if (ticket.ticket == "ticket-id-1") { + if (ticket.ticket == "ticket-ints-1") { BatchVector batches; - RETURN_NOT_OK(SimpleIntegerBatches(5, &batches)); + RETURN_NOT_OK(ExampleIntBatches(&batches)); + *out = std::make_shared(batches[0]->schema(), batches); + return Status::OK(); + } else if (ticket.ticket == "ticket-dicts-1") { + BatchVector batches; + RETURN_NOT_OK(ExampleDictBatches(&batches)); *out = std::make_shared(batches[0]->schema(), batches); return Status::OK(); } else { @@ -57,20 +62,16 @@ class FlightTestServer : public FlightServerBase { } Status GetFlightInfo(const FlightDescriptor& request, - std::unique_ptr* info) override { + std::unique_ptr* out) override { std::vector flights = ExampleFlightInfo(); - const FlightInfo* value; - - // We only have one kind of flight for each descriptor type - if (request.type == FlightDescriptor::PATH) { - value = &flights[0]; - } else { - value = &flights[1]; + for (const auto& info : flights) { + if (info.descriptor().Equals(request)) { + *out = std::unique_ptr(new FlightInfo(info)); + return Status::OK(); + } } - - *info = std::unique_ptr(new FlightInfo(*value)); - return Status::OK(); + return Status::Invalid("Flight not found: ", request.ToString()); } Status DoGet(const Ticket& request, diff --git a/cpp/src/arrow/flight/test-util.cc b/cpp/src/arrow/flight/test-util.cc index 7ce8ef50dc2..71ab30c86d6 100644 --- a/cpp/src/arrow/flight/test-util.cc +++ b/cpp/src/arrow/flight/test-util.cc @@ -128,28 +128,62 @@ Status MakeFlightInfo(const Schema& schema, const FlightDescriptor& descriptor, return internal::SchemaToString(schema, &out->schema); } +std::shared_ptr ExampleIntSchema() { + auto f0 = field("f0", int32()); + auto f1 = field("f1", int32()); + return ::arrow::schema({f0, f1}); +} + +std::shared_ptr ExampleStringSchema() { + auto f0 = field("f0", utf8()); + auto f1 = field("f1", binary()); + return ::arrow::schema({f0, f1}); +} + +std::shared_ptr ExampleDictSchema() { + std::shared_ptr batch; + ABORT_NOT_OK(ipc::test::MakeDictionary(&batch)); + return batch->schema(); +} + std::vector ExampleFlightInfo() { - FlightEndpoint endpoint1({{"ticket-id-1"}, {{"foo1.bar.com", 92385}}}); - FlightEndpoint endpoint2({{"ticket-id-2"}, {{"foo2.bar.com", 92385}}}); - FlightEndpoint endpoint3({{"ticket-id-3"}, {{"foo3.bar.com", 92385}}}); - FlightDescriptor descr1{FlightDescriptor::PATH, "", {"foo", "bar"}}; + FlightInfo::Data flight1, flight2, flight3; + + FlightEndpoint endpoint1({{"ticket-ints-1"}, {{"foo1.bar.com", 92385}}}); + FlightEndpoint endpoint2({{"ticket-ints-2"}, {{"foo2.bar.com", 92385}}}); + FlightEndpoint endpoint3({{"ticket-cmd"}, {{"foo3.bar.com", 92385}}}); + FlightEndpoint endpoint4({{"ticket-dicts-1"}, {{"foo4.bar.com", 92385}}}); + + FlightDescriptor descr1{FlightDescriptor::PATH, "", {"examples", "ints"}}; FlightDescriptor descr2{FlightDescriptor::CMD, "my_command", {}}; + FlightDescriptor descr3{FlightDescriptor::PATH, "", {"examples", "dicts"}}; - auto schema1 = ExampleSchema1(); - auto schema2 = ExampleSchema2(); + auto schema1 = ExampleIntSchema(); + auto schema2 = ExampleStringSchema(); + auto schema3 = ExampleDictSchema(); - FlightInfo::Data flight1, flight2; ARROW_EXPECT_OK( MakeFlightInfo(*schema1, descr1, {endpoint1, endpoint2}, 1000, 100000, &flight1)); ARROW_EXPECT_OK(MakeFlightInfo(*schema2, descr2, {endpoint3}, 1000, 100000, &flight2)); - return {FlightInfo(flight1), FlightInfo(flight2)}; + ARROW_EXPECT_OK(MakeFlightInfo(*schema3, descr3, {endpoint4}, -1, -1, &flight3)); + return {FlightInfo(flight1), FlightInfo(flight2), FlightInfo(flight3)}; } -Status SimpleIntegerBatches(const int num_batches, BatchVector* out) { +Status ExampleIntBatches(BatchVector* out) { std::shared_ptr batch; - for (int i = 0; i < num_batches; ++i) { + for (int i = 0; i < 5; ++i) { // Make all different sizes, use different random seed - RETURN_NOT_OK(ipc::MakeIntBatchSized(10 + i, &batch, i)); + RETURN_NOT_OK(ipc::test::MakeIntBatchSized(10 + i, &batch, i)); + out->push_back(batch); + } + return Status::OK(); +} + +Status ExampleDictBatches(BatchVector* out) { + // Just the same batch, repeated a few times + std::shared_ptr batch; + for (int i = 0; i < 3; ++i) { + RETURN_NOT_OK(ipc::test::MakeDictionary(&batch)); out->push_back(batch); } return Status::OK(); diff --git a/cpp/src/arrow/flight/test-util.h b/cpp/src/arrow/flight/test-util.h index 006c966c377..0c41ec1c6cd 100644 --- a/cpp/src/arrow/flight/test-util.h +++ b/cpp/src/arrow/flight/test-util.h @@ -88,31 +88,28 @@ class BatchIterator : public RecordBatchReader { using BatchVector = std::vector>; -inline std::shared_ptr ExampleSchema1() { - auto f0 = field("f0", int32()); - auto f1 = field("f1", int32()); - return ::arrow::schema({f0, f1}); -} - -inline std::shared_ptr ExampleSchema2() { - auto f0 = field("f0", utf8()); - auto f1 = field("f1", binary()); - return ::arrow::schema({f0, f1}); -} +ARROW_EXPORT std::shared_ptr ExampleIntSchema(); + +ARROW_EXPORT std::shared_ptr ExampleStringSchema(); + +ARROW_EXPORT std::shared_ptr ExampleDictSchema(); ARROW_EXPORT -Status MakeFlightInfo(const Schema& schema, const FlightDescriptor& descriptor, - const std::vector& endpoints, int64_t total_records, - int64_t total_bytes, FlightInfo::Data* out); +Status ExampleIntBatches(BatchVector* out); ARROW_EXPORT -std::vector ExampleFlightInfo(); +Status ExampleDictBatches(BatchVector* out); ARROW_EXPORT -Status SimpleIntegerBatches(const int num_batches, BatchVector* out); +std::vector ExampleFlightInfo(); ARROW_EXPORT std::vector ExampleActionTypes(); +ARROW_EXPORT +Status MakeFlightInfo(const Schema& schema, const FlightDescriptor& descriptor, + const std::vector& endpoints, int64_t total_records, + int64_t total_bytes, FlightInfo::Data* out); + } // namespace flight } // namespace arrow diff --git a/cpp/src/arrow/flight/types.cc b/cpp/src/arrow/flight/types.cc index fb8f8c643be..3625bc5633f 100644 --- a/cpp/src/arrow/flight/types.cc +++ b/cpp/src/arrow/flight/types.cc @@ -18,6 +18,7 @@ #include "arrow/flight/types.h" #include +#include #include #include "arrow/io/memory.h" @@ -27,6 +28,47 @@ namespace arrow { namespace flight { +bool FlightDescriptor::Equals(const FlightDescriptor& other) const { + if (type != other.type) { + return false; + } + switch (type) { + case PATH: + return path == other.path; + case CMD: + return cmd == other.cmd; + default: + return false; + } +} + +std::string FlightDescriptor::ToString() const { + std::stringstream ss; + ss << "FlightDescriptor<"; + switch (type) { + case PATH: { + bool first = true; + ss << "path = '"; + for (const auto& p : path) { + if (!first) { + ss << "/"; + } + first = false; + ss << p; + } + ss << "'"; + break; + } + case CMD: + ss << "cmd = '" << cmd << "'"; + break; + default: + break; + } + ss << ">"; + return ss.str(); +} + Status FlightInfo::GetSchema(std::shared_ptr* out) const { if (reconstructed_schema_) { *out = schema_; diff --git a/cpp/src/arrow/flight/types.h b/cpp/src/arrow/flight/types.h index ba0ab857484..0c09766298a 100644 --- a/cpp/src/arrow/flight/types.h +++ b/cpp/src/arrow/flight/types.h @@ -87,6 +87,20 @@ struct FlightDescriptor { /// List of strings identifying a particular dataset. Should only be defined /// when type is PATH std::vector path; + + bool Equals(const FlightDescriptor& other) const; + + std::string ToString() const; + + // Convenience factory functions + + static FlightDescriptor Command(const std::string& c) { + return FlightDescriptor{CMD, c, {}}; + } + + static FlightDescriptor Path(const std::vector& p) { + return FlightDescriptor{PATH, "", p}; + } }; /// \brief Data structure providing an opaque identifier or credential to use @@ -114,6 +128,8 @@ struct FlightEndpoint { }; /// \brief Staging data structure for messages about to be put on the wire +/// +/// This structure corresponds to FlightData in the protocol. struct FlightPayload { std::shared_ptr descriptor; ipc::internal::IpcPayload ipc_message; diff --git a/cpp/src/arrow/gpu/cuda-test.cc b/cpp/src/arrow/gpu/cuda-test.cc index 51366e1688b..9a10b2743c0 100644 --- a/cpp/src/arrow/gpu/cuda-test.cc +++ b/cpp/src/arrow/gpu/cuda-test.cc @@ -25,6 +25,7 @@ #include "arrow/ipc/test-common.h" #include "arrow/status.h" #include "arrow/testing/gtest_util.h" +#include "arrow/testing/util.h" #include "arrow/gpu/cuda_api.h" @@ -320,7 +321,7 @@ class TestCudaArrowIpc : public TestCudaBufferBase { TEST_F(TestCudaArrowIpc, BasicWriteRead) { std::shared_ptr batch; - ASSERT_OK(ipc::MakeIntRecordBatch(&batch)); + ASSERT_OK(ipc::test::MakeIntRecordBatch(&batch)); std::shared_ptr device_serialized; ASSERT_OK(SerializeRecordBatch(*batch, context_.get(), &device_serialized)); diff --git a/cpp/src/arrow/ipc/dictionary.h b/cpp/src/arrow/ipc/dictionary.h index 4494b134cbf..69ea4855f78 100644 --- a/cpp/src/arrow/ipc/dictionary.h +++ b/cpp/src/arrow/ipc/dictionary.h @@ -42,6 +42,8 @@ using DictionaryTypeMap = std::unordered_map>; class ARROW_EXPORT DictionaryMemo { public: DictionaryMemo(); + DictionaryMemo(DictionaryMemo&&) = default; + DictionaryMemo& operator=(DictionaryMemo&&) = default; /// \brief Returns KeyError if dictionary not found Status GetDictionary(int64_t id, std::shared_ptr* dictionary) const; diff --git a/cpp/src/arrow/ipc/feather-test.cc b/cpp/src/arrow/ipc/feather-test.cc index e7b699d79ad..001e36ac0df 100644 --- a/cpp/src/arrow/ipc/feather-test.cc +++ b/cpp/src/arrow/ipc/feather-test.cc @@ -304,9 +304,9 @@ class TestTableReader : public ::testing::Test { TEST_F(TestTableReader, ReadIndices) { std::shared_ptr batch1; - ASSERT_OK(MakeIntRecordBatch(&batch1)); + ASSERT_OK(ipc::test::MakeIntRecordBatch(&batch1)); std::shared_ptr batch2; - ASSERT_OK(MakeIntRecordBatch(&batch2)); + ASSERT_OK(ipc::test::MakeIntRecordBatch(&batch2)); ASSERT_OK(writer_->Append("f0", *batch1->column(0))); ASSERT_OK(writer_->Append("f1", *batch1->column(1))); @@ -329,9 +329,9 @@ TEST_F(TestTableReader, ReadIndices) { TEST_F(TestTableReader, ReadNames) { std::shared_ptr batch1; - ASSERT_OK(MakeIntRecordBatch(&batch1)); + ASSERT_OK(ipc::test::MakeIntRecordBatch(&batch1)); std::shared_ptr batch2; - ASSERT_OK(MakeIntRecordBatch(&batch2)); + ASSERT_OK(ipc::test::MakeIntRecordBatch(&batch2)); ASSERT_OK(writer_->Append("f0", *batch1->column(0))); ASSERT_OK(writer_->Append("f1", *batch1->column(1))); @@ -419,7 +419,7 @@ TEST_F(TestTableWriter, SetDescription) { TEST_F(TestTableWriter, PrimitiveRoundTrip) { std::shared_ptr batch; - ASSERT_OK(MakeIntRecordBatch(&batch)); + ASSERT_OK(ipc::test::MakeIntRecordBatch(&batch)); ASSERT_OK(writer_->Append("f0", *batch->column(0))); ASSERT_OK(writer_->Append("f1", *batch->column(1))); @@ -437,7 +437,7 @@ TEST_F(TestTableWriter, PrimitiveRoundTrip) { TEST_F(TestTableWriter, CategoryRoundtrip) { std::shared_ptr batch; - ASSERT_OK(MakeDictionaryFlat(&batch)); + ASSERT_OK(ipc::test::MakeDictionaryFlat(&batch)); CheckBatch(batch); } @@ -489,13 +489,13 @@ TEST_F(TestTableWriter, TimeTypes) { TEST_F(TestTableWriter, VLenPrimitiveRoundTrip) { std::shared_ptr batch; - ASSERT_OK(MakeStringTypesRecordBatch(&batch)); + ASSERT_OK(ipc::test::MakeStringTypesRecordBatch(&batch)); CheckBatch(batch); } TEST_F(TestTableWriter, PrimitiveNullRoundTrip) { std::shared_ptr batch; - ASSERT_OK(MakeNullRecordBatch(&batch)); + ASSERT_OK(ipc::test::MakeNullRecordBatch(&batch)); for (int i = 0; i < batch->num_columns(); ++i) { ASSERT_OK(writer_->Append(batch->column_name(i), *batch->column(i))); @@ -540,7 +540,7 @@ class TestTableWriterSlice : public TestTableWriter, TEST_P(TestTableWriterSlice, SliceRoundTrip) { std::shared_ptr batch; - ASSERT_OK(MakeIntBatchSized(600, &batch)); + ASSERT_OK(ipc::test::MakeIntBatchSized(600, &batch)); CheckSlice(batch); } @@ -549,13 +549,13 @@ TEST_P(TestTableWriterSlice, SliceStringsRoundTrip) { auto start = std::get<0>(p); auto with_nulls = start % 2 == 0; std::shared_ptr batch; - ASSERT_OK(MakeStringTypesRecordBatch(&batch, with_nulls)); + ASSERT_OK(ipc::test::MakeStringTypesRecordBatch(&batch, with_nulls)); CheckSlice(batch); } TEST_P(TestTableWriterSlice, SliceBooleanRoundTrip) { std::shared_ptr batch; - ASSERT_OK(MakeBooleanBatchSized(600, &batch)); + ASSERT_OK(ipc::test::MakeBooleanBatchSized(600, &batch)); CheckSlice(batch); } diff --git a/cpp/src/arrow/ipc/json-internal.cc b/cpp/src/arrow/ipc/json-internal.cc index 0420c133dbd..0bc0f2004af 100644 --- a/cpp/src/arrow/ipc/json-internal.cc +++ b/cpp/src/arrow/ipc/json-internal.cc @@ -343,9 +343,8 @@ class SchemaWriter { return VisitType(*type.dictionary()->type()); } - Status Visit(const ExtensionType& type) { - return Status::NotImplemented("extension type"); - } + // Default case + Status Visit(const DataType& type) { return Status::NotImplemented(type.name()); } private: DictionaryMemo dictionary_memo_; @@ -1210,10 +1209,6 @@ class ArrayReader { return Status::OK(); } - Status Visit(const ExtensionType& type) { - return Status::NotImplemented("extension type"); - } - Status Visit(const DictionaryType& type) { // This stores the indices in result_ // @@ -1226,6 +1221,9 @@ class ArrayReader { return Status::OK(); } + // Default case + Status Visit(const DataType& type) { return Status::NotImplemented(type.name()); } + Status GetChildren(const RjObject& obj, const DataType& type, std::vector>* array) { const auto& json_children = obj.FindMember("children"); diff --git a/cpp/src/arrow/ipc/json-test.cc b/cpp/src/arrow/ipc/json-test.cc index 72504d45a29..f6198e3e6a2 100644 --- a/cpp/src/arrow/ipc/json-test.cc +++ b/cpp/src/arrow/ipc/json-test.cc @@ -34,6 +34,7 @@ #include "arrow/record_batch.h" #include "arrow/testing/gtest_util.h" #include "arrow/testing/random.h" +#include "arrow/testing/util.h" #include "arrow/type.h" #include "arrow/type_traits.h" @@ -42,6 +43,8 @@ namespace ipc { namespace internal { namespace json { +using namespace ::arrow::ipc::test; // NOLINT + void TestSchemaRoundTrip(const Schema& schema) { rj::StringBuffer sb; rj::Writer writer(sb); diff --git a/cpp/src/arrow/ipc/metadata-internal.cc b/cpp/src/arrow/ipc/metadata-internal.cc index 2589a10035f..dedeee3a632 100644 --- a/cpp/src/arrow/ipc/metadata-internal.cc +++ b/cpp/src/arrow/ipc/metadata-internal.cc @@ -974,11 +974,12 @@ FileBlocksToFlatbuffer(FBB& fbb, const std::vector& blocks) { Status WriteFileFooter(const Schema& schema, const std::vector& dictionaries, const std::vector& record_batches, - DictionaryMemo* dictionary_memo, io::OutputStream* out) { + io::OutputStream* out) { FBB fbb; flatbuffers::Offset fb_schema; - RETURN_NOT_OK(SchemaToFlatbuffer(fbb, schema, dictionary_memo, &fb_schema)); + DictionaryMemo dictionary_memo; // unused + RETURN_NOT_OK(SchemaToFlatbuffer(fbb, schema, &dictionary_memo, &fb_schema)); #ifndef NDEBUG for (size_t i = 0; i < dictionaries.size(); ++i) { diff --git a/cpp/src/arrow/ipc/metadata-internal.h b/cpp/src/arrow/ipc/metadata-internal.h index 6562382b878..c91983d91c0 100644 --- a/cpp/src/arrow/ipc/metadata-internal.h +++ b/cpp/src/arrow/ipc/metadata-internal.h @@ -151,7 +151,7 @@ Status WriteSparseTensorMessage(const SparseTensor& sparse_tensor, int64_t body_ Status WriteFileFooter(const Schema& schema, const std::vector& dictionaries, const std::vector& record_batches, - DictionaryMemo* dictionary_memo, io::OutputStream* out); + io::OutputStream* out); Status WriteDictionaryMessage(const int64_t id, const int64_t length, const int64_t body_length, diff --git a/cpp/src/arrow/ipc/read-write-test.cc b/cpp/src/arrow/ipc/read-write-test.cc index 6f4da28b3ff..0408a1712fe 100644 --- a/cpp/src/arrow/ipc/read-write-test.cc +++ b/cpp/src/arrow/ipc/read-write-test.cc @@ -51,24 +51,10 @@ namespace arrow { using internal::checked_cast; namespace ipc { +namespace test { using BatchVector = std::vector>; -class TestSchemaMetadata : public ::testing::Test { - public: - void SetUp() {} - - void CheckRoundtrip(const Schema& schema) { - std::shared_ptr buffer; - ASSERT_OK(SerializeSchema(schema, default_memory_pool(), &buffer)); - - std::shared_ptr result; - io::BufferReader reader(buffer); - ASSERT_OK(ReadSchema(&reader, &result)); - AssertSchemaEqual(schema, *result); - } -}; - TEST(TestMessage, Equals) { std::string metadata = "foo"; std::string body = "bar"; @@ -147,6 +133,21 @@ TEST(TestMessage, Verify) { ASSERT_FALSE(message.Verify()); } +class TestSchemaMetadata : public ::testing::Test { + public: + void SetUp() {} + + void CheckRoundtrip(const Schema& schema) { + std::shared_ptr buffer; + ASSERT_OK(SerializeSchema(schema, default_memory_pool(), &buffer)); + + std::shared_ptr result; + io::BufferReader reader(buffer); + ASSERT_OK(ReadSchema(&reader, &result)); + AssertSchemaEqual(schema, *result); + } +}; + const std::shared_ptr INT32 = std::make_shared(); TEST_F(TestSchemaMetadata, PrimitiveFields) { @@ -178,6 +179,25 @@ TEST_F(TestSchemaMetadata, NestedFields) { CheckRoundtrip(schema); } +TEST_F(TestSchemaMetadata, DictionaryFields) { + { + auto dict_type = + dictionary(int8(), ArrayFromJSON(int32(), "[6, 5, 4]"), true /* ordered */); + auto f0 = field("f0", dict_type); + auto f1 = field("f1", list(dict_type)); + + Schema schema({f0, f1}); + CheckRoundtrip(schema); + } + { + auto dict_type = dictionary(int8(), ArrayFromJSON(list(int32()), "[[4, 5], [6]]")); + auto f0 = field("f0", dict_type); + + Schema schema({f0}); + CheckRoundtrip(schema); + } +} + TEST_F(TestSchemaMetadata, KeyValueMetadata) { auto field_metadata = key_value_metadata({{"key", "value"}}); auto schema_metadata = key_value_metadata({{"foo", "bar"}, {"bizz", "buzz"}}); @@ -388,7 +408,7 @@ TEST_F(TestWriteRecordBatch, SliceTruncatesBuffers) { // String / Binary { - auto s = MakeRandomBinaryArray(500, false, pool, &a0); + auto s = MakeRandomStringArray(500, false, pool, &a0); ASSERT_TRUE(s.ok()); } CheckArray(a0); @@ -993,5 +1013,6 @@ TEST(TestRecordBatchStreamReader, MalformedInput) { ASSERT_RAISES(Invalid, RecordBatchStreamReader::Open(&garbage_reader, &batch_reader)); } +} // namespace test } // namespace ipc } // namespace arrow diff --git a/cpp/src/arrow/ipc/reader.cc b/cpp/src/arrow/ipc/reader.cc index a33f07c859c..85c64004aa6 100644 --- a/cpp/src/arrow/ipc/reader.cc +++ b/cpp/src/arrow/ipc/reader.cc @@ -56,6 +56,38 @@ namespace ipc { using internal::FileBlock; using internal::kArrowMagicBytes; +namespace { + +Status InvalidMessageType(Message::Type expected, Message::Type actual) { + return Status::IOError("Expected IPC message of type ", FormatMessageType(expected), + " got ", FormatMessageType(actual)); +} + +#define CHECK_MESSAGE_TYPE(expected, actual) \ + do { \ + if ((actual) != (expected)) { \ + return InvalidMessageType((expected), (actual)); \ + } \ + } while (0) + +#define CHECK_HAS_BODY(message) \ + do { \ + if ((message).body() == nullptr) { \ + return Status::IOError("Expected body in IPC message of type ", \ + FormatMessageType((message).type())); \ + } \ + } while (0) + +#define CHECK_HAS_NO_BODY(message) \ + do { \ + if ((message).body_length() != 0) { \ + return Status::IOError("Unexpected body in IPC message of type ", \ + FormatMessageType((message).type())); \ + } \ + } while (0) + +} // namespace + // ---------------------------------------------------------------------- // Record batch read path @@ -287,8 +319,9 @@ Status ReadRecordBatch(const Buffer& metadata, const std::shared_ptr& sc Status ReadRecordBatch(const Message& message, const std::shared_ptr& schema, std::shared_ptr* out) { + CHECK_MESSAGE_TYPE(message.type(), Message::RECORD_BATCH); + CHECK_HAS_BODY(message); io::BufferReader reader(message.body()); - DCHECK_EQ(message.type(), Message::RECORD_BATCH); return ReadRecordBatch(*message.metadata(), schema, kMaxNestingDepth, &reader, out); } @@ -382,14 +415,11 @@ static Status ReadMessageAndValidate(MessageReader* reader, Message::Type expect } if ((*message) == nullptr) { + // End of stream? return Status::OK(); } - if ((*message)->type() != expected_type) { - return Status::IOError( - "Message not expected type: ", FormatMessageType(expected_type), - ", was: ", (*message)->type()); - } + CHECK_MESSAGE_TYPE((*message)->type(), expected_type); return Status::OK(); } @@ -414,7 +444,13 @@ class RecordBatchStreamReader::RecordBatchStreamReaderImpl { std::unique_ptr message; RETURN_NOT_OK(ReadMessageAndValidate(message_reader_.get(), Message::DICTIONARY_BATCH, false, &message)); + if (message == nullptr) { + // End of stream + return Status::IOError( + "End of IPC stream when attempting to read dictionary batch"); + } + CHECK_HAS_BODY(*message); io::BufferReader reader(message->body()); std::shared_ptr dictionary; @@ -428,7 +464,12 @@ class RecordBatchStreamReader::RecordBatchStreamReaderImpl { std::unique_ptr message; RETURN_NOT_OK( ReadMessageAndValidate(message_reader_.get(), Message::SCHEMA, false, &message)); + if (message == nullptr) { + // End of stream + return Status::IOError("End of IPC stream when attempting to read schema"); + } + CHECK_HAS_NO_BODY(*message); if (message->header() == nullptr) { return Status::IOError("Header-pointer of flatbuffer-encoded Message is null."); } @@ -448,13 +489,13 @@ class RecordBatchStreamReader::RecordBatchStreamReaderImpl { std::unique_ptr message; RETURN_NOT_OK(ReadMessageAndValidate(message_reader_.get(), Message::RECORD_BATCH, true, &message)); - if (message == nullptr) { // End of stream *batch = nullptr; return Status::OK(); } + CHECK_HAS_BODY(*message); io::BufferReader reader(message->body()); return ReadRecordBatch(*message->metadata(), schema_, &reader, batch); } @@ -485,6 +526,15 @@ Status RecordBatchStreamReader::Open(std::unique_ptr message_read return Status::OK(); } +Status RecordBatchStreamReader::Open(std::unique_ptr message_reader, + std::unique_ptr* reader) { + // Private ctor + auto result = std::unique_ptr(new RecordBatchStreamReader()); + RETURN_NOT_OK(result->impl_->Open(std::move(message_reader))); + *reader = std::move(result); + return Status::OK(); +} + Status RecordBatchStreamReader::Open(io::InputStream* stream, std::shared_ptr* out) { return Open(MessageReader::Open(stream), out); @@ -854,7 +904,8 @@ Status ReadSparseTensor(const Message& message, std::shared_ptr* o Status ReadSparseTensor(io::InputStream* file, std::shared_ptr* out) { std::unique_ptr message; RETURN_NOT_OK(ReadContiguousPayload(file, &message)); - DCHECK_EQ(message->type(), Message::SPARSE_TENSOR); + CHECK_MESSAGE_TYPE(message->type(), Message::SPARSE_TENSOR); + CHECK_HAS_BODY(*message); io::BufferReader buffer_reader(message->body()); return ReadSparseTensor(*message->metadata(), &buffer_reader, out); } diff --git a/cpp/src/arrow/ipc/reader.h b/cpp/src/arrow/ipc/reader.h index 641de3eaf7b..8fe310f5b77 100644 --- a/cpp/src/arrow/ipc/reader.h +++ b/cpp/src/arrow/ipc/reader.h @@ -56,13 +56,16 @@ class ARROW_EXPORT RecordBatchStreamReader : public RecordBatchReader { public: ~RecordBatchStreamReader() override; - /// Create batch reader from generic MessageReader + /// Create batch reader from generic MessageReader. + /// This will take ownership of the given MessageReader. /// /// \param[in] message_reader a MessageReader implementation /// \param[out] out the created RecordBatchReader object /// \return Status static Status Open(std::unique_ptr message_reader, std::shared_ptr* out); + static Status Open(std::unique_ptr message_reader, + std::unique_ptr* out); /// \brief Record batch stream reader from InputStream /// diff --git a/cpp/src/arrow/ipc/test-common.cc b/cpp/src/arrow/ipc/test-common.cc new file mode 100644 index 00000000000..44b608da087 --- /dev/null +++ b/cpp/src/arrow/ipc/test-common.cc @@ -0,0 +1,669 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#include +#include +#include +#include +#include +#include + +#include "arrow/array.h" +#include "arrow/buffer.h" +#include "arrow/builder.h" +#include "arrow/ipc/test-common.h" +#include "arrow/memory_pool.h" +#include "arrow/pretty_print.h" +#include "arrow/record_batch.h" +#include "arrow/status.h" +#include "arrow/testing/gtest_util.h" +#include "arrow/testing/random.h" +#include "arrow/testing/util.h" +#include "arrow/type.h" +#include "arrow/util/bit-util.h" + +namespace arrow { +namespace ipc { +namespace test { + +void CompareArraysDetailed(int index, const Array& result, const Array& expected) { + if (!expected.Equals(result)) { + std::stringstream pp_result; + std::stringstream pp_expected; + + ASSERT_OK(PrettyPrint(expected, 0, &pp_expected)); + ASSERT_OK(PrettyPrint(result, 0, &pp_result)); + + FAIL() << "Index: " << index << " Expected: " << pp_expected.str() + << "\nGot: " << pp_result.str(); + } +} + +void CompareBatchColumnsDetailed(const RecordBatch& result, const RecordBatch& expected) { + for (int i = 0; i < expected.num_columns(); ++i) { + auto left = result.column(i); + auto right = expected.column(i); + CompareArraysDetailed(i, *left, *right); + } +} + +Status MakeRandomInt32Array(int64_t length, bool include_nulls, MemoryPool* pool, + std::shared_ptr* out, uint32_t seed) { + random::RandomArrayGenerator rand(seed); + const double null_probability = include_nulls ? 0.5 : 0.0; + + *out = rand.Int32(length, 0, 1000, null_probability); + + return Status::OK(); +} + +Status MakeRandomListArray(const std::shared_ptr& child_array, int num_lists, + bool include_nulls, MemoryPool* pool, + std::shared_ptr* out) { + // Create the null list values + std::vector valid_lists(num_lists); + const double null_percent = include_nulls ? 0.1 : 0; + random_null_bytes(num_lists, null_percent, valid_lists.data()); + + // Create list offsets + const int max_list_size = 10; + + std::vector list_sizes(num_lists, 0); + std::vector offsets( + num_lists + 1, 0); // +1 so we can shift for nulls. See partial sum below. + const uint32_t seed = static_cast(child_array->length()); + + if (num_lists > 0) { + rand_uniform_int(num_lists, seed, 0, max_list_size, list_sizes.data()); + // make sure sizes are consistent with null + std::transform(list_sizes.begin(), list_sizes.end(), valid_lists.begin(), + list_sizes.begin(), + [](int32_t size, int32_t valid) { return valid == 0 ? 0 : size; }); + std::partial_sum(list_sizes.begin(), list_sizes.end(), ++offsets.begin()); + + // Force invariants + const int32_t child_length = static_cast(child_array->length()); + offsets[0] = 0; + std::replace_if(offsets.begin(), offsets.end(), + [child_length](int32_t offset) { return offset > child_length; }, + child_length); + } + + offsets[num_lists] = static_cast(child_array->length()); + + /// TODO(wesm): Implement support for nulls in ListArray::FromArrays + std::shared_ptr null_bitmap, offsets_buffer; + RETURN_NOT_OK(GetBitmapFromVector(valid_lists, &null_bitmap)); + RETURN_NOT_OK(CopyBufferFromVector(offsets, pool, &offsets_buffer)); + + *out = std::make_shared(list(child_array->type()), num_lists, offsets_buffer, + child_array, null_bitmap, kUnknownNullCount); + return ValidateArray(**out); +} + +Status MakeRandomBooleanArray(const int length, bool include_nulls, + std::shared_ptr* out) { + std::vector values(length); + random_null_bytes(length, 0.5, values.data()); + std::shared_ptr data; + RETURN_NOT_OK(BitUtil::BytesToBits(values, default_memory_pool(), &data)); + + if (include_nulls) { + std::vector valid_bytes(length); + std::shared_ptr null_bitmap; + RETURN_NOT_OK(BitUtil::BytesToBits(valid_bytes, default_memory_pool(), &null_bitmap)); + random_null_bytes(length, 0.1, valid_bytes.data()); + *out = std::make_shared(length, data, null_bitmap, -1); + } else { + *out = std::make_shared(length, data, NULLPTR, 0); + } + return Status::OK(); +} + +Status MakeBooleanBatchSized(const int length, std::shared_ptr* out) { + // Make the schema + auto f0 = field("f0", boolean()); + auto f1 = field("f1", boolean()); + auto schema = ::arrow::schema({f0, f1}); + + std::shared_ptr a0, a1; + RETURN_NOT_OK(MakeRandomBooleanArray(length, true, &a0)); + RETURN_NOT_OK(MakeRandomBooleanArray(length, false, &a1)); + *out = RecordBatch::Make(schema, length, {a0, a1}); + return Status::OK(); +} + +Status MakeBooleanBatch(std::shared_ptr* out) { + return MakeBooleanBatchSized(1000, out); +} + +Status MakeIntBatchSized(int length, std::shared_ptr* out, uint32_t seed) { + // Make the schema + auto f0 = field("f0", int32()); + auto f1 = field("f1", int32()); + auto schema = ::arrow::schema({f0, f1}); + + // Example data + std::shared_ptr a0, a1; + MemoryPool* pool = default_memory_pool(); + RETURN_NOT_OK(MakeRandomInt32Array(length, false, pool, &a0, seed)); + RETURN_NOT_OK(MakeRandomInt32Array(length, true, pool, &a1, seed + 1)); + *out = RecordBatch::Make(schema, length, {a0, a1}); + return Status::OK(); +} + +Status MakeIntRecordBatch(std::shared_ptr* out) { + return MakeIntBatchSized(10, out); +} + +Status MakeRandomStringArray(int64_t length, bool include_nulls, MemoryPool* pool, + std::shared_ptr* out) { + const std::vector values = {"", "", "abc", "123", + "efg", "456!@#!@#", "12312"}; + StringBuilder builder(pool); + const size_t values_len = values.size(); + for (int64_t i = 0; i < length; ++i) { + int64_t values_index = i % values_len; + if (include_nulls && values_index == 0) { + RETURN_NOT_OK(builder.AppendNull()); + } else { + const auto& value = values[values_index]; + RETURN_NOT_OK(builder.Append(value)); + } + } + return builder.Finish(out); +} + +template +static Status MakeBinaryArrayWithUniqueValues(int64_t length, bool include_nulls, + MemoryPool* pool, + std::shared_ptr* out) { + Builder builder(pool); + for (int64_t i = 0; i < length; ++i) { + if (include_nulls && (i % 7 == 0)) { + RETURN_NOT_OK(builder.AppendNull()); + } else { + const std::string value = std::to_string(i); + RETURN_NOT_OK(builder.Append(reinterpret_cast(value.data()), + static_cast(value.size()))); + } + } + return builder.Finish(out); +} + +Status MakeStringTypesRecordBatch(std::shared_ptr* out, bool with_nulls) { + const int64_t length = 500; + auto string_type = utf8(); + auto binary_type = binary(); + auto f0 = field("f0", string_type); + auto f1 = field("f1", binary_type); + auto schema = ::arrow::schema({f0, f1}); + + std::shared_ptr a0, a1; + MemoryPool* pool = default_memory_pool(); + + // Quirk with RETURN_NOT_OK macro and templated functions + { + auto s = MakeBinaryArrayWithUniqueValues(length, with_nulls, + pool, &a0); + RETURN_NOT_OK(s); + } + + { + auto s = MakeBinaryArrayWithUniqueValues(length, with_nulls, + pool, &a1); + RETURN_NOT_OK(s); + } + *out = RecordBatch::Make(schema, length, {a0, a1}); + return Status::OK(); +} + +Status MakeStringTypesRecordBatchWithNulls(std::shared_ptr* out) { + return MakeStringTypesRecordBatch(out, true); +} + +Status MakeNullRecordBatch(std::shared_ptr* out) { + const int64_t length = 500; + auto f0 = field("f0", null()); + auto schema = ::arrow::schema({f0}); + std::shared_ptr a0 = std::make_shared(length); + *out = RecordBatch::Make(schema, length, {a0}); + return Status::OK(); +} + +Status MakeListRecordBatch(std::shared_ptr* out) { + // Make the schema + auto f0 = field("f0", list(int32())); + auto f1 = field("f1", list(list(int32()))); + auto f2 = field("f2", int32()); + auto schema = ::arrow::schema({f0, f1, f2}); + + // Example data + + MemoryPool* pool = default_memory_pool(); + const int length = 200; + std::shared_ptr leaf_values, list_array, list_list_array, flat_array; + const bool include_nulls = true; + RETURN_NOT_OK(MakeRandomInt32Array(1000, include_nulls, pool, &leaf_values)); + RETURN_NOT_OK( + MakeRandomListArray(leaf_values, length, include_nulls, pool, &list_array)); + RETURN_NOT_OK( + MakeRandomListArray(list_array, length, include_nulls, pool, &list_list_array)); + RETURN_NOT_OK(MakeRandomInt32Array(length, include_nulls, pool, &flat_array)); + *out = RecordBatch::Make(schema, length, {list_array, list_list_array, flat_array}); + return Status::OK(); +} + +Status MakeZeroLengthRecordBatch(std::shared_ptr* out) { + // Make the schema + auto f0 = field("f0", list(int32())); + auto f1 = field("f1", list(list(int32()))); + auto f2 = field("f2", int32()); + auto schema = ::arrow::schema({f0, f1, f2}); + + // Example data + MemoryPool* pool = default_memory_pool(); + const bool include_nulls = true; + std::shared_ptr leaf_values, list_array, list_list_array, flat_array; + RETURN_NOT_OK(MakeRandomInt32Array(0, include_nulls, pool, &leaf_values)); + RETURN_NOT_OK(MakeRandomListArray(leaf_values, 0, include_nulls, pool, &list_array)); + RETURN_NOT_OK( + MakeRandomListArray(list_array, 0, include_nulls, pool, &list_list_array)); + RETURN_NOT_OK(MakeRandomInt32Array(0, include_nulls, pool, &flat_array)); + *out = RecordBatch::Make(schema, 0, {list_array, list_list_array, flat_array}); + return Status::OK(); +} + +Status MakeNonNullRecordBatch(std::shared_ptr* out) { + // Make the schema + auto f0 = field("f0", list(int32())); + auto f1 = field("f1", list(list(int32()))); + auto f2 = field("f2", int32()); + auto schema = ::arrow::schema({f0, f1, f2}); + + // Example data + MemoryPool* pool = default_memory_pool(); + const int length = 50; + std::shared_ptr leaf_values, list_array, list_list_array, flat_array; + + RETURN_NOT_OK(MakeRandomInt32Array(1000, true, pool, &leaf_values)); + bool include_nulls = false; + RETURN_NOT_OK( + MakeRandomListArray(leaf_values, length, include_nulls, pool, &list_array)); + RETURN_NOT_OK( + MakeRandomListArray(list_array, length, include_nulls, pool, &list_list_array)); + RETURN_NOT_OK(MakeRandomInt32Array(length, include_nulls, pool, &flat_array)); + *out = RecordBatch::Make(schema, length, {list_array, list_list_array, flat_array}); + return Status::OK(); +} + +Status MakeDeeplyNestedList(std::shared_ptr* out) { + const int batch_length = 5; + auto type = int32(); + + MemoryPool* pool = default_memory_pool(); + std::shared_ptr array; + const bool include_nulls = true; + RETURN_NOT_OK(MakeRandomInt32Array(1000, include_nulls, pool, &array)); + for (int i = 0; i < 63; ++i) { + type = std::static_pointer_cast(list(type)); + RETURN_NOT_OK(MakeRandomListArray(array, batch_length, include_nulls, pool, &array)); + } + + auto f0 = field("f0", type); + auto schema = ::arrow::schema({f0}); + std::vector> arrays = {array}; + *out = RecordBatch::Make(schema, batch_length, arrays); + return Status::OK(); +} + +Status MakeStruct(std::shared_ptr* out) { + // reuse constructed list columns + std::shared_ptr list_batch; + RETURN_NOT_OK(MakeListRecordBatch(&list_batch)); + std::vector> columns = { + list_batch->column(0), list_batch->column(1), list_batch->column(2)}; + auto list_schema = list_batch->schema(); + + // Define schema + std::shared_ptr type(new StructType( + {list_schema->field(0), list_schema->field(1), list_schema->field(2)})); + auto f0 = field("non_null_struct", type); + auto f1 = field("null_struct", type); + auto schema = ::arrow::schema({f0, f1}); + + // construct individual nullable/non-nullable struct arrays + std::shared_ptr no_nulls(new StructArray(type, list_batch->num_rows(), columns)); + std::vector null_bytes(list_batch->num_rows(), 1); + null_bytes[0] = 0; + std::shared_ptr null_bitmask; + RETURN_NOT_OK(BitUtil::BytesToBits(null_bytes, default_memory_pool(), &null_bitmask)); + std::shared_ptr with_nulls( + new StructArray(type, list_batch->num_rows(), columns, null_bitmask, 1)); + + // construct batch + std::vector> arrays = {no_nulls, with_nulls}; + *out = RecordBatch::Make(schema, list_batch->num_rows(), arrays); + return Status::OK(); +} + +Status MakeUnion(std::shared_ptr* out) { + // Define schema + std::vector> union_types( + {field("u0", int32()), field("u1", uint8())}); + + std::vector type_codes = {5, 10}; + auto sparse_type = + std::make_shared(union_types, type_codes, UnionMode::SPARSE); + + auto dense_type = + std::make_shared(union_types, type_codes, UnionMode::DENSE); + + auto f0 = field("sparse_nonnull", sparse_type, false); + auto f1 = field("sparse", sparse_type); + auto f2 = field("dense", dense_type); + + auto schema = ::arrow::schema({f0, f1, f2}); + + // Create data + std::vector> sparse_children(2); + std::vector> dense_children(2); + + const int64_t length = 7; + + std::shared_ptr type_ids_buffer; + std::vector type_ids = {5, 10, 5, 5, 10, 10, 5}; + RETURN_NOT_OK(CopyBufferFromVector(type_ids, default_memory_pool(), &type_ids_buffer)); + + std::vector u0_values = {0, 1, 2, 3, 4, 5, 6}; + ArrayFromVector(u0_values, &sparse_children[0]); + + std::vector u1_values = {10, 11, 12, 13, 14, 15, 16}; + ArrayFromVector(u1_values, &sparse_children[1]); + + // dense children + u0_values = {0, 2, 3, 7}; + ArrayFromVector(u0_values, &dense_children[0]); + + u1_values = {11, 14, 15}; + ArrayFromVector(u1_values, &dense_children[1]); + + std::shared_ptr offsets_buffer; + std::vector offsets = {0, 0, 1, 2, 1, 2, 3}; + RETURN_NOT_OK(CopyBufferFromVector(offsets, default_memory_pool(), &offsets_buffer)); + + std::vector null_bytes(length, 1); + null_bytes[2] = 0; + std::shared_ptr null_bitmask; + RETURN_NOT_OK(BitUtil::BytesToBits(null_bytes, default_memory_pool(), &null_bitmask)); + + // construct individual nullable/non-nullable struct arrays + auto sparse_no_nulls = + std::make_shared(sparse_type, length, sparse_children, type_ids_buffer); + auto sparse = std::make_shared(sparse_type, length, sparse_children, + type_ids_buffer, NULLPTR, null_bitmask, 1); + + auto dense = + std::make_shared(dense_type, length, dense_children, type_ids_buffer, + offsets_buffer, null_bitmask, 1); + + // construct batch + std::vector> arrays = {sparse_no_nulls, sparse, dense}; + *out = RecordBatch::Make(schema, length, arrays); + return Status::OK(); +} + +Status MakeDictionary(std::shared_ptr* out) { + const int64_t length = 6; + + std::vector is_valid = {true, true, false, true, true, true}; + + auto dict1 = ArrayFromJSON(utf8(), "[\"foo\", \"bar\", \"baz\"]"); + auto dict2 = ArrayFromJSON(utf8(), "[\"foo\", \"bar\", \"baz\", \"qux\"]"); + + auto f0_type = arrow::dictionary(arrow::int32(), dict1); + auto f1_type = arrow::dictionary(arrow::int8(), dict1, true); + auto f2_type = arrow::dictionary(arrow::int32(), dict2); + + std::shared_ptr indices0, indices1, indices2; + std::vector indices0_values = {1, 2, -1, 0, 2, 0}; + std::vector indices1_values = {0, 0, 2, 2, 1, 1}; + std::vector indices2_values = {3, 0, 2, 1, 0, 2}; + + ArrayFromVector(is_valid, indices0_values, &indices0); + ArrayFromVector(is_valid, indices1_values, &indices1); + ArrayFromVector(is_valid, indices2_values, &indices2); + + auto a0 = std::make_shared(f0_type, indices0); + auto a1 = std::make_shared(f1_type, indices1); + auto a2 = std::make_shared(f2_type, indices2); + + // Lists of dictionary-encoded strings + auto f3_type = list(f1_type); + + auto indices3 = ArrayFromJSON(int8(), "[0, 1, 2, 0, 1, 1, 2, 1, 0]"); + auto offsets3 = ArrayFromJSON(int32(), "[0, 0, 2, 2, 5, 6, 9]"); + + std::shared_ptr null_bitmap; + RETURN_NOT_OK(GetBitmapFromVector(is_valid, &null_bitmap)); + + std::shared_ptr a3 = std::make_shared( + f3_type, length, std::static_pointer_cast(offsets3)->values(), + std::make_shared(f1_type, indices3), null_bitmap, 1); + + // Dictionary-encoded lists of integers + auto dict4 = ArrayFromJSON(list(int8()), "[[44, 55], [], [66]]"); + auto f4_type = dictionary(int8(), dict4); + + auto indices4 = ArrayFromJSON(int8(), "[0, 1, 2, 0, 2, 2]"); + auto a4 = std::make_shared(f4_type, indices4); + + // construct batch + auto schema = ::arrow::schema( + {field("dict1", f0_type), field("dict2", f1_type), field("dict3", f2_type), + field("list", f3_type), field("encoded list", f4_type)}); + + std::vector> arrays = {a0, a1, a2, a3, a4}; + + *out = RecordBatch::Make(schema, length, arrays); + return Status::OK(); +} + +Status MakeDictionaryFlat(std::shared_ptr* out) { + const int64_t length = 6; + + std::vector is_valid = {true, true, false, true, true, true}; + + auto dict1 = ArrayFromJSON(utf8(), "[\"foo\", \"bar\", \"baz\"]"); + auto dict2 = ArrayFromJSON(utf8(), "[\"foo\", \"bar\", \"baz\", \"qux\"]"); + + auto f0_type = arrow::dictionary(arrow::int32(), dict1); + auto f1_type = arrow::dictionary(arrow::int8(), dict1); + auto f2_type = arrow::dictionary(arrow::int32(), dict2); + + std::shared_ptr indices0, indices1, indices2; + std::vector indices0_values = {1, 2, -1, 0, 2, 0}; + std::vector indices1_values = {0, 0, 2, 2, 1, 1}; + std::vector indices2_values = {3, 0, 2, 1, 0, 2}; + + ArrayFromVector(is_valid, indices0_values, &indices0); + ArrayFromVector(is_valid, indices1_values, &indices1); + ArrayFromVector(is_valid, indices2_values, &indices2); + + auto a0 = std::make_shared(f0_type, indices0); + auto a1 = std::make_shared(f1_type, indices1); + auto a2 = std::make_shared(f2_type, indices2); + + // construct batch + auto schema = ::arrow::schema( + {field("dict1", f0_type), field("dict2", f1_type), field("dict3", f2_type)}); + + std::vector> arrays = {a0, a1, a2}; + *out = RecordBatch::Make(schema, length, arrays); + return Status::OK(); +} + +Status MakeDates(std::shared_ptr* out) { + std::vector is_valid = {true, true, true, false, true, true, true}; + auto f0 = field("f0", date32()); + auto f1 = field("f1", date64()); + auto schema = ::arrow::schema({f0, f1}); + + std::vector date32_values = {0, 1, 2, 3, 4, 5, 6}; + std::shared_ptr date32_array; + ArrayFromVector(is_valid, date32_values, &date32_array); + + std::vector date64_values = {1489269000000, 1489270000000, 1489271000000, + 1489272000000, 1489272000000, 1489273000000, + 1489274000000}; + std::shared_ptr date64_array; + ArrayFromVector(is_valid, date64_values, &date64_array); + + *out = RecordBatch::Make(schema, date32_array->length(), {date32_array, date64_array}); + return Status::OK(); +} + +Status MakeTimestamps(std::shared_ptr* out) { + std::vector is_valid = {true, true, true, false, true, true, true}; + auto f0 = field("f0", timestamp(TimeUnit::MILLI)); + auto f1 = field("f1", timestamp(TimeUnit::NANO, "America/New_York")); + auto f2 = field("f2", timestamp(TimeUnit::SECOND)); + auto schema = ::arrow::schema({f0, f1, f2}); + + std::vector ts_values = {1489269000000, 1489270000000, 1489271000000, + 1489272000000, 1489272000000, 1489273000000}; + + std::shared_ptr a0, a1, a2; + ArrayFromVector(f0->type(), is_valid, ts_values, &a0); + ArrayFromVector(f1->type(), is_valid, ts_values, &a1); + ArrayFromVector(f2->type(), is_valid, ts_values, &a2); + + *out = RecordBatch::Make(schema, a0->length(), {a0, a1, a2}); + return Status::OK(); +} + +Status MakeTimes(std::shared_ptr* out) { + std::vector is_valid = {true, true, true, false, true, true, true}; + auto f0 = field("f0", time32(TimeUnit::MILLI)); + auto f1 = field("f1", time64(TimeUnit::NANO)); + auto f2 = field("f2", time32(TimeUnit::SECOND)); + auto f3 = field("f3", time64(TimeUnit::NANO)); + auto schema = ::arrow::schema({f0, f1, f2, f3}); + + std::vector t32_values = {1489269000, 1489270000, 1489271000, + 1489272000, 1489272000, 1489273000}; + std::vector t64_values = {1489269000000, 1489270000000, 1489271000000, + 1489272000000, 1489272000000, 1489273000000}; + + std::shared_ptr a0, a1, a2, a3; + ArrayFromVector(f0->type(), is_valid, t32_values, &a0); + ArrayFromVector(f1->type(), is_valid, t64_values, &a1); + ArrayFromVector(f2->type(), is_valid, t32_values, &a2); + ArrayFromVector(f3->type(), is_valid, t64_values, &a3); + + *out = RecordBatch::Make(schema, a0->length(), {a0, a1, a2, a3}); + return Status::OK(); +} + +template +static void AppendValues(const std::vector& is_valid, const std::vector& values, + BuilderType* builder) { + for (size_t i = 0; i < values.size(); ++i) { + if (is_valid[i]) { + ASSERT_OK(builder->Append(values[i])); + } else { + ASSERT_OK(builder->AppendNull()); + } + } +} + +Status MakeFWBinary(std::shared_ptr* out) { + std::vector is_valid = {true, true, true, false}; + auto f0 = field("f0", fixed_size_binary(4)); + auto f1 = field("f1", fixed_size_binary(0)); + auto schema = ::arrow::schema({f0, f1}); + + std::shared_ptr a1, a2; + + FixedSizeBinaryBuilder b1(f0->type()); + FixedSizeBinaryBuilder b2(f1->type()); + + std::vector values1 = {"foo1", "foo2", "foo3", "foo4"}; + AppendValues(is_valid, values1, &b1); + + std::vector values2 = {"", "", "", ""}; + AppendValues(is_valid, values2, &b2); + + RETURN_NOT_OK(b1.Finish(&a1)); + RETURN_NOT_OK(b2.Finish(&a2)); + + *out = RecordBatch::Make(schema, a1->length(), {a1, a2}); + return Status::OK(); +} + +Status MakeDecimal(std::shared_ptr* out) { + constexpr int kDecimalPrecision = 38; + auto type = decimal(kDecimalPrecision, 4); + auto f0 = field("f0", type); + auto f1 = field("f1", type); + auto schema = ::arrow::schema({f0, f1}); + + constexpr int kDecimalSize = 16; + constexpr int length = 10; + + std::shared_ptr data, is_valid; + std::vector is_valid_bytes(length); + + RETURN_NOT_OK(AllocateBuffer(kDecimalSize * length, &data)); + + random_decimals(length, 1, kDecimalPrecision, data->mutable_data()); + random_null_bytes(length, 0.1, is_valid_bytes.data()); + + RETURN_NOT_OK(BitUtil::BytesToBits(is_valid_bytes, default_memory_pool(), &is_valid)); + + auto a1 = std::make_shared(f0->type(), length, data, is_valid, + kUnknownNullCount); + + auto a2 = std::make_shared(f1->type(), length, data); + + *out = RecordBatch::Make(schema, length, {a1, a2}); + return Status::OK(); +} + +Status MakeNull(std::shared_ptr* out) { + auto f0 = field("f0", null()); + + // Also put a non-null field to make sure we handle the null array buffers properly + auto f1 = field("f1", int64()); + + auto schema = ::arrow::schema({f0, f1}); + + auto a1 = std::make_shared(10); + + std::vector int_values = {0, 1, 2, 3, 4, 5, 6, 7, 8, 9}; + std::vector is_valid = {true, true, true, false, false, + true, true, true, true, true}; + std::shared_ptr a2; + ArrayFromVector(f1->type(), is_valid, int_values, &a2); + + *out = RecordBatch::Make(schema, a1->length(), {a1, a2}); + return Status::OK(); +} + +} // namespace test +} // namespace ipc +} // namespace arrow diff --git a/cpp/src/arrow/ipc/test-common.h b/cpp/src/arrow/ipc/test-common.h index 8593fbc0860..735991b3231 100644 --- a/cpp/src/arrow/ipc/test-common.h +++ b/cpp/src/arrow/ipc/test-common.h @@ -18,695 +18,110 @@ #ifndef ARROW_IPC_TEST_COMMON_H #define ARROW_IPC_TEST_COMMON_H -#include #include #include -#include -#include -#include #include "arrow/array.h" -#include "arrow/buffer.h" -#include "arrow/builder.h" -#include "arrow/memory_pool.h" -#include "arrow/pretty_print.h" #include "arrow/record_batch.h" #include "arrow/status.h" -#include "arrow/testing/gtest_util.h" -#include "arrow/testing/random.h" -#include "arrow/testing/util.h" #include "arrow/type.h" -#include "arrow/util/bit-util.h" namespace arrow { namespace ipc { +namespace test { -static inline void CompareArraysDetailed(int index, const Array& result, - const Array& expected) { - if (!expected.Equals(result)) { - std::stringstream pp_result; - std::stringstream pp_expected; - - ASSERT_OK(PrettyPrint(expected, 0, &pp_expected)); - ASSERT_OK(PrettyPrint(result, 0, &pp_result)); - - FAIL() << "Index: " << index << " Expected: " << pp_expected.str() - << "\nGot: " << pp_result.str(); - } -} - -static inline void CompareBatchColumnsDetailed(const RecordBatch& result, - const RecordBatch& expected) { - for (int i = 0; i < expected.num_columns(); ++i) { - auto left = result.column(i); - auto right = expected.column(i); - CompareArraysDetailed(i, *left, *right); - } -} - -const auto kListInt32 = list(int32()); -const auto kListListInt32 = list(kListInt32); - -static inline Status MakeRandomInt32Array(int64_t length, bool include_nulls, - MemoryPool* pool, std::shared_ptr* out, - uint32_t seed = 0) { - random::RandomArrayGenerator rand(seed); - const double null_probability = include_nulls ? 0.5 : 0.0; - - *out = rand.Int32(length, 0, 1000, null_probability); - - return Status::OK(); -} - -static inline Status MakeRandomListArray(const std::shared_ptr& child_array, - int num_lists, bool include_nulls, - MemoryPool* pool, std::shared_ptr* out) { - // Create the null list values - std::vector valid_lists(num_lists); - const double null_percent = include_nulls ? 0.1 : 0; - random_null_bytes(num_lists, null_percent, valid_lists.data()); - - // Create list offsets - const int max_list_size = 10; - - std::vector list_sizes(num_lists, 0); - std::vector offsets( - num_lists + 1, 0); // +1 so we can shift for nulls. See partial sum below. - const uint32_t seed = static_cast(child_array->length()); - - if (num_lists > 0) { - rand_uniform_int(num_lists, seed, 0, max_list_size, list_sizes.data()); - // make sure sizes are consistent with null - std::transform(list_sizes.begin(), list_sizes.end(), valid_lists.begin(), - list_sizes.begin(), - [](int32_t size, int32_t valid) { return valid == 0 ? 0 : size; }); - std::partial_sum(list_sizes.begin(), list_sizes.end(), ++offsets.begin()); - - // Force invariants - const int32_t child_length = static_cast(child_array->length()); - offsets[0] = 0; - std::replace_if(offsets.begin(), offsets.end(), - [child_length](int32_t offset) { return offset > child_length; }, - child_length); - } - - offsets[num_lists] = static_cast(child_array->length()); - - /// TODO(wesm): Implement support for nulls in ListArray::FromArrays - std::shared_ptr null_bitmap, offsets_buffer; - RETURN_NOT_OK(GetBitmapFromVector(valid_lists, &null_bitmap)); - RETURN_NOT_OK(CopyBufferFromVector(offsets, pool, &offsets_buffer)); - - *out = std::make_shared(list(child_array->type()), num_lists, offsets_buffer, - child_array, null_bitmap, kUnknownNullCount); - return ValidateArray(**out); -} - +// A typedef used for test parameterization typedef Status MakeRecordBatch(std::shared_ptr* out); -static inline Status MakeRandomBooleanArray(const int length, bool include_nulls, - std::shared_ptr* out) { - std::vector values(length); - random_null_bytes(length, 0.5, values.data()); - std::shared_ptr data; - RETURN_NOT_OK(BitUtil::BytesToBits(values, default_memory_pool(), &data)); - - if (include_nulls) { - std::vector valid_bytes(length); - std::shared_ptr null_bitmap; - RETURN_NOT_OK(BitUtil::BytesToBits(valid_bytes, default_memory_pool(), &null_bitmap)); - random_null_bytes(length, 0.1, valid_bytes.data()); - *out = std::make_shared(length, data, null_bitmap, -1); - } else { - *out = std::make_shared(length, data, NULLPTR, 0); - } - return Status::OK(); -} - -static inline Status MakeBooleanBatchSized(const int length, - std::shared_ptr* out) { - // Make the schema - auto f0 = field("f0", boolean()); - auto f1 = field("f1", boolean()); - auto schema = ::arrow::schema({f0, f1}); - - std::shared_ptr a0, a1; - RETURN_NOT_OK(MakeRandomBooleanArray(length, true, &a0)); - RETURN_NOT_OK(MakeRandomBooleanArray(length, false, &a1)); - *out = RecordBatch::Make(schema, length, {a0, a1}); - return Status::OK(); -} - -static inline Status MakeBooleanBatch(std::shared_ptr* out) { - return MakeBooleanBatchSized(1000, out); -} - -static inline Status MakeIntBatchSized(int length, std::shared_ptr* out, - uint32_t seed = 0) { - // Make the schema - auto f0 = field("f0", int32()); - auto f1 = field("f1", int32()); - auto schema = ::arrow::schema({f0, f1}); - - // Example data - std::shared_ptr a0, a1; - MemoryPool* pool = default_memory_pool(); - RETURN_NOT_OK(MakeRandomInt32Array(length, false, pool, &a0, seed)); - RETURN_NOT_OK(MakeRandomInt32Array(length, true, pool, &a1, seed + 1)); - *out = RecordBatch::Make(schema, length, {a0, a1}); - return Status::OK(); -} - -static inline Status MakeIntRecordBatch(std::shared_ptr* out) { - return MakeIntBatchSized(10, out); -} - -template -Status MakeRandomBinaryArray(int64_t length, bool include_nulls, MemoryPool* pool, - std::shared_ptr* out) { - const std::vector values = {"", "", "abc", "123", - "efg", "456!@#!@#", "12312"}; - Builder builder(pool); - const size_t values_len = values.size(); - for (int64_t i = 0; i < length; ++i) { - int64_t values_index = i % values_len; - if (include_nulls && values_index == 0) { - RETURN_NOT_OK(builder.AppendNull()); - } else { - const std::string& value = values[values_index]; - RETURN_NOT_OK(builder.Append(reinterpret_cast(value.data()), - static_cast(value.size()))); - } - } - return builder.Finish(out); -} - -template -Status MakeBinaryArrayWithUniqueValues(int64_t length, bool include_nulls, - MemoryPool* pool, std::shared_ptr* out) { - Builder builder(pool); - for (int64_t i = 0; i < length; ++i) { - if (include_nulls && (i % 7 == 0)) { - RETURN_NOT_OK(builder.AppendNull()); - } else { - const std::string value = std::to_string(i); - RETURN_NOT_OK(builder.Append(reinterpret_cast(value.data()), - static_cast(value.size()))); - } - } - return builder.Finish(out); -} - -static inline Status MakeStringTypesRecordBatch(std::shared_ptr* out, - bool with_nulls = true) { - const int64_t length = 500; - auto string_type = utf8(); - auto binary_type = binary(); - auto f0 = field("f0", string_type); - auto f1 = field("f1", binary_type); - auto schema = ::arrow::schema({f0, f1}); - - std::shared_ptr a0, a1; - MemoryPool* pool = default_memory_pool(); - - // Quirk with RETURN_NOT_OK macro and templated functions - { - auto s = MakeBinaryArrayWithUniqueValues(length, with_nulls, - pool, &a0); - RETURN_NOT_OK(s); - } - - { - auto s = MakeBinaryArrayWithUniqueValues(length, with_nulls, - pool, &a1); - RETURN_NOT_OK(s); - } - *out = RecordBatch::Make(schema, length, {a0, a1}); - return Status::OK(); -} - -static inline Status MakeStringTypesRecordBatchWithNulls( - std::shared_ptr* out) { - return MakeStringTypesRecordBatch(out, true); -} - -static inline Status MakeNullRecordBatch(std::shared_ptr* out) { - const int64_t length = 500; - auto f0 = field("f0", null()); - auto schema = ::arrow::schema({f0}); - std::shared_ptr a0 = std::make_shared(length); - *out = RecordBatch::Make(schema, length, {a0}); - return Status::OK(); -} - -static inline Status MakeListRecordBatch(std::shared_ptr* out) { - // Make the schema - auto f0 = field("f0", kListInt32); - auto f1 = field("f1", kListListInt32); - auto f2 = field("f2", int32()); - auto schema = ::arrow::schema({f0, f1, f2}); - - // Example data - - MemoryPool* pool = default_memory_pool(); - const int length = 200; - std::shared_ptr leaf_values, list_array, list_list_array, flat_array; - const bool include_nulls = true; - RETURN_NOT_OK(MakeRandomInt32Array(1000, include_nulls, pool, &leaf_values)); - RETURN_NOT_OK( - MakeRandomListArray(leaf_values, length, include_nulls, pool, &list_array)); - RETURN_NOT_OK( - MakeRandomListArray(list_array, length, include_nulls, pool, &list_list_array)); - RETURN_NOT_OK(MakeRandomInt32Array(length, include_nulls, pool, &flat_array)); - *out = RecordBatch::Make(schema, length, {list_array, list_list_array, flat_array}); - return Status::OK(); -} - -static inline Status MakeZeroLengthRecordBatch(std::shared_ptr* out) { - // Make the schema - auto f0 = field("f0", kListInt32); - auto f1 = field("f1", kListListInt32); - auto f2 = field("f2", int32()); - auto schema = ::arrow::schema({f0, f1, f2}); - - // Example data - MemoryPool* pool = default_memory_pool(); - const bool include_nulls = true; - std::shared_ptr leaf_values, list_array, list_list_array, flat_array; - RETURN_NOT_OK(MakeRandomInt32Array(0, include_nulls, pool, &leaf_values)); - RETURN_NOT_OK(MakeRandomListArray(leaf_values, 0, include_nulls, pool, &list_array)); - RETURN_NOT_OK( - MakeRandomListArray(list_array, 0, include_nulls, pool, &list_list_array)); - RETURN_NOT_OK(MakeRandomInt32Array(0, include_nulls, pool, &flat_array)); - *out = RecordBatch::Make(schema, 0, {list_array, list_list_array, flat_array}); - return Status::OK(); -} - -static inline Status MakeNonNullRecordBatch(std::shared_ptr* out) { - // Make the schema - auto f0 = field("f0", kListInt32); - auto f1 = field("f1", kListListInt32); - auto f2 = field("f2", int32()); - auto schema = ::arrow::schema({f0, f1, f2}); - - // Example data - MemoryPool* pool = default_memory_pool(); - const int length = 50; - std::shared_ptr leaf_values, list_array, list_list_array, flat_array; - - RETURN_NOT_OK(MakeRandomInt32Array(1000, true, pool, &leaf_values)); - bool include_nulls = false; - RETURN_NOT_OK( - MakeRandomListArray(leaf_values, length, include_nulls, pool, &list_array)); - RETURN_NOT_OK( - MakeRandomListArray(list_array, length, include_nulls, pool, &list_list_array)); - RETURN_NOT_OK(MakeRandomInt32Array(length, include_nulls, pool, &flat_array)); - *out = RecordBatch::Make(schema, length, {list_array, list_list_array, flat_array}); - return Status::OK(); -} - -static inline Status MakeDeeplyNestedList(std::shared_ptr* out) { - const int batch_length = 5; - auto type = int32(); - - MemoryPool* pool = default_memory_pool(); - std::shared_ptr array; - const bool include_nulls = true; - RETURN_NOT_OK(MakeRandomInt32Array(1000, include_nulls, pool, &array)); - for (int i = 0; i < 63; ++i) { - type = std::static_pointer_cast(list(type)); - RETURN_NOT_OK(MakeRandomListArray(array, batch_length, include_nulls, pool, &array)); - } - - auto f0 = field("f0", type); - auto schema = ::arrow::schema({f0}); - std::vector> arrays = {array}; - *out = RecordBatch::Make(schema, batch_length, arrays); - return Status::OK(); -} - -static inline Status MakeStruct(std::shared_ptr* out) { - // reuse constructed list columns - std::shared_ptr list_batch; - RETURN_NOT_OK(MakeListRecordBatch(&list_batch)); - std::vector> columns = { - list_batch->column(0), list_batch->column(1), list_batch->column(2)}; - auto list_schema = list_batch->schema(); - - // Define schema - std::shared_ptr type(new StructType( - {list_schema->field(0), list_schema->field(1), list_schema->field(2)})); - auto f0 = field("non_null_struct", type); - auto f1 = field("null_struct", type); - auto schema = ::arrow::schema({f0, f1}); - - // construct individual nullable/non-nullable struct arrays - std::shared_ptr no_nulls(new StructArray(type, list_batch->num_rows(), columns)); - std::vector null_bytes(list_batch->num_rows(), 1); - null_bytes[0] = 0; - std::shared_ptr null_bitmask; - RETURN_NOT_OK(BitUtil::BytesToBits(null_bytes, default_memory_pool(), &null_bitmask)); - std::shared_ptr with_nulls( - new StructArray(type, list_batch->num_rows(), columns, null_bitmask, 1)); - - // construct batch - std::vector> arrays = {no_nulls, with_nulls}; - *out = RecordBatch::Make(schema, list_batch->num_rows(), arrays); - return Status::OK(); -} - -static inline Status MakeUnion(std::shared_ptr* out) { - // Define schema - std::vector> union_types( - {field("u0", int32()), field("u1", uint8())}); - - std::vector type_codes = {5, 10}; - auto sparse_type = - std::make_shared(union_types, type_codes, UnionMode::SPARSE); - - auto dense_type = - std::make_shared(union_types, type_codes, UnionMode::DENSE); - - auto f0 = field("sparse_nonnull", sparse_type, false); - auto f1 = field("sparse", sparse_type); - auto f2 = field("dense", dense_type); - - auto schema = ::arrow::schema({f0, f1, f2}); - - // Create data - std::vector> sparse_children(2); - std::vector> dense_children(2); - - const int64_t length = 7; - - std::shared_ptr type_ids_buffer; - std::vector type_ids = {5, 10, 5, 5, 10, 10, 5}; - RETURN_NOT_OK(CopyBufferFromVector(type_ids, default_memory_pool(), &type_ids_buffer)); - - std::vector u0_values = {0, 1, 2, 3, 4, 5, 6}; - ArrayFromVector(u0_values, &sparse_children[0]); - - std::vector u1_values = {10, 11, 12, 13, 14, 15, 16}; - ArrayFromVector(u1_values, &sparse_children[1]); - - // dense children - u0_values = {0, 2, 3, 7}; - ArrayFromVector(u0_values, &dense_children[0]); - - u1_values = {11, 14, 15}; - ArrayFromVector(u1_values, &dense_children[1]); - - std::shared_ptr offsets_buffer; - std::vector offsets = {0, 0, 1, 2, 1, 2, 3}; - RETURN_NOT_OK(CopyBufferFromVector(offsets, default_memory_pool(), &offsets_buffer)); - - std::vector null_bytes(length, 1); - null_bytes[2] = 0; - std::shared_ptr null_bitmask; - RETURN_NOT_OK(BitUtil::BytesToBits(null_bytes, default_memory_pool(), &null_bitmask)); - - // construct individual nullable/non-nullable struct arrays - auto sparse_no_nulls = - std::make_shared(sparse_type, length, sparse_children, type_ids_buffer); - auto sparse = std::make_shared(sparse_type, length, sparse_children, - type_ids_buffer, NULLPTR, null_bitmask, 1); - - auto dense = - std::make_shared(dense_type, length, dense_children, type_ids_buffer, - offsets_buffer, null_bitmask, 1); - - // construct batch - std::vector> arrays = {sparse_no_nulls, sparse, dense}; - *out = RecordBatch::Make(schema, length, arrays); - return Status::OK(); -} - -static inline Status MakeDictionary(std::shared_ptr* out) { - const int64_t length = 6; - - std::vector is_valid = {true, true, false, true, true, true}; - std::shared_ptr dict1, dict2; - - std::vector dict1_values = {"foo", "bar", "baz"}; - std::vector dict2_values = {"foo", "bar", "baz", "qux"}; - - ArrayFromVector(dict1_values, &dict1); - ArrayFromVector(dict2_values, &dict2); - - auto f0_type = arrow::dictionary(arrow::int32(), dict1); - auto f1_type = arrow::dictionary(arrow::int8(), dict1, true); - auto f2_type = arrow::dictionary(arrow::int32(), dict2); - - std::shared_ptr indices0, indices1, indices2; - std::vector indices0_values = {1, 2, -1, 0, 2, 0}; - std::vector indices1_values = {0, 0, 2, 2, 1, 1}; - std::vector indices2_values = {3, 0, 2, 1, 0, 2}; - - ArrayFromVector(is_valid, indices0_values, &indices0); - ArrayFromVector(is_valid, indices1_values, &indices1); - ArrayFromVector(is_valid, indices2_values, &indices2); - - auto a0 = std::make_shared(f0_type, indices0); - auto a1 = std::make_shared(f1_type, indices1); - auto a2 = std::make_shared(f2_type, indices2); - - // List of dictionary-encoded string - auto f3_type = list(f1_type); - - std::vector list_offsets = {0, 0, 2, 2, 5, 6, 9}; - std::shared_ptr offsets, indices3; - ArrayFromVector(std::vector(list_offsets.size(), true), - list_offsets, &offsets); - - std::vector indices3_values = {0, 1, 2, 0, 1, 2, 0, 1, 2}; - std::vector is_valid3(9, true); - ArrayFromVector(is_valid3, indices3_values, &indices3); - - std::shared_ptr null_bitmap; - RETURN_NOT_OK(GetBitmapFromVector(is_valid, &null_bitmap)); - - std::shared_ptr a3 = std::make_shared( - f3_type, length, std::static_pointer_cast(offsets)->values(), - std::make_shared(f1_type, indices3), null_bitmap, 1); +ARROW_EXPORT +void CompareArraysDetailed(int index, const Array& result, const Array& expected); - // Dictionary-encoded list of integer - auto f4_value_type = list(int8()); +ARROW_EXPORT +void CompareBatchColumnsDetailed(const RecordBatch& result, const RecordBatch& expected); - std::shared_ptr offsets4, values4, indices4; +ARROW_EXPORT +Status MakeRandomInt32Array(int64_t length, bool include_nulls, MemoryPool* pool, + std::shared_ptr* out, uint32_t seed = 0); - std::vector list_offsets4 = {0, 2, 2, 3}; - ArrayFromVector(std::vector(4, true), list_offsets4, - &offsets4); +ARROW_EXPORT +Status MakeRandomListArray(const std::shared_ptr& child_array, int num_lists, + bool include_nulls, MemoryPool* pool, + std::shared_ptr* out); - std::vector list_values4 = {0, 1, 2}; - ArrayFromVector(std::vector(3, true), list_values4, &values4); +ARROW_EXPORT +Status MakeRandomBooleanArray(const int length, bool include_nulls, + std::shared_ptr* out); - auto dict3 = std::make_shared( - f4_value_type, 3, std::static_pointer_cast(offsets4)->values(), - values4); +ARROW_EXPORT +Status MakeBooleanBatchSized(const int length, std::shared_ptr* out); - std::vector indices4_values = {0, 1, 2, 0, 1, 2}; - ArrayFromVector(is_valid, indices4_values, &indices4); +ARROW_EXPORT +Status MakeBooleanBatch(std::shared_ptr* out); - auto f4_type = dictionary(int8(), dict3); - auto a4 = std::make_shared(f4_type, indices4); - - // construct batch - auto schema = ::arrow::schema( - {field("dict1", f0_type), field("sparse", f1_type), field("dense", f2_type), - field("list of encoded string", f3_type), field("encoded list", f4_type)}); - - std::vector> arrays = {a0, a1, a2, a3, a4}; - - *out = RecordBatch::Make(schema, length, arrays); - return Status::OK(); -} - -static inline Status MakeDictionaryFlat(std::shared_ptr* out) { - const int64_t length = 6; - - std::vector is_valid = {true, true, false, true, true, true}; - std::shared_ptr dict1, dict2; - - std::vector dict1_values = {"foo", "bar", "baz"}; - std::vector dict2_values = {"foo", "bar", "baz", "qux"}; - - ArrayFromVector(dict1_values, &dict1); - ArrayFromVector(dict2_values, &dict2); - - auto f0_type = arrow::dictionary(arrow::int32(), dict1); - auto f1_type = arrow::dictionary(arrow::int8(), dict1); - auto f2_type = arrow::dictionary(arrow::int32(), dict2); - - std::shared_ptr indices0, indices1, indices2; - std::vector indices0_values = {1, 2, -1, 0, 2, 0}; - std::vector indices1_values = {0, 0, 2, 2, 1, 1}; - std::vector indices2_values = {3, 0, 2, 1, 0, 2}; - - ArrayFromVector(is_valid, indices0_values, &indices0); - ArrayFromVector(is_valid, indices1_values, &indices1); - ArrayFromVector(is_valid, indices2_values, &indices2); - - auto a0 = std::make_shared(f0_type, indices0); - auto a1 = std::make_shared(f1_type, indices1); - auto a2 = std::make_shared(f2_type, indices2); - - // construct batch - auto schema = ::arrow::schema( - {field("dict1", f0_type), field("sparse", f1_type), field("dense", f2_type)}); - - std::vector> arrays = {a0, a1, a2}; - *out = RecordBatch::Make(schema, length, arrays); - return Status::OK(); -} - -static inline Status MakeDates(std::shared_ptr* out) { - std::vector is_valid = {true, true, true, false, true, true, true}; - auto f0 = field("f0", date32()); - auto f1 = field("f1", date64()); - auto schema = ::arrow::schema({f0, f1}); - - std::vector date32_values = {0, 1, 2, 3, 4, 5, 6}; - std::shared_ptr date32_array; - ArrayFromVector(is_valid, date32_values, &date32_array); - - std::vector date64_values = {1489269000000, 1489270000000, 1489271000000, - 1489272000000, 1489272000000, 1489273000000, - 1489274000000}; - std::shared_ptr date64_array; - ArrayFromVector(is_valid, date64_values, &date64_array); - - *out = RecordBatch::Make(schema, date32_array->length(), {date32_array, date64_array}); - return Status::OK(); -} - -static inline Status MakeTimestamps(std::shared_ptr* out) { - std::vector is_valid = {true, true, true, false, true, true, true}; - auto f0 = field("f0", timestamp(TimeUnit::MILLI)); - auto f1 = field("f1", timestamp(TimeUnit::NANO, "America/New_York")); - auto f2 = field("f2", timestamp(TimeUnit::SECOND)); - auto schema = ::arrow::schema({f0, f1, f2}); +ARROW_EXPORT +Status MakeIntBatchSized(int length, std::shared_ptr* out, + uint32_t seed = 0); - std::vector ts_values = {1489269000000, 1489270000000, 1489271000000, - 1489272000000, 1489272000000, 1489273000000}; +ARROW_EXPORT +Status MakeIntRecordBatch(std::shared_ptr* out); - std::shared_ptr a0, a1, a2; - ArrayFromVector(f0->type(), is_valid, ts_values, &a0); - ArrayFromVector(f1->type(), is_valid, ts_values, &a1); - ArrayFromVector(f2->type(), is_valid, ts_values, &a2); +ARROW_EXPORT +Status MakeRandomStringArray(int64_t length, bool include_nulls, MemoryPool* pool, + std::shared_ptr* out); - *out = RecordBatch::Make(schema, a0->length(), {a0, a1, a2}); - return Status::OK(); -} +ARROW_EXPORT +Status MakeStringTypesRecordBatch(std::shared_ptr* out, + bool with_nulls = true); -static inline Status MakeTimes(std::shared_ptr* out) { - std::vector is_valid = {true, true, true, false, true, true, true}; - auto f0 = field("f0", time32(TimeUnit::MILLI)); - auto f1 = field("f1", time64(TimeUnit::NANO)); - auto f2 = field("f2", time32(TimeUnit::SECOND)); - auto f3 = field("f3", time64(TimeUnit::NANO)); - auto schema = ::arrow::schema({f0, f1, f2, f3}); +ARROW_EXPORT +Status MakeStringTypesRecordBatchWithNulls(std::shared_ptr* out); - std::vector t32_values = {1489269000, 1489270000, 1489271000, - 1489272000, 1489272000, 1489273000}; - std::vector t64_values = {1489269000000, 1489270000000, 1489271000000, - 1489272000000, 1489272000000, 1489273000000}; - - std::shared_ptr a0, a1, a2, a3; - ArrayFromVector(f0->type(), is_valid, t32_values, &a0); - ArrayFromVector(f1->type(), is_valid, t64_values, &a1); - ArrayFromVector(f2->type(), is_valid, t32_values, &a2); - ArrayFromVector(f3->type(), is_valid, t64_values, &a3); +ARROW_EXPORT +Status MakeNullRecordBatch(std::shared_ptr* out); - *out = RecordBatch::Make(schema, a0->length(), {a0, a1, a2, a3}); - return Status::OK(); -} +ARROW_EXPORT +Status MakeListRecordBatch(std::shared_ptr* out); -template -void AppendValues(const std::vector& is_valid, const std::vector& values, - BuilderType* builder) { - for (size_t i = 0; i < values.size(); ++i) { - if (is_valid[i]) { - ASSERT_OK(builder->Append(values[i])); - } else { - ASSERT_OK(builder->AppendNull()); - } - } -} - -static inline Status MakeFWBinary(std::shared_ptr* out) { - std::vector is_valid = {true, true, true, false}; - auto f0 = field("f0", fixed_size_binary(4)); - auto f1 = field("f1", fixed_size_binary(0)); - auto schema = ::arrow::schema({f0, f1}); - - std::shared_ptr a1, a2; - - FixedSizeBinaryBuilder b1(f0->type()); - FixedSizeBinaryBuilder b2(f1->type()); - - std::vector values1 = {"foo1", "foo2", "foo3", "foo4"}; - AppendValues(is_valid, values1, &b1); - - std::vector values2 = {"", "", "", ""}; - AppendValues(is_valid, values2, &b2); - - RETURN_NOT_OK(b1.Finish(&a1)); - RETURN_NOT_OK(b2.Finish(&a2)); - - *out = RecordBatch::Make(schema, a1->length(), {a1, a2}); - return Status::OK(); -} - -static inline Status MakeDecimal(std::shared_ptr* out) { - constexpr int kDecimalPrecision = 38; - auto type = decimal(kDecimalPrecision, 4); - auto f0 = field("f0", type); - auto f1 = field("f1", type); - auto schema = ::arrow::schema({f0, f1}); +ARROW_EXPORT +Status MakeZeroLengthRecordBatch(std::shared_ptr* out); - constexpr int kDecimalSize = 16; - constexpr int length = 10; - - std::shared_ptr data, is_valid; - std::vector is_valid_bytes(length); - - RETURN_NOT_OK(AllocateBuffer(kDecimalSize * length, &data)); +ARROW_EXPORT +Status MakeNonNullRecordBatch(std::shared_ptr* out); - random_decimals(length, 1, kDecimalPrecision, data->mutable_data()); - random_null_bytes(length, 0.1, is_valid_bytes.data()); +ARROW_EXPORT +Status MakeDeeplyNestedList(std::shared_ptr* out); - RETURN_NOT_OK(BitUtil::BytesToBits(is_valid_bytes, default_memory_pool(), &is_valid)); +ARROW_EXPORT +Status MakeStruct(std::shared_ptr* out); - auto a1 = std::make_shared(f0->type(), length, data, is_valid, - kUnknownNullCount); +ARROW_EXPORT +Status MakeUnion(std::shared_ptr* out); - auto a2 = std::make_shared(f1->type(), length, data); +ARROW_EXPORT +Status MakeDictionary(std::shared_ptr* out); - *out = RecordBatch::Make(schema, length, {a1, a2}); - return Status::OK(); -} +ARROW_EXPORT +Status MakeDictionaryFlat(std::shared_ptr* out); -static inline Status MakeNull(std::shared_ptr* out) { - auto f0 = field("f0", null()); +ARROW_EXPORT +Status MakeDates(std::shared_ptr* out); - // Also put a non-null field to make sure we handle the null array buffers properly - auto f1 = field("f1", int64()); +ARROW_EXPORT +Status MakeTimestamps(std::shared_ptr* out); - auto schema = ::arrow::schema({f0, f1}); +ARROW_EXPORT +Status MakeTimes(std::shared_ptr* out); - auto a1 = std::make_shared(10); +ARROW_EXPORT +Status MakeFWBinary(std::shared_ptr* out); - std::vector int_values = {0, 1, 2, 3, 4, 5, 6, 7, 8, 9}; - std::vector is_valid = {true, true, true, false, false, - true, true, true, true, true}; - std::shared_ptr a2; - ArrayFromVector(f1->type(), is_valid, int_values, &a2); +ARROW_EXPORT +Status MakeDecimal(std::shared_ptr* out); - *out = RecordBatch::Make(schema, a1->length(), {a1, a2}); - return Status::OK(); -} +ARROW_EXPORT +Status MakeNull(std::shared_ptr* out); +} // namespace test } // namespace ipc } // namespace arrow diff --git a/cpp/src/arrow/ipc/writer.cc b/cpp/src/arrow/ipc/writer.cc index ba9939016f1..bc89dc48da4 100644 --- a/cpp/src/arrow/ipc/writer.cc +++ b/cpp/src/arrow/ipc/writer.cc @@ -22,6 +22,7 @@ #include #include #include +#include #include #include "arrow/array.h" @@ -43,12 +44,14 @@ #include "arrow/util/bit-util.h" #include "arrow/util/checked_cast.h" #include "arrow/util/logging.h" +#include "arrow/util/stl.h" #include "arrow/visitor.h" namespace arrow { using internal::checked_cast; using internal::CopyBitmap; +using internal::make_unique; namespace ipc { @@ -529,17 +532,46 @@ Status WriteIpcPayload(const IpcPayload& payload, io::OutputStream* dst, return Status::OK(); } -Status GetSchemaPayload(const Schema& schema, MemoryPool* pool, - DictionaryMemo* dictionary_memo, IpcPayload* out) { - out->type = Message::Type::SCHEMA; - out->body_buffers.clear(); - out->body_length = 0; - RETURN_NOT_OK(SerializeSchema(schema, pool, &out->metadata)); - return WriteSchemaMessage(schema, dictionary_memo, &out->metadata); +Status GetSchemaPayloads(const Schema& schema, MemoryPool* pool, DictionaryMemo* out_memo, + std::vector* out_payloads) { + DictionaryMemo dictionary_memo; + IpcPayload payload; + + out_payloads->clear(); + payload.type = Message::SCHEMA; + RETURN_NOT_OK(WriteSchemaMessage(schema, &dictionary_memo, &payload.metadata)); + out_payloads->push_back(std::move(payload)); + out_payloads->reserve(dictionary_memo.size() + 1); + + // Append dictionaries + for (auto& pair : dictionary_memo.id_to_dictionary()) { + int64_t dictionary_id = pair.first; + const auto& dictionary = pair.second; + + // Frame of reference is 0, see ARROW-384 + const int64_t buffer_start_offset = 0; + payload.type = Message::DICTIONARY_BATCH; + DictionaryWriter writer(dictionary_id, pool, buffer_start_offset, kMaxNestingDepth, + true /* allow_64bit */, &payload); + RETURN_NOT_OK(writer.Assemble(dictionary)); + out_payloads->push_back(std::move(payload)); + } + + if (out_memo != nullptr) { + *out_memo = std::move(dictionary_memo); + } + + return Status::OK(); +} + +Status GetSchemaPayloads(const Schema& schema, MemoryPool* pool, + std::vector* out_payloads) { + return GetSchemaPayloads(schema, pool, nullptr, out_payloads); } Status GetRecordBatchPayload(const RecordBatch& batch, MemoryPool* pool, IpcPayload* out) { + out->type = Message::RECORD_BATCH; RecordBatchSerializer writer(pool, 0, kMaxNestingDepth, true, out); return writer.Assemble(batch); } @@ -846,11 +878,93 @@ Status RecordBatchWriter::WriteTable(const Table& table, int64_t max_chunksize) Status RecordBatchWriter::WriteTable(const Table& table) { return WriteTable(table, -1); } // ---------------------------------------------------------------------- -// Stream writer implementation +// Payload writer implementation + +namespace internal { + +IpcPayloadWriter::~IpcPayloadWriter() {} + +Status IpcPayloadWriter::Start() { return Status::OK(); } + +} // namespace internal + +namespace { + +/// A RecordBatchWriter implementation that writes to a IpcPayloadWriter. +class RecordBatchPayloadWriter : public RecordBatchWriter { + public: + ~RecordBatchPayloadWriter() override = default; + + RecordBatchPayloadWriter(std::unique_ptr payload_writer, + const Schema& schema) + : payload_writer_(std::move(payload_writer)), + schema_(schema), + pool_(default_memory_pool()), + started_(false) {} + + // A Schema-owning constructor variant + RecordBatchPayloadWriter(std::unique_ptr payload_writer, + const std::shared_ptr& schema) + : payload_writer_(std::move(payload_writer)), + shared_schema_(schema), + schema_(*schema), + pool_(default_memory_pool()), + started_(false) {} + + Status WriteRecordBatch(const RecordBatch& batch, bool allow_64bit = false) override { + if (!batch.schema()->Equals(schema_, false /* check_metadata */)) { + return Status::Invalid("Tried to write record batch with different schema"); + } + + RETURN_NOT_OK(CheckStarted()); + internal::IpcPayload payload; + RETURN_NOT_OK(GetRecordBatchPayload(batch, pool_, &payload)); + return payload_writer_->WritePayload(payload); + } + + Status Close() override { + RETURN_NOT_OK(CheckStarted()); + return payload_writer_->Close(); + } + + void set_memory_pool(MemoryPool* pool) override { pool_ = pool; } + + Status Start() { + started_ = true; + RETURN_NOT_OK(payload_writer_->Start()); + + // Write out schema payloads + std::vector payloads; + // XXX should we have a GetSchemaPayloads() variant that generates them + // one by one, to minimize memory usage? + RETURN_NOT_OK(GetSchemaPayloads(schema_, pool_, &payloads)); + for (const auto& payload : payloads) { + RETURN_NOT_OK(payload_writer_->WritePayload(payload)); + } + return Status::OK(); + } + + protected: + Status CheckStarted() { + if (!started_) { + return Start(); + } + return Status::OK(); + } + + protected: + std::unique_ptr payload_writer_; + std::shared_ptr shared_schema_; + const Schema& schema_; + MemoryPool* pool_; + bool started_; +}; + +// ---------------------------------------------------------------------- +// Stream and file writer implementation class StreamBookKeeper { public: - StreamBookKeeper() : sink_(nullptr), position_(-1) {} explicit StreamBookKeeper(io::OutputStream* sink) : sink_(sink), position_(-1) {} Status UpdatePosition() { return sink_->Tell(&position_); } @@ -883,142 +997,131 @@ class StreamBookKeeper { int64_t position_; }; -class SchemaWriter : public StreamBookKeeper { +/// A IpcPayloadWriter implementation that writes to a IPC stream +/// (with an end-of-stream marker) +class PayloadStreamWriter : public internal::IpcPayloadWriter, + protected StreamBookKeeper { public: - SchemaWriter(const Schema& schema, DictionaryMemo* dictionary_memo, MemoryPool* pool, - io::OutputStream* sink) - : StreamBookKeeper(sink), - pool_(pool), - schema_(schema), - dictionary_memo_(dictionary_memo) {} + explicit PayloadStreamWriter(io::OutputStream* sink) : StreamBookKeeper(sink) {} + + ~PayloadStreamWriter() override = default; - Status WriteSchema() { + Status WritePayload(const internal::IpcPayload& payload) override { #ifndef NDEBUG // Catch bug fixed in ARROW-3236 RETURN_NOT_OK(UpdatePositionCheckAligned()); #endif - std::shared_ptr schema_fb; - RETURN_NOT_OK(internal::WriteSchemaMessage(schema_, dictionary_memo_, &schema_fb)); - - int32_t metadata_length = 0; - RETURN_NOT_OK(internal::WriteMessage(*schema_fb, 8, sink_, &metadata_length)); + int32_t metadata_length = 0; // unused + RETURN_NOT_OK(WriteIpcPayload(payload, sink_, &metadata_length)); RETURN_NOT_OK(UpdatePositionCheckAligned()); return Status::OK(); } - Status WriteDictionaries(std::vector* dictionaries) { - const DictionaryMap& id_to_dictionary = dictionary_memo_->id_to_dictionary(); - - dictionaries->resize(id_to_dictionary.size()); - - // TODO(wesm): does sorting by id yield any benefit? - int dict_index = 0; - for (const auto& entry : id_to_dictionary) { - FileBlock* block = &(*dictionaries)[dict_index++]; - - block->offset = position_; - - // Frame of reference in file format is 0, see ARROW-384 - const int64_t buffer_start_offset = 0; - RETURN_NOT_OK(WriteDictionary(entry.first, entry.second, buffer_start_offset, sink_, - &block->metadata_length, &block->body_length, pool_)); - RETURN_NOT_OK(UpdatePositionCheckAligned()); - } - - return Status::OK(); - } - - Status Write(std::vector* dictionaries) { - RETURN_NOT_OK(WriteSchema()); - - // If there are any dictionaries, write them as the next messages - return WriteDictionaries(dictionaries); + Status Close() override { + // Write 0 EOS message + const int32_t kEos = 0; + return Write(&kEos, sizeof(int32_t)); } - - private: - MemoryPool* pool_; - const Schema& schema_; - DictionaryMemo* dictionary_memo_; }; -class RecordBatchStreamWriter::RecordBatchStreamWriterImpl : public StreamBookKeeper { +/// A IpcPayloadWriter implementation that writes to a IPC file +/// (with a footer as defined in File.fbs) +class PayloadFileWriter : public internal::IpcPayloadWriter, protected StreamBookKeeper { public: - RecordBatchStreamWriterImpl(io::OutputStream* sink, - const std::shared_ptr& schema) - : StreamBookKeeper(sink), - schema_(schema), - pool_(default_memory_pool()), - started_(false) {} + PayloadFileWriter(io::OutputStream* sink, const std::shared_ptr& schema) + : StreamBookKeeper(sink), schema_(schema) {} - virtual ~RecordBatchStreamWriterImpl() = default; + ~PayloadFileWriter() override = default; - virtual Status Start() { - SchemaWriter schema_writer(*schema_, &dictionary_memo_, pool_, sink_); - RETURN_NOT_OK(schema_writer.Write(&dictionaries_)); - started_ = true; - return Status::OK(); - } - - virtual Status Close() { - // Write the schema if not already written - // User is responsible for closing the OutputStream - RETURN_NOT_OK(CheckStarted()); + Status WritePayload(const internal::IpcPayload& payload) override { +#ifndef NDEBUG + // Catch bug fixed in ARROW-3236 + RETURN_NOT_OK(UpdatePositionCheckAligned()); +#endif - // Write 0 EOS message - const int32_t kEos = 0; - return Write(&kEos, sizeof(int32_t)); - } + // Metadata length must include padding, it's computed by WriteIpcPayload() + FileBlock block = {position_, 0, payload.body_length}; + RETURN_NOT_OK(WriteIpcPayload(payload, sink_, &block.metadata_length)); + RETURN_NOT_OK(UpdatePositionCheckAligned()); - Status CheckStarted() { - if (!started_) { - return Start(); + // Record position and size of some message types, to list them in the footer + switch (payload.type) { + case Message::DICTIONARY_BATCH: + dictionaries_.push_back(block); + break; + case Message::RECORD_BATCH: + record_batches_.push_back(block); + break; + default: + break; } + return Status::OK(); } - Status WriteRecordBatch(const RecordBatch& batch, bool allow_64bit, FileBlock* block) { - RETURN_NOT_OK(CheckStarted()); + Status Start() override { + // ARROW-3236: The initial position -1 needs to be updated to the stream's + // current position otherwise an incorrect amount of padding will be + // written to new files. RETURN_NOT_OK(UpdatePosition()); - block->offset = position_; - - // Frame of reference in file format is 0, see ARROW-384 - const int64_t buffer_start_offset = 0; - RETURN_NOT_OK(arrow::ipc::WriteRecordBatch( - batch, buffer_start_offset, sink_, &block->metadata_length, &block->body_length, - pool_, kMaxNestingDepth, allow_64bit)); - RETURN_NOT_OK(UpdatePositionCheckAligned()); + // It is only necessary to align to 8-byte boundary at the start of the file + RETURN_NOT_OK(Write(kArrowMagicBytes, strlen(kArrowMagicBytes))); + RETURN_NOT_OK(Align()); return Status::OK(); } - Status WriteRecordBatch(const RecordBatch& batch, bool allow_64bit) { - // Push an empty FileBlock. Can be written in the footer later - if (!batch.schema()->Equals(*schema_, false /* check_metadata */)) { - return Status::Invalid("Tried to write record batch with different schema"); + Status Close() override { + // Write file footer + RETURN_NOT_OK(UpdatePosition()); + int64_t initial_position = position_; + RETURN_NOT_OK(WriteFileFooter(*schema_, dictionaries_, record_batches_, sink_)); + + // Write footer length + RETURN_NOT_OK(UpdatePosition()); + int32_t footer_length = static_cast(position_ - initial_position); + if (footer_length <= 0) { + return Status::Invalid("Invalid file footer"); } - record_batches_.push_back({0, 0, 0}); - return WriteRecordBatch(batch, allow_64bit, - &record_batches_[record_batches_.size() - 1]); - } + RETURN_NOT_OK(Write(&footer_length, sizeof(int32_t))); - void set_memory_pool(MemoryPool* pool) { pool_ = pool; } + // Write magic bytes to end file + return Write(kArrowMagicBytes, strlen(kArrowMagicBytes)); + } protected: std::shared_ptr schema_; - MemoryPool* pool_; - bool started_; - - // When writing out the schema, we keep track of all the dictionaries we - // encounter, as they must be written out first in the stream - DictionaryMemo dictionary_memo_; - std::vector dictionaries_; std::vector record_batches_; }; +} // namespace + +class RecordBatchStreamWriter::RecordBatchStreamWriterImpl + : public RecordBatchPayloadWriter { + public: + RecordBatchStreamWriterImpl(io::OutputStream* sink, + const std::shared_ptr& schema) + : RecordBatchPayloadWriter( + std::unique_ptr(new PayloadStreamWriter(sink)), + schema) {} + + ~RecordBatchStreamWriterImpl() = default; +}; + +class RecordBatchFileWriter::RecordBatchFileWriterImpl : public RecordBatchPayloadWriter { + public: + RecordBatchFileWriterImpl(io::OutputStream* sink, const std::shared_ptr& schema) + : RecordBatchPayloadWriter(std::unique_ptr( + new PayloadFileWriter(sink, schema)), + schema) {} + + ~RecordBatchFileWriterImpl() = default; +}; + RecordBatchStreamWriter::RecordBatchStreamWriter() {} RecordBatchStreamWriter::~RecordBatchStreamWriter() {} @@ -1044,59 +1147,6 @@ Status RecordBatchStreamWriter::Open(io::OutputStream* sink, Status RecordBatchStreamWriter::Close() { return impl_->Close(); } -// ---------------------------------------------------------------------- -// File writer implementation - -class RecordBatchFileWriter::RecordBatchFileWriterImpl - : public RecordBatchStreamWriter::RecordBatchStreamWriterImpl { - public: - using BASE = RecordBatchStreamWriter::RecordBatchStreamWriterImpl; - - RecordBatchFileWriterImpl(io::OutputStream* sink, const std::shared_ptr& schema) - : BASE(sink, schema) {} - - Status Start() override { - // ARROW-3236: The initial position -1 needs to be updated to the stream's - // current position otherwise an incorrect amount of padding will be - // written to new files. - RETURN_NOT_OK(UpdatePosition()); - - // It is only necessary to align to 8-byte boundary at the start of the file - RETURN_NOT_OK(Write(kArrowMagicBytes, strlen(kArrowMagicBytes))); - RETURN_NOT_OK(Align()); - - // We write the schema at the start of the file (and the end). This also - // writes all the dictionaries at the beginning of the file - return BASE::Start(); - } - - Status Close() override { - // Write the schema if not already written - // User is responsible for closing the OutputStream - RETURN_NOT_OK(CheckStarted()); - - // Write metadata - RETURN_NOT_OK(UpdatePosition()); - - int64_t initial_position = position_; - RETURN_NOT_OK(WriteFileFooter(*schema_, dictionaries_, record_batches_, - &dictionary_memo_, sink_)); - RETURN_NOT_OK(UpdatePosition()); - - // Write footer length - int32_t footer_length = static_cast(position_ - initial_position); - - if (footer_length <= 0) { - return Status::Invalid("Invalid file footer"); - } - - RETURN_NOT_OK(Write(&footer_length, sizeof(int32_t))); - - // Write magic bytes to end file - return Write(kArrowMagicBytes, strlen(kArrowMagicBytes)); - } -}; - RecordBatchFileWriter::RecordBatchFileWriter() {} RecordBatchFileWriter::~RecordBatchFileWriter() {} @@ -1118,6 +1168,18 @@ Status RecordBatchFileWriter::WriteRecordBatch(const RecordBatch& batch, Status RecordBatchFileWriter::Close() { return file_impl_->Close(); } +namespace internal { + +Status OpenRecordBatchWriter(std::unique_ptr sink, + const std::shared_ptr& schema, + std::unique_ptr* out) { + out->reset(new RecordBatchPayloadWriter(std::move(sink), schema)); + // XXX should we call Start()? + return Status::OK(); +} + +} // namespace internal + // ---------------------------------------------------------------------- // Serialization public APIs @@ -1142,18 +1204,20 @@ Status SerializeRecordBatch(const RecordBatch& batch, MemoryPool* pool, kMaxNestingDepth, true); } +// TODO: this function also serializes dictionaries. This is suboptimal for +// the purpose of transmitting working set metadata without actually sending +// the data (e.g. ListFlights() in Flight RPC). + Status SerializeSchema(const Schema& schema, MemoryPool* pool, std::shared_ptr* out) { std::shared_ptr stream; RETURN_NOT_OK(io::BufferOutputStream::Create(1024, pool, &stream)); - DictionaryMemo memo; - SchemaWriter schema_writer(schema, &memo, pool, stream.get()); - - // Unused - std::vector dictionary_blocks; + auto payload_writer = make_unique(stream.get()); + RecordBatchPayloadWriter writer(std::move(payload_writer), schema); + // Write out schema and dictionaries + RETURN_NOT_OK(writer.Start()); - RETURN_NOT_OK(schema_writer.Write(&dictionary_blocks)); return stream->Finish(out); } diff --git a/cpp/src/arrow/ipc/writer.h b/cpp/src/arrow/ipc/writer.h index 50872e9377b..75034ea9ae9 100644 --- a/cpp/src/arrow/ipc/writer.h +++ b/cpp/src/arrow/ipc/writer.h @@ -302,28 +302,48 @@ namespace internal { // Intermediate data structure with metadata header, and zero or more buffers // for the message body. struct IpcPayload { - Message::Type type; + Message::Type type = Message::NONE; std::shared_ptr metadata; std::vector> body_buffers; - int64_t body_length; + int64_t body_length = 0; }; -/// \brief Extract IPC payloads from given schema for purposes of wire -/// transport, separate from using the *StreamWriter classes +class ARROW_EXPORT IpcPayloadWriter { + public: + virtual ~IpcPayloadWriter(); + + // Default implementation is a no-op + virtual Status Start(); + + virtual Status WritePayload(const IpcPayload& payload) = 0; + + virtual Status Close() = 0; +}; + +/// Create a new RecordBatchWriter from IpcPayloadWriter and schema. +/// +/// \param[in] sink the IpcPayloadWriter to write to +/// \param[in] schema the schema of the record batches to be written +/// \param[out] out the created RecordBatchWriter +/// \return Status ARROW_EXPORT -Status GetDictionaryPayloads(const Schema& schema, - std::vector>* out); +Status OpenRecordBatchWriter(std::unique_ptr sink, + const std::shared_ptr& schema, + std::unique_ptr* out); -/// \brief Compute IpcPayload for the given schema +/// \brief Compute IpcPayloads for the given schema /// \param[in] schema the Schema that is being serialized /// \param[in,out] pool for any required temporary memory allocations /// \param[in,out] dictionary_memo class for tracking dictionaries and assigning /// dictionary ids -/// \param[out] out the returned IpcPayload +/// \param[out] out the returned vector of IpcPayloads /// \return Status ARROW_EXPORT -Status GetSchemaPayload(const Schema& schema, MemoryPool* pool, - DictionaryMemo* dictionary_memo, IpcPayload* out); +Status GetSchemaPayloads(const Schema& schema, MemoryPool* pool, + DictionaryMemo* dictionary_memo, std::vector* out); +ARROW_EXPORT +Status GetSchemaPayloads(const Schema& schema, MemoryPool* pool, + std::vector* out); /// \brief Compute IpcPayload for the given record batch /// \param[in] batch the RecordBatch that is being serialized diff --git a/cpp/src/arrow/python/arrow_to_pandas.cc b/cpp/src/arrow/python/arrow_to_pandas.cc index f5810a7120d..01fb29d7ff4 100644 --- a/cpp/src/arrow/python/arrow_to_pandas.cc +++ b/cpp/src/arrow/python/arrow_to_pandas.cc @@ -1889,11 +1889,8 @@ class ArrowDeserializer { return Status::OK(); } - Status Visit(const UnionType& type) { return Status::NotImplemented("union type"); } - - Status Visit(const ExtensionType& type) { - return Status::NotImplemented("extension type"); - } + // Default case + Status Visit(const DataType& type) { return Status::NotImplemented(type.name()); } Status Convert(PyObject** out) { RETURN_NOT_OK(VisitTypeInline(*col_->type(), this)); diff --git a/cpp/src/arrow/python/numpy_to_arrow.cc b/cpp/src/arrow/python/numpy_to_arrow.cc index 36f3ccb3f8e..ca3f5969eba 100644 --- a/cpp/src/arrow/python/numpy_to_arrow.cc +++ b/cpp/src/arrow/python/numpy_to_arrow.cc @@ -234,13 +234,8 @@ class NumPyConverter { Status Visit(const FixedSizeBinaryType& type); - Status Visit(const Decimal128Type& type) { return TypeNotImplemented(type.ToString()); } - - Status Visit(const DictionaryType& type) { return TypeNotImplemented(type.ToString()); } - - Status Visit(const NestedType& type) { return TypeNotImplemented(type.ToString()); } - - Status Visit(const ExtensionType& type) { return TypeNotImplemented(type.ToString()); } + // Default case + Status Visit(const DataType& type) { return TypeNotImplemented(type.ToString()); } protected: Status InitNullBitmap() { diff --git a/cpp/src/gandiva/expression_registry.cc b/cpp/src/gandiva/expression_registry.cc index 8e667f8ad8a..d0629635530 100644 --- a/cpp/src/gandiva/expression_registry.cc +++ b/cpp/src/gandiva/expression_registry.cc @@ -139,15 +139,8 @@ void ExpressionRegistry::AddArrowTypesToVector(arrow::Type::type& type, case arrow::Type::type::DECIMAL: vector.push_back(arrow::decimal(38, 0)); break; - case arrow::Type::type::FIXED_SIZE_BINARY: - case arrow::Type::type::MAP: - case arrow::Type::type::INTERVAL: - case arrow::Type::type::LIST: - case arrow::Type::type::STRUCT: - case arrow::Type::type::UNION: - case arrow::Type::type::DICTIONARY: - case arrow::Type::type::EXTENSION: - // un-supported types. test ensures that + default: + // Unsupported types. test ensures that // when one of these are added build breaks. DCHECK(false); } diff --git a/cpp/src/gandiva/jni/expression_registry_helper.cc b/cpp/src/gandiva/jni/expression_registry_helper.cc index 2275641301b..7b7834d5a42 100644 --- a/cpp/src/gandiva/jni/expression_registry_helper.cc +++ b/cpp/src/gandiva/jni/expression_registry_helper.cc @@ -127,14 +127,7 @@ void ArrowToProtobuf(DataTypePtr type, types::ExtGandivaType* gandiva_data_type) gandiva_data_type->set_scale(0); break; } - case arrow::Type::type::FIXED_SIZE_BINARY: - case arrow::Type::type::MAP: - case arrow::Type::type::INTERVAL: - case arrow::Type::type::LIST: - case arrow::Type::type::STRUCT: - case arrow::Type::type::UNION: - case arrow::Type::type::DICTIONARY: - case arrow::Type::type::EXTENSION: + default: // un-supported types. test ensures that // when one of these are added build breaks. DCHECK(false);