diff --git a/c_glib/arrow-glib/reader.cpp b/c_glib/arrow-glib/reader.cpp index 3713a264b19..3190d240026 100644 --- a/c_glib/arrow-glib/reader.cpp +++ b/c_glib/arrow-glib/reader.cpp @@ -296,14 +296,12 @@ GArrowRecordBatchStreamReader * garrow_record_batch_stream_reader_new(GArrowInputStream *stream, GError **error) { - using BaseType = arrow::ipc::RecordBatchReader; using ReaderType = arrow::ipc::RecordBatchStreamReader; auto arrow_input_stream = garrow_input_stream_get_raw(stream); - std::shared_ptr arrow_reader; - auto status = ReaderType::Open(arrow_input_stream, &arrow_reader); - if (garrow_error_check(error, status, "[record-batch-stream-reader][open]")) { - auto subtype = std::dynamic_pointer_cast(arrow_reader); + auto arrow_reader = ReaderType::Open(arrow_input_stream); + if (garrow::check(error, arrow_reader, "[record-batch-stream-reader][open]")) { + auto subtype = std::dynamic_pointer_cast(*arrow_reader); return garrow_record_batch_stream_reader_new_raw(&subtype); } else { return NULL; @@ -411,14 +409,13 @@ GArrowRecordBatchFileReader * garrow_record_batch_file_reader_new(GArrowSeekableInputStream *file, GError **error) { - auto arrow_random_access_file = garrow_seekable_input_stream_get_raw(file); + using ReaderType = arrow::ipc::RecordBatchFileReader; - std::shared_ptr arrow_reader; - auto status = - arrow::ipc::RecordBatchFileReader::Open(arrow_random_access_file, - &arrow_reader); - if (garrow_error_check(error, status, "[record-batch-file-reader][open]")) { - return garrow_record_batch_file_reader_new_raw(&arrow_reader); + auto arrow_random_access_file = garrow_seekable_input_stream_get_raw(file); + auto arrow_reader = ReaderType::Open(arrow_random_access_file); + if (garrow::check(error, arrow_reader, "[record-batch-file-reader][open]")) { + auto subtype = std::dynamic_pointer_cast(*arrow_reader); + return garrow_record_batch_file_reader_new_raw(&subtype); } else { return NULL; } diff --git a/c_glib/arrow-glib/writer.cpp b/c_glib/arrow-glib/writer.cpp index 4aa2ba485f0..8e9132c5a89 100644 --- a/c_glib/arrow-glib/writer.cpp +++ b/c_glib/arrow-glib/writer.cpp @@ -238,17 +238,15 @@ garrow_record_batch_stream_writer_new(GArrowOutputStream *sink, GArrowSchema *schema, GError **error) { - using BaseType = arrow::ipc::RecordBatchWriter; - using WriterType = arrow::ipc::RecordBatchStreamWriter; - auto arrow_sink = garrow_output_stream_get_raw(sink).get(); - std::shared_ptr arrow_writer; - auto status = WriterType::Open(arrow_sink, - garrow_schema_get_raw(schema), - &arrow_writer); - if (garrow_error_check(error, status, "[record-batch-stream-writer][open]")) { - auto subtype = std::dynamic_pointer_cast(arrow_writer); - return garrow_record_batch_stream_writer_new_raw(&subtype); + auto arrow_schema = garrow_schema_get_raw(schema); + auto arrow_writer_result = + arrow::ipc::NewStreamWriter(arrow_sink, arrow_schema); + if (garrow::check(error, + arrow_writer_result, + "[record-batch-stream-writer][open]")) { + auto arrow_writer = *arrow_writer_result; + return garrow_record_batch_stream_writer_new_raw(&arrow_writer); } else { return NULL; } @@ -285,17 +283,16 @@ garrow_record_batch_file_writer_new(GArrowOutputStream *sink, GArrowSchema *schema, GError **error) { - using BaseType = arrow::ipc::RecordBatchWriter; - using WriterType = arrow::ipc::RecordBatchFileWriter; - - auto arrow_sink = garrow_output_stream_get_raw(sink); - std::shared_ptr arrow_writer; - auto status = WriterType::Open(arrow_sink.get(), - garrow_schema_get_raw(schema), - &arrow_writer); - if (garrow_error_check(error, status, "[record-batch-file-writer][open]")) { - auto subtype = std::dynamic_pointer_cast(arrow_writer); - return garrow_record_batch_file_writer_new_raw(&subtype); + auto arrow_sink = garrow_output_stream_get_raw(sink).get(); + auto arrow_schema = garrow_schema_get_raw(schema); + std::shared_ptr arrow_writer; + auto arrow_writer_result = + arrow::ipc::NewFileWriter(arrow_sink, arrow_schema); + if (garrow::check(error, + arrow_writer_result, + "[record-batch-file-writer][open]")) { + auto arrow_writer = *arrow_writer_result; + return garrow_record_batch_file_writer_new_raw(&arrow_writer); } else { return NULL; } @@ -529,7 +526,7 @@ garrow_record_batch_writer_get_raw(GArrowRecordBatchWriter *writer) } GArrowRecordBatchStreamWriter * -garrow_record_batch_stream_writer_new_raw(std::shared_ptr *arrow_writer) +garrow_record_batch_stream_writer_new_raw(std::shared_ptr *arrow_writer) { auto writer = GARROW_RECORD_BATCH_STREAM_WRITER( @@ -540,7 +537,7 @@ garrow_record_batch_stream_writer_new_raw(std::shared_ptr *arrow_writer) +garrow_record_batch_file_writer_new_raw(std::shared_ptr *arrow_writer) { auto writer = GARROW_RECORD_BATCH_FILE_WRITER( diff --git a/c_glib/arrow-glib/writer.hpp b/c_glib/arrow-glib/writer.hpp index d57f69b657d..61d9d679dc3 100644 --- a/c_glib/arrow-glib/writer.hpp +++ b/c_glib/arrow-glib/writer.hpp @@ -28,9 +28,9 @@ GArrowRecordBatchWriter *garrow_record_batch_writer_new_raw(std::shared_ptr *arrow_writer); std::shared_ptr garrow_record_batch_writer_get_raw(GArrowRecordBatchWriter *writer); -GArrowRecordBatchStreamWriter *garrow_record_batch_stream_writer_new_raw(std::shared_ptr *arrow_writer); +GArrowRecordBatchStreamWriter *garrow_record_batch_stream_writer_new_raw(std::shared_ptr *arrow_writer); -GArrowRecordBatchFileWriter *garrow_record_batch_file_writer_new_raw(std::shared_ptr *arrow_writer); +GArrowRecordBatchFileWriter *garrow_record_batch_file_writer_new_raw(std::shared_ptr *arrow_writer); GArrowFeatherFileWriter *garrow_feather_file_writer_new_raw(arrow::ipc::feather::TableWriter *arrow_writer); arrow::ipc::feather::TableWriter *garrow_feather_file_writer_get_raw(GArrowFeatherFileWriter *writer); diff --git a/cpp/src/arrow/array.h b/cpp/src/arrow/array.h index eaabcf2f6b1..60ad599284a 100644 --- a/cpp/src/arrow/array.h +++ b/cpp/src/arrow/array.h @@ -17,7 +17,7 @@ #pragma once -#include +#include // IWYU pragma: export #include #include #include diff --git a/cpp/src/arrow/compute/kernel.h b/cpp/src/arrow/compute/kernel.h index 7dbdda736e4..1169eb4f924 100644 --- a/cpp/src/arrow/compute/kernel.h +++ b/cpp/src/arrow/compute/kernel.h @@ -273,15 +273,24 @@ class ARROW_EXPORT BinaryKernel : public OpKernel { Datum* out) = 0; }; +// TODO doxygen 1.8.16 does not like the following code +///@cond INTERNAL + static inline bool CollectionEquals(const std::vector& left, const std::vector& right) { - if (left.size() != right.size()) return false; - - for (size_t i = 0; i < left.size(); i++) - if (!left[i].Equals(right[i])) return false; + if (left.size() != right.size()) { + return false; + } + for (size_t i = 0; i < left.size(); i++) { + if (!left[i].Equals(right[i])) { + return false; + } + } return true; } +///@endcond + } // namespace compute } // namespace arrow diff --git a/cpp/src/arrow/dataset/file_ipc.cc b/cpp/src/arrow/dataset/file_ipc.cc index 957dec9cdf3..fbb970afec9 100644 --- a/cpp/src/arrow/dataset/file_ipc.cc +++ b/cpp/src/arrow/dataset/file_ipc.cc @@ -19,15 +19,14 @@ #include #include -#include #include #include #include "arrow/dataset/dataset_internal.h" #include "arrow/dataset/file_base.h" -#include "arrow/dataset/filter.h" #include "arrow/dataset/scanner.h" #include "arrow/ipc/reader.h" +#include "arrow/ipc/writer.h" #include "arrow/util/iterator.h" namespace arrow { @@ -40,12 +39,11 @@ Result> OpenReader( } std::shared_ptr reader; - auto status = ipc::RecordBatchFileReader::Open(std::move(input), &reader); + auto status = ipc::RecordBatchFileReader::Open(std::move(input)).Value(&reader); if (!status.ok()) { return status.WithMessage("Could not open IPC input source '", source.path(), "': ", status.message()); } - return reader; } @@ -161,10 +159,8 @@ Result> IpcFileFormat::WriteFragment( RETURN_NOT_OK(CreateDestinationParentDir()); ARROW_ASSIGN_OR_RAISE(auto out_stream, destination_.OpenWritable()); - - ARROW_ASSIGN_OR_RAISE(auto writer, ipc::RecordBatchFileWriter::Open( - out_stream.get(), fragment_->schema())); - + ARROW_ASSIGN_OR_RAISE(auto writer, + ipc::NewFileWriter(out_stream.get(), fragment_->schema())); ARROW_ASSIGN_OR_RAISE(auto scan_task_it, fragment_->Scan(scan_context_)); for (auto maybe_scan_task : scan_task_it) { diff --git a/cpp/src/arrow/dataset/file_ipc.h b/cpp/src/arrow/dataset/file_ipc.h index 5c7be4486d8..57c98f54d24 100644 --- a/cpp/src/arrow/dataset/file_ipc.h +++ b/cpp/src/arrow/dataset/file_ipc.h @@ -19,11 +19,11 @@ #include #include -#include #include "arrow/dataset/file_base.h" #include "arrow/dataset/type_fwd.h" #include "arrow/dataset/visibility.h" +#include "arrow/result.h" namespace arrow { namespace dataset { diff --git a/cpp/src/arrow/dataset/file_ipc_test.cc b/cpp/src/arrow/dataset/file_ipc_test.cc index 7a160c02062..f50a1a494a5 100644 --- a/cpp/src/arrow/dataset/file_ipc_test.cc +++ b/cpp/src/arrow/dataset/file_ipc_test.cc @@ -46,8 +46,7 @@ class ArrowIpcWriterMixin : public ::testing::Test { std::shared_ptr Write(RecordBatchReader* reader) { EXPECT_OK_AND_ASSIGN(auto sink, io::BufferOutputStream::Create()); - EXPECT_OK_AND_ASSIGN(auto writer, - ipc::RecordBatchFileWriter::Open(sink.get(), reader->schema())); + EXPECT_OK_AND_ASSIGN(auto writer, ipc::NewFileWriter(sink.get(), reader->schema())); std::vector> batches; ARROW_EXPECT_OK(reader->ReadAll(&batches)); @@ -63,9 +62,7 @@ class ArrowIpcWriterMixin : public ::testing::Test { std::shared_ptr Write(const Table& table) { EXPECT_OK_AND_ASSIGN(auto sink, io::BufferOutputStream::Create()); - - EXPECT_OK_AND_ASSIGN(auto writer, - ipc::RecordBatchFileWriter::Open(sink.get(), table.schema())); + EXPECT_OK_AND_ASSIGN(auto writer, ipc::NewFileWriter(sink.get(), table.schema())); ARROW_EXPECT_OK(writer->WriteTable(table)); diff --git a/cpp/src/arrow/extension_type_test.cc b/cpp/src/arrow/extension_type_test.cc index a26f57e667f..ed44706a658 100644 --- a/cpp/src/arrow/extension_type_test.cc +++ b/cpp/src/arrow/extension_type_test.cc @@ -218,14 +218,14 @@ TEST_F(TestExtensionType, ExtensionTypeTest) { auto RoundtripBatch = [](const std::shared_ptr& batch, std::shared_ptr* out) { ASSERT_OK_AND_ASSIGN(auto out_stream, io::BufferOutputStream::Create()); - ASSERT_OK(ipc::WriteRecordBatchStream({batch}, ipc::IpcOptions::Defaults(), + ASSERT_OK(ipc::WriteRecordBatchStream({batch}, ipc::IpcWriteOptions::Defaults(), out_stream.get())); ASSERT_OK_AND_ASSIGN(auto complete_ipc_stream, out_stream->Finish()); io::BufferReader reader(complete_ipc_stream); std::shared_ptr batch_reader; - ASSERT_OK(ipc::RecordBatchStreamReader::Open(&reader, &batch_reader)); + ASSERT_OK_AND_ASSIGN(batch_reader, ipc::RecordBatchStreamReader::Open(&reader)); ASSERT_OK(batch_reader->ReadNext(out)); }; @@ -256,7 +256,7 @@ TEST_F(TestExtensionType, UnrecognizedExtension) { // Write full IPC stream including schema, then unregister type, then read // and ensure that a plain instance of the storage type is created ASSERT_OK_AND_ASSIGN(auto out_stream, io::BufferOutputStream::Create()); - ASSERT_OK(ipc::WriteRecordBatchStream({batch}, ipc::IpcOptions::Defaults(), + ASSERT_OK(ipc::WriteRecordBatchStream({batch}, ipc::IpcWriteOptions::Defaults(), out_stream.get())); ASSERT_OK_AND_ASSIGN(auto complete_ipc_stream, out_stream->Finish()); @@ -270,7 +270,7 @@ TEST_F(TestExtensionType, UnrecognizedExtension) { io::BufferReader reader(complete_ipc_stream); std::shared_ptr batch_reader; - ASSERT_OK(ipc::RecordBatchStreamReader::Open(&reader, &batch_reader)); + ASSERT_OK_AND_ASSIGN(batch_reader, ipc::RecordBatchStreamReader::Open(&reader)); std::shared_ptr read_batch; ASSERT_OK(batch_reader->ReadNext(&read_batch)); CompareBatch(*batch_no_ext, *read_batch); diff --git a/cpp/src/arrow/flight/client.cc b/cpp/src/arrow/flight/client.cc index eae31a52553..97b73548b8e 100644 --- a/cpp/src/arrow/flight/client.cc +++ b/cpp/src/arrow/flight/client.cc @@ -265,7 +265,7 @@ class GrpcStreamReader : public FlightStreamReader { private: friend class GrpcIpcMessageReader; - std::unique_ptr batch_reader_; + std::shared_ptr batch_reader_; std::shared_ptr last_app_metadata_; std::shared_ptr rpc_; }; @@ -327,8 +327,8 @@ Status GrpcStreamReader::Open(std::unique_ptr rpc, out->get()->rpc_ = std::move(rpc); std::unique_ptr message_reader( new GrpcIpcMessageReader(out->get(), out->get()->rpc_, std::move(stream))); - return ipc::RecordBatchStreamReader::Open(std::move(message_reader), - &(*out)->batch_reader_); + return (ipc::RecordBatchStreamReader::Open(std::move(message_reader)) + .Value(&(*out)->batch_reader_)); } std::shared_ptr GrpcStreamReader::schema() const { @@ -385,9 +385,6 @@ class GrpcStreamWriter : public FlightStreamWriter { } return Status::OK(); } - void set_memory_pool(MemoryPool* pool) override { - batch_writer_->set_memory_pool(pool); - } Status Close() override { return batch_writer_->Close(); } private: diff --git a/cpp/src/arrow/flight/perf_server.cc b/cpp/src/arrow/flight/perf_server.cc index c899076029d..59a32cba52d 100644 --- a/cpp/src/arrow/flight/perf_server.cc +++ b/cpp/src/arrow/flight/perf_server.cc @@ -102,8 +102,8 @@ class PerfDataStream : public FlightDataStream { } else { records_sent_ += batch_length_; } - return ipc::internal::GetRecordBatchPayload( - *batch, ipc_options_, default_memory_pool(), &payload->ipc_message); + return ipc::internal::GetRecordBatchPayload(*batch, ipc_options_, + &payload->ipc_message); } private: @@ -114,7 +114,7 @@ class PerfDataStream : public FlightDataStream { int64_t records_sent_; std::shared_ptr schema_; ipc::DictionaryMemo dictionary_memo_; - ipc::IpcOptions ipc_options_; + ipc::IpcWriteOptions ipc_options_; std::shared_ptr batch_; ArrayVector arrays_; }; diff --git a/cpp/src/arrow/flight/server.cc b/cpp/src/arrow/flight/server.cc index 891b27e3311..014921c8518 100644 --- a/cpp/src/arrow/flight/server.cc +++ b/cpp/src/arrow/flight/server.cc @@ -150,7 +150,8 @@ class FlightMessageReaderImpl : public FlightMessageReader { Status Init() { message_reader_ = new FlightIpcMessageReader(reader_, &last_metadata_); return ipc::RecordBatchStreamReader::Open( - std::unique_ptr(message_reader_), &batch_reader_); + std::unique_ptr(message_reader_)) + .Value(&batch_reader_); } const FlightDescriptor& descriptor() const override { @@ -804,7 +805,9 @@ class RecordBatchStream::RecordBatchStreamImpl { RecordBatchStreamImpl(const std::shared_ptr& reader, MemoryPool* pool) - : pool_(pool), reader_(reader), ipc_options_(ipc::IpcOptions::Defaults()) {} + : reader_(reader), ipc_options_(ipc::IpcWriteOptions::Defaults()) { + ipc_options_.memory_pool = pool; + } std::shared_ptr schema() { return reader_->schema(); } @@ -828,7 +831,7 @@ class RecordBatchStream::RecordBatchStreamImpl { if (stage_ == Stage::DICTIONARY) { if (dictionary_index_ == static_cast(dictionaries_.size())) { stage_ = Stage::RECORD_BATCH; - return ipc::internal::GetRecordBatchPayload(*current_batch_, ipc_options_, pool_, + return ipc::internal::GetRecordBatchPayload(*current_batch_, ipc_options_, &payload->ipc_message); } else { return GetNextDictionary(payload); @@ -843,7 +846,7 @@ class RecordBatchStream::RecordBatchStreamImpl { payload->ipc_message.metadata = nullptr; return Status::OK(); } else { - return ipc::internal::GetRecordBatchPayload(*current_batch_, ipc_options_, pool_, + return ipc::internal::GetRecordBatchPayload(*current_batch_, ipc_options_, &payload->ipc_message); } } @@ -851,7 +854,7 @@ class RecordBatchStream::RecordBatchStreamImpl { private: Status GetNextDictionary(FlightPayload* payload) { const auto& it = dictionaries_[dictionary_index_++]; - return ipc::internal::GetDictionaryPayload(it.first, it.second, ipc_options_, pool_, + return ipc::internal::GetDictionaryPayload(it.first, it.second, ipc_options_, &payload->ipc_message); } @@ -864,10 +867,9 @@ class RecordBatchStream::RecordBatchStreamImpl { } Stage stage_ = Stage::NEW; - MemoryPool* pool_; std::shared_ptr reader_; ipc::DictionaryMemo dictionary_memo_; - ipc::IpcOptions ipc_options_; + ipc::IpcWriteOptions ipc_options_; std::shared_ptr current_batch_; std::vector>> dictionaries_; diff --git a/cpp/src/arrow/flight/types.cc b/cpp/src/arrow/flight/types.cc index 080a25b8d90..43ce6ddbf12 100644 --- a/cpp/src/arrow/flight/types.cc +++ b/cpp/src/arrow/flight/types.cc @@ -121,8 +121,7 @@ std::string FlightDescriptor::ToString() const { Status SchemaResult::GetSchema(ipc::DictionaryMemo* dictionary_memo, std::shared_ptr* out) const { io::BufferReader schema_reader(raw_schema_); - RETURN_NOT_OK(ipc::ReadSchema(&schema_reader, dictionary_memo, out)); - return Status::OK(); + return ipc::ReadSchema(&schema_reader, dictionary_memo).Value(out); } Status FlightDescriptor::SerializeToString(std::string* out) const { @@ -171,7 +170,7 @@ Status FlightInfo::GetSchema(ipc::DictionaryMemo* dictionary_memo, return Status::OK(); } io::BufferReader schema_reader(data_.schema); - RETURN_NOT_OK(ipc::ReadSchema(&schema_reader, dictionary_memo, &schema_)); + RETURN_NOT_OK(ipc::ReadSchema(&schema_reader, dictionary_memo).Value(&schema_)); reconstructed_schema_ = true; *out = schema_; return Status::OK(); diff --git a/cpp/src/arrow/gpu/cuda_arrow_ipc.cc b/cpp/src/arrow/gpu/cuda_arrow_ipc.cc index 2f4fa9fb756..7ffcae05478 100644 --- a/cpp/src/arrow/gpu/cuda_arrow_ipc.cc +++ b/cpp/src/arrow/gpu/cuda_arrow_ipc.cc @@ -72,9 +72,8 @@ Result> ReadRecordBatch( // Zero-copy read on device memory ipc::DictionaryMemo unused_memo; - std::shared_ptr batch; - RETURN_NOT_OK(ipc::ReadRecordBatch(*message, schema, &unused_memo, &batch)); - return batch; + return ipc::ReadRecordBatch(*message, schema, &unused_memo, + ipc::IpcReadOptions::Defaults()); } Status ReadRecordBatch(const std::shared_ptr& schema, diff --git a/cpp/src/arrow/gpu/cuda_benchmark.cc b/cpp/src/arrow/gpu/cuda_benchmark.cc index 267d64a1776..2787d103cc7 100644 --- a/cpp/src/arrow/gpu/cuda_benchmark.cc +++ b/cpp/src/arrow/gpu/cuda_benchmark.cc @@ -37,12 +37,12 @@ static void CudaBufferWriterBenchmark(benchmark::State& state, const int64_t tot const int64_t chunksize, const int64_t buffer_size) { CudaDeviceManager* manager; - ABORT_NOT_OK(CudaDeviceManager::GetInstance(&manager)); + ABORT_NOT_OK(CudaDeviceManager::Instance().Value(&manager)); std::shared_ptr context; - ABORT_NOT_OK(manager->GetContext(kGpuNumber, &context)); + ABORT_NOT_OK(manager->GetContext(kGpuNumber).Value(&context)); std::shared_ptr device_buffer; - ABORT_NOT_OK(context->Allocate(total_bytes, &device_buffer)); + ABORT_NOT_OK(context->Allocate(total_bytes).Value(&device_buffer)); CudaBufferWriter writer(device_buffer); if (buffer_size > 0) { diff --git a/cpp/src/arrow/gpu/cuda_test.cc b/cpp/src/arrow/gpu/cuda_test.cc index 120109eebd3..e5f6a7d2813 100644 --- a/cpp/src/arrow/gpu/cuda_test.cc +++ b/cpp/src/arrow/gpu/cuda_test.cc @@ -575,7 +575,9 @@ TEST_F(TestCudaArrowIpc, BasicWriteRead) { std::shared_ptr cpu_batch; io::BufferReader cpu_reader(host_buffer); ipc::DictionaryMemo unused_memo; - ASSERT_OK(ipc::ReadRecordBatch(batch->schema(), &unused_memo, &cpu_reader, &cpu_batch)); + ASSERT_OK_AND_ASSIGN( + cpu_batch, ipc::ReadRecordBatch(batch->schema(), &unused_memo, + ipc::IpcReadOptions::Defaults(), &cpu_reader)); CompareBatch(*batch, *cpu_batch); } diff --git a/cpp/src/arrow/ipc/dictionary.cc b/cpp/src/arrow/ipc/dictionary.cc index a639f136446..8d9aa33ec9c 100644 --- a/cpp/src/arrow/ipc/dictionary.cc +++ b/cpp/src/arrow/ipc/dictionary.cc @@ -18,10 +18,9 @@ #include "arrow/ipc/dictionary.h" #include -#include #include -#include #include +#include #include "arrow/array.h" #include "arrow/record_batch.h" @@ -103,15 +102,15 @@ Status DictionaryMemo::AddField(int64_t id, const std::shared_ptr& field) } } -Status DictionaryMemo::GetId(const Field& field, int64_t* id) const { - auto it = field_to_id_.find(&field); +Status DictionaryMemo::GetId(const Field* field, int64_t* id) const { + auto it = field_to_id_.find(field); if (it != field_to_id_.end()) { // Field recorded, return the id *id = it->second; return Status::OK(); } else { return Status::KeyError("Field with memory address ", - reinterpret_cast(&field), " not found"); + reinterpret_cast(field), " not found"); } } diff --git a/cpp/src/arrow/ipc/dictionary.h b/cpp/src/arrow/ipc/dictionary.h index d8432d0c350..633f21fa1cd 100644 --- a/cpp/src/arrow/ipc/dictionary.h +++ b/cpp/src/arrow/ipc/dictionary.h @@ -61,7 +61,7 @@ class ARROW_EXPORT DictionaryMemo { /// \brief Return id for dictionary if it exists, otherwise return /// KeyError - Status GetId(const Field& type, int64_t* id) const; + Status GetId(const Field* type, int64_t* id) const; /// \brief Return true if dictionary for type is in this memo bool HasDictionary(const Field& type) const; diff --git a/cpp/src/arrow/ipc/feather.cc b/cpp/src/arrow/ipc/feather.cc index 1d7ec19f9c7..9a324a07757 100644 --- a/cpp/src/arrow/ipc/feather.cc +++ b/cpp/src/arrow/ipc/feather.cc @@ -32,6 +32,7 @@ #include "arrow/io/interfaces.h" #include "arrow/ipc/feather_internal.h" #include "arrow/ipc/util.h" // IWYU pragma: keep +#include "arrow/result.h" #include "arrow/status.h" #include "arrow/table.h" // IWYU pragma: keep #include "arrow/type.h" diff --git a/cpp/src/arrow/ipc/file_to_stream.cc b/cpp/src/arrow/ipc/file_to_stream.cc index 91c885c73d8..292c193021c 100644 --- a/cpp/src/arrow/ipc/file_to_stream.cc +++ b/cpp/src/arrow/ipc/file_to_stream.cc @@ -34,14 +34,11 @@ namespace ipc { // Reads a file on the file system and prints to stdout the stream version of it. Status ConvertToStream(const char* path) { - std::shared_ptr reader; + io::StdoutStream sink; ARROW_ASSIGN_OR_RAISE(auto in_file, io::ReadableFile::Open(path)); - RETURN_NOT_OK(ipc::RecordBatchFileReader::Open(in_file.get(), &reader)); - - io::StdoutStream sink; - std::shared_ptr writer; - RETURN_NOT_OK(RecordBatchStreamWriter::Open(&sink, reader->schema(), &writer)); + ARROW_ASSIGN_OR_RAISE(auto reader, ipc::RecordBatchFileReader::Open(in_file.get())); + ARROW_ASSIGN_OR_RAISE(auto writer, ipc::NewStreamWriter(&sink, reader->schema())); for (int i = 0; i < reader->num_record_batches(); ++i) { std::shared_ptr chunk; RETURN_NOT_OK(reader->ReadRecordBatch(i, &chunk)); diff --git a/cpp/src/arrow/ipc/generate_fuzz_corpus.cc b/cpp/src/arrow/ipc/generate_fuzz_corpus.cc index 7526b1f77b5..3a876d052d5 100644 --- a/cpp/src/arrow/ipc/generate_fuzz_corpus.cc +++ b/cpp/src/arrow/ipc/generate_fuzz_corpus.cc @@ -94,12 +94,15 @@ Result>> Batches() { return batches; } -template Result> SerializeRecordBatch( - const std::shared_ptr& batch) { + const std::shared_ptr& batch, bool is_stream_format) { ARROW_ASSIGN_OR_RAISE(auto sink, io::BufferOutputStream::Create(1024)); - ARROW_ASSIGN_OR_RAISE(auto writer, - RecordBatchWriterClass::Open(sink.get(), batch->schema())); + std::shared_ptr writer; + if (is_stream_format) { + ARROW_ASSIGN_OR_RAISE(writer, NewStreamWriter(sink.get(), batch->schema())); + } else { + ARROW_ASSIGN_OR_RAISE(writer, NewFileWriter(sink.get(), batch->schema())); + } RETURN_NOT_OK(writer->WriteRecordBatch(*batch)); RETURN_NOT_OK(writer->Close()); return sink->Finish(); @@ -114,13 +117,11 @@ Status DoMain(bool is_stream_format, const std::string& out_dir) { return "batch-" + std::to_string(sample_num++); }; - auto serialize_func = is_stream_format ? SerializeRecordBatch - : SerializeRecordBatch; ARROW_ASSIGN_OR_RAISE(auto batches, Batches()); for (const auto& batch : batches) { RETURN_NOT_OK(batch->ValidateFull()); - ARROW_ASSIGN_OR_RAISE(auto buf, serialize_func(batch)); + ARROW_ASSIGN_OR_RAISE(auto buf, SerializeRecordBatch(batch, is_stream_format)); ARROW_ASSIGN_OR_RAISE(auto sample_fn, dir_fn.Join(sample_name())); std::cerr << sample_fn.ToString() << std::endl; ARROW_ASSIGN_OR_RAISE(auto file, io::FileOutputStream::Open(sample_fn.ToString())); diff --git a/cpp/src/arrow/ipc/json_integration.cc b/cpp/src/arrow/ipc/json_integration.cc index 1dcba204b72..17d3df1eb4d 100644 --- a/cpp/src/arrow/ipc/json_integration.cc +++ b/cpp/src/arrow/ipc/json_integration.cc @@ -18,20 +18,25 @@ #include "arrow/ipc/json_integration.h" #include +#include #include #include +#include -#include "arrow/array.h" #include "arrow/buffer.h" #include "arrow/io/file.h" #include "arrow/ipc/dictionary.h" #include "arrow/ipc/json_internal.h" -#include "arrow/memory_pool.h" #include "arrow/record_batch.h" +#include "arrow/result.h" #include "arrow/status.h" #include "arrow/type.h" #include "arrow/util/logging.h" +#include +#include +#include + using std::size_t; namespace arrow { diff --git a/cpp/src/arrow/ipc/json_integration_test.cc b/cpp/src/arrow/ipc/json_integration_test.cc index 304814248d8..fde90fddd3a 100644 --- a/cpp/src/arrow/ipc/json_integration_test.cc +++ b/cpp/src/arrow/ipc/json_integration_test.cc @@ -71,9 +71,7 @@ static Status ConvertJsonToArrow(const std::string& json_path, << reader->schema()->ToString(/* show_metadata = */ true) << std::endl; } - std::shared_ptr writer; - RETURN_NOT_OK(RecordBatchFileWriter::Open(out_file.get(), reader->schema(), &writer)); - + ARROW_ASSIGN_OR_RAISE(auto writer, NewFileWriter(out_file.get(), reader->schema())); for (int i = 0; i < reader->num_record_batches(); ++i) { std::shared_ptr batch; RETURN_NOT_OK(reader->ReadRecordBatch(i, &batch)); @@ -89,7 +87,7 @@ static Status ConvertArrowToJson(const std::string& arrow_path, ARROW_ASSIGN_OR_RAISE(auto out_file, io::FileOutputStream::Open(json_path)); std::shared_ptr reader; - RETURN_NOT_OK(RecordBatchFileReader::Open(in_file.get(), &reader)); + ARROW_ASSIGN_OR_RAISE(reader, RecordBatchFileReader::Open(in_file.get())); if (FLAGS_verbose) { std::cout << "Found schema:\n" << reader->schema()->ToString() << std::endl; @@ -124,7 +122,7 @@ static Status ValidateArrowVsJson(const std::string& arrow_path, ARROW_ASSIGN_OR_RAISE(auto arrow_file, io::ReadableFile::Open(arrow_path)); std::shared_ptr arrow_reader; - RETURN_NOT_OK(RecordBatchFileReader::Open(arrow_file.get(), &arrow_reader)); + ARROW_ASSIGN_OR_RAISE(arrow_reader, RecordBatchFileReader::Open(arrow_file.get())); auto json_schema = json_reader->schema(); auto arrow_schema = arrow_reader->schema(); diff --git a/cpp/src/arrow/ipc/json_internal.cc b/cpp/src/arrow/ipc/json_internal.cc index 9ec172896db..4e3b4508166 100644 --- a/cpp/src/arrow/ipc/json_internal.cc +++ b/cpp/src/arrow/ipc/json_internal.cc @@ -20,18 +20,18 @@ #include #include #include -#include #include #include -#include #include #include #include "arrow/array.h" #include "arrow/buffer.h" #include "arrow/builder.h" // IWYU pragma: keep +#include "arrow/extension_type.h" #include "arrow/ipc/dictionary.h" #include "arrow/record_batch.h" +#include "arrow/result.h" #include "arrow/status.h" #include "arrow/type.h" #include "arrow/type_traits.h" @@ -1432,7 +1432,7 @@ class ArrayReader { // Look up dictionary int64_t dictionary_id = -1; - RETURN_NOT_OK(dictionary_memo_->GetId(*field_, &dictionary_id)); + RETURN_NOT_OK(dictionary_memo_->GetId(field_.get(), &dictionary_id)); std::shared_ptr dictionary; RETURN_NOT_OK(dictionary_memo_->GetDictionary(dictionary_id, &dictionary)); diff --git a/cpp/src/arrow/ipc/json_internal.h b/cpp/src/arrow/ipc/json_internal.h index 1218f69e9a4..b792aec28a2 100644 --- a/cpp/src/arrow/ipc/json_internal.h +++ b/cpp/src/arrow/ipc/json_internal.h @@ -17,15 +17,18 @@ #pragma once +#include #include #include -#include "arrow/json/rapidjson_defs.h" -#include "rapidjson/document.h" // IWYU pragma: export -#include "rapidjson/encodings.h" // IWYU pragma: export -#include "rapidjson/error/en.h" // IWYU pragma: export -#include "rapidjson/stringbuffer.h" // IWYU pragma: export -#include "rapidjson/writer.h" // IWYU pragma: export +#include "arrow/json/rapidjson_defs.h" // IWYU pragma: keep + +#include // IWYU pragma: export +#include // IWYU pragma: export +#include // IWYU pragma: export +#include // IWYU pragma: export +#include // IWYU pragma: export +#include // IWYU pragma: export #include "arrow/status.h" // IWYU pragma: export #include "arrow/type_fwd.h" // IWYU pragma: keep @@ -72,6 +75,13 @@ using RjObject = rj::Value::ConstObject; } namespace arrow { + +class Array; +class Field; +class MemoryPool; +class RecordBatch; +class Schema; + namespace ipc { class DictionaryMemo; diff --git a/cpp/src/arrow/ipc/json_simple.cc b/cpp/src/arrow/ipc/json_simple.cc index c301417c43e..1972ae41da5 100644 --- a/cpp/src/arrow/ipc/json_simple.cc +++ b/cpp/src/arrow/ipc/json_simple.cc @@ -21,18 +21,23 @@ #include #include -#include "arrow/array.h" #include "arrow/builder.h" -#include "arrow/ipc/json_internal.h" #include "arrow/ipc/json_simple.h" -#include "arrow/memory_pool.h" #include "arrow/type_traits.h" #include "arrow/util/checked_cast.h" #include "arrow/util/decimal.h" -#include "arrow/util/logging.h" #include "arrow/util/parsing.h" #include "arrow/util/string_view.h" +#include "arrow/json/rapidjson_defs.h" + +#include +#include +#include +#include + +namespace rj = arrow::rapidjson; + namespace arrow { namespace ipc { namespace internal { diff --git a/cpp/src/arrow/ipc/message.cc b/cpp/src/arrow/ipc/message.cc index 4fc250ae722..6192138f2a3 100644 --- a/cpp/src/arrow/ipc/message.cc +++ b/cpp/src/arrow/ipc/message.cc @@ -18,17 +18,17 @@ #include "arrow/ipc/message.h" #include +#include #include #include -#include #include #include -#include - #include "arrow/buffer.h" +#include "arrow/device.h" #include "arrow/io/interfaces.h" #include "arrow/ipc/metadata_internal.h" +#include "arrow/ipc/options.h" #include "arrow/ipc/util.h" #include "arrow/status.h" #include "arrow/util/logging.h" @@ -37,6 +37,10 @@ #include "generated/Message_generated.h" namespace arrow { + +class KeyValueMetadata; +class MemoryPool; + namespace ipc { class Message::MessageImpl { @@ -54,6 +58,12 @@ class Message::MessageImpl { return Status::Invalid("Old metadata version not supported"); } + if (message_->custom_metadata() != nullptr) { + // Deserialize from Flatbuffers if first time called + RETURN_NOT_OK( + internal::GetKeyValueMetadata(message_->custom_metadata(), &custom_metadata_)); + } + return Status::OK(); } @@ -86,11 +96,18 @@ class Message::MessageImpl { std::shared_ptr metadata() const { return metadata_; } + const std::shared_ptr& custom_metadata() const { + return custom_metadata_; + } + private: // The Flatbuffer metadata std::shared_ptr metadata_; const flatbuf::Message* message_; + // The recontructed custom_metadata field from the Message Flatbuffer + std::shared_ptr custom_metadata_; + // The message body, if any std::shared_ptr body_; }; @@ -120,6 +137,10 @@ MetadataVersion Message::metadata_version() const { return impl_->version(); } const void* Message::header() const { return impl_->header(); } +const std::shared_ptr& Message::custom_metadata() const { + return impl_->custom_metadata(); +} + bool Message::Equals(const Message& other) const { int64_t metadata_bytes = std::min(metadata()->size(), other.metadata()->size()); @@ -200,7 +221,7 @@ Status WritePadding(io::OutputStream* stream, int64_t nbytes) { return Status::OK(); } -Status Message::SerializeTo(io::OutputStream* stream, const IpcOptions& options, +Status Message::SerializeTo(io::OutputStream* stream, const IpcWriteOptions& options, int64_t* output_length) const { int32_t metadata_length = 0; RETURN_NOT_OK(WriteMessage(*metadata(), options, stream, &metadata_length)); @@ -368,7 +389,7 @@ Result> ReadMessage(io::InputStream* file, MemoryPool* return DoReadMessage(file, pool); } -Status WriteMessage(const Buffer& message, const IpcOptions& options, +Status WriteMessage(const Buffer& message, const IpcWriteOptions& options, io::OutputStream* file, int32_t* message_length) { const int32_t prefix_size = options.write_legacy_ipc_format ? 4 : 8; const int32_t flatbuffer_size = static_cast(message.size()); diff --git a/cpp/src/arrow/ipc/message.h b/cpp/src/arrow/ipc/message.h index 94b0ba2bbd3..81cd58bfe91 100644 --- a/cpp/src/arrow/ipc/message.h +++ b/cpp/src/arrow/ipc/message.h @@ -23,16 +23,27 @@ #include #include -#include "arrow/io/type_fwd.h" -#include "arrow/ipc/options.h" +#include "arrow/result.h" #include "arrow/status.h" #include "arrow/type_fwd.h" #include "arrow/util/macros.h" #include "arrow/util/visibility.h" namespace arrow { + +namespace io { + +class FileInterface; +class InputStream; +class OutputStream; +class RandomAccessFile; + +} // namespace io + namespace ipc { +struct IpcWriteOptions; + enum class MetadataVersion : char { /// 0.1.0 V1, @@ -106,6 +117,10 @@ class ARROW_EXPORT Message { /// \return buffer std::shared_ptr metadata() const; + /// \brief Custom metadata serialized in metadata Flatbuffer. Returns nullptr + /// when none set + const std::shared_ptr& custom_metadata() const; + /// \brief the Message body, if any /// /// \return buffer is null if no body @@ -129,7 +144,7 @@ class ARROW_EXPORT Message { /// \param[in] options IPC writing options including alignment /// \param[out] output_length the number of bytes written /// \return Status - Status SerializeTo(io::OutputStream* file, const IpcOptions& options, + Status SerializeTo(io::OutputStream* file, const IpcWriteOptions& options, int64_t* output_length) const; /// \brief Return true if the Message metadata passes Flatbuffer validation @@ -244,7 +259,7 @@ Result> ReadMessage(io::InputStream* stream, /// \param[out] message_length the total size of the payload written including /// padding /// \return Status -Status WriteMessage(const Buffer& message, const IpcOptions& options, +Status WriteMessage(const Buffer& message, const IpcWriteOptions& options, io::OutputStream* file, int32_t* message_length); } // namespace ipc diff --git a/cpp/src/arrow/ipc/metadata_internal.cc b/cpp/src/arrow/ipc/metadata_internal.cc index b9c960b90aa..ac30467a291 100644 --- a/cpp/src/arrow/ipc/metadata_internal.cc +++ b/cpp/src/arrow/ipc/metadata_internal.cc @@ -25,25 +25,23 @@ #include -#include "arrow/array.h" #include "arrow/extension_type.h" #include "arrow/io/interfaces.h" +#include "arrow/ipc/dictionary.h" #include "arrow/ipc/message.h" -#include "arrow/ipc/util.h" #include "arrow/sparse_tensor.h" #include "arrow/status.h" -#include "arrow/tensor.h" #include "arrow/type.h" +#include "arrow/type_traits.h" #include "arrow/util/checked_cast.h" #include "arrow/util/key_value_metadata.h" -#include "arrow/util/logging.h" -#include "arrow/util/ubsan.h" #include "arrow/visitor_inline.h" -#include "generated/File_generated.h" // IWYU pragma: keep +#include "generated/File_generated.h" #include "generated/Message_generated.h" -#include "generated/SparseTensor_generated.h" // IWYU pragma: keep -#include "generated/Tensor_generated.h" // IWYU pragma: keep +#include "generated/Schema_generated.h" +#include "generated/SparseTensor_generated.h" +#include "generated/Tensor_generated.h" namespace arrow { @@ -56,12 +54,10 @@ namespace internal { using FBB = flatbuffers::FlatBufferBuilder; using DictionaryOffset = flatbuffers::Offset; using FieldOffset = flatbuffers::Offset; -using KeyValueOffset = flatbuffers::Offset; using RecordBatchOffset = flatbuffers::Offset; using SparseTensorOffset = flatbuffers::Offset; using Offset = flatbuffers::Offset; using FBString = flatbuffers::Offset; -using KVVector = flatbuffers::Vector; MetadataVersion GetMetadataVersion(flatbuf::MetadataVersion version) { switch (version) { @@ -461,22 +457,6 @@ void AppendKeyValueMetadata(FBB& fbb, const KeyValueMetadata& metadata, } } -Status KeyValueMetadataFromFlatbuffer(const KVVector* fb_metadata, - std::shared_ptr* out) { - auto metadata = std::make_shared(); - - metadata->reserve(fb_metadata->size()); - for (const auto& pair : *fb_metadata) { - CHECK_FLATBUFFERS_NOT_NULL(pair->key(), "custom_metadata.key"); - CHECK_FLATBUFFERS_NOT_NULL(pair->value(), "custom_metadata.value"); - metadata->Append(pair->key()->str(), pair->value()->str()); - } - - *out = metadata; - - return Status::OK(); -} - class FieldToFlatbufferVisitor { public: FieldToFlatbufferVisitor(FBB& fbb, DictionaryMemo* dictionary_memo) @@ -743,21 +723,12 @@ Status FieldToFlatbuffer(FBB& fbb, const std::shared_ptr& field, return field_visitor.GetResult(field, offset); } -Status GetFieldMetadata(const flatbuf::Field* field, - std::shared_ptr* metadata) { - auto fb_metadata = field->custom_metadata(); - if (fb_metadata != nullptr) { - RETURN_NOT_OK(KeyValueMetadataFromFlatbuffer(fb_metadata, metadata)); - } - return Status::OK(); -} - Status FieldFromFlatbuffer(const flatbuf::Field* field, DictionaryMemo* dictionary_memo, std::shared_ptr* out) { std::shared_ptr type; - std::shared_ptr metadata; - RETURN_NOT_OK(GetFieldMetadata(field, &metadata)); + std::shared_ptr metadata; + RETURN_NOT_OK(internal::GetKeyValueMetadata(field->custom_metadata(), &metadata)); // Reconstruct the data type auto children = field->children(); @@ -802,6 +773,18 @@ flatbuf::Endianness endianness() { return bint.c[0] == 1 ? flatbuf::Endianness::Big : flatbuf::Endianness::Little; } +flatbuffers::Offset SerializeCustomMetadata( + FBB& fbb, const std::shared_ptr& metadata) { + std::vector key_values; + if (metadata != nullptr) { + AppendKeyValueMetadata(fbb, *metadata, &key_values); + return fbb.CreateVector(key_values); + } else { + // null + return 0; + } +} + Status SchemaToFlatbuffer(FBB& fbb, const Schema& schema, DictionaryMemo* dictionary_memo, flatbuffers::Offset* out) { /// Fields @@ -813,26 +796,18 @@ Status SchemaToFlatbuffer(FBB& fbb, const Schema& schema, DictionaryMemo* dictio } auto fb_offsets = fbb.CreateVector(field_offsets); - - /// Custom metadata - auto metadata = schema.metadata(); - - flatbuffers::Offset fb_custom_metadata; - std::vector key_values; - if (metadata != nullptr) { - AppendKeyValueMetadata(fbb, *metadata, &key_values); - fb_custom_metadata = fbb.CreateVector(key_values); - } - *out = flatbuf::CreateSchema(fbb, endianness(), fb_offsets, fb_custom_metadata); + *out = flatbuf::CreateSchema(fbb, endianness(), fb_offsets, + SerializeCustomMetadata(fbb, schema.metadata())); return Status::OK(); } -Result> WriteFBMessage(FBB& fbb, - flatbuf::MessageHeader header_type, - flatbuffers::Offset header, - int64_t body_length) { - auto message = flatbuf::CreateMessage(fbb, kCurrentMetadataVersion, header_type, header, - body_length); +Result> WriteFBMessage( + FBB& fbb, flatbuf::MessageHeader header_type, flatbuffers::Offset header, + int64_t body_length, + const std::shared_ptr& custom_metadata = nullptr) { + auto message = + flatbuf::CreateMessage(fbb, kCurrentMetadataVersion, header_type, header, + body_length, SerializeCustomMetadata(fbb, custom_metadata)); fbb.Finish(message); return WriteFlatbufferBuilder(fbb); } @@ -1027,24 +1002,46 @@ Status MakeSparseTensor(FBB& fbb, const SparseTensor& sparse_tensor, int64_t bod } // namespace +Status GetKeyValueMetadata(const KVVector* fb_metadata, + std::shared_ptr* out) { + if (fb_metadata == nullptr) { + *out = nullptr; + return Status::OK(); + } + + auto metadata = std::make_shared(); + + metadata->reserve(fb_metadata->size()); + for (const auto& pair : *fb_metadata) { + CHECK_FLATBUFFERS_NOT_NULL(pair->key(), "custom_metadata.key"); + CHECK_FLATBUFFERS_NOT_NULL(pair->value(), "custom_metadata.value"); + metadata->Append(pair->key()->str(), pair->value()->str()); + } + + *out = std::move(metadata); + return Status::OK(); +} + Status WriteSchemaMessage(const Schema& schema, DictionaryMemo* dictionary_memo, std::shared_ptr* out) { FBB fbb; flatbuffers::Offset fb_schema; RETURN_NOT_OK(SchemaToFlatbuffer(fbb, schema, dictionary_memo, &fb_schema)); - return WriteFBMessage(fbb, flatbuf::MessageHeader::Schema, fb_schema.Union(), 0) + return WriteFBMessage(fbb, flatbuf::MessageHeader::Schema, fb_schema.Union(), + /*body_length=*/0) .Value(out); } -Status WriteRecordBatchMessage(int64_t length, int64_t body_length, - const std::vector& nodes, - const std::vector& buffers, - std::shared_ptr* out) { +Status WriteRecordBatchMessage( + int64_t length, int64_t body_length, + const std::shared_ptr& custom_metadata, + const std::vector& nodes, const std::vector& buffers, + std::shared_ptr* out) { FBB fbb; RecordBatchOffset record_batch; RETURN_NOT_OK(MakeRecordBatch(fbb, length, body_length, nodes, buffers, &record_batch)); return WriteFBMessage(fbb, flatbuf::MessageHeader::RecordBatch, record_batch.Union(), - body_length) + body_length, custom_metadata) .Value(out); } @@ -1093,16 +1090,17 @@ Result> WriteSparseTensorMessage( fb_sparse_tensor.Union(), body_length); } -Status WriteDictionaryMessage(int64_t id, int64_t length, int64_t body_length, - const std::vector& nodes, - const std::vector& buffers, - std::shared_ptr* out) { +Status WriteDictionaryMessage( + int64_t id, int64_t length, int64_t body_length, + const std::shared_ptr& custom_metadata, + const std::vector& nodes, const std::vector& buffers, + std::shared_ptr* out) { FBB fbb; RecordBatchOffset record_batch; RETURN_NOT_OK(MakeRecordBatch(fbb, length, body_length, nodes, buffers, &record_batch)); auto dictionary_batch = flatbuf::CreateDictionaryBatch(fbb, id, record_batch).Union(); return WriteFBMessage(fbb, flatbuf::MessageHeader::DictionaryBatch, dictionary_batch, - body_length) + body_length, custom_metadata) .Value(out); } @@ -1169,15 +1167,9 @@ Status GetSchema(const void* opaque_schema, DictionaryMemo* dictionary_memo, RETURN_NOT_OK(FieldFromFlatbuffer(field, dictionary_memo, &fields[i])); } - auto fb_metadata = schema->custom_metadata(); - std::shared_ptr metadata; - - if (fb_metadata != nullptr) { - RETURN_NOT_OK(KeyValueMetadataFromFlatbuffer(fb_metadata, &metadata)); - } - + std::shared_ptr metadata; + RETURN_NOT_OK(internal::GetKeyValueMetadata(schema->custom_metadata(), &metadata)); *out = ::arrow::schema(std::move(fields), metadata); - return Status::OK(); } diff --git a/cpp/src/arrow/ipc/metadata_internal.h b/cpp/src/arrow/ipc/metadata_internal.h index 7ed7f498aeb..327a86118a4 100644 --- a/cpp/src/arrow/ipc/metadata_internal.h +++ b/cpp/src/arrow/ipc/metadata_internal.h @@ -28,20 +28,26 @@ #include #include "arrow/buffer.h" -#include "arrow/ipc/dictionary.h" // IYWU pragma: keep #include "arrow/ipc/message.h" +#include "arrow/result.h" #include "arrow/sparse_tensor.h" #include "arrow/status.h" #include "arrow/type_fwd.h" #include "arrow/util/macros.h" +#include "arrow/util/visibility.h" #include "generated/Message_generated.h" #include "generated/Schema_generated.h" +#include "generated/SparseTensor_generated.h" // IWYU pragma: keep namespace arrow { namespace flatbuf = org::apache::arrow::flatbuf; +class DataType; +class KeyValueMetadata; +class Schema; + namespace io { class OutputStream; @@ -54,6 +60,9 @@ class DictionaryMemo; namespace internal { +using KeyValueOffset = flatbuffers::Offset; +using KVVector = flatbuffers::Vector; + // This 0xFFFFFFFF value is the first 4 bytes of a valid IPC message constexpr int32_t kIpcContinuationToken = -1; @@ -134,6 +143,9 @@ Status GetSparseTensorMetadata(const Buffer& metadata, std::shared_ptr std::vector* dim_names, int64_t* length, SparseTensorFormat::type* sparse_tensor_format_id); +Status GetKeyValueMetadata(const KVVector* fb_metadata, + std::shared_ptr* out); + static inline Status VerifyMessage(const uint8_t* data, int64_t size, const flatbuf::Message** out) { flatbuffers::Verifier verifier(data, size, /*max_depth=*/128); @@ -154,10 +166,13 @@ static inline Status VerifyMessage(const uint8_t* data, int64_t size, Status WriteSchemaMessage(const Schema& schema, DictionaryMemo* dictionary_memo, std::shared_ptr* out); -Status WriteRecordBatchMessage(const int64_t length, const int64_t body_length, - const std::vector& nodes, - const std::vector& buffers, - std::shared_ptr* out); +// This function is used in a unit test +ARROW_EXPORT +Status WriteRecordBatchMessage( + const int64_t length, const int64_t body_length, + const std::shared_ptr& custom_metadata, + const std::vector& nodes, const std::vector& buffers, + std::shared_ptr* out); Result> WriteTensorMessage(const Tensor& tensor, const int64_t buffer_start_offset); @@ -170,11 +185,11 @@ Status WriteFileFooter(const Schema& schema, const std::vector& dicti const std::vector& record_batches, io::OutputStream* out); -Status WriteDictionaryMessage(const int64_t id, const int64_t length, - const int64_t body_length, - const std::vector& nodes, - const std::vector& buffers, - std::shared_ptr* out); +Status WriteDictionaryMessage( + const int64_t id, const int64_t length, const int64_t body_length, + const std::shared_ptr& custom_metadata, + const std::vector& nodes, const std::vector& buffers, + std::shared_ptr* out); static inline Result> WriteFlatbufferBuilder( flatbuffers::FlatBufferBuilder& fbb) { diff --git a/cpp/src/arrow/ipc/options.cc b/cpp/src/arrow/ipc/options.cc index a5714f38aec..933c7db9ad3 100644 --- a/cpp/src/arrow/ipc/options.cc +++ b/cpp/src/arrow/ipc/options.cc @@ -20,7 +20,9 @@ namespace arrow { namespace ipc { -IpcOptions IpcOptions::Defaults() { return IpcOptions(); } +IpcWriteOptions IpcWriteOptions::Defaults() { return IpcWriteOptions(); } + +IpcReadOptions IpcReadOptions::Defaults() { return IpcReadOptions(); } } // namespace ipc } // namespace arrow diff --git a/cpp/src/arrow/ipc/options.h b/cpp/src/arrow/ipc/options.h index 3570c06eb43..3281ba50378 100644 --- a/cpp/src/arrow/ipc/options.h +++ b/cpp/src/arrow/ipc/options.h @@ -18,10 +18,17 @@ #pragma once #include +#include +#include "arrow/type_fwd.h" +#include "arrow/util/compression.h" +#include "arrow/util/optional.h" #include "arrow/util/visibility.h" namespace arrow { + +class MemoryPool; + namespace ipc { // ARROW-109: We set this number arbitrarily to help catch user mistakes. For @@ -29,7 +36,8 @@ namespace ipc { // maximum allowed recursion depth constexpr int kMaxNestingDepth = 64; -struct ARROW_EXPORT IpcOptions { +/// \brief Options for writing Arrow IPC messages +struct ARROW_EXPORT IpcWriteOptions { // If true, allow field lengths that don't fit in a signed 32-bit int. // Some implementations may not be able to parse such streams. bool allow_64bit = false; @@ -44,7 +52,34 @@ struct ARROW_EXPORT IpcOptions { /// consisting of a 4-byte prefix instead of 8 byte bool write_legacy_ipc_format = false; - static IpcOptions Defaults(); + /// \brief The memory pool to use for allocations made during IPC writing + MemoryPool* memory_pool = default_memory_pool(); + + /// \brief EXPERIMENTAL: Codec to use for compressing and decompressing + /// record batch body buffers. This is not part of the Arrow IPC protocol and + /// only for internal use (e.g. Feather files) + Compression::type compression = Compression::UNCOMPRESSED; + int compression_level = Compression::kUseDefaultCompressionLevel; + + static IpcWriteOptions Defaults(); +}; + +#ifndef ARROW_NO_DEPRECATED_API +using IpcOptions = IpcWriteOptions; +#endif + +struct ARROW_EXPORT IpcReadOptions { + // The maximum permitted schema nesting depth. + int max_recursion_depth = kMaxNestingDepth; + + /// \brief The memory pool to use for allocations made during IPC writing + MemoryPool* memory_pool = default_memory_pool(); + + /// \brief EXPERIMENTAL: Top-level schema fields to include when + /// deserializing RecordBatch. If null, return all deserialized fields + util::optional> included_fields; + + static IpcReadOptions Defaults(); }; } // namespace ipc diff --git a/cpp/src/arrow/ipc/read_write_benchmark.cc b/cpp/src/arrow/ipc/read_write_benchmark.cc index d2b1d99d592..cad10cb0d58 100644 --- a/cpp/src/arrow/ipc/read_write_benchmark.cc +++ b/cpp/src/arrow/ipc/read_write_benchmark.cc @@ -50,7 +50,7 @@ std::shared_ptr MakeRecordBatch(int64_t total_size, int64_t num_fie static void WriteRecordBatch(benchmark::State& state) { // NOLINT non-const reference // 1MB constexpr int64_t kTotalSize = 1 << 20; - auto options = ipc::IpcOptions::Defaults(); + auto options = ipc::IpcWriteOptions::Defaults(); std::shared_ptr buffer; ABORT_NOT_OK(AllocateResizableBuffer(kTotalSize & 2, &buffer)); @@ -61,7 +61,7 @@ static void WriteRecordBatch(benchmark::State& state) { // NOLINT non-const ref int32_t metadata_length; int64_t body_length; if (!ipc::WriteRecordBatch(*record_batch, 0, &stream, &metadata_length, &body_length, - options, default_memory_pool()) + options) .ok()) { state.SkipWithError("Failed to write!"); } @@ -72,7 +72,7 @@ static void WriteRecordBatch(benchmark::State& state) { // NOLINT non-const ref static void ReadRecordBatch(benchmark::State& state) { // NOLINT non-const reference // 1MB constexpr int64_t kTotalSize = 1 << 20; - auto options = ipc::IpcOptions::Defaults(); + auto options = ipc::IpcWriteOptions::Defaults(); std::shared_ptr buffer; ABORT_NOT_OK(AllocateResizableBuffer(kTotalSize & 2, &buffer)); @@ -83,17 +83,16 @@ static void ReadRecordBatch(benchmark::State& state) { // NOLINT non-const refe int32_t metadata_length; int64_t body_length; if (!ipc::WriteRecordBatch(*record_batch, 0, &stream, &metadata_length, &body_length, - options, default_memory_pool()) + options) .ok()) { state.SkipWithError("Failed to write!"); } ipc::DictionaryMemo empty_memo; while (state.KeepRunning()) { - std::shared_ptr result; io::BufferReader reader(buffer); - - if (!ipc::ReadRecordBatch(record_batch->schema(), &empty_memo, &reader, &result) + if (!ipc::ReadRecordBatch(record_batch->schema(), &empty_memo, + ipc::IpcReadOptions::Defaults(), &reader) .ok()) { state.SkipWithError("Failed to read!"); } diff --git a/cpp/src/arrow/ipc/read_write_test.cc b/cpp/src/arrow/ipc/read_write_test.cc index 4d9fd04ca3e..1adf9229a8d 100644 --- a/cpp/src/arrow/ipc/read_write_test.cc +++ b/cpp/src/arrow/ipc/read_write_test.cc @@ -21,6 +21,7 @@ #include #include #include +#include #include #include @@ -103,7 +104,7 @@ TEST(TestMessage, SerializeTo) { ASSERT_OK(Message::Open(metadata, std::make_shared(body), &message)); auto CheckWithAlignment = [&](int32_t alignment) { - IpcOptions options; + IpcWriteOptions options; options.alignment = alignment; const int32_t prefix_size = 8; int64_t output_length = 0; @@ -118,6 +119,26 @@ TEST(TestMessage, SerializeTo) { CheckWithAlignment(64); } +TEST(TestMessage, SerializeCustomMetadata) { + std::vector> cases = { + nullptr, key_value_metadata({}, {}), + key_value_metadata({"foo", "bar"}, {"fizz", "buzz"})}; + for (auto metadata : cases) { + std::shared_ptr serialized; + std::unique_ptr message; + ASSERT_OK(internal::WriteRecordBatchMessage(/*length=*/0, /*body_length=*/0, metadata, + /*nodes=*/{}, + /*buffers=*/{}, &serialized)); + ASSERT_OK(Message::Open(serialized, /*body=*/nullptr, &message)); + + if (metadata) { + ASSERT_TRUE(message->custom_metadata()->Equals(*metadata)); + } else { + ASSERT_EQ(nullptr, message->custom_metadata()); + } + } +} + void BuffersOverlapEquals(const Buffer& left, const Buffer& right) { ASSERT_GT(left.size(), 0); ASSERT_GT(right.size(), 0); @@ -128,12 +149,11 @@ TEST(TestMessage, LegacyIpcBackwardsCompatibility) { std::shared_ptr batch; ASSERT_OK(MakeIntBatchSized(36, &batch)); - auto RoundtripWithOptions = [&](const IpcOptions& arg_options, + auto RoundtripWithOptions = [&](const IpcWriteOptions& arg_options, std::shared_ptr* out_serialized, std::unique_ptr* out) { internal::IpcPayload payload; - ASSERT_OK(internal::GetRecordBatchPayload(*batch, arg_options, default_memory_pool(), - &payload)); + ASSERT_OK(internal::GetRecordBatchPayload(*batch, arg_options, &payload)); ASSERT_OK_AND_ASSIGN(auto stream, io::BufferOutputStream::Create(1 << 20)); @@ -149,7 +169,7 @@ TEST(TestMessage, LegacyIpcBackwardsCompatibility) { std::shared_ptr serialized, legacy_serialized; std::unique_ptr message, legacy_message; - IpcOptions options; + IpcWriteOptions options; RoundtripWithOptions(options, &serialized, &message); // First 4 bytes 0xFFFFFFFF Continuation marker @@ -183,10 +203,9 @@ class TestSchemaMetadata : public ::testing::Test { DictionaryMemo in_memo, out_memo; ASSERT_OK(SerializeSchema(schema, &out_memo, default_memory_pool(), &buffer)); - std::shared_ptr result; io::BufferReader reader(buffer); - ASSERT_OK(ReadSchema(&reader, &in_memo, &result)); - AssertSchemaEqual(schema, *result); + ASSERT_OK_AND_ASSIGN(auto actual_schema, ReadSchema(&reader, &in_memo)); + AssertSchemaEqual(schema, *actual_schema); } }; @@ -262,29 +281,30 @@ static int g_file_number = 0; class IpcTestFixture : public io::MemoryMapFixture { public: - void SetUp() { - pool_ = default_memory_pool(); - options_ = IpcOptions::Defaults(); - } + void SetUp() { options_ = IpcWriteOptions::Defaults(); } void DoSchemaRoundTrip(const Schema& schema, DictionaryMemo* out_memo, std::shared_ptr* result) { std::shared_ptr serialized_schema; - ASSERT_OK(SerializeSchema(schema, out_memo, pool_, &serialized_schema)); + ASSERT_OK( + SerializeSchema(schema, out_memo, options_.memory_pool, &serialized_schema)); DictionaryMemo in_memo; io::BufferReader buf_reader(serialized_schema); - ASSERT_OK(ReadSchema(&buf_reader, &in_memo, result)); + ASSERT_OK_AND_ASSIGN(*result, ReadSchema(&buf_reader, &in_memo)); ASSERT_EQ(out_memo->num_fields(), in_memo.num_fields()); } - Status DoStandardRoundTrip(const RecordBatch& batch, DictionaryMemo* dictionary_memo, - std::shared_ptr* batch_result) { + Status DoStandardRoundTrip(const RecordBatch& batch, const IpcWriteOptions& options, + DictionaryMemo* dictionary_memo, + std::shared_ptr* result) { std::shared_ptr serialized_batch; - RETURN_NOT_OK(SerializeRecordBatch(batch, pool_, &serialized_batch)); + RETURN_NOT_OK(SerializeRecordBatch(batch, options, &serialized_batch)); io::BufferReader buf_reader(serialized_batch); - return ReadRecordBatch(batch.schema(), dictionary_memo, &buf_reader, batch_result); + return ReadRecordBatch(batch.schema(), dictionary_memo, IpcReadOptions::Defaults(), + &buf_reader) + .Value(result); } Status DoLargeRoundTrip(const RecordBatch& batch, bool zero_data, @@ -297,16 +317,15 @@ class IpcTestFixture : public io::MemoryMapFixture { auto options = options_; options.allow_64bit = true; - auto res = RecordBatchFileWriter::Open(mmap_.get(), batch.schema(), options); - RETURN_NOT_OK(res.status()); - std::shared_ptr file_writer = *res; + ARROW_ASSIGN_OR_RAISE(auto file_writer, + NewFileWriter(mmap_.get(), batch.schema(), options)); RETURN_NOT_OK(file_writer->WriteRecordBatch(batch)); RETURN_NOT_OK(file_writer->Close()); ARROW_ASSIGN_OR_RAISE(int64_t offset, mmap_->Tell()); std::shared_ptr file_reader; - RETURN_NOT_OK(RecordBatchFileReader::Open(mmap_.get(), offset, &file_reader)); + ARROW_ASSIGN_OR_RAISE(file_reader, RecordBatchFileReader::Open(mmap_.get(), offset)); return file_reader->ReadRecordBatch(0, result); } @@ -322,7 +341,9 @@ class IpcTestFixture : public io::MemoryMapFixture { CompareBatchColumnsDetailed(result, expected); } - void CheckRoundtrip(const RecordBatch& batch, int64_t buffer_size) { + void CheckRoundtrip(const RecordBatch& batch, + IpcWriteOptions options = IpcWriteOptions::Defaults(), + int64_t buffer_size = 1 << 20) { std::stringstream ss; ss << "test-write-row-batch-" << g_file_number++; ASSERT_OK_AND_ASSIGN(mmap_, @@ -337,26 +358,26 @@ class IpcTestFixture : public io::MemoryMapFixture { ASSERT_OK(CollectDictionaries(batch, &dictionary_memo)); std::shared_ptr result; - ASSERT_OK(DoStandardRoundTrip(batch, &dictionary_memo, &result)); + ASSERT_OK(DoStandardRoundTrip(batch, options, &dictionary_memo, &result)); CheckReadResult(*result, batch); ASSERT_OK(DoLargeRoundTrip(batch, /*zero_data=*/true, &result)); CheckReadResult(*result, batch); } - - void CheckRoundtrip(const std::shared_ptr& array, int64_t buffer_size) { + void CheckRoundtrip(const std::shared_ptr& array, + IpcWriteOptions options = IpcWriteOptions::Defaults(), + int64_t buffer_size = 1 << 20) { auto f0 = arrow::field("f0", array->type()); std::vector> fields = {f0}; auto schema = std::make_shared(fields); auto batch = RecordBatch::Make(schema, 0, {array}); - CheckRoundtrip(*batch, buffer_size); + CheckRoundtrip(*batch, options, buffer_size); } protected: std::shared_ptr mmap_; - MemoryPool* pool_; - IpcOptions options_; + IpcWriteOptions options_; }; class TestWriteRecordBatch : public ::testing::Test, public IpcTestFixture { @@ -376,7 +397,7 @@ TEST_P(TestIpcRoundTrip, RoundTrip) { std::shared_ptr batch; ASSERT_OK((*GetParam())(&batch)); // NOLINT clang-tidy gtest issue - CheckRoundtrip(*batch, 1 << 20); + CheckRoundtrip(*batch); } TEST_F(TestIpcRoundTrip, MetadataVersion) { @@ -392,7 +413,7 @@ TEST_F(TestIpcRoundTrip, MetadataVersion) { const int64_t buffer_offset = 0; ASSERT_OK(WriteRecordBatch(*batch, buffer_offset, mmap_.get(), &metadata_length, - &body_length, options_, pool_)); + &body_length, options_)); std::unique_ptr message; ASSERT_OK(ReadMessage(0, metadata_length, mmap_.get(), &message)); @@ -422,7 +443,7 @@ TEST_P(TestIpcRoundTrip, SliceRoundTrip) { } auto sliced_batch = batch->Slice(2, 10); - CheckRoundtrip(*sliced_batch, 1 << 20); + CheckRoundtrip(*sliced_batch); } TEST_P(TestIpcRoundTrip, ZeroLengthArrays) { @@ -436,11 +457,11 @@ TEST_P(TestIpcRoundTrip, ZeroLengthArrays) { zero_length_batch = batch->Slice(0, 0); } - CheckRoundtrip(*zero_length_batch, 1 << 20); + CheckRoundtrip(*zero_length_batch); // ARROW-544: check binary array std::shared_ptr value_offsets; - ASSERT_OK(AllocateBuffer(pool_, sizeof(int32_t), &value_offsets)); + ASSERT_OK(AllocateBuffer(options_.memory_pool, sizeof(int32_t), &value_offsets)); *reinterpret_cast(value_offsets->mutable_data()) = 0; std::shared_ptr bin_array = std::make_shared( @@ -450,8 +471,44 @@ TEST_P(TestIpcRoundTrip, ZeroLengthArrays) { // null value_offsets std::shared_ptr bin_array2 = std::make_shared(0, nullptr, nullptr); - CheckRoundtrip(bin_array, 1 << 20); - CheckRoundtrip(bin_array2, 1 << 20); + CheckRoundtrip(bin_array); + CheckRoundtrip(bin_array2); +} + +TEST_F(TestWriteRecordBatch, WriteWithCompression) { + random::RandomArrayGenerator rg(/*seed=*/0); + + // Generate both regular and dictionary encoded because the dictionary batch + // gets compressed also + + int64_t length = 500; + + int dict_size = 50; + std::shared_ptr dict = rg.String(dict_size, /*min_length=*/5, /*max_length=*/5, + /*null_probability=*/0); + std::shared_ptr indices = rg.Int32(length, /*min=*/0, /*max=*/dict_size - 1, + /*null_probability=*/0.1); + + auto dict_type = dictionary(int32(), utf8()); + auto dict_field = field("f1", dict_type); + std::shared_ptr dict_array; + ASSERT_OK(DictionaryArray::FromArrays(dict_type, indices, dict, &dict_array)); + + auto schema = ::arrow::schema({field("f0", utf8()), dict_field}); + auto batch = + RecordBatch::Make(schema, length, {rg.String(500, 0, 10, 0.1), dict_array}); + + std::vector codecs = {Compression::GZIP, Compression::LZ4, + Compression::ZSTD, Compression::SNAPPY, + Compression::BROTLI}; + for (auto codec : codecs) { + if (!util::Codec::IsAvailable(codec)) { + return; + } + IpcWriteOptions options = IpcWriteOptions::Defaults(); + options.compression = codec; + CheckRoundtrip(*batch, options); + } } TEST_F(TestWriteRecordBatch, SliceTruncatesBinaryOffsets) { @@ -470,7 +527,8 @@ TEST_F(TestWriteRecordBatch, SliceTruncatesBinaryOffsets) { mmap_, io::MemoryMapFixture::InitMemoryMap(/*buffer_size=*/1 << 20, ss.str())); DictionaryMemo dictionary_memo; std::shared_ptr result; - ASSERT_OK(DoStandardRoundTrip(*sliced_batch, &dictionary_memo, &result)); + ASSERT_OK(DoStandardRoundTrip(*sliced_batch, IpcWriteOptions::Defaults(), + &dictionary_memo, &result)); ASSERT_EQ(6 * sizeof(int32_t), result->column(0)->data()->buffers[1]->size()); } @@ -489,7 +547,7 @@ TEST_F(TestWriteRecordBatch, SliceTruncatesBuffers) { ASSERT_TRUE(sliced_size < full_size) << sliced_size << " " << full_size; // make sure we can write and read it - this->CheckRoundtrip(*sliced_batch, 1 << 20); + this->CheckRoundtrip(*sliced_batch); }; std::shared_ptr a0, a1; @@ -557,7 +615,8 @@ TEST_F(TestWriteRecordBatch, RoundtripPreservesBufferSizes) { mmap_, io::MemoryMapFixture::InitMemoryMap(/*buffer_size=*/1 << 20, ss.str())); DictionaryMemo dictionary_memo; std::shared_ptr result; - ASSERT_OK(DoStandardRoundTrip(*batch, &dictionary_memo, &result)); + ASSERT_OK(DoStandardRoundTrip(*batch, IpcWriteOptions::Defaults(), &dictionary_memo, + &result)); // Make sure that the validity bitmap is size 2 as expected ASSERT_EQ(2, arr->data()->buffers[0]->size()); @@ -568,14 +627,14 @@ TEST_F(TestWriteRecordBatch, RoundtripPreservesBufferSizes) { } } -void TestGetRecordBatchSize(const IpcOptions& options, +void TestGetRecordBatchSize(const IpcWriteOptions& options, std::shared_ptr batch) { io::MockOutputStream mock; int32_t mock_metadata_length = -1; int64_t mock_body_length = -1; int64_t size = -1; ASSERT_OK(WriteRecordBatch(*batch, 0, &mock, &mock_metadata_length, &mock_body_length, - options, default_memory_pool())); + options)); ASSERT_OK(GetRecordBatchSize(*batch, &size)); ASSERT_EQ(mock.GetExtentBytesWritten(), size); } @@ -630,12 +689,12 @@ class RecursionLimits : public ::testing::Test, public io::MemoryMapFixture { ARROW_ASSIGN_OR_RAISE(mmap_, io::MemoryMapFixture::InitMemoryMap(memory_map_size, ss.str())); - auto options = IpcOptions::Defaults(); + auto options = IpcWriteOptions::Defaults(); if (override_level) { options.max_recursion_depth = recursion_level + 1; } return WriteRecordBatch(**batch, 0, mmap_.get(), metadata_length, body_length, - options, pool_); + options); } protected: @@ -669,9 +728,8 @@ TEST_F(RecursionLimits, ReadLimit) { io::BufferReader reader(message->body()); DictionaryMemo empty_memo; - std::shared_ptr result; ASSERT_RAISES(Invalid, ReadRecordBatch(*message->metadata(), schema, &empty_memo, - &reader, &result)); + IpcReadOptions::Defaults(), &reader)); } // Test fails with a structured exception on Windows + Debug @@ -690,12 +748,12 @@ TEST_F(RecursionLimits, StressLimit) { DictionaryMemo empty_memo; - auto options = IpcOptions::Defaults(); + auto options = IpcReadOptions::Defaults(); options.max_recursion_depth = recursion_depth + 1; io::BufferReader reader(message->body()); std::shared_ptr result; - ASSERT_OK(ReadRecordBatch(*message->metadata(), schema, &empty_memo, options, &reader, - &result)); + ASSERT_OK_AND_ASSIGN(result, ReadRecordBatch(*message->metadata(), schema, + &empty_memo, options, &reader)); *it_works = result->Equals(*batch); }; @@ -712,13 +770,12 @@ TEST_F(RecursionLimits, StressLimit) { #endif // !defined(_WIN32) || defined(NDEBUG) struct FileWriterHelper { - Status Init(const std::shared_ptr& schema, const IpcOptions& options) { + Status Init(const std::shared_ptr& schema, const IpcWriteOptions& options) { num_batches_written_ = 0; RETURN_NOT_OK(AllocateResizableBuffer(0, &buffer_)); sink_.reset(new io::BufferOutputStream(buffer_)); - ARROW_ASSIGN_OR_RAISE(writer_, - RecordBatchFileWriter::Open(sink_.get(), schema, options)); + ARROW_ASSIGN_OR_RAISE(writer_, NewFileWriter(sink_.get(), schema, options)); return Status::OK(); } @@ -735,10 +792,11 @@ struct FileWriterHelper { return sink_->Tell().Value(&footer_offset_); } - Status ReadBatches(BatchVector* out_batches) { + Status ReadBatches(const IpcReadOptions& options, BatchVector* out_batches) { auto buf_reader = std::make_shared(buffer_); std::shared_ptr reader; - RETURN_NOT_OK(RecordBatchFileReader::Open(buf_reader.get(), footer_offset_, &reader)); + ARROW_ASSIGN_OR_RAISE( + reader, RecordBatchFileReader::Open(buf_reader.get(), footer_offset_, options)); EXPECT_EQ(num_batches_written_, reader->num_record_batches()); for (int i = 0; i < num_batches_written_; ++i) { @@ -752,8 +810,8 @@ struct FileWriterHelper { Status ReadSchema(std::shared_ptr* out) { auto buf_reader = std::make_shared(buffer_); - std::shared_ptr reader; - RETURN_NOT_OK(RecordBatchFileReader::Open(buf_reader.get(), footer_offset_, &reader)); + ARROW_ASSIGN_OR_RAISE(auto reader, + RecordBatchFileReader::Open(buf_reader.get(), footer_offset_)); *out = reader->schema(); return Status::OK(); @@ -767,11 +825,10 @@ struct FileWriterHelper { }; struct StreamWriterHelper { - Status Init(const std::shared_ptr& schema, const IpcOptions& options) { + Status Init(const std::shared_ptr& schema, const IpcWriteOptions& options) { RETURN_NOT_OK(AllocateResizableBuffer(0, &buffer_)); sink_.reset(new io::BufferOutputStream(buffer_)); - ARROW_ASSIGN_OR_RAISE(writer_, - RecordBatchStreamWriter::Open(sink_.get(), schema, options)); + ARROW_ASSIGN_OR_RAISE(writer_, NewStreamWriter(sink_.get(), schema, options)); return Status::OK(); } @@ -785,17 +842,16 @@ struct StreamWriterHelper { return sink_->Close(); } - Status ReadBatches(BatchVector* out_batches) { + Status ReadBatches(const IpcReadOptions& options, BatchVector* out_batches) { auto buf_reader = std::make_shared(buffer_); std::shared_ptr reader; - RETURN_NOT_OK(RecordBatchStreamReader::Open(buf_reader, &reader)); + ARROW_ASSIGN_OR_RAISE(reader, RecordBatchStreamReader::Open(buf_reader, options)) return reader->ReadAll(out_batches); } Status ReadSchema(std::shared_ptr* out) { auto buf_reader = std::make_shared(buffer_); - std::shared_ptr reader; - RETURN_NOT_OK(RecordBatchStreamReader::Open(buf_reader.get(), &reader)); + ARROW_ASSIGN_OR_RAISE(auto reader, RecordBatchStreamReader::Open(buf_reader.get())); *out = reader->schema(); return Status::OK(); @@ -806,7 +862,7 @@ struct StreamWriterHelper { std::shared_ptr writer_; }; -// Parameterized mixin with tests for RecordBatchStreamWriter / RecordBatchFileWriter +// Parameterized mixin with tests for stream / file writer template class ReaderWriterMixin { @@ -815,7 +871,7 @@ class ReaderWriterMixin { // Check simple RecordBatch roundtripping template - void TestRoundTrip(Param&& param, const IpcOptions& options) { + void TestRoundTrip(Param&& param, const IpcWriteOptions& options) { std::shared_ptr batch1; std::shared_ptr batch2; ASSERT_OK(param(&batch1)); // NOLINT clang-tidy gtest issue @@ -824,7 +880,8 @@ class ReaderWriterMixin { BatchVector in_batches = {batch1, batch2}; BatchVector out_batches; - ASSERT_OK(RoundTripHelper(in_batches, options, &out_batches)); + ASSERT_OK( + RoundTripHelper(in_batches, options, IpcReadOptions::Defaults(), &out_batches)); ASSERT_EQ(out_batches.size(), in_batches.size()); // Compare batches @@ -834,7 +891,7 @@ class ReaderWriterMixin { } template - void TestZeroLengthRoundTrip(Param&& param, const IpcOptions& options) { + void TestZeroLengthRoundTrip(Param&& param, const IpcWriteOptions& options) { std::shared_ptr batch1; std::shared_ptr batch2; ASSERT_OK(param(&batch1)); // NOLINT clang-tidy gtest issue @@ -845,7 +902,8 @@ class ReaderWriterMixin { BatchVector in_batches = {batch1, batch2}; BatchVector out_batches; - ASSERT_OK(RoundTripHelper(in_batches, options, &out_batches)); + ASSERT_OK( + RoundTripHelper(in_batches, options, IpcReadOptions::Defaults(), &out_batches)); ASSERT_EQ(out_batches.size(), in_batches.size()); // Compare batches @@ -859,7 +917,8 @@ class ReaderWriterMixin { ASSERT_OK(MakeDictionary(&batch)); BatchVector out_batches; - ASSERT_OK(RoundTripHelper({batch}, IpcOptions::Defaults(), &out_batches)); + ASSERT_OK(RoundTripHelper({batch}, IpcWriteOptions::Defaults(), + IpcReadOptions::Defaults(), &out_batches)); ASSERT_EQ(out_batches.size(), 1); // TODO(wesm): This was broken in ARROW-3144. I'm not sure how to @@ -870,6 +929,40 @@ class ReaderWriterMixin { // CheckDictionariesDeduplicated(*out_batches[0]); } + void TestReadSubsetOfFields() { + // Part of ARROW-7979 + auto a0 = ArrayFromJSON(utf8(), "[\"a0\", null]"); + auto a1 = ArrayFromJSON(utf8(), "[\"a1\", null]"); + auto a2 = ArrayFromJSON(utf8(), "[\"a2\", null]"); + auto a3 = ArrayFromJSON(utf8(), "[\"a3\", null]"); + + auto my_schema = schema({field("a0", utf8()), field("a1", utf8()), + field("a2", utf8()), field("a3", utf8())}, + key_value_metadata({"key1"}, {"value1"})); + auto batch = RecordBatch::Make(my_schema, a0->length(), {a0, a1, a2, a3}); + + IpcReadOptions options = IpcReadOptions::Defaults(); + + options.included_fields = {1, 3}; + + BatchVector out_batches; + ASSERT_OK( + RoundTripHelper({batch}, IpcWriteOptions::Defaults(), options, &out_batches)); + + auto ex_schema = schema({field("a1", utf8()), field("a3", utf8())}, + key_value_metadata({"key1"}, {"value1"})); + auto ex_batch = RecordBatch::Make(ex_schema, a0->length(), {a1, a3}); + AssertBatchesEqual(*ex_batch, *out_batches[0], /*check_metadata=*/true); + + // Out of bounds cases + options.included_fields = {1, 3, 5}; + ASSERT_RAISES(Invalid, RoundTripHelper({batch}, IpcWriteOptions::Defaults(), options, + &out_batches)); + options.included_fields = {1, 3, -1}; + ASSERT_RAISES(Invalid, RoundTripHelper({batch}, IpcWriteOptions::Defaults(), options, + &out_batches)); + } + void TestWriteDifferentSchema() { // Test writing batches with a different schema than the RecordBatchWriter // was initialized with. @@ -882,7 +975,7 @@ class ReaderWriterMixin { schema = schema->WithMetadata(key_value_metadata({"some_key"}, {"some_value"})); WriterHelper writer_helper; - ASSERT_OK(writer_helper.Init(schema, IpcOptions::Defaults())); + ASSERT_OK(writer_helper.Init(schema, IpcWriteOptions::Defaults())); // Writing a record batch with a different schema ASSERT_RAISES(Invalid, writer_helper.WriteBatch(batch_ints)); // Writing a record batch with the same schema (except metadata) @@ -891,7 +984,7 @@ class ReaderWriterMixin { // The single successful batch can be read again BatchVector out_batches; - ASSERT_OK(writer_helper.ReadBatches(&out_batches)); + ASSERT_OK(writer_helper.ReadBatches(IpcReadOptions::Defaults(), &out_batches)); ASSERT_EQ(out_batches.size(), 1); CompareBatch(*out_batches[0], *batch_bools, false /* compare_metadata */); // Metadata from the RecordBatchWriter initialization schema was kept @@ -903,11 +996,11 @@ class ReaderWriterMixin { auto schema = arrow::schema({field("a", int32())}); WriterHelper writer_helper; - ASSERT_OK(writer_helper.Init(schema, IpcOptions::Defaults())); + ASSERT_OK(writer_helper.Init(schema, IpcWriteOptions::Defaults())); ASSERT_OK(writer_helper.Finish()); BatchVector out_batches; - ASSERT_OK(writer_helper.ReadBatches(&out_batches)); + ASSERT_OK(writer_helper.ReadBatches(IpcReadOptions::Defaults(), &out_batches)); ASSERT_EQ(out_batches.size(), 0); std::shared_ptr actual_schema; @@ -916,15 +1009,16 @@ class ReaderWriterMixin { } private: - Status RoundTripHelper(const BatchVector& in_batches, const IpcOptions& options, - BatchVector* out_batches) { + Status RoundTripHelper(const BatchVector& in_batches, + const IpcWriteOptions& write_options, + const IpcReadOptions& read_options, BatchVector* out_batches) { WriterHelper writer_helper; - RETURN_NOT_OK(writer_helper.Init(in_batches[0]->schema(), options)); + RETURN_NOT_OK(writer_helper.Init(in_batches[0]->schema(), write_options)); for (const auto& batch : in_batches) { RETURN_NOT_OK(writer_helper.WriteBatch(batch)); } RETURN_NOT_OK(writer_helper.Finish()); - RETURN_NOT_OK(writer_helper.ReadBatches(out_batches)); + RETURN_NOT_OK(writer_helper.ReadBatches(read_options, out_batches)); for (const auto& batch : *out_batches) { RETURN_NOT_OK(batch->ValidateFull()); } @@ -954,20 +1048,20 @@ class TestStreamFormat : public ReaderWriterMixin, public ::testing::TestWithParam {}; TEST_P(TestFileFormat, RoundTrip) { - TestRoundTrip(*GetParam(), IpcOptions::Defaults()); - TestZeroLengthRoundTrip(*GetParam(), IpcOptions::Defaults()); + TestRoundTrip(*GetParam(), IpcWriteOptions::Defaults()); + TestZeroLengthRoundTrip(*GetParam(), IpcWriteOptions::Defaults()); - IpcOptions options; + IpcWriteOptions options; options.write_legacy_ipc_format = true; TestRoundTrip(*GetParam(), options); TestZeroLengthRoundTrip(*GetParam(), options); } TEST_P(TestStreamFormat, RoundTrip) { - TestRoundTrip(*GetParam(), IpcOptions::Defaults()); - TestZeroLengthRoundTrip(*GetParam(), IpcOptions::Defaults()); + TestRoundTrip(*GetParam(), IpcWriteOptions::Defaults()); + TestZeroLengthRoundTrip(*GetParam(), IpcWriteOptions::Defaults()); - IpcOptions options; + IpcWriteOptions options; options.write_legacy_ipc_format = true; TestRoundTrip(*GetParam(), options); TestZeroLengthRoundTrip(*GetParam(), options); @@ -1022,6 +1116,10 @@ TEST_F(TestStreamFormat, NoRecordBatches) { TestWriteNoRecordBatches(); } TEST_F(TestFileFormat, NoRecordBatches) { TestWriteNoRecordBatches(); } +TEST_F(TestStreamFormat, ReadFieldSubset) { TestReadSubsetOfFields(); } + +TEST_F(TestFileFormat, ReadFieldSubset) { TestReadSubsetOfFields(); } + TEST(TestRecordBatchStreamReader, EmptyStreamWithDictionaries) { // ARROW-6006 auto f0 = arrow::field("f0", arrow::dictionary(arrow::int8(), arrow::utf8())); @@ -1029,14 +1127,13 @@ TEST(TestRecordBatchStreamReader, EmptyStreamWithDictionaries) { ASSERT_OK_AND_ASSIGN(auto stream, io::BufferOutputStream::Create(0)); - std::shared_ptr writer; - ASSERT_OK(RecordBatchStreamWriter::Open(stream.get(), schema, &writer)); + ASSERT_OK_AND_ASSIGN(auto writer, NewStreamWriter(stream.get(), schema)); ASSERT_OK(writer->Close()); ASSERT_OK_AND_ASSIGN(auto buffer, stream->Finish()); io::BufferReader buffer_reader(buffer); std::shared_ptr reader; - ASSERT_OK(RecordBatchStreamReader::Open(&buffer_reader, &reader)); + ASSERT_OK_AND_ASSIGN(reader, RecordBatchStreamReader::Open(&buffer_reader)); std::shared_ptr batch; ASSERT_OK(reader->ReadNext(&batch)); @@ -1069,7 +1166,7 @@ void SpliceMessages(std::shared_ptr stream, continue; } - IpcOptions options; + IpcWriteOptions options; internal::IpcPayload payload; payload.type = msg->type(); payload.metadata = msg->metadata(); @@ -1088,8 +1185,7 @@ TEST(TestRecordBatchStreamReader, NotEnoughDictionaries) { ASSERT_OK(MakeDictionaryFlat(&batch)); ASSERT_OK_AND_ASSIGN(auto out, io::BufferOutputStream::Create(0)); - std::shared_ptr writer; - ASSERT_OK(RecordBatchStreamWriter::Open(out.get(), batch->schema(), &writer)); + ASSERT_OK_AND_ASSIGN(auto writer, NewStreamWriter(out.get(), batch->schema())); ASSERT_OK(writer->WriteRecordBatch(*batch)); ASSERT_OK(writer->Close()); @@ -1099,8 +1195,7 @@ TEST(TestRecordBatchStreamReader, NotEnoughDictionaries) { auto AssertFailsWith = [](std::shared_ptr stream, const std::string& ex_error) { io::BufferReader reader(stream); - std::shared_ptr ipc_reader; - ASSERT_OK(RecordBatchStreamReader::Open(&reader, &ipc_reader)); + ASSERT_OK_AND_ASSIGN(auto ipc_reader, RecordBatchStreamReader::Open(&reader)); std::shared_ptr batch; Status s = ipc_reader->ReadNext(&batch); ASSERT_TRUE(s.IsInvalid()); @@ -1471,13 +1566,11 @@ TEST(TestRecordBatchStreamReader, MalformedInput) { auto empty = std::make_shared(empty_str); auto garbage = std::make_shared(garbage_str); - std::shared_ptr batch_reader; - io::BufferReader empty_reader(empty); - ASSERT_RAISES(Invalid, RecordBatchStreamReader::Open(&empty_reader, &batch_reader)); + ASSERT_RAISES(Invalid, RecordBatchStreamReader::Open(&empty_reader)); io::BufferReader garbage_reader(garbage); - ASSERT_RAISES(Invalid, RecordBatchStreamReader::Open(&garbage_reader, &batch_reader)); + ASSERT_RAISES(Invalid, RecordBatchStreamReader::Open(&garbage_reader)); } // ---------------------------------------------------------------------- @@ -1511,10 +1604,10 @@ TEST(TestDictionaryMemo, ReusedDictionaries) { ASSERT_TRUE(memo.HasDictionary(*field2)); int64_t returned_id = -1; - ASSERT_OK(memo.GetId(*field1, &returned_id)); + ASSERT_OK(memo.GetId(field1.get(), &returned_id)); ASSERT_EQ(0, returned_id); returned_id = -1; - ASSERT_OK(memo.GetId(*field2, &returned_id)); + ASSERT_OK(memo.GetId(field2.get(), &returned_id)); ASSERT_EQ(0, returned_id); } diff --git a/cpp/src/arrow/ipc/reader.cc b/cpp/src/arrow/ipc/reader.cc index 293c8f123f1..1675bac989b 100644 --- a/cpp/src/arrow/ipc/reader.cc +++ b/cpp/src/arrow/ipc/reader.cc @@ -17,9 +17,9 @@ #include "arrow/ipc/reader.h" +#include #include #include -#include #include #include #include @@ -29,23 +29,30 @@ #include "arrow/array.h" #include "arrow/buffer.h" +#include "arrow/extension_type.h" #include "arrow/io/interfaces.h" #include "arrow/io/memory.h" -#include "arrow/ipc/dictionary.h" #include "arrow/ipc/message.h" #include "arrow/ipc/metadata_internal.h" +#include "arrow/ipc/writer.h" #include "arrow/record_batch.h" #include "arrow/sparse_tensor.h" #include "arrow/status.h" -#include "arrow/tensor.h" #include "arrow/type.h" #include "arrow/type_traits.h" +#include "arrow/util/bit_util.h" +#include "arrow/util/checked_cast.h" +#include "arrow/util/compression.h" +#include "arrow/util/key_value_metadata.h" #include "arrow/util/logging.h" +#include "arrow/util/optional.h" +#include "arrow/util/ubsan.h" #include "arrow/visitor_inline.h" #include "generated/File_generated.h" // IWYU pragma: export #include "generated/Message_generated.h" #include "generated/Schema_generated.h" +#include "generated/SparseTensor_generated.h" namespace arrow { @@ -94,11 +101,98 @@ Status InvalidMessageType(Message::Type expected, Message::Type actual) { // ---------------------------------------------------------------------- // Record batch read path -/// Accessor class for flatbuffers metadata -class IpcComponentSource { +/// The field_index and buffer_index are incremented based on how much of the +/// batch is "consumed" (through nested data reconstruction, for example) +class ArrayLoader { public: - IpcComponentSource(const flatbuf::RecordBatch* metadata, io::RandomAccessFile* file) - : metadata_(metadata), file_(file) {} + explicit ArrayLoader(const flatbuf::RecordBatch* metadata, + const DictionaryMemo* dictionary_memo, + const IpcReadOptions& options, Compression::type compression, + io::RandomAccessFile* file) + : metadata_(metadata), + file_(file), + dictionary_memo_(dictionary_memo), + options_(options), + compression_(compression), + max_recursion_depth_(options.max_recursion_depth) {} + + Status ReadBuffer(int64_t offset, int64_t length, std::shared_ptr* out) { + if (skip_io_) { + return Status::OK(); + } + // This construct permits overriding GetBuffer at compile time + if (!BitUtil::IsMultipleOf8(offset)) { + return Status::Invalid("Buffer ", buffer_index_, + " did not start on 8-byte aligned offset: ", offset); + } + return file_->ReadAt(offset, length).Value(out); + } + + Status LoadType(const DataType& type) { return VisitTypeInline(type, this); } + + Status DecompressBuffers() { + if (skip_io_) { + return Status::OK(); + } + std::unique_ptr codec; + ARROW_ASSIGN_OR_RAISE(codec, util::Codec::Create(compression_)); + + // TODO: Consider strategies to enable columns to decompress in parallel + for (size_t i = 0; i < out_->buffers.size(); ++i) { + if (out_->buffers[i] == nullptr) { + continue; + } + if (out_->buffers[i]->size() == 0) { + continue; + } + const uint8_t* data = out_->buffers[i]->data(); + int64_t compressed_size = out_->buffers[i]->size() - sizeof(int64_t); + int64_t uncompressed_size = util::SafeLoadAs(data); + + std::shared_ptr uncompressed; + RETURN_NOT_OK( + AllocateBuffer(options_.memory_pool, uncompressed_size, &uncompressed)); + + int64_t actual_decompressed; + ARROW_ASSIGN_OR_RAISE( + actual_decompressed, + codec->Decompress(compressed_size, data + sizeof(int64_t), uncompressed_size, + uncompressed->mutable_data())); + if (actual_decompressed != uncompressed_size) { + return Status::Invalid("Failed to fully decompress buffer, expected ", + uncompressed_size, " bytes but decompressed ", + actual_decompressed); + } + out_->buffers[i] = uncompressed; + } + return Status::OK(); + } + + Status Load(const Field* field, ArrayData* out) { + if (max_recursion_depth_ <= 0) { + return Status::Invalid("Max recursion depth reached"); + } + + field_ = field; + out_ = out; + out_->type = field_->type(); + RETURN_NOT_OK(LoadType(*field_->type())); + + // If the buffers are indicated to be compressed, instantiate the codec and + // decompress them + if (compression_ != Compression::UNCOMPRESSED) { + RETURN_NOT_OK(DecompressBuffers()); + } + return Status::OK(); + } + + Status SkipField(const Field* field) { + ArrayData dummy; + skip_io_ = true; + Status status = Load(field, &dummy); + skip_io_ = false; + return status; + } Status GetBuffer(int buffer_index, std::shared_ptr* out) { auto buffers = metadata_->buffers(); @@ -107,18 +201,12 @@ class IpcComponentSource { return Status::IOError("buffer_index out of range."); } const flatbuf::Buffer* buffer = buffers->Get(buffer_index); - if (buffer->length() == 0) { // Should never return a null buffer here. // (zero-sized buffer allocations are cheap) return AllocateBuffer(0, out); } else { - if (!BitUtil::IsMultipleOf8(buffer->offset())) { - return Status::Invalid( - "Buffer ", buffer_index, - " did not start on 8-byte aligned offset: ", buffer->offset()); - } - return file_->ReadAt(buffer->offset(), buffer->length()).Value(out); + return ReadBuffer(buffer->offset(), buffer->length(), out); } } @@ -137,58 +225,19 @@ class IpcComponentSource { return Status::OK(); } - private: - const flatbuf::RecordBatch* metadata_; - io::RandomAccessFile* file_; -}; - -/// Bookkeeping struct for loading array objects from their constituent pieces of raw data -/// -/// The field_index and buffer_index are incremented in the ArrayLoader -/// based on how much of the batch is "consumed" (through nested data -/// reconstruction, for example) -struct ArrayLoaderContext { - IpcComponentSource* source; - const DictionaryMemo* dictionary_memo; - int buffer_index; - int field_index; - int max_recursion_depth; -}; - -static Status LoadArray(const Field& field, ArrayLoaderContext* context, ArrayData* out); - -class ArrayLoader { - public: - ArrayLoader(const Field& field, ArrayData* out, ArrayLoaderContext* context) - : field_(field), context_(context), out_(out) {} - - Status Load() { - if (context_->max_recursion_depth <= 0) { - return Status::Invalid("Max recursion depth reached"); - } - - RETURN_NOT_OK(VisitTypeInline(*field_.type(), this)); - out_->type = field_.type(); - return Status::OK(); - } - - Status GetBuffer(int buffer_index, std::shared_ptr* out) { - return context_->source->GetBuffer(buffer_index, out); - } - Status LoadCommon() { // This only contains the length and null count, which we need to figure // out what to do with the buffers. For example, if null_count == 0, then // we can skip that buffer without reading from shared memory - RETURN_NOT_OK(context_->source->GetFieldMetadata(context_->field_index++, out_)); + RETURN_NOT_OK(GetFieldMetadata(field_index_++, out_)); // extract null_bitmap which is common to all arrays if (out_->null_count == 0) { out_->buffers[0] = nullptr; } else { - RETURN_NOT_OK(GetBuffer(context_->buffer_index, &out_->buffers[0])); + RETURN_NOT_OK(GetBuffer(buffer_index_, &out_->buffers[0])); } - context_->buffer_index++; + buffer_index_++; return Status::OK(); } @@ -198,9 +247,9 @@ class ArrayLoader { RETURN_NOT_OK(LoadCommon()); if (out_->length > 0) { - RETURN_NOT_OK(GetBuffer(context_->buffer_index++, &out_->buffers[1])); + RETURN_NOT_OK(GetBuffer(buffer_index_++, &out_->buffers[1])); } else { - context_->buffer_index++; + buffer_index_++; out_->buffers[1].reset(new Buffer(nullptr, 0)); } return Status::OK(); @@ -211,8 +260,8 @@ class ArrayLoader { out_->buffers.resize(3); RETURN_NOT_OK(LoadCommon()); - RETURN_NOT_OK(GetBuffer(context_->buffer_index++, &out_->buffers[1])); - return GetBuffer(context_->buffer_index++, &out_->buffers[2]); + RETURN_NOT_OK(GetBuffer(buffer_index_++, &out_->buffers[1])); + return GetBuffer(buffer_index_++, &out_->buffers[2]); } template @@ -220,7 +269,7 @@ class ArrayLoader { out_->buffers.resize(2); RETURN_NOT_OK(LoadCommon()); - RETURN_NOT_OK(GetBuffer(context_->buffer_index++, &out_->buffers[1])); + RETURN_NOT_OK(GetBuffer(buffer_index_++, &out_->buffers[1])); const int num_children = type.num_children(); if (num_children != 1) { @@ -230,22 +279,17 @@ class ArrayLoader { return LoadChildren(type.children()); } - Status LoadChild(const Field& field, ArrayData* out) { - ArrayLoader loader(field, out, context_); - --context_->max_recursion_depth; - RETURN_NOT_OK(loader.Load()); - ++context_->max_recursion_depth; - return Status::OK(); - } - Status LoadChildren(std::vector> child_fields) { - out_->child_data.reserve(static_cast(child_fields.size())); - + ArrayData* parent = out_; + parent->child_data.reserve(static_cast(child_fields.size())); for (const auto& child_field : child_fields) { auto field_array = std::make_shared(); - RETURN_NOT_OK(LoadChild(*child_field, field_array.get())); - out_->child_data.emplace_back(field_array); + --max_recursion_depth_; + RETURN_NOT_OK(Load(child_field.get(), field_array.get())); + ++max_recursion_depth_; + parent->child_data.emplace_back(field_array); } + out_ = parent; return Status::OK(); } @@ -253,8 +297,7 @@ class ArrayLoader { out_->buffers.resize(1); // ARROW-6379: NullType has no buffers in the IPC payload - RETURN_NOT_OK(context_->source->GetFieldMetadata(context_->field_index++, out_)); - return Status::OK(); + return GetFieldMetadata(field_index_++, out_); } template @@ -274,7 +317,7 @@ class ArrayLoader { Status Visit(const FixedSizeBinaryType& type) { out_->buffers.resize(2); RETURN_NOT_OK(LoadCommon()); - return GetBuffer(context_->buffer_index++, &out_->buffers[1]); + return GetBuffer(buffer_index_++, &out_->buffers[1]); } template @@ -311,102 +354,149 @@ class ArrayLoader { RETURN_NOT_OK(LoadCommon()); if (out_->length > 0) { - RETURN_NOT_OK(GetBuffer(context_->buffer_index, &out_->buffers[1])); + RETURN_NOT_OK(GetBuffer(buffer_index_, &out_->buffers[1])); if (type.mode() == UnionMode::DENSE) { - RETURN_NOT_OK(GetBuffer(context_->buffer_index + 1, &out_->buffers[2])); + RETURN_NOT_OK(GetBuffer(buffer_index_ + 1, &out_->buffers[2])); } } - context_->buffer_index += type.mode() == UnionMode::DENSE ? 2 : 1; + buffer_index_ += type.mode() == UnionMode::DENSE ? 2 : 1; return LoadChildren(type.children()); } Status Visit(const DictionaryType& type) { - RETURN_NOT_OK( - LoadArray(*::arrow::field("indices", type.index_type()), context_, out_)); + RETURN_NOT_OK(LoadType(*type.index_type())); // Look up dictionary int64_t id = -1; - RETURN_NOT_OK(context_->dictionary_memo->GetId(field_, &id)); - RETURN_NOT_OK(context_->dictionary_memo->GetDictionary(id, &out_->dictionary)); + RETURN_NOT_OK(dictionary_memo_->GetId(field_, &id)); + RETURN_NOT_OK(dictionary_memo_->GetDictionary(id, &out_->dictionary)); return Status::OK(); } - Status Visit(const ExtensionType& type) { - return LoadArray(*::arrow::field("storage", type.storage_type()), context_, out_); - } + Status Visit(const ExtensionType& type) { return LoadType(*type.storage_type()); } private: - const Field& field_; - ArrayLoaderContext* context_; - - // Used in visitor pattern + const flatbuf::RecordBatch* metadata_; + io::RandomAccessFile* file_; + const DictionaryMemo* dictionary_memo_; + const IpcReadOptions& options_; + Compression::type compression_; + int max_recursion_depth_; + int buffer_index_ = 0; + int field_index_ = 0; + bool skip_io_ = false; + + const Field* field_; ArrayData* out_; }; -static Status LoadArray(const Field& field, ArrayLoaderContext* context, ArrayData* out) { - ArrayLoader loader(field, out, context); - return loader.Load(); -} +Result> LoadRecordBatchSubset( + const flatbuf::RecordBatch* metadata, const std::shared_ptr& schema, + const std::vector& inclusion_mask, const DictionaryMemo* dictionary_memo, + const IpcReadOptions& options, Compression::type compression, + io::RandomAccessFile* file) { + ArrayLoader loader(metadata, dictionary_memo, options, compression, file); -Status ReadRecordBatch(const Buffer& metadata, const std::shared_ptr& schema, - const DictionaryMemo* dictionary_memo, io::RandomAccessFile* file, - std::shared_ptr* out) { - auto options = IpcOptions::Defaults(); - return ReadRecordBatch(metadata, schema, dictionary_memo, options, file, out); -} + std::vector> field_data; + std::vector> schema_fields; -Status ReadRecordBatch(const Message& message, const std::shared_ptr& schema, - const DictionaryMemo* dictionary_memo, - std::shared_ptr* out) { - CHECK_MESSAGE_TYPE(Message::RECORD_BATCH, message.type()); - CHECK_HAS_BODY(message); - auto options = IpcOptions::Defaults(); - ARROW_ASSIGN_OR_RAISE(auto reader, Buffer::GetReader(message.body())); - return ReadRecordBatch(*message.metadata(), schema, dictionary_memo, options, - reader.get(), out); -} + for (int i = 0; i < schema->num_fields(); ++i) { + if (inclusion_mask[i]) { + // Read field + auto arr = std::make_shared(); + RETURN_NOT_OK(loader.Load(schema->field(i).get(), arr.get())); + if (metadata->length() != arr->length) { + return Status::IOError("Array length did not match record batch length"); + } + field_data.emplace_back(std::move(arr)); + schema_fields.emplace_back(schema->field(i)); + } else { + // Skip field. This logic must be executed to advance the state of the + // loader to the next field + RETURN_NOT_OK(loader.SkipField(schema->field(i).get())); + } + } -// ---------------------------------------------------------------------- -// Array loading + return RecordBatch::Make(::arrow::schema(std::move(schema_fields), schema->metadata()), + metadata->length(), std::move(field_data)); +} -static Status LoadRecordBatchFromSource(const std::shared_ptr& schema, - int64_t num_rows, int max_recursion_depth, - IpcComponentSource* source, - const DictionaryMemo* dictionary_memo, - std::shared_ptr* out) { - ArrayLoaderContext context{source, dictionary_memo, /*field_index=*/0, - /*buffer_index=*/0, max_recursion_depth}; +Result> LoadRecordBatch( + const flatbuf::RecordBatch* metadata, const std::shared_ptr& schema, + const std::vector& inclusion_mask, const DictionaryMemo* dictionary_memo, + const IpcReadOptions& options, Compression::type compression, + io::RandomAccessFile* file) { + if (inclusion_mask.size() > 0) { + return LoadRecordBatchSubset(metadata, schema, inclusion_mask, dictionary_memo, + options, compression, file); + } + ArrayLoader loader(metadata, dictionary_memo, options, compression, file); std::vector> arrays(schema->num_fields()); for (int i = 0; i < schema->num_fields(); ++i) { auto arr = std::make_shared(); - RETURN_NOT_OK(LoadArray(*schema->field(i), &context, arr.get())); - if (num_rows != arr->length) { + RETURN_NOT_OK(loader.Load(schema->field(i).get(), arr.get())); + if (metadata->length() != arr->length) { return Status::IOError("Array length did not match record batch length"); } arrays[i] = std::move(arr); } + return RecordBatch::Make(schema, metadata->length(), std::move(arrays)); +} + +// ---------------------------------------------------------------------- +// Array loading + +Status GetCompression(const flatbuf::Message* message, Compression::type* out) { + *out = Compression::UNCOMPRESSED; + if (message->custom_metadata() != nullptr) { + // TODO: Ensure this deserialization only ever happens once + std::shared_ptr metadata; + RETURN_NOT_OK(internal::GetKeyValueMetadata(message->custom_metadata(), &metadata)); + int index = metadata->FindKey("ARROW:body_compression"); + if (index != -1) { + ARROW_ASSIGN_OR_RAISE(*out, + util::Codec::GetCompressionType(metadata->value(index))); + } + } + return Status::OK(); +} - *out = RecordBatch::Make(schema, num_rows, std::move(arrays)); +static Status ReadContiguousPayload(io::InputStream* file, + std::unique_ptr* message) { + RETURN_NOT_OK(ReadMessage(file, message)); + if (*message == nullptr) { + return Status::Invalid("Unable to read metadata at offset"); + } return Status::OK(); } -static inline Status ReadRecordBatch(const flatbuf::RecordBatch* metadata, - const std::shared_ptr& schema, - const DictionaryMemo* dictionary_memo, - const IpcOptions& options, - io::RandomAccessFile* file, - std::shared_ptr* out) { - IpcComponentSource source(metadata, file); - return LoadRecordBatchFromSource(schema, metadata->length(), - options.max_recursion_depth, &source, dictionary_memo, - out); +Result> ReadRecordBatch( + const std::shared_ptr& schema, const DictionaryMemo* dictionary_memo, + const IpcReadOptions& options, io::InputStream* file) { + std::unique_ptr message; + RETURN_NOT_OK(ReadContiguousPayload(file, &message)); + CHECK_HAS_BODY(*message); + ARROW_ASSIGN_OR_RAISE(auto reader, Buffer::GetReader(message->body())); + return ReadRecordBatch(*message->metadata(), schema, dictionary_memo, options, + reader.get()); } -Status ReadRecordBatch(const Buffer& metadata, const std::shared_ptr& schema, - const DictionaryMemo* dictionary_memo, const IpcOptions& options, - io::RandomAccessFile* file, std::shared_ptr* out) { +Result> ReadRecordBatch( + const Message& message, const std::shared_ptr& schema, + const DictionaryMemo* dictionary_memo, const IpcReadOptions& options) { + CHECK_MESSAGE_TYPE(Message::RECORD_BATCH, message.type()); + CHECK_HAS_BODY(message); + ARROW_ASSIGN_OR_RAISE(auto reader, Buffer::GetReader(message.body())); + return ReadRecordBatch(*message.metadata(), schema, dictionary_memo, options, + reader.get()); +} + +Result> ReadRecordBatchInternal( + const Buffer& metadata, const std::shared_ptr& schema, + const std::vector& inclusion_mask, const DictionaryMemo* dictionary_memo, + const IpcReadOptions& options, io::RandomAccessFile* file) { const flatbuf::Message* message; RETURN_NOT_OK(internal::VerifyMessage(metadata.data(), metadata.size(), &message)); auto batch = message->header_as_RecordBatch(); @@ -414,13 +504,41 @@ Status ReadRecordBatch(const Buffer& metadata, const std::shared_ptr& sc return Status::IOError( "Header-type of flatbuffer-encoded Message is not RecordBatch."); } - return ReadRecordBatch(batch, schema, dictionary_memo, options, file, out); + Compression::type compression; + RETURN_NOT_OK(GetCompression(message, &compression)); + return LoadRecordBatch(batch, schema, inclusion_mask, dictionary_memo, options, + compression, file); } -Status ReadDictionary(const Buffer& metadata, DictionaryMemo* dictionary_memo, - io::RandomAccessFile* file) { - auto options = IpcOptions::Defaults(); +Status PopulateInclusionMask(const std::vector& included_indices, + int schema_num_fields, std::vector* mask) { + mask->resize(schema_num_fields, false); + for (int i : included_indices) { + // Ignore out of bounds indices + if (i < 0 || i >= schema_num_fields) { + return Status::Invalid("Out of bounds field index: ", i); + } + (*mask)[i] = true; + } + return Status::OK(); +} + +Result> ReadRecordBatch( + const Buffer& metadata, const std::shared_ptr& schema, + const DictionaryMemo* dictionary_memo, const IpcReadOptions& options, + io::RandomAccessFile* file) { + // Empty means do not use + std::vector inclusion_mask; + if (options.included_fields) { + RETURN_NOT_OK(PopulateInclusionMask(*options.included_fields, schema->num_fields(), + &inclusion_mask)); + } + return ReadRecordBatchInternal(metadata, schema, inclusion_mask, dictionary_memo, + options, file); +} +Status ReadDictionary(const Buffer& metadata, DictionaryMemo* dictionary_memo, + const IpcReadOptions& options, io::RandomAccessFile* file) { const flatbuf::Message* message; RETURN_NOT_OK(internal::VerifyMessage(metadata.data(), metadata.size(), &message)); auto dictionary_batch = message->header_as_DictionaryBatch(); @@ -429,6 +547,9 @@ Status ReadDictionary(const Buffer& metadata, DictionaryMemo* dictionary_memo, "Header-type of flatbuffer-encoded Message is not DictionaryBatch."); } + Compression::type compression; + RETURN_NOT_OK(GetCompression(message, &compression)); + int64_t id = dictionary_batch->id(); // Look up the field, which must have been added to the @@ -439,11 +560,14 @@ Status ReadDictionary(const Buffer& metadata, DictionaryMemo* dictionary_memo, auto value_field = ::arrow::field("dummy", value_type); // The dictionary is embedded in a record batch with a single column - std::shared_ptr batch; auto batch_meta = dictionary_batch->data(); CHECK_FLATBUFFERS_NOT_NULL(batch_meta, "DictionaryBatch.data"); - RETURN_NOT_OK(ReadRecordBatch(batch_meta, ::arrow::schema({value_field}), - dictionary_memo, options, file, &batch)); + + std::shared_ptr batch; + ARROW_ASSIGN_OR_RAISE( + batch, LoadRecordBatch(batch_meta, ::arrow::schema({value_field}), + /*field_inclusion_mask=*/{}, dictionary_memo, options, + compression, file)); if (batch->num_columns() != 1) { return Status::Invalid("Dictionary record batch must only contain one field"); } @@ -454,21 +578,14 @@ Status ReadDictionary(const Buffer& metadata, DictionaryMemo* dictionary_memo, // ---------------------------------------------------------------------- // RecordBatchStreamReader implementation -static inline FileBlock FileBlockFromFlatbuffer(const flatbuf::Block* block) { - return FileBlock{block->offset(), block->metaDataLength(), block->bodyLength()}; -} - -class RecordBatchStreamReader::RecordBatchStreamReaderImpl { +class RecordBatchStreamReaderImpl : public RecordBatchStreamReader { public: - RecordBatchStreamReaderImpl() {} - ~RecordBatchStreamReaderImpl() {} - - Status Open(std::unique_ptr message_reader) { + Status Open(std::unique_ptr message_reader, + const IpcReadOptions& options) { message_reader_ = std::move(message_reader); - return ReadSchema(); - } + options_ = options; - Status ReadSchema() { + // Read schema std::unique_ptr message; RETURN_NOT_OK(message_reader_->ReadNextMessage(&message)); if (!message) { @@ -476,15 +593,59 @@ class RecordBatchStreamReader::RecordBatchStreamReaderImpl { } CHECK_MESSAGE_TYPE(Message::SCHEMA, message->type()); CHECK_HAS_NO_BODY(*message); - return internal::GetSchema(message->header(), &dictionary_memo_, &schema_); + + RETURN_NOT_OK(internal::GetSchema(message->header(), &dictionary_memo_, &schema_)); + + // If we are selecting only certain fields, populate the inclusion mask now + // for fast lookups + if (options.included_fields) { + RETURN_NOT_OK(PopulateInclusionMask(*options.included_fields, schema_->num_fields(), + &field_inclusion_mask_)); + } + return Status::OK(); } + Status ReadNext(std::shared_ptr* batch) override { + if (!have_read_initial_dictionaries_) { + RETURN_NOT_OK(ReadInitialDictionaries()); + } + + if (empty_stream_) { + // ARROW-6006: Degenerate case where stream contains no data, we do not + // bother trying to read a RecordBatch message from the stream + *batch = nullptr; + return Status::OK(); + } + + std::unique_ptr message; + RETURN_NOT_OK(message_reader_->ReadNextMessage(&message)); + if (message == nullptr) { + // End of stream + *batch = nullptr; + return Status::OK(); + } + + if (message->type() == Message::DICTIONARY_BATCH) { + // TODO(wesm): implement delta dictionaries + return Status::NotImplemented("Delta dictionaries not yet implemented"); + } else { + CHECK_HAS_BODY(*message); + ARROW_ASSIGN_OR_RAISE(auto reader, Buffer::GetReader(message->body())); + return ReadRecordBatchInternal(*message->metadata(), schema_, field_inclusion_mask_, + &dictionary_memo_, options_, reader.get()) + .Value(batch); + } + } + + std::shared_ptr schema() const override { return schema_; } + + private: Status ParseDictionary(const Message& message) { // Only invoke this method if we already know we have a dictionary message DCHECK_EQ(message.type(), Message::DICTIONARY_BATCH); CHECK_HAS_BODY(message); ARROW_ASSIGN_OR_RAISE(auto reader, Buffer::GetReader(message.body())); - return ReadDictionary(*message.metadata(), &dictionary_memo_, reader.get()); + return ReadDictionary(*message.metadata(), &dictionary_memo_, options_, reader.get()); } Status ReadInitialDictionaries() { @@ -521,47 +682,15 @@ class RecordBatchStreamReader::RecordBatchStreamReaderImpl { RETURN_NOT_OK(ParseDictionary(*message)); } - read_initial_dictionaries_ = true; + have_read_initial_dictionaries_ = true; return Status::OK(); } - Status ReadNext(std::shared_ptr* batch) { - if (!read_initial_dictionaries_) { - RETURN_NOT_OK(ReadInitialDictionaries()); - } - - if (empty_stream_) { - // ARROW-6006: Degenerate case where stream contains no data, we do not - // bother trying to read a RecordBatch message from the stream - *batch = nullptr; - return Status::OK(); - } - - std::unique_ptr message; - RETURN_NOT_OK(message_reader_->ReadNextMessage(&message)); - if (message == nullptr) { - // End of stream - *batch = nullptr; - return Status::OK(); - } - - if (message->type() == Message::DICTIONARY_BATCH) { - // TODO(wesm): implement delta dictionaries - return Status::NotImplemented("Delta dictionaries not yet implemented"); - } else { - CHECK_HAS_BODY(*message); - ARROW_ASSIGN_OR_RAISE(auto reader, Buffer::GetReader(message->body())); - return ReadRecordBatch(*message->metadata(), schema_, &dictionary_memo_, - reader.get(), batch); - } - } - - std::shared_ptr schema() const { return schema_; } - - private: std::unique_ptr message_reader_; + IpcReadOptions options_; + std::vector field_inclusion_mask_; - bool read_initial_dictionaries_ = false; + bool have_read_initial_dictionaries_ = false; // Flag to set in case where we fail to observe all dictionaries in a stream, // and so the reader should not attempt to parse any messages @@ -571,108 +700,93 @@ class RecordBatchStreamReader::RecordBatchStreamReaderImpl { std::shared_ptr schema_; }; -RecordBatchStreamReader::RecordBatchStreamReader() { - impl_.reset(new RecordBatchStreamReaderImpl()); -} - -RecordBatchStreamReader::~RecordBatchStreamReader() {} - -Status RecordBatchStreamReader::Open(std::unique_ptr message_reader, - std::shared_ptr* reader) { - // Private ctor - auto result = std::shared_ptr(new RecordBatchStreamReader()); - RETURN_NOT_OK(result->impl_->Open(std::move(message_reader))); - *reader = result; - return Status::OK(); -} +// ---------------------------------------------------------------------- +// Stream reader constructors -Status RecordBatchStreamReader::Open(std::unique_ptr message_reader, - std::unique_ptr* reader) { +Result> RecordBatchStreamReader::Open( + std::unique_ptr message_reader, const IpcReadOptions& options) { // Private ctor - auto result = std::unique_ptr(new RecordBatchStreamReader()); - RETURN_NOT_OK(result->impl_->Open(std::move(message_reader))); - *reader = std::move(result); - return Status::OK(); -} - -Status RecordBatchStreamReader::Open(io::InputStream* stream, - std::shared_ptr* out) { - return Open(MessageReader::Open(stream), out); -} - -Status RecordBatchStreamReader::Open(const std::shared_ptr& stream, - std::shared_ptr* out) { - return Open(MessageReader::Open(stream), out); + auto result = std::make_shared(); + RETURN_NOT_OK(result->Open(std::move(message_reader), options)); + return result; } -std::shared_ptr RecordBatchStreamReader::schema() const { - return impl_->schema(); +Result> RecordBatchStreamReader::Open( + io::InputStream* stream, const IpcReadOptions& options) { + return Open(MessageReader::Open(stream), options); } -Status RecordBatchStreamReader::ReadNext(std::shared_ptr* batch) { - return impl_->ReadNext(batch); +Result> RecordBatchStreamReader::Open( + const std::shared_ptr& stream, const IpcReadOptions& options) { + return Open(MessageReader::Open(stream), options); } // ---------------------------------------------------------------------- // Reader implementation -class RecordBatchFileReader::RecordBatchFileReaderImpl { +static inline FileBlock FileBlockFromFlatbuffer(const flatbuf::Block* block) { + return FileBlock{block->offset(), block->metaDataLength(), block->bodyLength()}; +} + +class RecordBatchFileReaderImpl : public RecordBatchFileReader { public: RecordBatchFileReaderImpl() : file_(NULLPTR), footer_offset_(0), footer_(NULLPTR) {} - Status ReadFooter() { - const int32_t magic_size = static_cast(strlen(kArrowMagicBytes)); + int num_record_batches() const override { + return static_cast(internal::FlatBuffersVectorSize(footer_->recordBatches())); + } - if (footer_offset_ <= magic_size * 2 + 4) { - return Status::Invalid("File is too small: ", footer_offset_); - } + MetadataVersion version() const override { + return internal::GetMetadataVersion(footer_->version()); + } - int file_end_size = static_cast(magic_size + sizeof(int32_t)); - ARROW_ASSIGN_OR_RAISE(auto buffer, - file_->ReadAt(footer_offset_ - file_end_size, file_end_size)); + Status ReadRecordBatch(int i, std::shared_ptr* batch) override { + DCHECK_GE(i, 0); + DCHECK_LT(i, num_record_batches()); - const int64_t expected_footer_size = magic_size + sizeof(int32_t); - if (buffer->size() < expected_footer_size) { - return Status::Invalid("Unable to read ", expected_footer_size, "from end of file"); + if (!read_dictionaries_) { + RETURN_NOT_OK(ReadDictionaries()); + read_dictionaries_ = true; } - if (memcmp(buffer->data() + sizeof(int32_t), kArrowMagicBytes, magic_size)) { - return Status::Invalid("Not an Arrow file"); - } + std::unique_ptr message; + RETURN_NOT_OK(ReadMessageFromBlock(GetRecordBatchBlock(i), &message)); - int32_t footer_length = *reinterpret_cast(buffer->data()); + CHECK_HAS_BODY(*message); + ARROW_ASSIGN_OR_RAISE(auto reader, Buffer::GetReader(message->body())); + return ::arrow::ipc::ReadRecordBatch(*message->metadata(), schema_, &dictionary_memo_, + options_, reader.get()) + .Value(batch); + } - if (footer_length <= 0 || footer_length > footer_offset_ - magic_size * 2 - 4) { - return Status::Invalid("File is smaller than indicated metadata size"); - } + Status Open(const std::shared_ptr& file, int64_t footer_offset, + const IpcReadOptions& options) { + owned_file_ = file; + return Open(file.get(), footer_offset, options); + } - // Now read the footer - ARROW_ASSIGN_OR_RAISE( - footer_buffer_, - file_->ReadAt(footer_offset_ - footer_length - file_end_size, footer_length)); + Status Open(io::RandomAccessFile* file, int64_t footer_offset, + const IpcReadOptions& options) { + file_ = file; + options_ = options; + footer_offset_ = footer_offset; + RETURN_NOT_OK(ReadFooter()); - auto data = footer_buffer_->data(); - flatbuffers::Verifier verifier(data, footer_buffer_->size(), 128); - if (!flatbuf::VerifyFooterBuffer(verifier)) { - return Status::IOError("Verification of flatbuffer-encoded Footer failed."); - } - footer_ = flatbuf::GetFooter(data); + // Get the schema and record any observed dictionaries + RETURN_NOT_OK(internal::GetSchema(footer_->schema(), &dictionary_memo_, &schema_)); + // If we are selecting only certain fields, populate the inclusion mask now + // for fast lookups + if (options.included_fields) { + RETURN_NOT_OK(PopulateInclusionMask(*options.included_fields, schema_->num_fields(), + &field_inclusion_mask_)); + } return Status::OK(); } - int num_dictionaries() const { - return static_cast(internal::FlatBuffersVectorSize(footer_->dictionaries())); - } - - int num_record_batches() const { - return static_cast(internal::FlatBuffersVectorSize(footer_->recordBatches())); - } - - MetadataVersion version() const { - return internal::GetMetadataVersion(footer_->version()); - } + std::shared_ptr schema() const override { return schema_; } + private: FileBlock GetRecordBatchBlock(int i) const { return FileBlockFromFlatbuffer(footer_->recordBatches()->Get(i)); } @@ -703,51 +817,60 @@ class RecordBatchFileReader::RecordBatchFileReaderImpl { CHECK_HAS_BODY(*message); ARROW_ASSIGN_OR_RAISE(auto reader, Buffer::GetReader(message->body())); - RETURN_NOT_OK( - ReadDictionary(*message->metadata(), &dictionary_memo_, reader.get())); + RETURN_NOT_OK(ReadDictionary(*message->metadata(), &dictionary_memo_, options_, + reader.get())); } return Status::OK(); } - Status ReadRecordBatch(int i, std::shared_ptr* batch) { - DCHECK_GE(i, 0); - DCHECK_LT(i, num_record_batches()); + Status ReadFooter() { + const int32_t magic_size = static_cast(strlen(kArrowMagicBytes)); - if (!read_dictionaries_) { - RETURN_NOT_OK(ReadDictionaries()); - read_dictionaries_ = true; + if (footer_offset_ <= magic_size * 2 + 4) { + return Status::Invalid("File is too small: ", footer_offset_); } - std::unique_ptr message; - RETURN_NOT_OK(ReadMessageFromBlock(GetRecordBatchBlock(i), &message)); + int file_end_size = static_cast(magic_size + sizeof(int32_t)); + ARROW_ASSIGN_OR_RAISE(auto buffer, + file_->ReadAt(footer_offset_ - file_end_size, file_end_size)); - CHECK_HAS_BODY(*message); - ARROW_ASSIGN_OR_RAISE(auto reader, Buffer::GetReader(message->body())); - return ::arrow::ipc::ReadRecordBatch(*message->metadata(), schema_, &dictionary_memo_, - reader.get(), batch); - } + const int64_t expected_footer_size = magic_size + sizeof(int32_t); + if (buffer->size() < expected_footer_size) { + return Status::Invalid("Unable to read ", expected_footer_size, "from end of file"); + } - Status ReadSchema() { - // Get the schema and record any observed dictionaries - return internal::GetSchema(footer_->schema(), &dictionary_memo_, &schema_); - } + if (memcmp(buffer->data() + sizeof(int32_t), kArrowMagicBytes, magic_size)) { + return Status::Invalid("Not an Arrow file"); + } - Status Open(const std::shared_ptr& file, int64_t footer_offset) { - owned_file_ = file; - return Open(file.get(), footer_offset); - } + int32_t footer_length = *reinterpret_cast(buffer->data()); - Status Open(io::RandomAccessFile* file, int64_t footer_offset) { - file_ = file; - footer_offset_ = footer_offset; - RETURN_NOT_OK(ReadFooter()); - return ReadSchema(); + if (footer_length <= 0 || footer_length > footer_offset_ - magic_size * 2 - 4) { + return Status::Invalid("File is smaller than indicated metadata size"); + } + + // Now read the footer + ARROW_ASSIGN_OR_RAISE( + footer_buffer_, + file_->ReadAt(footer_offset_ - footer_length - file_end_size, footer_length)); + + auto data = footer_buffer_->data(); + flatbuffers::Verifier verifier(data, footer_buffer_->size(), 128); + if (!flatbuf::VerifyFooterBuffer(verifier)) { + return Status::IOError("Verification of flatbuffer-encoded Footer failed."); + } + footer_ = flatbuf::GetFooter(data); + + return Status::OK(); } - std::shared_ptr schema() const { return schema_; } + int num_dictionaries() const { + return static_cast(internal::FlatBuffersVectorSize(footer_->dictionaries())); + } - private: io::RandomAccessFile* file_; + IpcReadOptions options_; + std::vector field_inclusion_mask_; std::shared_ptr owned_file_; @@ -766,61 +889,35 @@ class RecordBatchFileReader::RecordBatchFileReaderImpl { std::shared_ptr schema_; }; -RecordBatchFileReader::RecordBatchFileReader() { - impl_.reset(new RecordBatchFileReaderImpl()); -} - -RecordBatchFileReader::~RecordBatchFileReader() {} - -Status RecordBatchFileReader::Open(io::RandomAccessFile* file, - std::shared_ptr* reader) { +Result> RecordBatchFileReader::Open( + io::RandomAccessFile* file, const IpcReadOptions& options) { ARROW_ASSIGN_OR_RAISE(int64_t footer_offset, file->GetSize()); - return Open(file, footer_offset, reader); + return Open(file, footer_offset, options); } -Status RecordBatchFileReader::Open(io::RandomAccessFile* file, int64_t footer_offset, - std::shared_ptr* reader) { - *reader = std::shared_ptr(new RecordBatchFileReader()); - return (*reader)->impl_->Open(file, footer_offset); +Result> RecordBatchFileReader::Open( + io::RandomAccessFile* file, int64_t footer_offset, const IpcReadOptions& options) { + auto result = std::make_shared(); + RETURN_NOT_OK(result->Open(file, footer_offset, options)); + return result; } -Status RecordBatchFileReader::Open(const std::shared_ptr& file, - std::shared_ptr* reader) { +Result> RecordBatchFileReader::Open( + const std::shared_ptr& file, const IpcReadOptions& options) { ARROW_ASSIGN_OR_RAISE(int64_t footer_offset, file->GetSize()); - return Open(file, footer_offset, reader); -} - -Status RecordBatchFileReader::Open(const std::shared_ptr& file, - int64_t footer_offset, - std::shared_ptr* reader) { - *reader = std::shared_ptr(new RecordBatchFileReader()); - return (*reader)->impl_->Open(file, footer_offset); -} - -std::shared_ptr RecordBatchFileReader::schema() const { return impl_->schema(); } - -int RecordBatchFileReader::num_record_batches() const { - return impl_->num_record_batches(); + return Open(file, footer_offset, options); } -MetadataVersion RecordBatchFileReader::version() const { return impl_->version(); } - -Status RecordBatchFileReader::ReadRecordBatch(int i, - std::shared_ptr* batch) { - return impl_->ReadRecordBatch(i, batch); +Result> RecordBatchFileReader::Open( + const std::shared_ptr& file, int64_t footer_offset, + const IpcReadOptions& options) { + auto result = std::make_shared(); + RETURN_NOT_OK(result->Open(file, footer_offset, options)); + return result; } -static Status ReadContiguousPayload(io::InputStream* file, - std::unique_ptr* message) { - RETURN_NOT_OK(ReadMessage(file, message)); - if (*message == nullptr) { - return Status::Invalid("Unable to read metadata at offset"); - } - return Status::OK(); -} - -Status ReadSchema(io::InputStream* stream, DictionaryMemo* dictionary_memo, - std::shared_ptr* out) { +Result> ReadSchema(io::InputStream* stream, + DictionaryMemo* dictionary_memo) { std::unique_ptr reader = MessageReader::Open(stream); std::unique_ptr message; RETURN_NOT_OK(reader->ReadNextMessage(&message)); @@ -828,25 +925,14 @@ Status ReadSchema(io::InputStream* stream, DictionaryMemo* dictionary_memo, return Status::Invalid("Tried reading schema message, was null or length 0"); } CHECK_MESSAGE_TYPE(Message::SCHEMA, message->type()); - return ReadSchema(*message, dictionary_memo, out); + return ReadSchema(*message, dictionary_memo); } -Status ReadSchema(const Message& message, DictionaryMemo* dictionary_memo, - std::shared_ptr* out) { - std::shared_ptr reader; - return internal::GetSchema(message.header(), dictionary_memo, &*out); -} - -Status ReadRecordBatch(const std::shared_ptr& schema, - const DictionaryMemo* dictionary_memo, io::InputStream* file, - std::shared_ptr* out) { - auto options = IpcOptions::Defaults(); - std::unique_ptr message; - RETURN_NOT_OK(ReadContiguousPayload(file, &message)); - CHECK_HAS_BODY(*message); - ARROW_ASSIGN_OR_RAISE(auto reader, Buffer::GetReader(message->body())); - return ReadRecordBatch(*message->metadata(), schema, dictionary_memo, options, - reader.get(), out); +Result> ReadSchema(const Message& message, + DictionaryMemo* dictionary_memo) { + std::shared_ptr result; + RETURN_NOT_OK(internal::GetSchema(message.header(), dictionary_memo, &result)); + return result; } Result> ReadTensor(io::InputStream* file) { @@ -1192,7 +1278,7 @@ Status FuzzIpcStream(const uint8_t* data, int64_t size) { io::BufferReader buffer_reader(buffer); std::shared_ptr batch_reader; - RETURN_NOT_OK(RecordBatchStreamReader::Open(&buffer_reader, &batch_reader)); + ARROW_ASSIGN_OR_RAISE(batch_reader, RecordBatchStreamReader::Open(&buffer_reader)); while (true) { std::shared_ptr batch; @@ -1211,7 +1297,7 @@ Status FuzzIpcFile(const uint8_t* data, int64_t size) { io::BufferReader buffer_reader(buffer); std::shared_ptr batch_reader; - RETURN_NOT_OK(RecordBatchFileReader::Open(&buffer_reader, &batch_reader)); + ARROW_ASSIGN_OR_RAISE(batch_reader, RecordBatchFileReader::Open(&buffer_reader)); const int n_batches = batch_reader->num_record_batches(); for (int i = 0; i < n_batches; ++i) { @@ -1224,5 +1310,86 @@ Status FuzzIpcFile(const uint8_t* data, int64_t size) { } } // namespace internal + +// ---------------------------------------------------------------------- +// Deprecated functions + +Status RecordBatchStreamReader::Open(std::unique_ptr message_reader, + std::shared_ptr* out) { + return Open(std::move(message_reader), IpcReadOptions::Defaults()).Value(out); +} + +Status RecordBatchStreamReader::Open(std::unique_ptr message_reader, + std::unique_ptr* out) { + auto result = + std::unique_ptr(new RecordBatchStreamReaderImpl()); + RETURN_NOT_OK(result->Open(std::move(message_reader), IpcReadOptions::Defaults())); + *out = std::move(result); + return Status::OK(); +} + +Status RecordBatchStreamReader::Open(io::InputStream* stream, + std::shared_ptr* out) { + return Open(MessageReader::Open(stream)).Value(out); +} + +Status RecordBatchStreamReader::Open(const std::shared_ptr& stream, + std::shared_ptr* out) { + return Open(MessageReader::Open(stream)).Value(out); +} + +Status RecordBatchFileReader::Open(io::RandomAccessFile* file, + std::shared_ptr* out) { + return Open(file).Value(out); +} + +Status RecordBatchFileReader::Open(io::RandomAccessFile* file, int64_t footer_offset, + std::shared_ptr* out) { + return Open(file, footer_offset).Value(out); +} + +Status RecordBatchFileReader::Open(const std::shared_ptr& file, + std::shared_ptr* out) { + return Open(file).Value(out); +} + +Status RecordBatchFileReader::Open(const std::shared_ptr& file, + int64_t footer_offset, + std::shared_ptr* out) { + return Open(file, footer_offset).Value(out); +} + +Status ReadSchema(io::InputStream* stream, DictionaryMemo* dictionary_memo, + std::shared_ptr* out) { + return ReadSchema(stream, dictionary_memo).Value(out); +} + +Status ReadSchema(const Message& message, DictionaryMemo* dictionary_memo, + std::shared_ptr* out) { + return ReadSchema(message, dictionary_memo).Value(out); +} + +Status ReadRecordBatch(const std::shared_ptr& schema, + const DictionaryMemo* dictionary_memo, io::InputStream* stream, + std::shared_ptr* out) { + return ReadRecordBatch(schema, dictionary_memo, IpcReadOptions::Defaults(), stream) + .Value(out); +} + +Status ReadRecordBatch(const Message& message, const std::shared_ptr& schema, + const DictionaryMemo* dictionary_memo, + std::shared_ptr* out) { + return ReadRecordBatch(message, schema, dictionary_memo, IpcReadOptions::Defaults()) + .Value(out); +} + +Status ReadRecordBatch(const Buffer& metadata, const std::shared_ptr& schema, + const DictionaryMemo* dictionary_memo, io::RandomAccessFile* file, + std::shared_ptr* out) { + return ReadRecordBatch(metadata, schema, dictionary_memo, IpcReadOptions::Defaults(), + file) + .Value(out); +} + } // namespace ipc } // namespace arrow diff --git a/cpp/src/arrow/ipc/reader.h b/cpp/src/arrow/ipc/reader.h index bc24c22a13d..642c3ad02e0 100644 --- a/cpp/src/arrow/ipc/reader.h +++ b/cpp/src/arrow/ipc/reader.h @@ -19,15 +19,15 @@ #pragma once +#include #include #include -#include "arrow/ipc/dictionary.h" #include "arrow/ipc/message.h" #include "arrow/ipc/options.h" -#include "arrow/ipc/writer.h" #include "arrow/record_batch.h" -#include "arrow/sparse_tensor.h" +#include "arrow/result.h" +#include "arrow/util/macros.h" #include "arrow/util/visibility.h" namespace arrow { @@ -47,6 +47,14 @@ class RandomAccessFile; namespace ipc { +class DictionaryMemo; + +namespace internal { + +struct IpcPayload; + +} // namespace internal + using RecordBatchReader = ::arrow::RecordBatchReader; /// \class RecordBatchStreamReader @@ -57,50 +65,51 @@ using RecordBatchReader = ::arrow::RecordBatchReader; /// reads see the ReadRecordBatch functions class ARROW_EXPORT RecordBatchStreamReader : public RecordBatchReader { public: - ~RecordBatchStreamReader() override; - /// Create batch reader from generic MessageReader. /// This will take ownership of the given MessageReader. /// /// \param[in] message_reader a MessageReader implementation - /// \param[out] out the created RecordBatchReader object - /// \return Status - static Status Open(std::unique_ptr message_reader, - std::shared_ptr* out); - static Status Open(std::unique_ptr message_reader, - std::unique_ptr* out); + /// \param[in] options any IPC reading options (optional) + /// \return the created batch reader + static Result> Open( + std::unique_ptr message_reader, + const IpcReadOptions& options = IpcReadOptions::Defaults()); /// \brief Record batch stream reader from InputStream /// /// \param[in] stream an input stream instance. Must stay alive throughout /// lifetime of stream reader - /// \param[out] out the created RecordBatchStreamReader object - /// \return Status - static Status Open(io::InputStream* stream, std::shared_ptr* out); + /// \param[in] options any IPC reading options (optional) + /// \return the created batch reader + static Result> Open( + io::InputStream* stream, + const IpcReadOptions& options = IpcReadOptions::Defaults()); /// \brief Open stream and retain ownership of stream object /// \param[in] stream the input stream - /// \param[out] out the batch reader - /// \return Status + /// \param[in] options any IPC reading options (optional) + /// \return the created batch reader + static Result> Open( + const std::shared_ptr& stream, + const IpcReadOptions& options = IpcReadOptions::Defaults()); + + ARROW_DEPRECATED("Use Result-returning version") + static Status Open(std::unique_ptr message_reader, + std::shared_ptr* out); + ARROW_DEPRECATED("Use Result-returning version") + static Status Open(std::unique_ptr message_reader, + std::unique_ptr* out); + ARROW_DEPRECATED("Use Result-returning version") + static Status Open(io::InputStream* stream, std::shared_ptr* out); + ARROW_DEPRECATED("Use Result-returning version") static Status Open(const std::shared_ptr& stream, std::shared_ptr* out); - - /// \brief Returns the schema read from the stream - std::shared_ptr schema() const override; - - Status ReadNext(std::shared_ptr* batch) override; - - private: - RecordBatchStreamReader(); - - class ARROW_NO_EXPORT RecordBatchStreamReaderImpl; - std::unique_ptr impl_; }; /// \brief Reads the record batch file format class ARROW_EXPORT RecordBatchFileReader { public: - ~RecordBatchFileReader(); + virtual ~RecordBatchFileReader() = default; /// \brief Open a RecordBatchFileReader /// @@ -109,8 +118,9 @@ class ARROW_EXPORT RecordBatchFileReader { /// can be any amount of data preceding the Arrow-formatted data, because we /// need only locate the end of the Arrow file stream to discover the metadata /// and then proceed to read the data into memory. - static Status Open(io::RandomAccessFile* file, - std::shared_ptr* reader); + static Result> Open( + io::RandomAccessFile* file, + const IpcReadOptions& options = IpcReadOptions::Defaults()); /// \brief Open a RecordBatchFileReader /// If the file is embedded within some larger file or memory region, you can @@ -120,37 +130,52 @@ class ARROW_EXPORT RecordBatchFileReader { /// /// \param[in] file the data source /// \param[in] footer_offset the position of the end of the Arrow file - /// \param[out] reader the returned reader - /// \return Status - static Status Open(io::RandomAccessFile* file, int64_t footer_offset, - std::shared_ptr* reader); + /// \param[in] options options for IPC reading + /// \return the returned reader + static Result> Open( + io::RandomAccessFile* file, int64_t footer_offset, + const IpcReadOptions& options = IpcReadOptions::Defaults()); /// \brief Version of Open that retains ownership of file /// /// \param[in] file the data source - /// \param[out] reader the returned reader - /// \return Status - static Status Open(const std::shared_ptr& file, - std::shared_ptr* reader); + /// \param[in] options options for IPC reading + /// \return the returned reader + static Result> Open( + const std::shared_ptr& file, + const IpcReadOptions& options = IpcReadOptions::Defaults()); /// \brief Version of Open that retains ownership of file /// /// \param[in] file the data source /// \param[in] footer_offset the position of the end of the Arrow file - /// \param[out] reader the returned reader - /// \return Status + /// \param[in] options options for IPC reading + /// \return the returned reader + static Result> Open( + const std::shared_ptr& file, int64_t footer_offset, + const IpcReadOptions& options = IpcReadOptions::Defaults()); + + ARROW_DEPRECATED("Use Result-returning version") static Status Open(const std::shared_ptr& file, - int64_t footer_offset, - std::shared_ptr* reader); + int64_t footer_offset, std::shared_ptr* out); + ARROW_DEPRECATED("Use Result-returning version") + static Status Open(const std::shared_ptr& file, + std::shared_ptr* out); + ARROW_DEPRECATED("Use Result-returning version") + static Status Open(io::RandomAccessFile* file, int64_t footer_offset, + std::shared_ptr* out); + ARROW_DEPRECATED("Use Result-returning version") + static Status Open(io::RandomAccessFile* file, + std::shared_ptr* out); /// \brief The schema read from the file - std::shared_ptr schema() const; + virtual std::shared_ptr schema() const = 0; /// \brief Returns the number of record batches in the file - int num_record_batches() const; + virtual int num_record_batches() const = 0; /// \brief Return the metadata version from the file metadata - MetadataVersion version() const; + virtual MetadataVersion version() const = 0; /// \brief Read a particular record batch from the file. Does not copy memory /// if the input source supports zero-copy. @@ -158,13 +183,7 @@ class ARROW_EXPORT RecordBatchFileReader { /// \param[in] i the index of the record batch to return /// \param[out] batch the read batch /// \return Status - Status ReadRecordBatch(int i, std::shared_ptr* batch); - - private: - RecordBatchFileReader(); - - class ARROW_NO_EXPORT RecordBatchFileReaderImpl; - std::unique_ptr impl_; + virtual Status ReadRecordBatch(int i, std::shared_ptr* batch) = 0; }; // Generic read functions; does not copy data if the input supports zero copy reads @@ -174,26 +193,24 @@ class ARROW_EXPORT RecordBatchFileReader { /// /// \param[in] stream an InputStream /// \param[in] dictionary_memo for recording dictionary-encoded fields -/// \param[out] out the output Schema -/// \return Status +/// \return the output Schema /// /// If record batches follow the schema, it is better to use /// RecordBatchStreamReader ARROW_EXPORT -Status ReadSchema(io::InputStream* stream, DictionaryMemo* dictionary_memo, - std::shared_ptr* out); +Result> ReadSchema(io::InputStream* stream, + DictionaryMemo* dictionary_memo); /// \brief Read Schema from encapsulated Message /// -/// \param[in] message a message instance containing metadata +/// \param[in] message the message containing the Schema IPC metadata /// \param[in] dictionary_memo DictionaryMemo for recording dictionary-encoded /// fields. Can be nullptr if you are sure there are no /// dictionary-encoded fields -/// \param[out] out the resulting Schema -/// \return Status +/// \return the resulting Schema ARROW_EXPORT -Status ReadSchema(const Message& message, DictionaryMemo* dictionary_memo, - std::shared_ptr* out); +Result> ReadSchema(const Message& message, + DictionaryMemo* dictionary_memo); /// Read record batch as encapsulated IPC message with metadata size prefix and /// header @@ -202,42 +219,27 @@ Status ReadSchema(const Message& message, DictionaryMemo* dictionary_memo, /// \param[in] dictionary_memo DictionaryMemo which has any /// dictionaries. Can be nullptr if you are sure there are no /// dictionary-encoded fields +/// \param[in] options IPC options for reading /// \param[in] stream the file where the batch is located -/// \param[out] out the read record batch -/// \return Status +/// \return the read record batch ARROW_EXPORT -Status ReadRecordBatch(const std::shared_ptr& schema, - const DictionaryMemo* dictionary_memo, io::InputStream* stream, - std::shared_ptr* out); +Result> ReadRecordBatch( + const std::shared_ptr& schema, const DictionaryMemo* dictionary_memo, + const IpcReadOptions& options, io::InputStream* stream); -/// \brief Read record batch from file given metadata and schema +/// \brief Read record batch from message /// -/// \param[in] metadata a Message containing the record batch metadata +/// \param[in] message a Message containing the record batch metadata /// \param[in] schema the record batch schema /// \param[in] dictionary_memo DictionaryMemo which has any /// dictionaries. Can be nullptr if you are sure there are no /// dictionary-encoded fields -/// \param[in] file a random access file -/// \param[out] out the read record batch -/// \return Status +/// \param[in] options IPC options for reading +/// \return the read record batch ARROW_EXPORT -Status ReadRecordBatch(const Buffer& metadata, const std::shared_ptr& schema, - const DictionaryMemo* dictionary_memo, io::RandomAccessFile* file, - std::shared_ptr* out); - -/// \brief Read record batch from encapsulated Message -/// -/// \param[in] message a message instance containing metadata and body -/// \param[in] schema the record batch schema -/// \param[in] dictionary_memo DictionaryMemo which has any -/// dictionaries. Can be nullptr if you are sure there are no -/// dictionary-encoded fields -/// \param[out] out the resulting RecordBatch -/// \return Status -ARROW_EXPORT -Status ReadRecordBatch(const Message& message, const std::shared_ptr& schema, - const DictionaryMemo* dictionary_memo, - std::shared_ptr* out); +Result> ReadRecordBatch( + const Message& message, const std::shared_ptr& schema, + const DictionaryMemo* dictionary_memo, const IpcReadOptions& options); /// Read record batch from file given metadata and schema /// @@ -248,12 +250,12 @@ Status ReadRecordBatch(const Message& message, const std::shared_ptr& sc /// dictionary-encoded fields /// \param[in] file a random access file /// \param[in] options options for deserialization -/// \param[out] out the read record batch -/// \return Status +/// \return the read record batch ARROW_EXPORT -Status ReadRecordBatch(const Buffer& metadata, const std::shared_ptr& schema, - const DictionaryMemo* dictionary_memo, const IpcOptions& options, - io::RandomAccessFile* file, std::shared_ptr* out); +Result> ReadRecordBatch( + const Buffer& metadata, const std::shared_ptr& schema, + const DictionaryMemo* dictionary_memo, const IpcReadOptions& options, + io::RandomAccessFile* file); /// \brief Read arrow::Tensor as encapsulated IPC message in file /// @@ -307,5 +309,33 @@ Status FuzzIpcFile(const uint8_t* data, int64_t size); } // namespace internal +ARROW_DEPRECATED("Use version with Result return value") +ARROW_EXPORT +Status ReadSchema(io::InputStream* stream, DictionaryMemo* dictionary_memo, + std::shared_ptr* out); + +ARROW_DEPRECATED("Use version with Result return value") +ARROW_EXPORT +Status ReadSchema(const Message& message, DictionaryMemo* dictionary_memo, + std::shared_ptr* out); + +ARROW_DEPRECATED("Use version with Result return value") +ARROW_EXPORT +Status ReadRecordBatch(const std::shared_ptr& schema, + const DictionaryMemo* dictionary_memo, io::InputStream* stream, + std::shared_ptr* out); + +ARROW_DEPRECATED("Use version with Result return value") +ARROW_EXPORT +Status ReadRecordBatch(const Buffer& metadata, const std::shared_ptr& schema, + const DictionaryMemo* dictionary_memo, io::RandomAccessFile* file, + std::shared_ptr* out); + +ARROW_DEPRECATED("Use version with Result return value") +ARROW_EXPORT +Status ReadRecordBatch(const Message& message, const std::shared_ptr& schema, + const DictionaryMemo* dictionary_memo, + std::shared_ptr* out); + } // namespace ipc } // namespace arrow diff --git a/cpp/src/arrow/ipc/stream_to_file.cc b/cpp/src/arrow/ipc/stream_to_file.cc index d0c7444db75..126195aae9b 100644 --- a/cpp/src/arrow/ipc/stream_to_file.cc +++ b/cpp/src/arrow/ipc/stream_to_file.cc @@ -34,13 +34,10 @@ namespace ipc { // $ | stream-to-file > file.arrow Status ConvertToFile() { io::StdinStream input; - std::shared_ptr reader; - RETURN_NOT_OK(RecordBatchStreamReader::Open(&input, &reader)); - io::StdoutStream sink; - std::shared_ptr writer; - RETURN_NOT_OK(RecordBatchFileWriter::Open(&sink, reader->schema(), &writer)); + ARROW_ASSIGN_OR_RAISE(auto reader, RecordBatchStreamReader::Open(&input)); + ARROW_ASSIGN_OR_RAISE(auto writer, NewFileWriter(&sink, reader->schema())); std::shared_ptr batch; while (true) { RETURN_NOT_OK(reader->ReadNext(&batch)); diff --git a/cpp/src/arrow/ipc/writer.cc b/cpp/src/arrow/ipc/writer.cc index 7589c83148c..18f075ea014 100644 --- a/cpp/src/arrow/ipc/writer.cc +++ b/cpp/src/arrow/ipc/writer.cc @@ -22,11 +22,14 @@ #include #include #include +#include +#include #include #include #include "arrow/array.h" #include "arrow/buffer.h" +#include "arrow/device.h" #include "arrow/extension_type.h" #include "arrow/io/interfaces.h" #include "arrow/io/memory.h" @@ -34,39 +37,37 @@ #include "arrow/ipc/message.h" #include "arrow/ipc/metadata_internal.h" #include "arrow/ipc/util.h" -#include "arrow/memory_pool.h" #include "arrow/record_batch.h" #include "arrow/result_internal.h" #include "arrow/sparse_tensor.h" #include "arrow/status.h" #include "arrow/table.h" -#include "arrow/tensor.h" #include "arrow/type.h" +#include "arrow/type_traits.h" #include "arrow/util/bit_util.h" #include "arrow/util/checked_cast.h" +#include "arrow/util/compression.h" +#include "arrow/util/key_value_metadata.h" #include "arrow/util/logging.h" #include "arrow/util/make_unique.h" -#include "arrow/visitor.h" +#include "arrow/visitor_inline.h" namespace arrow { using internal::checked_cast; using internal::checked_pointer_cast; using internal::CopyBitmap; -using internal::make_unique; namespace ipc { using internal::FileBlock; using internal::kArrowMagicBytes; -// ---------------------------------------------------------------------- -// Record batch write path +namespace internal { -static inline Status GetTruncatedBitmap(int64_t offset, int64_t length, - const std::shared_ptr input, - MemoryPool* pool, - std::shared_ptr* buffer) { +Status GetTruncatedBitmap(int64_t offset, int64_t length, + const std::shared_ptr input, MemoryPool* pool, + std::shared_ptr* buffer) { if (!input) { *buffer = input; return Status::OK(); @@ -81,15 +82,13 @@ static inline Status GetTruncatedBitmap(int64_t offset, int64_t length, return Status::OK(); } -template -inline Status GetTruncatedBuffer(int64_t offset, int64_t length, - const std::shared_ptr input, MemoryPool* pool, - std::shared_ptr* buffer) { +Status GetTruncatedBuffer(int64_t offset, int64_t length, int32_t byte_width, + const std::shared_ptr input, MemoryPool* pool, + std::shared_ptr* buffer) { if (!input) { *buffer = input; return Status::OK(); } - int32_t byte_width = static_cast(sizeof(T)); int64_t padded_length = PaddedLength(length * byte_width); if (offset != 0 || padded_length < input->size()) { *buffer = @@ -109,21 +108,18 @@ static inline bool NeedTruncate(int64_t offset, const Buffer* buffer, return offset != 0 || min_length < buffer->size(); } -namespace internal { - -class RecordBatchSerializer : public ArrayVisitor { +class RecordBatchSerializer { public: - RecordBatchSerializer(MemoryPool* pool, int64_t buffer_start_offset, - const IpcOptions& options, IpcPayload* out) + RecordBatchSerializer(int64_t buffer_start_offset, const IpcWriteOptions& options, + IpcPayload* out) : out_(out), - pool_(pool), options_(options), max_recursion_depth_(options.max_recursion_depth), buffer_start_offset_(buffer_start_offset) { DCHECK_GT(max_recursion_depth_, 0); } - ~RecordBatchSerializer() override = default; + virtual ~RecordBatchSerializer() = default; Status VisitArray(const Array& arr) { static std::shared_ptr kNullBuffer = std::make_shared(nullptr, 0); @@ -144,20 +140,61 @@ class RecordBatchSerializer : public ArrayVisitor { if (arr.null_count() > 0) { std::shared_ptr bitmap; RETURN_NOT_OK(GetTruncatedBitmap(arr.offset(), arr.length(), arr.null_bitmap(), - pool_, &bitmap)); + options_.memory_pool, &bitmap)); out_->body_buffers.emplace_back(bitmap); } else { // Push a dummy zero-length buffer, not to be copied out_->body_buffers.emplace_back(kNullBuffer); } } - return arr.Accept(this); + return VisitType(arr); } // Override this for writing dictionary metadata virtual Status SerializeMetadata(int64_t num_rows) { - return WriteRecordBatchMessage(num_rows, out_->body_length, field_nodes_, - buffer_meta_, &out_->metadata); + return WriteRecordBatchMessage(num_rows, out_->body_length, custom_metadata_, + field_nodes_, buffer_meta_, &out_->metadata); + } + + void AppendCustomMetadata(const std::string& key, const std::string& value) { + if (!custom_metadata_) { + custom_metadata_ = std::make_shared(); + } + custom_metadata_->Append(key, value); + } + + Status CompressBuffer(const Buffer& buffer, util::Codec* codec, + std::shared_ptr* out) { + // Convert buffer to uncompressed-length-prefixed compressed buffer + int64_t maximum_length = codec->MaxCompressedLen(buffer.size(), buffer.data()); + std::shared_ptr result; + RETURN_NOT_OK(AllocateBuffer(maximum_length + sizeof(int64_t), &result)); + + int64_t actual_length; + ARROW_ASSIGN_OR_RAISE(actual_length, + codec->Compress(buffer.size(), buffer.data(), maximum_length, + result->mutable_data() + sizeof(int64_t))); + *reinterpret_cast(result->mutable_data()) = buffer.size(); + *out = SliceBuffer(result, /*offset=*/0, actual_length + sizeof(int64_t)); + return Status::OK(); + } + + Status CompressBodyBuffers() { + std::unique_ptr codec; + + AppendCustomMetadata("ARROW:body_compression", + util::Codec::GetCodecAsString(options_.compression)); + + ARROW_ASSIGN_OR_RAISE( + codec, util::Codec::Create(options_.compression, options_.compression_level)); + // TODO: Parallelize buffer compression + for (size_t i = 0; i < out_->body_buffers.size(); ++i) { + if (out_->body_buffers[i]->size() > 0) { + RETURN_NOT_OK( + CompressBuffer(*out_->body_buffers[i], codec.get(), &out_->body_buffers[i])); + } + } + return Status::OK(); } Status Assemble(const RecordBatch& batch) { @@ -172,6 +209,10 @@ class RecordBatchSerializer : public ArrayVisitor { RETURN_NOT_OK(VisitArray(*batch.column(i))); } + if (options_.compression != Compression::UNCOMPRESSED) { + RETURN_NOT_OK(CompressBodyBuffers()); + } + // The position for the start of a buffer relative to the passed frame of // reference. May be 0 or some other position in an address space int64_t offset = buffer_start_offset_; @@ -205,29 +246,6 @@ class RecordBatchSerializer : public ArrayVisitor { return SerializeMetadata(batch.num_rows()); } - protected: - template - Status VisitFixedWidth(const ArrayType& array) { - std::shared_ptr data = array.values(); - - const auto& fw_type = checked_cast(*array.type()); - const int64_t type_width = fw_type.bit_width() / 8; - int64_t min_length = PaddedLength(array.length() * type_width); - - if (NeedTruncate(array.offset(), data.get(), min_length)) { - // Non-zero offset, slice the buffer - const int64_t byte_offset = array.offset() * type_width; - - // Send padding if it's available - const int64_t buffer_length = - std::min(BitUtil::RoundUpToMultipleOf8(array.length() * type_width), - data->size() - byte_offset); - data = SliceBuffer(data, byte_offset, buffer_length); - } - out_->body_buffers.emplace_back(data); - return Status::OK(); - } - template Status GetZeroBasedValueOffsets(const ArrayType& array, std::shared_ptr* value_offsets) { @@ -243,7 +261,8 @@ class RecordBatchSerializer : public ArrayVisitor { // b) slice the values array accordingly std::shared_ptr shifted_offsets; - RETURN_NOT_OK(AllocateBuffer(pool_, required_bytes, &shifted_offsets)); + RETURN_NOT_OK( + AllocateBuffer(options_.memory_pool, required_bytes, &shifted_offsets)); offset_type* dest_offsets = reinterpret_cast(shifted_offsets->mutable_data()); @@ -266,10 +285,46 @@ class RecordBatchSerializer : public ArrayVisitor { return Status::OK(); } - template - Status VisitBinary(const ArrayType& array) { + Status Visit(const BooleanArray& array) { + std::shared_ptr data; + RETURN_NOT_OK(GetTruncatedBitmap(array.offset(), array.length(), array.values(), + options_.memory_pool, &data)); + out_->body_buffers.emplace_back(data); + return Status::OK(); + } + + Status Visit(const NullArray& array) { return Status::OK(); } + + template + typename std::enable_if::value || + is_temporal_type::value || + is_fixed_size_binary_type::value, + Status>::type + Visit(const T& array) { + std::shared_ptr data = array.values(); + + const auto& fw_type = checked_cast(*array.type()); + const int64_t type_width = fw_type.bit_width() / 8; + int64_t min_length = PaddedLength(array.length() * type_width); + + if (NeedTruncate(array.offset(), data.get(), min_length)) { + // Non-zero offset, slice the buffer + const int64_t byte_offset = array.offset() * type_width; + + // Send padding if it's available + const int64_t buffer_length = + std::min(BitUtil::RoundUpToMultipleOf8(array.length() * type_width), + data->size() - byte_offset); + data = SliceBuffer(data, byte_offset, buffer_length); + } + out_->body_buffers.emplace_back(data); + return Status::OK(); + } + + template + enable_if_base_binary Visit(const T& array) { std::shared_ptr value_offsets; - RETURN_NOT_OK(GetZeroBasedValueOffsets(array, &value_offsets)); + RETURN_NOT_OK(GetZeroBasedValueOffsets(array, &value_offsets)); auto data = array.value_data(); int64_t total_data_bytes = 0; @@ -289,12 +344,12 @@ class RecordBatchSerializer : public ArrayVisitor { return Status::OK(); } - template - Status VisitList(const ArrayType& array) { - using offset_type = typename ArrayType::offset_type; + template + enable_if_base_list Visit(const T& array) { + using offset_type = typename T::offset_type; std::shared_ptr value_offsets; - RETURN_NOT_OK(GetZeroBasedValueOffsets(array, &value_offsets)); + RETURN_NOT_OK(GetZeroBasedValueOffsets(array, &value_offsets)); out_->body_buffers.emplace_back(value_offsets); --max_recursion_depth_; @@ -316,58 +371,7 @@ class RecordBatchSerializer : public ArrayVisitor { return Status::OK(); } - Status Visit(const BooleanArray& array) override { - std::shared_ptr data; - RETURN_NOT_OK( - GetTruncatedBitmap(array.offset(), array.length(), array.values(), pool_, &data)); - out_->body_buffers.emplace_back(data); - return Status::OK(); - } - - Status Visit(const NullArray& array) override { return Status::OK(); } - -#define VISIT_FIXED_WIDTH(TYPE) \ - Status Visit(const TYPE& array) override { return VisitFixedWidth(array); } - - VISIT_FIXED_WIDTH(Int8Array) - VISIT_FIXED_WIDTH(Int16Array) - VISIT_FIXED_WIDTH(Int32Array) - VISIT_FIXED_WIDTH(Int64Array) - VISIT_FIXED_WIDTH(UInt8Array) - VISIT_FIXED_WIDTH(UInt16Array) - VISIT_FIXED_WIDTH(UInt32Array) - VISIT_FIXED_WIDTH(UInt64Array) - VISIT_FIXED_WIDTH(HalfFloatArray) - VISIT_FIXED_WIDTH(FloatArray) - VISIT_FIXED_WIDTH(DoubleArray) - VISIT_FIXED_WIDTH(Date32Array) - VISIT_FIXED_WIDTH(Date64Array) - VISIT_FIXED_WIDTH(TimestampArray) - VISIT_FIXED_WIDTH(DurationArray) - VISIT_FIXED_WIDTH(MonthIntervalArray) - VISIT_FIXED_WIDTH(DayTimeIntervalArray) - VISIT_FIXED_WIDTH(Time32Array) - VISIT_FIXED_WIDTH(Time64Array) - VISIT_FIXED_WIDTH(FixedSizeBinaryArray) - VISIT_FIXED_WIDTH(Decimal128Array) - -#undef VISIT_FIXED_WIDTH - - Status Visit(const StringArray& array) override { return VisitBinary(array); } - - Status Visit(const BinaryArray& array) override { return VisitBinary(array); } - - Status Visit(const LargeStringArray& array) override { return VisitBinary(array); } - - Status Visit(const LargeBinaryArray& array) override { return VisitBinary(array); } - - Status Visit(const ListArray& array) override { return VisitList(array); } - - Status Visit(const LargeListArray& array) override { return VisitList(array); } - - Status Visit(const MapArray& array) override { return VisitList(array); } - - Status Visit(const FixedSizeListArray& array) override { + Status Visit(const FixedSizeListArray& array) { --max_recursion_depth_; auto size = array.list_type()->list_size(); auto values = array.values()->Slice(array.offset() * size, array.length() * size); @@ -377,7 +381,7 @@ class RecordBatchSerializer : public ArrayVisitor { return Status::OK(); } - Status Visit(const StructArray& array) override { + Status Visit(const StructArray& array) { --max_recursion_depth_; for (int i = 0; i < array.num_fields(); ++i) { std::shared_ptr field = array.field(i); @@ -387,13 +391,14 @@ class RecordBatchSerializer : public ArrayVisitor { return Status::OK(); } - Status Visit(const UnionArray& array) override { + Status Visit(const UnionArray& array) { const int64_t offset = array.offset(); const int64_t length = array.length(); std::shared_ptr type_codes; - RETURN_NOT_OK(GetTruncatedBuffer( - offset, length, array.type_codes(), pool_, &type_codes)); + RETURN_NOT_OK(GetTruncatedBuffer( + offset, length, static_cast(sizeof(UnionArray::type_code_t)), + array.type_codes(), options_.memory_pool, &type_codes)); out_->body_buffers.emplace_back(type_codes); --max_recursion_depth_; @@ -401,8 +406,9 @@ class RecordBatchSerializer : public ArrayVisitor { const auto& type = checked_cast(*array.type()); std::shared_ptr value_offsets; - RETURN_NOT_OK(GetTruncatedBuffer(offset, length, array.value_offsets(), - pool_, &value_offsets)); + RETURN_NOT_OK(GetTruncatedBuffer( + offset, length, static_cast(sizeof(int32_t)), array.value_offsets(), + options_.memory_pool, &value_offsets)); // The Union type codes are not necessary 0-indexed int8_t max_code = 0; @@ -427,8 +433,8 @@ class RecordBatchSerializer : public ArrayVisitor { // Allocate the shifted offsets std::shared_ptr shifted_offsets_buffer; - RETURN_NOT_OK( - AllocateBuffer(pool_, length * sizeof(int32_t), &shifted_offsets_buffer)); + RETURN_NOT_OK(AllocateBuffer(options_.memory_pool, length * sizeof(int32_t), + &shifted_offsets_buffer)); int32_t* shifted_offsets = reinterpret_cast(shifted_offsets_buffer->mutable_data()); @@ -486,39 +492,40 @@ class RecordBatchSerializer : public ArrayVisitor { return Status::OK(); } - Status Visit(const DictionaryArray& array) override { + Status Visit(const DictionaryArray& array) { // Dictionary written out separately. Slice offset contained in the indices - return array.indices()->Accept(this); + return VisitType(*array.indices()); } - Status Visit(const ExtensionArray& array) override { - return array.storage()->Accept(this); - } + Status Visit(const ExtensionArray& array) { return VisitType(*array.storage()); } + Status VisitType(const Array& values) { return VisitArrayInline(values, this); } + + protected: // Destination for output buffers IpcPayload* out_; - // In some cases, intermediate buffers may need to be allocated (with sliced arrays) - MemoryPool* pool_; + std::shared_ptr custom_metadata_; - std::vector field_nodes_; - std::vector buffer_meta_; + std::vector field_nodes_; + std::vector buffer_meta_; - const IpcOptions& options_; + const IpcWriteOptions& options_; int64_t max_recursion_depth_; int64_t buffer_start_offset_; }; -class DictionaryWriter : public RecordBatchSerializer { +class DictionarySerializer : public RecordBatchSerializer { public: - DictionaryWriter(int64_t dictionary_id, MemoryPool* pool, int64_t buffer_start_offset, - const IpcOptions& options, IpcPayload* out) - : RecordBatchSerializer(pool, buffer_start_offset, options, out), + DictionarySerializer(int64_t dictionary_id, int64_t buffer_start_offset, + const IpcWriteOptions& options, IpcPayload* out) + : RecordBatchSerializer(buffer_start_offset, options, out), dictionary_id_(dictionary_id) {} Status SerializeMetadata(int64_t num_rows) override { return WriteDictionaryMessage(dictionary_id_, num_rows, out_->body_length, - field_nodes_, buffer_meta_, &out_->metadata); + custom_metadata_, field_nodes_, buffer_meta_, + &out_->metadata); } Status Assemble(const std::shared_ptr& dictionary) { @@ -532,7 +539,7 @@ class DictionaryWriter : public RecordBatchSerializer { int64_t dictionary_id_; }; -Status WriteIpcPayload(const IpcPayload& payload, const IpcOptions& options, +Status WriteIpcPayload(const IpcPayload& payload, const IpcWriteOptions& options, io::OutputStream* dst, int32_t* metadata_length) { RETURN_NOT_OK(WriteMessage(*payload.metadata, options, dst, metadata_length)); @@ -568,39 +575,37 @@ Status WriteIpcPayload(const IpcPayload& payload, const IpcOptions& options, return Status::OK(); } -Status GetSchemaPayload(const Schema& schema, const IpcOptions& options, +Status GetSchemaPayload(const Schema& schema, const IpcWriteOptions& options, DictionaryMemo* dictionary_memo, IpcPayload* out) { out->type = Message::SCHEMA; return WriteSchemaMessage(schema, dictionary_memo, &out->metadata); } Status GetDictionaryPayload(int64_t id, const std::shared_ptr& dictionary, - const IpcOptions& options, MemoryPool* pool, - IpcPayload* out) { + const IpcWriteOptions& options, IpcPayload* out) { out->type = Message::DICTIONARY_BATCH; // Frame of reference is 0, see ARROW-384 - DictionaryWriter writer(id, pool, /*buffer_start_offset=*/0, options, out); - return writer.Assemble(dictionary); + DictionarySerializer assembler(id, /*buffer_start_offset=*/0, options, out); + return assembler.Assemble(dictionary); } -Status GetRecordBatchPayload(const RecordBatch& batch, const IpcOptions& options, - MemoryPool* pool, IpcPayload* out) { +Status GetRecordBatchPayload(const RecordBatch& batch, const IpcWriteOptions& options, + IpcPayload* out) { out->type = Message::RECORD_BATCH; - RecordBatchSerializer writer(pool, /*buffer_start_offset=*/0, options, out); - return writer.Assemble(batch); + RecordBatchSerializer assembler(/*buffer_start_offset=*/0, options, out); + return assembler.Assemble(batch); } } // namespace internal Status WriteRecordBatch(const RecordBatch& batch, int64_t buffer_start_offset, io::OutputStream* dst, int32_t* metadata_length, - int64_t* body_length, const IpcOptions& options, - MemoryPool* pool) { + int64_t* body_length, const IpcWriteOptions& options) { internal::IpcPayload payload; - internal::RecordBatchSerializer writer(pool, buffer_start_offset, options, &payload); - RETURN_NOT_OK(writer.Assemble(batch)); + internal::RecordBatchSerializer assembler(buffer_start_offset, options, &payload); + RETURN_NOT_OK(assembler.Assemble(batch)); - // TODO(wesm): it's a rough edge that the metadata and body length here are + // TODO: it's a rough edge that the metadata and body length here are // computed separately // The body size is computed in the payload @@ -610,9 +615,9 @@ Status WriteRecordBatch(const RecordBatch& batch, int64_t buffer_start_offset, } Status WriteRecordBatchStream(const std::vector>& batches, - const IpcOptions& options, io::OutputStream* dst) { + const IpcWriteOptions& options, io::OutputStream* dst) { ASSIGN_OR_RAISE(std::shared_ptr writer, - RecordBatchStreamWriter::Open(dst, batches[0]->schema(), options)); + NewStreamWriter(dst, batches[0]->schema(), options)); for (const auto& batch : batches) { DCHECK(batch->schema()->Equals(*batches[0]->schema())) << "Schemas unequal"; RETURN_NOT_OK(writer->WriteRecordBatch(*batch)); @@ -627,7 +632,7 @@ Status WriteTensorHeader(const Tensor& tensor, io::OutputStream* dst, int32_t* metadata_length) { std::shared_ptr metadata; ARROW_ASSIGN_OR_RAISE(metadata, internal::WriteTensorMessage(tensor, 0)); - IpcOptions options; + IpcWriteOptions options; options.alignment = kTensorAlignment; return WriteMessage(*metadata, options, dst, metadata_length); } @@ -697,8 +702,8 @@ Status WriteTensor(const Tensor& tensor, io::OutputStream* dst, int32_t* metadat Tensor dummy(tensor.type(), nullptr, tensor.shape()); RETURN_NOT_OK(WriteTensorHeader(dummy, dst, metadata_length)); - // TODO(wesm): Do we care enough about this temporary allocation to pass in - // a MemoryPool to this function? + // TODO: Do we care enough about this temporary allocation to pass in a + // MemoryPool to this function? std::shared_ptr scratch_space; RETURN_NOT_OK( AllocateBuffer(tensor.shape()[tensor.ndim() - 1] * elem_size, &scratch_space)); @@ -832,7 +837,8 @@ Status WriteSparseTensor(const SparseTensor& sparse_tensor, io::OutputStream* ds RETURN_NOT_OK(writer.Assemble(sparse_tensor)); *body_length = payload.body_length; - return internal::WriteIpcPayload(payload, IpcOptions::Defaults(), dst, metadata_length); + return internal::WriteIpcPayload(payload, IpcWriteOptions::Defaults(), dst, + metadata_length); } Status GetSparseTensorMessage(const SparseTensor& sparse_tensor, MemoryPool* pool, @@ -849,12 +855,12 @@ Status GetSparseTensorMessage(const SparseTensor& sparse_tensor, MemoryPool* poo Status GetRecordBatchSize(const RecordBatch& batch, int64_t* size) { // emulates the behavior of Write without actually writing - auto options = IpcOptions::Defaults(); + auto options = IpcWriteOptions::Defaults(); int32_t metadata_length = 0; int64_t body_length = 0; io::MockOutputStream dst; - RETURN_NOT_OK(WriteRecordBatch(batch, 0, &dst, &metadata_length, &body_length, options, - default_memory_pool())); + RETURN_NOT_OK( + WriteRecordBatch(batch, 0, &dst, &metadata_length, &body_length, options)); *size = dst.GetExtentBytesWritten(); return Status::OK(); } @@ -903,21 +909,14 @@ IpcPayloadWriter::~IpcPayloadWriter() {} Status IpcPayloadWriter::Start() { return Status::OK(); } -} // namespace internal - -namespace { - -/// A RecordBatchWriter implementation that writes to a IpcPayloadWriter. -class RecordBatchPayloadWriter : public RecordBatchWriter { +class ARROW_EXPORT IpcFormatWriter : public RecordBatchWriter { public: - ~RecordBatchPayloadWriter() override = default; - - RecordBatchPayloadWriter(std::unique_ptr payload_writer, - const Schema& schema, const IpcOptions& options, - DictionaryMemo* out_memo = nullptr) + /// A RecordBatchWriter implementation that writes to a IpcPayloadWriter. + IpcFormatWriter(std::unique_ptr payload_writer, + const Schema& schema, const IpcWriteOptions& options, + DictionaryMemo* out_memo = nullptr) : payload_writer_(std::move(payload_writer)), schema_(schema), - pool_(default_memory_pool()), dictionary_memo_(out_memo), options_(options) { if (out_memo == nullptr) { @@ -926,10 +925,10 @@ class RecordBatchPayloadWriter : public RecordBatchWriter { } // A Schema-owning constructor variant - RecordBatchPayloadWriter(std::unique_ptr payload_writer, - const std::shared_ptr& schema, - const IpcOptions& options, DictionaryMemo* out_memo = nullptr) - : RecordBatchPayloadWriter(std::move(payload_writer), *schema, options, out_memo) { + IpcFormatWriter(std::unique_ptr payload_writer, + const std::shared_ptr& schema, const IpcWriteOptions& options, + DictionaryMemo* out_memo = nullptr) + : IpcFormatWriter(std::move(payload_writer), *schema, options, out_memo) { shared_schema_ = schema; } @@ -945,11 +944,11 @@ class RecordBatchPayloadWriter : public RecordBatchWriter { wrote_dictionaries_ = true; } - // TODO(wesm): Check for delta dictionaries. Can we scan for - // deltas while computing the RecordBatch payload to save time? + // TODO: Check for delta dictionaries. Can we scan for deltas while computing + // the RecordBatch payload to save time? - internal::IpcPayload payload; - RETURN_NOT_OK(GetRecordBatchPayload(batch, options_, pool_, &payload)); + IpcPayload payload; + RETURN_NOT_OK(GetRecordBatchPayload(batch, options_, &payload)); return payload_writer_->WritePayload(payload); } @@ -958,8 +957,6 @@ class RecordBatchPayloadWriter : public RecordBatchWriter { return payload_writer_->Close(); } - void set_memory_pool(MemoryPool* pool) override { pool_ = pool; } - Status Start() { started_ = true; RETURN_NOT_OK(payload_writer_->Start()); @@ -985,31 +982,25 @@ class RecordBatchPayloadWriter : public RecordBatchWriter { int64_t dictionary_id = pair.first; const auto& dictionary = pair.second; - RETURN_NOT_OK( - GetDictionaryPayload(dictionary_id, dictionary, options_, pool_, &payload)); + RETURN_NOT_OK(GetDictionaryPayload(dictionary_id, dictionary, options_, &payload)); RETURN_NOT_OK(payload_writer_->WritePayload(payload)); } return Status::OK(); } - protected: - std::unique_ptr payload_writer_; + std::unique_ptr payload_writer_; std::shared_ptr shared_schema_; const Schema& schema_; - MemoryPool* pool_; DictionaryMemo* dictionary_memo_; DictionaryMemo internal_dict_memo_; bool started_ = false; bool wrote_dictionaries_ = false; - IpcOptions options_; + IpcWriteOptions options_; }; -// ---------------------------------------------------------------------- -// Stream and file writer implementation - class StreamBookKeeper { public: - explicit StreamBookKeeper(const IpcOptions& options, io::OutputStream* sink) + explicit StreamBookKeeper(const IpcWriteOptions& options, io::OutputStream* sink) : options_(options), sink_(sink), position_(-1) {} Status UpdatePosition() { return sink_->Tell().Value(&position_); } @@ -1041,28 +1032,28 @@ class StreamBookKeeper { // End of stream marker constexpr int32_t kZeroLength = 0; if (!options_.write_legacy_ipc_format) { - RETURN_NOT_OK(Write(&internal::kIpcContinuationToken, sizeof(int32_t))); + RETURN_NOT_OK(Write(&kIpcContinuationToken, sizeof(int32_t))); } return Write(&kZeroLength, sizeof(int32_t)); } protected: - IpcOptions options_; + IpcWriteOptions options_; io::OutputStream* sink_; int64_t position_; }; /// A IpcPayloadWriter implementation that writes to an IPC stream /// (with an end-of-stream marker) -class PayloadStreamWriter : public internal::IpcPayloadWriter, - protected StreamBookKeeper { +class PayloadStreamWriter : public IpcPayloadWriter, protected StreamBookKeeper { public: - PayloadStreamWriter(const IpcOptions& options, io::OutputStream* sink) + PayloadStreamWriter(io::OutputStream* sink, + const IpcWriteOptions& options = IpcWriteOptions::Defaults()) : StreamBookKeeper(options, sink) {} ~PayloadStreamWriter() override = default; - Status WritePayload(const internal::IpcPayload& payload) override { + Status WritePayload(const IpcPayload& payload) override { #ifndef NDEBUG // Catch bug fixed in ARROW-3236 RETURN_NOT_OK(UpdatePositionCheckAligned()); @@ -1081,7 +1072,7 @@ class PayloadStreamWriter : public internal::IpcPayloadWriter, /// (with a footer as defined in File.fbs) class PayloadFileWriter : public internal::IpcPayloadWriter, protected StreamBookKeeper { public: - PayloadFileWriter(const IpcOptions& options, const std::shared_ptr& schema, + PayloadFileWriter(const IpcWriteOptions& options, const std::shared_ptr& schema, io::OutputStream* sink) : StreamBookKeeper(options, sink), schema_(schema) {} @@ -1154,116 +1145,40 @@ class PayloadFileWriter : public internal::IpcPayloadWriter, protected StreamBoo std::vector record_batches_; }; -} // namespace - -class RecordBatchStreamWriter::RecordBatchStreamWriterImpl - : public RecordBatchPayloadWriter { - public: - RecordBatchStreamWriterImpl(io::OutputStream* sink, - const std::shared_ptr& schema, - const IpcOptions& options) - : RecordBatchPayloadWriter(std::unique_ptr( - new PayloadStreamWriter(options, sink)), - schema, options) {} - - ~RecordBatchStreamWriterImpl() = default; -}; - -class RecordBatchFileWriter::RecordBatchFileWriterImpl : public RecordBatchPayloadWriter { - public: - RecordBatchFileWriterImpl(io::OutputStream* sink, const std::shared_ptr& schema, - const IpcOptions& options) - : RecordBatchPayloadWriter(std::unique_ptr( - new PayloadFileWriter(options, schema, sink)), - schema, options) {} - - ~RecordBatchFileWriterImpl() = default; -}; - -RecordBatchStreamWriter::RecordBatchStreamWriter() {} - -RecordBatchStreamWriter::~RecordBatchStreamWriter() {} - -Status RecordBatchStreamWriter::WriteRecordBatch(const RecordBatch& batch) { - return impl_->WriteRecordBatch(batch); -} - -void RecordBatchStreamWriter::set_memory_pool(MemoryPool* pool) { - impl_->set_memory_pool(pool); -} - -Status RecordBatchStreamWriter::Open(io::OutputStream* sink, - const std::shared_ptr& schema, - std::shared_ptr* out) { - ASSIGN_OR_RAISE(*out, Open(sink, schema)); - return Status::OK(); -} +} // namespace internal -Result> RecordBatchStreamWriter::Open( +Result> NewStreamWriter( io::OutputStream* sink, const std::shared_ptr& schema, - const IpcOptions& options) { - // ctor is private - auto result = std::shared_ptr(new RecordBatchStreamWriter()); - result->impl_.reset(new RecordBatchStreamWriterImpl(sink, schema, options)); - return std::move(result); -} - -Result> RecordBatchStreamWriter::Open( - io::OutputStream* sink, const std::shared_ptr& schema) { - auto options = IpcOptions::Defaults(); - return Open(sink, schema, options); -} - -Status RecordBatchStreamWriter::Close() { return impl_->Close(); } - -RecordBatchFileWriter::RecordBatchFileWriter() {} - -RecordBatchFileWriter::~RecordBatchFileWriter() {} - -Status RecordBatchFileWriter::Open(io::OutputStream* sink, - const std::shared_ptr& schema, - std::shared_ptr* out) { - ASSIGN_OR_RAISE(*out, Open(sink, schema)); - return Status::OK(); + const IpcWriteOptions& options) { + return std::make_shared( + ::arrow::internal::make_unique(sink, options), + schema, options); } -Result> RecordBatchFileWriter::Open( +Result> NewFileWriter( io::OutputStream* sink, const std::shared_ptr& schema, - const IpcOptions& options) { - // ctor is private - auto result = std::shared_ptr(new RecordBatchFileWriter()); - result->file_impl_.reset(new RecordBatchFileWriterImpl(sink, schema, options)); - return std::move(result); + const IpcWriteOptions& options) { + return std::make_shared( + ::arrow::internal::make_unique(options, schema, sink), + schema, options); } -Result> RecordBatchFileWriter::Open( - io::OutputStream* sink, const std::shared_ptr& schema) { - auto options = IpcOptions::Defaults(); - return Open(sink, schema, options); -} - -Status RecordBatchFileWriter::WriteRecordBatch(const RecordBatch& batch) { - return file_impl_->WriteRecordBatch(batch); -} - -Status RecordBatchFileWriter::Close() { return file_impl_->Close(); } - namespace internal { Status OpenRecordBatchWriter(std::unique_ptr sink, const std::shared_ptr& schema, std::unique_ptr* out) { - auto options = IpcOptions::Defaults(); + auto options = IpcWriteOptions::Defaults(); ASSIGN_OR_RAISE(*out, OpenRecordBatchWriter(std::move(sink), schema, options)); return Status::OK(); } Result> OpenRecordBatchWriter( std::unique_ptr sink, const std::shared_ptr& schema, - const IpcOptions& options) { + const IpcWriteOptions& options) { // XXX should we call Start()? - return std::unique_ptr( - new RecordBatchPayloadWriter(std::move(sink), schema, options)); + return ::arrow::internal::make_unique(std::move(sink), + schema, options); } } // namespace internal @@ -1278,53 +1193,131 @@ Result> SerializeRecordBatch(const RecordBatch& batch, ARROW_ASSIGN_OR_RAISE(auto buffer, mm->AllocateBuffer(size)); ARROW_ASSIGN_OR_RAISE(auto writer, Buffer::GetWriter(buffer)); - MemoryPool* pool; + IpcWriteOptions options; // XXX Should we have a helper function for getting a MemoryPool // for any MemoryManager (not only CPU)? if (mm->is_cpu()) { - pool = checked_pointer_cast(mm)->pool(); - } else { - // Allocations will be ephemeral anyway - pool = default_memory_pool(); + options.memory_pool = checked_pointer_cast(mm)->pool(); } - RETURN_NOT_OK(SerializeRecordBatch(batch, pool, writer.get())); + RETURN_NOT_OK(SerializeRecordBatch(batch, options, writer.get())); RETURN_NOT_OK(writer->Close()); return buffer; } -Status SerializeRecordBatch(const RecordBatch& batch, MemoryPool* pool, +Status SerializeRecordBatch(const RecordBatch& batch, const IpcWriteOptions& options, std::shared_ptr* out) { int64_t size = 0; RETURN_NOT_OK(GetRecordBatchSize(batch, &size)); std::shared_ptr buffer; - RETURN_NOT_OK(AllocateBuffer(pool, size, &buffer)); + RETURN_NOT_OK(AllocateBuffer(options.memory_pool, size, &buffer)); io::FixedSizeBufferWriter stream(buffer); - RETURN_NOT_OK(SerializeRecordBatch(batch, pool, &stream)); + RETURN_NOT_OK(SerializeRecordBatch(batch, options, &stream)); *out = buffer; return Status::OK(); } -Status SerializeRecordBatch(const RecordBatch& batch, MemoryPool* pool, +Status SerializeRecordBatch(const RecordBatch& batch, const IpcWriteOptions& options, io::OutputStream* out) { - auto options = IpcOptions::Defaults(); int32_t metadata_length = 0; int64_t body_length = 0; - return WriteRecordBatch(batch, 0, out, &metadata_length, &body_length, options, pool); + return WriteRecordBatch(batch, 0, out, &metadata_length, &body_length, options); } Status SerializeSchema(const Schema& schema, DictionaryMemo* dictionary_memo, MemoryPool* pool, std::shared_ptr* out) { ARROW_ASSIGN_OR_RAISE(auto stream, io::BufferOutputStream::Create(1024, pool)); - auto options = IpcOptions::Defaults(); - auto payload_writer = make_unique(options, stream.get()); - RecordBatchPayloadWriter writer(std::move(payload_writer), schema, options, - dictionary_memo); + auto options = IpcWriteOptions::Defaults(); + internal::IpcFormatWriter writer( + ::arrow::internal::make_unique(stream.get()), schema, + options, dictionary_memo); // Write schema and populate fields (but not dictionaries) in dictionary_memo RETURN_NOT_OK(writer.Start()); return stream->Finish().Value(out); } +// ---------------------------------------------------------------------- +// Deprecated functions + +Status RecordBatchStreamWriter::Open(io::OutputStream* sink, + const std::shared_ptr& schema, + std::shared_ptr* out) { + ASSIGN_OR_RAISE(*out, NewStreamWriter(sink, schema)); + return Status::OK(); +} + +Result> RecordBatchStreamWriter::Open( + io::OutputStream* sink, const std::shared_ptr& schema) { + return NewStreamWriter(sink, schema); +} + +Result> RecordBatchStreamWriter::Open( + io::OutputStream* sink, const std::shared_ptr& schema, + const IpcWriteOptions& options) { + return NewStreamWriter(sink, schema, options); +} + +Status RecordBatchFileWriter::Open(io::OutputStream* sink, + const std::shared_ptr& schema, + std::shared_ptr* out) { + ASSIGN_OR_RAISE(*out, NewFileWriter(sink, schema)); + return Status::OK(); +} + +Result> RecordBatchFileWriter::Open( + io::OutputStream* sink, const std::shared_ptr& schema) { + return NewFileWriter(sink, schema); +} + +Result> RecordBatchFileWriter::Open( + io::OutputStream* sink, const std::shared_ptr& schema, + const IpcWriteOptions& options) { + return NewFileWriter(sink, schema, options); +} + +Status SerializeRecordBatch(const RecordBatch& batch, MemoryPool* pool, + std::shared_ptr* out) { + IpcWriteOptions options; + options.memory_pool = pool; + return SerializeRecordBatch(batch, options, out); +} + +Status SerializeRecordBatch(const RecordBatch& batch, MemoryPool* pool, + io::OutputStream* out) { + IpcWriteOptions options; + options.memory_pool = pool; + return SerializeRecordBatch(batch, options, out); +} + +Status WriteRecordBatch(const RecordBatch& batch, int64_t buffer_start_offset, + io::OutputStream* dst, int32_t* metadata_length, + int64_t* body_length, const IpcWriteOptions& options, + MemoryPool* pool) { + IpcWriteOptions modified_options = options; + modified_options.memory_pool = pool; + return WriteRecordBatch(batch, buffer_start_offset, dst, metadata_length, body_length, + modified_options); +} + +namespace internal { + +Status GetRecordBatchPayload(const RecordBatch& batch, const IpcWriteOptions& options, + MemoryPool* pool, IpcPayload* out) { + IpcWriteOptions modified_options = options; + modified_options.memory_pool = pool; + return GetRecordBatchPayload(batch, modified_options, out); +} + +Status GetDictionaryPayload(int64_t id, const std::shared_ptr& dictionary, + const IpcWriteOptions& options, MemoryPool* pool, + IpcPayload* payload) { + IpcWriteOptions modified_options = options; + modified_options.memory_pool = pool; + return GetDictionaryPayload(id, dictionary, modified_options, payload); +} + +} // namespace internal + } // namespace ipc } // namespace arrow diff --git a/cpp/src/arrow/ipc/writer.h b/cpp/src/arrow/ipc/writer.h index d0b5f4342cd..2d47a3772e3 100644 --- a/cpp/src/arrow/ipc/writer.h +++ b/cpp/src/arrow/ipc/writer.h @@ -26,13 +26,15 @@ #include "arrow/ipc/dictionary.h" // IWYU pragma: export #include "arrow/ipc/message.h" #include "arrow/ipc/options.h" -#include "arrow/type_fwd.h" +#include "arrow/result.h" +#include "arrow/util/macros.h" #include "arrow/util/visibility.h" namespace arrow { class Array; class Buffer; +class MemoryManager; class MemoryPool; class RecordBatch; class Schema; @@ -77,104 +79,32 @@ class ARROW_EXPORT RecordBatchWriter { /// \return Status virtual Status Close() = 0; - /// In some cases, writing may require memory allocation. We use the default - /// memory pool, but provide the option to override - /// - /// \param pool the memory pool to use for required allocations - virtual void set_memory_pool(MemoryPool* pool) = 0; + ARROW_DEPRECATED("No-op. Pass MemoryPool using IpcWriteOptions") + void set_memory_pool(MemoryPool* pool) {} }; -/// \class RecordBatchStreamWriter -/// \brief Synchronous batch stream writer that writes the Arrow streaming -/// format -class ARROW_EXPORT RecordBatchStreamWriter : public RecordBatchWriter { - public: - ~RecordBatchStreamWriter() override; - - /// Create a new writer from stream sink and schema. User is responsible for - /// closing the actual OutputStream. - /// - /// \param[in] sink output stream to write to - /// \param[in] schema the schema of the record batches to be written - /// \param[out] out the created stream writer - /// \return Status - static Status Open(io::OutputStream* sink, const std::shared_ptr& schema, - std::shared_ptr* out); - - /// Create a new writer from stream sink and schema. User is responsible for - /// closing the actual OutputStream. - /// - /// \param[in] sink output stream to write to - /// \param[in] schema the schema of the record batches to be written - /// \return Result> - static Result> Open( - io::OutputStream* sink, const std::shared_ptr& schema); - static Result> Open( - io::OutputStream* sink, const std::shared_ptr& schema, - const IpcOptions& options); - - /// \brief Write a record batch to the stream - /// - /// \param[in] batch the record batch to write - /// \return Status - Status WriteRecordBatch(const RecordBatch& batch) override; - - /// \brief Close the stream by writing a 4-byte int32 0 EOS market - /// \return Status - Status Close() override; - - void set_memory_pool(MemoryPool* pool) override; - - protected: - RecordBatchStreamWriter(); - class ARROW_NO_EXPORT RecordBatchStreamWriterImpl; - std::unique_ptr impl_; -}; - -/// \brief Creates the Arrow record batch file format +/// Create a new IPC stream writer from stream sink and schema. User is +/// responsible for closing the actual OutputStream. /// -/// Implements the random access file format, which structurally is a record -/// batch stream followed by a metadata footer at the end of the file. Magic -/// numbers are written at the start and end of the file -class ARROW_EXPORT RecordBatchFileWriter : public RecordBatchStreamWriter { - public: - ~RecordBatchFileWriter() override; - - /// Create a new writer from stream sink and schema - /// - /// \param[in] sink output stream to write to - /// \param[in] schema the schema of the record batches to be written - /// \param[out] out the created stream writer - /// \return Status - static Status Open(io::OutputStream* sink, const std::shared_ptr& schema, - std::shared_ptr* out); - - /// Create a new writer from stream sink and schema - /// - /// \param[in] sink output stream to write to - /// \param[in] schema the schema of the record batches to be written - /// \return Status - static Result> Open( - io::OutputStream* sink, const std::shared_ptr& schema); - static Result> Open( - io::OutputStream* sink, const std::shared_ptr& schema, - const IpcOptions& options); - - /// \brief Write a record batch to the file - /// - /// \param[in] batch the record batch to write - /// \return Status - Status WriteRecordBatch(const RecordBatch& batch) override; - - /// \brief Close the file stream by writing the file footer and magic number - /// \return Status - Status Close() override; +/// \param[in] sink output stream to write to +/// \param[in] schema the schema of the record batches to be written +/// \param[in] options options for serialization +/// \return Result> +ARROW_EXPORT +Result> NewStreamWriter( + io::OutputStream* sink, const std::shared_ptr& schema, + const IpcWriteOptions& options = IpcWriteOptions::Defaults()); - private: - RecordBatchFileWriter(); - class ARROW_NO_EXPORT RecordBatchFileWriterImpl; - std::unique_ptr file_impl_; -}; +/// Create a new IPC file writer from stream sink and schema +/// +/// \param[in] sink output stream to write to +/// \param[in] schema the schema of the record batches to be written +/// \param[in] options options for serialization +/// \return Status +ARROW_EXPORT +Result> NewFileWriter( + io::OutputStream* sink, const std::shared_ptr& schema, + const IpcWriteOptions& options = IpcWriteOptions::Defaults()); /// \brief Low-level API for writing a record batch (without schema) /// to an OutputStream as encapsulated IPC message. See Arrow format @@ -188,22 +118,20 @@ class ARROW_EXPORT RecordBatchFileWriter : public RecordBatchStreamWriter { /// including padding to a 64-byte boundary /// \param[out] body_length the size of the contiguous buffer block plus /// \param[in] options options for serialization -/// \param[in] pool the memory pool to allocate memory from /// \return Status ARROW_EXPORT Status WriteRecordBatch(const RecordBatch& batch, int64_t buffer_start_offset, io::OutputStream* dst, int32_t* metadata_length, - int64_t* body_length, const IpcOptions& options, - MemoryPool* pool); + int64_t* body_length, const IpcWriteOptions& options); /// \brief Serialize record batch as encapsulated IPC message in a new buffer /// /// \param[in] batch the record batch -/// \param[in] pool a MemoryPool to allocate memory from +/// \param[in] options the IpcWriteOptions to use for serialization /// \param[out] out the serialized message /// \return Status ARROW_EXPORT -Status SerializeRecordBatch(const RecordBatch& batch, MemoryPool* pool, +Status SerializeRecordBatch(const RecordBatch& batch, const IpcWriteOptions& options, std::shared_ptr* out); /// \brief Serialize record batch as encapsulated IPC message in a new buffer @@ -218,14 +146,14 @@ Result> SerializeRecordBatch(const RecordBatch& batch, /// \brief Write record batch to OutputStream /// /// \param[in] batch the record batch to write -/// \param[in] pool a MemoryPool to use for temporary allocations, if needed +/// \param[in] options the IpcWriteOptions to use for serialization /// \param[in] out the OutputStream to write the output to /// \return Status /// /// If writing to pre-allocated memory, you can use /// arrow::ipc::GetRecordBatchSize to compute how much space is required ARROW_EXPORT -Status SerializeRecordBatch(const RecordBatch& batch, MemoryPool* pool, +Status SerializeRecordBatch(const RecordBatch& batch, const IpcWriteOptions& options, io::OutputStream* out); /// \brief Serialize schema as encapsulated IPC message @@ -246,7 +174,7 @@ Status SerializeSchema(const Schema& schema, DictionaryMemo* dictionary_memo, /// \return Status ARROW_EXPORT Status WriteRecordBatchStream(const std::vector>& batches, - const IpcOptions& options, io::OutputStream* dst); + const IpcWriteOptions& options, io::OutputStream* dst); /// \brief Compute the number of bytes needed to write a record batch including metadata /// @@ -370,7 +298,7 @@ Status OpenRecordBatchWriter(std::unique_ptr sink, ARROW_EXPORT Result> OpenRecordBatchWriter( std::unique_ptr sink, const std::shared_ptr& schema, - const IpcOptions& options); + const IpcWriteOptions& options); /// \brief Compute IpcPayload for the given schema /// \param[in] schema the Schema that is being serialized @@ -379,7 +307,7 @@ Result> OpenRecordBatchWriter( /// \param[out] out the returned vector of IpcPayloads /// \return Status ARROW_EXPORT -Status GetSchemaPayload(const Schema& schema, const IpcOptions& options, +Status GetSchemaPayload(const Schema& schema, const IpcWriteOptions& options, DictionaryMemo* dictionary_memo, IpcPayload* out); /// \brief Compute IpcPayload for a dictionary @@ -390,21 +318,19 @@ Status GetSchemaPayload(const Schema& schema, const IpcOptions& options, /// \return Status ARROW_EXPORT Status GetDictionaryPayload(int64_t id, const std::shared_ptr& dictionary, - const IpcOptions& options, MemoryPool* pool, - IpcPayload* payload); + const IpcWriteOptions& options, IpcPayload* payload); /// \brief Compute IpcPayload for the given record batch /// \param[in] batch the RecordBatch that is being serialized /// \param[in] options options for serialization -/// \param[in,out] pool for any required temporary memory allocations /// \param[out] out the returned IpcPayload /// \return Status ARROW_EXPORT -Status GetRecordBatchPayload(const RecordBatch& batch, const IpcOptions& options, - MemoryPool* pool, IpcPayload* out); +Status GetRecordBatchPayload(const RecordBatch& batch, const IpcWriteOptions& options, + IpcPayload* out); ARROW_EXPORT -Status WriteIpcPayload(const IpcPayload& payload, const IpcOptions& options, +Status WriteIpcPayload(const IpcPayload& payload, const IpcWriteOptions& options, io::OutputStream* dst, int32_t* metadata_length); /// \brief Compute IpcPayload for the given sparse tensor @@ -418,5 +344,105 @@ Status GetSparseTensorPayload(const SparseTensor& sparse_tensor, MemoryPool* poo } // namespace internal +// Deprecated functions + +/// \class RecordBatchStreamWriter +/// \brief Synchronous batch stream writer that writes the Arrow streaming +/// format +class ARROW_EXPORT RecordBatchStreamWriter : public RecordBatchWriter { + public: + /// Create a new writer from stream sink and schema. User is responsible for + /// closing the actual OutputStream. + /// + /// \param[in] sink output stream to write to + /// \param[in] schema the schema of the record batches to be written + /// \param[out] out the created stream writer + /// \return Status + ARROW_DEPRECATED("Use arrow::ipc::NewStreamWriter()") + static Status Open(io::OutputStream* sink, const std::shared_ptr& schema, + std::shared_ptr* out); + + /// Create a new writer from stream sink and schema. User is responsible for + /// closing the actual OutputStream. + /// + /// \param[in] sink output stream to write to + /// \param[in] schema the schema of the record batches to be written + /// \return Result> + ARROW_DEPRECATED("Use arrow::ipc::NewStreamWriter()") + static Result> Open( + io::OutputStream* sink, const std::shared_ptr& schema); + + ARROW_DEPRECATED("Use arrow::ipc::NewStreamWriter()") + static Result> Open( + io::OutputStream* sink, const std::shared_ptr& schema, + const IpcWriteOptions& options); +}; + +/// \brief Creates the Arrow record batch file format +/// +/// Implements the random access file format, which structurally is a record +/// batch stream followed by a metadata footer at the end of the file. Magic +/// numbers are written at the start and end of the file +class ARROW_EXPORT RecordBatchFileWriter : public RecordBatchStreamWriter { + public: + /// Create a new writer from stream sink and schema + /// + /// \param[in] sink output stream to write to + /// \param[in] schema the schema of the record batches to be written + /// \param[out] out the created stream writer + /// \return Status + ARROW_DEPRECATED("Use arrow::ipc::NewFileWriter") + static Status Open(io::OutputStream* sink, const std::shared_ptr& schema, + std::shared_ptr* out); + + /// Create a new writer from stream sink and schema + /// + /// \param[in] sink output stream to write to + /// \param[in] schema the schema of the record batches to be written + /// \return Result> + ARROW_DEPRECATED("Use arrow::ipc::NewFileWriter") + static Result> Open( + io::OutputStream* sink, const std::shared_ptr& schema); + + ARROW_DEPRECATED("Use arrow::ipc::NewFileWriter") + static Result> Open( + io::OutputStream* sink, const std::shared_ptr& schema, + const IpcWriteOptions& options); +}; + +ARROW_DEPRECATED( + "Use version without MemoryPool argument " + "(use IpcWriteOptions to pass MemoryPool") +ARROW_EXPORT +Status WriteRecordBatch(const RecordBatch& batch, int64_t buffer_start_offset, + io::OutputStream* dst, int32_t* metadata_length, + int64_t* body_length, const IpcWriteOptions& options, + MemoryPool* pool); + +ARROW_DEPRECATED("Use version with IpcWriteOptions") +ARROW_EXPORT +Status SerializeRecordBatch(const RecordBatch& batch, MemoryPool* pool, + std::shared_ptr* out); + +ARROW_DEPRECATED("Use version with IpcWriteOptions") +ARROW_EXPORT +Status SerializeRecordBatch(const RecordBatch& batch, MemoryPool* pool, + io::OutputStream* out); + +namespace internal { + +ARROW_DEPRECATED("Pass MemoryPool with IpcWriteOptions") +ARROW_EXPORT +Status GetDictionaryPayload(int64_t id, const std::shared_ptr& dictionary, + const IpcWriteOptions& options, MemoryPool* pool, + IpcPayload* payload); + +ARROW_DEPRECATED("Pass MemoryPool with IpcWriteOptions") +ARROW_EXPORT +Status GetRecordBatchPayload(const RecordBatch& batch, const IpcWriteOptions& options, + MemoryPool* pool, IpcPayload* out); + +} // namespace internal + } // namespace ipc } // namespace arrow diff --git a/cpp/src/arrow/python/deserialize.cc b/cpp/src/arrow/python/deserialize.cc index 5218ffcd3be..28a739ec097 100644 --- a/cpp/src/arrow/python/deserialize.cc +++ b/cpp/src/arrow/python/deserialize.cc @@ -31,8 +31,10 @@ #include "arrow/array.h" #include "arrow/io/interfaces.h" #include "arrow/io/memory.h" +#include "arrow/ipc/options.h" #include "arrow/ipc/reader.h" #include "arrow/ipc/util.h" +#include "arrow/ipc/writer.h" #include "arrow/table.h" #include "arrow/util/checked_cast.h" #include "arrow/util/logging.h" @@ -323,7 +325,7 @@ Status ReadSerializedObject(io::RandomAccessFile* src, SerializedPyObject* out) // Align stream to 8-byte offset RETURN_NOT_OK(ipc::AlignStream(src, ipc::kArrowIpcAlignment)); std::shared_ptr reader; - RETURN_NOT_OK(ipc::RecordBatchStreamReader::Open(src, &reader)); + ARROW_ASSIGN_OR_RAISE(reader, ipc::RecordBatchStreamReader::Open(src)); RETURN_NOT_OK(reader->ReadNext(&out->batch)); /// Skip EOS marker @@ -403,7 +405,7 @@ Status GetSerializedFromComponents(int num_tensors, gil.release(); io::BufferReader buf_reader(data_buffer); std::shared_ptr reader; - RETURN_NOT_OK(ipc::RecordBatchStreamReader::Open(&buf_reader, &reader)); + ARROW_ASSIGN_OR_RAISE(reader, ipc::RecordBatchStreamReader::Open(&buf_reader)); RETURN_NOT_OK(reader->ReadNext(&out->batch)); gil.acquire(); } diff --git a/cpp/src/arrow/python/flight.cc b/cpp/src/arrow/python/flight.cc index 2405ad934d2..3b823082859 100644 --- a/cpp/src/arrow/python/flight.cc +++ b/cpp/src/arrow/python/flight.cc @@ -220,7 +220,7 @@ Status PyFlightDataStream::Next(FlightPayload* payload) { return stream_->Next(p PyGeneratorFlightDataStream::PyGeneratorFlightDataStream( PyObject* generator, std::shared_ptr schema, PyGeneratorFlightDataStreamCallback callback) - : schema_(schema), options_(ipc::IpcOptions::Defaults()), callback_(callback) { + : schema_(schema), options_(ipc::IpcWriteOptions::Defaults()), callback_(callback) { Py_INCREF(generator); generator_.reset(generator); } diff --git a/cpp/src/arrow/python/flight.h b/cpp/src/arrow/python/flight.h index 3aabadc78bb..2ecc7fadc71 100644 --- a/cpp/src/arrow/python/flight.h +++ b/cpp/src/arrow/python/flight.h @@ -322,7 +322,7 @@ class ARROW_PYFLIGHT_EXPORT PyGeneratorFlightDataStream OwnedRefNoGIL generator_; std::shared_ptr schema_; ipc::DictionaryMemo dictionary_memo_; - ipc::IpcOptions options_; + ipc::IpcWriteOptions options_; PyGeneratorFlightDataStreamCallback callback_; }; diff --git a/cpp/src/arrow/python/serialize.cc b/cpp/src/arrow/python/serialize.cc index 22a0449cbc9..ec38627c075 100644 --- a/cpp/src/arrow/python/serialize.cc +++ b/cpp/src/arrow/python/serialize.cc @@ -604,7 +604,8 @@ Status WriteNdarrayHeader(std::shared_ptr dtype, return serialized_tensor.WriteTo(dst); } -SerializedPyObject::SerializedPyObject() : ipc_options(ipc::IpcOptions::Defaults()) {} +SerializedPyObject::SerializedPyObject() + : ipc_options(ipc::IpcWriteOptions::Defaults()) {} Status SerializedPyObject::WriteTo(io::OutputStream* dst) { int32_t num_tensors = static_cast(this->tensors.size()); diff --git a/cpp/src/arrow/python/serialize.h b/cpp/src/arrow/python/serialize.h index 9fdb7d93ce1..8ef5516e059 100644 --- a/cpp/src/arrow/python/serialize.h +++ b/cpp/src/arrow/python/serialize.h @@ -54,7 +54,7 @@ struct ARROW_PYTHON_EXPORT SerializedPyObject { std::vector> sparse_tensors; std::vector> ndarrays; std::vector> buffers; - ipc::IpcOptions ipc_options; + ipc::IpcWriteOptions ipc_options; SerializedPyObject(); diff --git a/cpp/src/arrow/record_batch.cc b/cpp/src/arrow/record_batch.cc index b1a1708ba6c..6c1f52c526d 100644 --- a/cpp/src/arrow/record_batch.cc +++ b/cpp/src/arrow/record_batch.cc @@ -199,11 +199,17 @@ const std::string& RecordBatch::column_name(int i) const { return schema_->field(i)->name(); } -bool RecordBatch::Equals(const RecordBatch& other) const { +bool RecordBatch::Equals(const RecordBatch& other, bool check_metadata) const { if (num_columns() != other.num_columns() || num_rows_ != other.num_rows()) { return false; } + if (check_metadata) { + if (!schema_->Equals(*other.schema(), /*check_metadata=*/true)) { + return false; + } + } + for (int i = 0; i < num_columns(); ++i) { if (!column(i)->Equals(other.column(i))) { return false; diff --git a/cpp/src/arrow/record_batch.h b/cpp/src/arrow/record_batch.h index ada1ad7eef4..1d5adb8495a 100644 --- a/cpp/src/arrow/record_batch.h +++ b/cpp/src/arrow/record_batch.h @@ -76,8 +76,11 @@ class ARROW_EXPORT RecordBatch { std::shared_ptr* out); /// \brief Determine if two record batches are exactly equal + /// + /// \param[in] other the RecordBatch to compare with + /// \param[in] check_metadata if true, check that Schema metadata is the same /// \return true if batches are equal - bool Equals(const RecordBatch& other) const; + bool Equals(const RecordBatch& other, bool check_metadata = false) const; /// \brief Determine if two record batches are approximately equal bool ApproxEquals(const RecordBatch& other) const; diff --git a/cpp/src/arrow/table_test.cc b/cpp/src/arrow/table_test.cc index 4e942eaab23..fb5ffa2b2b6 100644 --- a/cpp/src/arrow/table_test.cc +++ b/cpp/src/arrow/table_test.cc @@ -794,21 +794,29 @@ TEST_F(TestRecordBatch, Equals) { auto f1 = field("f1", uint8()); auto f2 = field("f2", int16()); + auto metadata = key_value_metadata({"foo"}, {"bar"}); + std::vector> fields = {f0, f1, f2}; auto schema = ::arrow::schema({f0, f1, f2}); auto schema2 = ::arrow::schema({f0, f1}); + auto schema3 = ::arrow::schema({f0, f1, f2}, metadata); auto a0 = MakeRandomArray(length); auto a1 = MakeRandomArray(length); auto a2 = MakeRandomArray(length); auto b1 = RecordBatch::Make(schema, length, {a0, a1, a2}); + auto b2 = RecordBatch::Make(schema3, length, {a0, a1, a2}); auto b3 = RecordBatch::Make(schema2, length, {a0, a1}); auto b4 = RecordBatch::Make(schema, length, {a0, a1, a1}); ASSERT_TRUE(b1->Equals(*b1)); ASSERT_FALSE(b1->Equals(*b3)); ASSERT_FALSE(b1->Equals(*b4)); + + // Different metadata + ASSERT_TRUE(b1->Equals(*b2)); + ASSERT_FALSE(b1->Equals(*b2, /*check_metadata=*/true)); } TEST_F(TestRecordBatch, Validate) { diff --git a/cpp/src/arrow/testing/gtest_util.cc b/cpp/src/arrow/testing/gtest_util.cc index d0d864b3e74..a256f04d0d6 100644 --- a/cpp/src/arrow/testing/gtest_util.cc +++ b/cpp/src/arrow/testing/gtest_util.cc @@ -50,9 +50,9 @@ namespace arrow { -template -void AssertTsEqual(const T& expected, const T& actual) { - if (!expected.Equals(actual)) { +template +void AssertTsEqual(const T& expected, const T& actual, ExtraArgs... args) { + if (!expected.Equals(actual, args...)) { std::stringstream pp_expected; std::stringstream pp_actual; ::arrow::PrettyPrintOptions options(/*indent=*/2); @@ -78,7 +78,8 @@ void AssertArraysEqual(const Array& expected, const Array& actual, bool verbose) } } -void AssertBatchesEqual(const RecordBatch& expected, const RecordBatch& actual) { +void AssertBatchesEqual(const RecordBatch& expected, const RecordBatch& actual, + bool check_metadata) { AssertTsEqual(expected, actual); } diff --git a/cpp/src/arrow/testing/gtest_util.h b/cpp/src/arrow/testing/gtest_util.h index 850a52ef814..30310994ab3 100644 --- a/cpp/src/arrow/testing/gtest_util.h +++ b/cpp/src/arrow/testing/gtest_util.h @@ -172,7 +172,8 @@ using Datum = compute::Datum; ARROW_EXPORT void AssertArraysEqual(const Array& expected, const Array& actual, bool verbose = false); ARROW_EXPORT void AssertBatchesEqual(const RecordBatch& expected, - const RecordBatch& actual); + const RecordBatch& actual, + bool check_metadata = false); ARROW_EXPORT void AssertChunkedEqual(const ChunkedArray& expected, const ChunkedArray& actual); ARROW_EXPORT void AssertChunkedEqual(const ChunkedArray& actual, diff --git a/cpp/src/arrow/util/compression.cc b/cpp/src/arrow/util/compression.cc index aa73bb7e8e1..8f0ac3fedb9 100644 --- a/cpp/src/arrow/util/compression.cc +++ b/cpp/src/arrow/util/compression.cc @@ -85,6 +85,28 @@ std::string Codec::GetCodecAsString(Compression::type t) { } } +Result Codec::GetCompressionType(const std::string& name) { + if (name == "UNCOMPRESSED") { + return Compression::UNCOMPRESSED; + } else if (name == "GZIP") { + return Compression::GZIP; + } else if (name == "SNAPPY") { + return Compression::SNAPPY; + } else if (name == "LZO") { + return Compression::LZO; + } else if (name == "BROTLI") { + return Compression::BROTLI; + } else if (name == "LZ4") { + return Compression::LZ4; + } else if (name == "ZSTD") { + return Compression::ZSTD; + } else if (name == "BZ2") { + return Compression::BZ2; + } else { + return Status::Invalid("Unrecognized compression type: ", name); + } +} + Result> Codec::Create(Compression::type codec_type, int compression_level) { std::unique_ptr codec; diff --git a/cpp/src/arrow/util/compression.h b/cpp/src/arrow/util/compression.h index ec8ce8ec3f5..b6ffc7762d6 100644 --- a/cpp/src/arrow/util/compression.h +++ b/cpp/src/arrow/util/compression.h @@ -31,11 +31,13 @@ namespace arrow { struct Compression { /// \brief Compression algorithm enum type { UNCOMPRESSED, SNAPPY, GZIP, BROTLI, ZSTD, LZ4, LZO, BZ2 }; + + static constexpr int kUseDefaultCompressionLevel = std::numeric_limits::min(); }; namespace util { -constexpr int kUseDefaultCompressionLevel = std::numeric_limits::min(); +constexpr int kUseDefaultCompressionLevel = Compression::kUseDefaultCompressionLevel; /// \brief Streaming compressor interface /// @@ -124,6 +126,9 @@ class ARROW_EXPORT Codec { /// \brief Return a string name for compression type static std::string GetCodecAsString(Compression::type t); + /// \brief Return compression type for name (all upper case) + static Result GetCompressionType(const std::string& name); + /// \brief Create a codec for the given compression algorithm static Result> Create( Compression::type codec, int compression_level = kUseDefaultCompressionLevel); diff --git a/cpp/src/arrow/util/compression_test.cc b/cpp/src/arrow/util/compression_test.cc index 0575fc26ee0..761d3c93e59 100644 --- a/cpp/src/arrow/util/compression_test.cc +++ b/cpp/src/arrow/util/compression_test.cc @@ -324,6 +324,21 @@ TEST(TestCodecMisc, GetCodecAsString) { ASSERT_EQ("BROTLI", Codec::GetCodecAsString(Compression::BROTLI)); ASSERT_EQ("LZ4", Codec::GetCodecAsString(Compression::LZ4)); ASSERT_EQ("ZSTD", Codec::GetCodecAsString(Compression::ZSTD)); + ASSERT_EQ("BZ2", Codec::GetCodecAsString(Compression::BZ2)); +} + +TEST(TestCodecMisc, GetCompressionType) { + ASSERT_OK_AND_EQ(Compression::UNCOMPRESSED, Codec::GetCompressionType("UNCOMPRESSED")); + ASSERT_OK_AND_EQ(Compression::SNAPPY, Codec::GetCompressionType("SNAPPY")); + ASSERT_OK_AND_EQ(Compression::GZIP, Codec::GetCompressionType("GZIP")); + ASSERT_OK_AND_EQ(Compression::LZO, Codec::GetCompressionType("LZO")); + ASSERT_OK_AND_EQ(Compression::BROTLI, Codec::GetCompressionType("BROTLI")); + ASSERT_OK_AND_EQ(Compression::LZ4, Codec::GetCompressionType("LZ4")); + ASSERT_OK_AND_EQ(Compression::ZSTD, Codec::GetCompressionType("ZSTD")); + ASSERT_OK_AND_EQ(Compression::BZ2, Codec::GetCompressionType("BZ2")); + + ASSERT_RAISES(Invalid, Codec::GetCompressionType("unk")); + ASSERT_RAISES(Invalid, Codec::GetCompressionType("snappy")); } TEST_P(CodecTest, CodecRoundtrip) { diff --git a/cpp/src/parquet/arrow/reader_internal.cc b/cpp/src/parquet/arrow/reader_internal.cc index fd1ec1fec57..dbbe0b066d7 100644 --- a/cpp/src/parquet/arrow/reader_internal.cc +++ b/cpp/src/parquet/arrow/reader_internal.cc @@ -35,6 +35,7 @@ #include "arrow/extension_type.h" #include "arrow/io/memory.h" #include "arrow/ipc/reader.h" +#include "arrow/ipc/writer.h" #include "arrow/status.h" #include "arrow/table.h" #include "arrow/type.h" @@ -592,7 +593,7 @@ Status GetOriginSchema(const std::shared_ptr& metadata, ::arrow::ipc::DictionaryMemo dict_memo; ::arrow::io::BufferReader input(schema_buf); - RETURN_NOT_OK(::arrow::ipc::ReadSchema(&input, &dict_memo, out)); + ARROW_ASSIGN_OR_RAISE(*out, ::arrow::ipc::ReadSchema(&input, &dict_memo)); if (metadata->size() > 1) { // Copy the metadata without the schema key diff --git a/dev/archery/archery/integration/util.py b/dev/archery/archery/integration/util.py index 75f6e9ac4f0..e3f2542b1e8 100644 --- a/dev/archery/archery/integration/util.py +++ b/dev/archery/archery/integration/util.py @@ -131,7 +131,7 @@ def run_cmd(cmd): except subprocess.CalledProcessError as e: # this avoids hiding the stdout / stderr of failed processes sio = io.StringIO() - print('Command failed:', cmd, file=sio) + print('Command failed:', " ".join(cmd), file=sio) print('With output:', file=sio) print('--------------', file=sio) print(frombytes(e.output), file=sio) diff --git a/docker-compose.yml b/docker-compose.yml index 6f87d2b5bb5..229812628d1 100644 --- a/docker-compose.yml +++ b/docker-compose.yml @@ -1028,13 +1028,14 @@ services: <<: *ccache # tell archery where the arrow binaries are located ARROW_CPP_EXE_PATH: /build/cpp/debug + # Running integration tests serially until ARROW-8176 resolved command: ["/arrow/ci/scripts/cpp_build.sh /arrow /build && /arrow/ci/scripts/go_build.sh /arrow && /arrow/ci/scripts/java_build.sh /arrow /build && /arrow/ci/scripts/js_build.sh /arrow /build && pip install -e /arrow/dev/archery && - archery integration --with-all --run-flight"] + archery integration --with-all --serial --run-flight"] ################################ Docs ####################################### diff --git a/python/pyarrow/_flight.pyx b/python/pyarrow/_flight.pyx index 12a5436d85d..b29e539b7bb 100644 --- a/python/pyarrow/_flight.pyx +++ b/python/pyarrow/_flight.pyx @@ -1418,7 +1418,7 @@ cdef CStatus _data_stream_next(void* self, CFlightPayload* payload) except *: cdef: unique_ptr[CFlightDataStream] data_stream # TODO make it possible to pass IPC options around? - cdef CIpcOptions c_ipc_options = CIpcOptions.Defaults() + cdef CIpcWriteOptions c_ipc_options = CIpcWriteOptions.Defaults() py_stream = self if not isinstance(py_stream, GeneratorStream): @@ -1480,7 +1480,6 @@ cdef CStatus _data_stream_next(void* self, CFlightPayload* payload) except *: check_flight_status(GetRecordBatchPayload( deref(batch.batch), c_ipc_options, - c_default_memory_pool(), &payload.ipc_message)) if metadata: payload.app_metadata = pyarrow_unwrap_buffer(as_buffer(metadata)) diff --git a/python/pyarrow/includes/libarrow.pxd b/python/pyarrow/includes/libarrow.pxd index b3c1b17d89e..4d611f6c48e 100644 --- a/python/pyarrow/includes/libarrow.pxd +++ b/python/pyarrow/includes/libarrow.pxd @@ -1176,14 +1176,23 @@ cdef extern from "arrow/ipc/api.h" namespace "arrow::ipc" nogil: MessageType_V3" arrow::ipc::MetadataVersion::V3" MessageType_V4" arrow::ipc::MetadataVersion::V4" - cdef cppclass CIpcOptions" arrow::ipc::IpcOptions": + cdef cppclass CIpcWriteOptions" arrow::ipc::IpcWriteOptions": c_bool allow_64bit int max_recursion_depth int32_t alignment c_bool write_legacy_ipc_format + CMemoryPool* memory_pool @staticmethod - CIpcOptions Defaults() + CIpcWriteOptions Defaults() + + cdef cppclass CIpcReadOptions" arrow::ipc::IpcReadOptions": + int max_recursion_depth + CMemoryPool* memory_pool + shared_ptr[unordered_set[int]] included_fields + + @staticmethod + CIpcReadOptions Defaults() cdef cppclass CDictionaryMemo" arrow::ipc::DictionaryMemo": pass @@ -1207,7 +1216,8 @@ cdef extern from "arrow/ipc/api.h" namespace "arrow::ipc" nogil: MetadataVersion metadata_version() MessageType type() - CStatus SerializeTo(COutputStream* stream, const CIpcOptions& options, + CStatus SerializeTo(COutputStream* stream, + const CIpcWriteOptions& options, int64_t* output_length) c_string FormatMessageType(MessageType type) @@ -1226,37 +1236,33 @@ cdef extern from "arrow/ipc/api.h" namespace "arrow::ipc" nogil: cdef cppclass CRecordBatchStreamReader \ " arrow::ipc::RecordBatchStreamReader"(CRecordBatchReader): @staticmethod - CStatus Open(const CInputStream* stream, - shared_ptr[CRecordBatchReader]* out) + CResult[shared_ptr[CRecordBatchReader]] Open( + const CInputStream* stream, const CIpcReadOptions& options) @staticmethod - CStatus Open2" Open"(unique_ptr[CMessageReader] message_reader, - shared_ptr[CRecordBatchReader]* out) + CResult[shared_ptr[CRecordBatchReader]] Open2" Open"( + unique_ptr[CMessageReader] message_reader, + const CIpcReadOptions& options) - cdef cppclass CRecordBatchStreamWriter \ - " arrow::ipc::RecordBatchStreamWriter"(CRecordBatchWriter): - @staticmethod - CResult[shared_ptr[CRecordBatchWriter]] Open( - COutputStream* sink, const shared_ptr[CSchema]& schema, - CIpcOptions& options) + CResult[shared_ptr[CRecordBatchWriter]] NewStreamWriter( + COutputStream* sink, const shared_ptr[CSchema]& schema, + CIpcWriteOptions& options) - cdef cppclass CRecordBatchFileWriter \ - " arrow::ipc::RecordBatchFileWriter"(CRecordBatchWriter): - @staticmethod - CResult[shared_ptr[CRecordBatchWriter]] Open( - COutputStream* sink, const shared_ptr[CSchema]& schema, - CIpcOptions& options) + CResult[shared_ptr[CRecordBatchWriter]] NewFileWriter( + COutputStream* sink, const shared_ptr[CSchema]& schema, + CIpcWriteOptions& options) cdef cppclass CRecordBatchFileReader \ " arrow::ipc::RecordBatchFileReader": @staticmethod - CStatus Open(CRandomAccessFile* file, - shared_ptr[CRecordBatchFileReader]* out) + CResult[shared_ptr[CRecordBatchFileReader]] Open( + CRandomAccessFile* file, + const CIpcReadOptions& options) @staticmethod - CStatus Open2" Open"(CRandomAccessFile* file, - int64_t footer_offset, - shared_ptr[CRecordBatchFileReader]* out) + CResult[shared_ptr[CRecordBatchFileReader]] Open2" Open"( + CRandomAccessFile* file, int64_t footer_offset, + const CIpcReadOptions& options) shared_ptr[CSchema] schema() @@ -1276,26 +1282,27 @@ cdef extern from "arrow/ipc/api.h" namespace "arrow::ipc" nogil: CResult[shared_ptr[CTensor]] ReadTensor(CInputStream* stream) - CStatus ReadRecordBatch(const CMessage& message, - const shared_ptr[CSchema]& schema, - CDictionaryMemo* dictionary_memo, - shared_ptr[CRecordBatch]* out) + CResult[shared_ptr[CRecordBatch]] ReadRecordBatch( + const CMessage& message, const shared_ptr[CSchema]& schema, + CDictionaryMemo* dictionary_memo, + const CIpcReadOptions& options) CStatus SerializeSchema(const CSchema& schema, CDictionaryMemo* dictionary_memo, CMemoryPool* pool, shared_ptr[CBuffer]* out) CStatus SerializeRecordBatch(const CRecordBatch& schema, - CMemoryPool* pool, + const CIpcWriteOptions& options, shared_ptr[CBuffer]* out) - CStatus ReadSchema(CInputStream* stream, CDictionaryMemo* dictionary_memo, - shared_ptr[CSchema]* out) + CResult[shared_ptr[CSchema]] ReadSchema(CInputStream* stream, + CDictionaryMemo* dictionary_memo) - CStatus ReadRecordBatch(const shared_ptr[CSchema]& schema, - CDictionaryMemo* dictionary_memo, - CInputStream* stream, - shared_ptr[CRecordBatch]* out) + CResult[shared_ptr[CRecordBatch]] ReadRecordBatch( + const shared_ptr[CSchema]& schema, + CDictionaryMemo* dictionary_memo, + const CIpcReadOptions& options, + CInputStream* stream) CStatus AlignStream(CInputStream* stream, int64_t alignment) CStatus AlignStream(COutputStream* stream, int64_t alignment) @@ -1303,8 +1310,7 @@ cdef extern from "arrow/ipc/api.h" namespace "arrow::ipc" nogil: cdef CStatus GetRecordBatchPayload\ " arrow::ipc::internal::GetRecordBatchPayload"( const CRecordBatch& batch, - const CIpcOptions& options, - CMemoryPool* pool, + const CIpcWriteOptions& options, CIpcPayload* out) cdef cppclass CFeatherWriter" arrow::ipc::feather::TableWriter": diff --git a/python/pyarrow/ipc.pxi b/python/pyarrow/ipc.pxi index bfe3e17e354..585bcbfa924 100644 --- a/python/pyarrow/ipc.pxi +++ b/python/pyarrow/ipc.pxi @@ -78,7 +78,7 @@ cdef class Message: cdef: int64_t output_length = 0 COutputStream* out - CIpcOptions options + CIpcWriteOptions options options.alignment = alignment out = sink.get_output_stream().get() @@ -254,7 +254,7 @@ cdef class _CRecordBatchWriter: cdef class _RecordBatchStreamWriter(_CRecordBatchWriter): cdef: shared_ptr[COutputStream] sink - CIpcOptions options + CIpcWriteOptions options bint closed def __cinit__(self): @@ -272,8 +272,8 @@ cdef class _RecordBatchStreamWriter(_CRecordBatchWriter): get_writer(sink, &self.sink) with nogil: self.writer = GetResultValue( - CRecordBatchStreamWriter.Open( - self.sink.get(), schema.sp_schema, self.options)) + NewStreamWriter(self.sink.get(), schema.sp_schema, + self.options)) cdef _get_input_stream(object source, shared_ptr[CInputStream]* out): @@ -339,6 +339,7 @@ cdef class _CRecordBatchReader: cdef class _RecordBatchStreamReader(_CRecordBatchReader): cdef: shared_ptr[CInputStream] in_stream + CIpcReadOptions options cdef readonly: Schema schema @@ -349,8 +350,8 @@ cdef class _RecordBatchStreamReader(_CRecordBatchReader): def _open(self, source): _get_input_stream(source, &self.in_stream) with nogil: - check_status(CRecordBatchStreamReader.Open( - self.in_stream.get(), &self.reader)) + self.reader = GetResultValue(CRecordBatchStreamReader.Open( + self.in_stream.get(), self.options)) self.schema = pyarrow_wrap_schema(self.reader.get().schema()) @@ -362,14 +363,14 @@ cdef class _RecordBatchFileWriter(_RecordBatchStreamWriter): get_writer(sink, &self.sink) with nogil: self.writer = GetResultValue( - CRecordBatchFileWriter.Open( - self.sink.get(), schema.sp_schema, self.options)) + NewFileWriter(self.sink.get(), schema.sp_schema, self.options)) cdef class _RecordBatchFileReader: cdef: shared_ptr[CRecordBatchFileReader] reader shared_ptr[CRandomAccessFile] file + CIpcReadOptions options cdef readonly: Schema schema @@ -391,12 +392,14 @@ cdef class _RecordBatchFileReader: with nogil: if offset != 0: - check_status( + self.reader = GetResultValue( CRecordBatchFileReader.Open2(self.file.get(), offset, - &self.reader)) + self.options)) + else: - check_status( - CRecordBatchFileReader.Open(self.file.get(), &self.reader)) + self.reader = GetResultValue( + CRecordBatchFileReader.Open(self.file.get(), + self.options)) self.schema = pyarrow_wrap_schema(self.reader.get().schema()) @@ -594,7 +597,7 @@ def read_schema(obj, DictionaryMemo dictionary_memo=None): arg_dict_memo = &temp_memo with nogil: - check_status(ReadSchema(cpp_file.get(), arg_dict_memo, &result)) + result = GetResultValue(ReadSchema(cpp_file.get(), arg_dict_memo)) return pyarrow_wrap_schema(result) @@ -634,8 +637,10 @@ def read_record_batch(obj, Schema schema, arg_dict_memo = &temp_memo with nogil: - check_status(ReadRecordBatch(deref(message.message.get()), - schema.sp_schema, - arg_dict_memo, &result)) + result = GetResultValue( + ReadRecordBatch(deref(message.message.get()), + schema.sp_schema, + arg_dict_memo, + CIpcReadOptions.Defaults())) return pyarrow_wrap_batch(result) diff --git a/python/pyarrow/table.pxi b/python/pyarrow/table.pxi index b2f3293602b..c53216e5331 100644 --- a/python/pyarrow/table.pxi +++ b/python/pyarrow/table.pxi @@ -681,13 +681,13 @@ cdef class RecordBatch(_PandasConvertible): ------- serialized : Buffer """ - cdef: - shared_ptr[CBuffer] buffer - CMemoryPool* pool = maybe_unbox_memory_pool(memory_pool) + cdef shared_ptr[CBuffer] buffer + cdef CIpcWriteOptions options = CIpcWriteOptions.Defaults() + options.memory_pool = maybe_unbox_memory_pool(memory_pool) with nogil: check_status(SerializeRecordBatch(deref(self.batch), - pool, &buffer)) + options, &buffer)) return pyarrow_wrap_buffer(buffer) def slice(self, offset=0, length=None): diff --git a/r/src/message.cpp b/r/src/message.cpp index 1726eb28b81..440bb5b58b9 100644 --- a/r/src/message.cpp +++ b/r/src/message.cpp @@ -57,31 +57,25 @@ bool ipc___Message__Equals(const std::unique_ptr& x, std::shared_ptr ipc___ReadRecordBatch__Message__Schema( const std::unique_ptr& message, const std::shared_ptr& schema) { - std::shared_ptr batch; - // TODO: perhaps this should come from the R side arrow::ipc::DictionaryMemo memo; - STOP_IF_NOT_OK(arrow::ipc::ReadRecordBatch(*message, schema, &memo, &batch)); - return batch; + return VALUE_OR_STOP(arrow::ipc::ReadRecordBatch( + *message, schema, &memo, arrow::ipc::IpcReadOptions::Defaults())); } // [[arrow::export]] std::shared_ptr ipc___ReadSchema_InputStream( const std::shared_ptr& stream) { - std::shared_ptr schema; // TODO: promote to function argument arrow::ipc::DictionaryMemo memo; - STOP_IF_NOT_OK(arrow::ipc::ReadSchema(stream.get(), &memo, &schema)); - return schema; + return VALUE_OR_STOP(arrow::ipc::ReadSchema(stream.get(), &memo)); } // [[arrow::export]] std::shared_ptr ipc___ReadSchema_Message( const std::unique_ptr& message) { - std::shared_ptr schema; arrow::ipc::DictionaryMemo empty_memo; - STOP_IF_NOT_OK(arrow::ipc::ReadSchema(*message, &empty_memo, &schema)); - return schema; + return VALUE_OR_STOP(arrow::ipc::ReadSchema(*message, &empty_memo)); } //--------- MessageReader diff --git a/r/src/recordbatch.cpp b/r/src/recordbatch.cpp index 2babc642179..374570f95d1 100644 --- a/r/src/recordbatch.cpp +++ b/r/src/recordbatch.cpp @@ -162,8 +162,8 @@ Rcpp::RawVector ipc___SerializeRecordBatch__Raw( // serialize into the bytes of the raw vector auto buffer = std::make_shared>(out); arrow::io::FixedSizeBufferWriter stream(buffer); - STOP_IF_NOT_OK( - arrow::ipc::SerializeRecordBatch(*batch, arrow::default_memory_pool(), &stream)); + STOP_IF_NOT_OK(arrow::ipc::SerializeRecordBatch( + *batch, arrow::ipc::IpcWriteOptions::Defaults(), &stream)); STOP_IF_NOT_OK(stream.Close()); return out; @@ -173,11 +173,10 @@ Rcpp::RawVector ipc___SerializeRecordBatch__Raw( std::shared_ptr ipc___ReadRecordBatch__InputStream__Schema( const std::shared_ptr& stream, const std::shared_ptr& schema) { - std::shared_ptr batch; // TODO: promote to function arg arrow::ipc::DictionaryMemo memo; - STOP_IF_NOT_OK(arrow::ipc::ReadRecordBatch(schema, &memo, stream.get(), &batch)); - return batch; + return VALUE_OR_STOP(arrow::ipc::ReadRecordBatch( + schema, &memo, arrow::ipc::IpcReadOptions::Defaults(), stream.get())); } namespace arrow { diff --git a/r/src/recordbatchreader.cpp b/r/src/recordbatchreader.cpp index 42e2968fbc8..1a161ba779b 100644 --- a/r/src/recordbatchreader.cpp +++ b/r/src/recordbatchreader.cpp @@ -39,8 +39,7 @@ std::shared_ptr RecordBatchReader__ReadNext( std::shared_ptr ipc___RecordBatchStreamReader__Open( const std::shared_ptr& stream) { std::shared_ptr reader; - STOP_IF_NOT_OK(arrow::ipc::RecordBatchStreamReader::Open(stream, &reader)); - return reader; + return VALUE_OR_STOP(arrow::ipc::RecordBatchStreamReader::Open(stream)); } // [[arrow::export]] @@ -87,8 +86,7 @@ std::shared_ptr ipc___RecordBatchFileReader__ReadRecordBatch std::shared_ptr ipc___RecordBatchFileReader__Open( const std::shared_ptr& file) { std::shared_ptr reader; - STOP_IF_NOT_OK(arrow::ipc::RecordBatchFileReader::Open(file, &reader)); - return reader; + return VALUE_OR_STOP(arrow::ipc::RecordBatchFileReader::Open(file)); } // [[arrow::export]] diff --git a/r/src/recordbatchwriter.cpp b/r/src/recordbatchwriter.cpp index 751ce1f193b..c6ff171f80b 100644 --- a/r/src/recordbatchwriter.cpp +++ b/r/src/recordbatchwriter.cpp @@ -43,20 +43,18 @@ void ipc___RecordBatchWriter__Close( std::shared_ptr ipc___RecordBatchFileWriter__Open( const std::shared_ptr& stream, const std::shared_ptr& schema, bool use_legacy_format) { - auto options = arrow::ipc::IpcOptions::Defaults(); + auto options = arrow::ipc::IpcWriteOptions::Defaults(); options.write_legacy_ipc_format = use_legacy_format; - return VALUE_OR_STOP( - arrow::ipc::RecordBatchFileWriter::Open(stream.get(), schema, options)); + return VALUE_OR_STOP(arrow::ipc::NewFileWriter(stream.get(), schema, options)); } // [[arrow::export]] std::shared_ptr ipc___RecordBatchStreamWriter__Open( const std::shared_ptr& stream, const std::shared_ptr& schema, bool use_legacy_format) { - auto options = arrow::ipc::IpcOptions::Defaults(); + auto options = arrow::ipc::IpcWriteOptions::Defaults(); options.write_legacy_ipc_format = use_legacy_format; - return VALUE_OR_STOP( - arrow::ipc::RecordBatchStreamWriter::Open(stream.get(), schema, options)); + return VALUE_OR_STOP(NewStreamWriter(stream.get(), schema, options)); } #endif diff --git a/r/tests/testthat/test-dataset.R b/r/tests/testthat/test-dataset.R index dfca492f70d..d3426492737 100644 --- a/r/tests/testthat/test-dataset.R +++ b/r/tests/testthat/test-dataset.R @@ -174,6 +174,9 @@ test_that("IPC/Arrow format data", { dim(ds), "Number of rows unknown; returning NA" ) + # This causes a segfault on Windows R 32-bit following ARROW-7979 + # TODO: fix me + skip_on_os("windows") expect_equivalent( ds %>% select(string = chr, integer = int, part) %>%