diff --git a/cpp/CMakeLists.txt b/cpp/CMakeLists.txt index 0888a8b97fa..b77f8c79fa0 100644 --- a/cpp/CMakeLists.txt +++ b/cpp/CMakeLists.txt @@ -102,7 +102,9 @@ if("${CMAKE_SOURCE_DIR}" STREQUAL "${CMAKE_CURRENT_SOURCE_DIR}") ON) endif() -if(NOT ARROW_BUILD_TESTS) +if(ARROW_BUILD_TESTS) + set(ARROW_BUILD_STATIC ON) +else() set(NO_TESTS 1) endif() diff --git a/cpp/src/arrow/array.h b/cpp/src/arrow/array.h index 32d156b8cd0..9bb06afc9bf 100644 --- a/cpp/src/arrow/array.h +++ b/cpp/src/arrow/array.h @@ -40,6 +40,8 @@ class Status; class ArrayVisitor { public: + virtual ~ArrayVisitor() = default; + virtual Status Visit(const NullArray& array) = 0; virtual Status Visit(const BooleanArray& array) = 0; virtual Status Visit(const Int8Array& array) = 0; diff --git a/cpp/src/arrow/io/CMakeLists.txt b/cpp/src/arrow/io/CMakeLists.txt index b8882e46b48..ceb7b737932 100644 --- a/cpp/src/arrow/io/CMakeLists.txt +++ b/cpp/src/arrow/io/CMakeLists.txt @@ -70,13 +70,8 @@ set(ARROW_IO_STATIC_PRIVATE_LINK_LIBS boost_system_static boost_filesystem_static) -if (ARROW_BUILD_STATIC) - set(ARROW_IO_TEST_LINK_LIBS - arrow_io_static) -else() - set(ARROW_IO_TEST_LINK_LIBS - arrow_io_shared) -endif() +set(ARROW_IO_TEST_LINK_LIBS + arrow_io_static) set(ARROW_IO_SRCS file.cc diff --git a/cpp/src/arrow/ipc/CMakeLists.txt b/cpp/src/arrow/ipc/CMakeLists.txt index c047f53d6bf..e7a3fdb1dd8 100644 --- a/cpp/src/arrow/ipc/CMakeLists.txt +++ b/cpp/src/arrow/ipc/CMakeLists.txt @@ -24,20 +24,9 @@ set(ARROW_IPC_SHARED_LINK_LIBS arrow_shared ) -set(ARROW_IPC_STATIC_LINK_LIBS - arrow_static +set(ARROW_IPC_TEST_LINK_LIBS arrow_io_static -) - -if (ARROW_BUILD_STATIC) - set(ARROW_IPC_TEST_LINK_LIBS - arrow_io_static - arrow_ipc_static) -else() - set(ARROW_IPC_TEST_LINK_LIBS - arrow_io_shared - arrow_ipc_shared) -endif() + arrow_ipc_static) set(ARROW_IPC_SRCS adapter.cc diff --git a/cpp/src/arrow/ipc/adapter.cc b/cpp/src/arrow/ipc/adapter.cc index a24c007a405..08ac9832982 100644 --- a/cpp/src/arrow/ipc/adapter.cc +++ b/cpp/src/arrow/ipc/adapter.cc @@ -51,12 +51,15 @@ namespace ipc { class RecordBatchWriter : public ArrayVisitor { public: - RecordBatchWriter(MemoryPool* pool, const RecordBatch& batch, - int64_t buffer_start_offset, int max_recursion_depth) + RecordBatchWriter( + MemoryPool* pool, int64_t buffer_start_offset, int max_recursion_depth) : pool_(pool), - batch_(batch), max_recursion_depth_(max_recursion_depth), - buffer_start_offset_(buffer_start_offset) {} + buffer_start_offset_(buffer_start_offset) { + DCHECK_GT(max_recursion_depth, 0); + } + + virtual ~RecordBatchWriter() = default; Status VisitArray(const Array& arr) { if (max_recursion_depth_ <= 0) { @@ -81,7 +84,7 @@ class RecordBatchWriter : public ArrayVisitor { return arr.Accept(this); } - Status Assemble(int64_t* body_length) { + Status Assemble(const RecordBatch& batch, int64_t* body_length) { if (field_nodes_.size() > 0) { field_nodes_.clear(); buffer_meta_.clear(); @@ -89,8 +92,8 @@ class RecordBatchWriter : public ArrayVisitor { } // Perform depth-first traversal of the row-batch - for (int i = 0; i < batch_.num_columns(); ++i) { - RETURN_NOT_OK(VisitArray(*batch_.column(i))); + for (int i = 0; i < batch.num_columns(); ++i) { + RETURN_NOT_OK(VisitArray(*batch.column(i))); } // The position for the start of a buffer relative to the passed frame of @@ -127,16 +130,22 @@ class RecordBatchWriter : public ArrayVisitor { return Status::OK(); } - Status WriteMetadata( - int64_t body_length, io::OutputStream* dst, int32_t* metadata_length) { + // Override this for writing dictionary metadata + virtual Status WriteMetadataMessage( + int32_t num_rows, int64_t body_length, std::shared_ptr* out) { + return WriteRecordBatchMessage( + num_rows, body_length, field_nodes_, buffer_meta_, out); + } + + Status WriteMetadata(int32_t num_rows, int64_t body_length, io::OutputStream* dst, + int32_t* metadata_length) { // Now that we have computed the locations of all of the buffers in shared // memory, the data header can be converted to a flatbuffer and written out // // Note: The memory written here is prefixed by the size of the flatbuffer // itself as an int32_t. std::shared_ptr metadata_fb; - RETURN_NOT_OK(WriteRecordBatchMetadata( - batch_.num_rows(), body_length, field_nodes_, buffer_meta_, &metadata_fb)); + RETURN_NOT_OK(WriteMetadataMessage(num_rows, body_length, &metadata_fb)); // Need to write 4 bytes (metadata size), the metadata, plus padding to // end on an 8-byte offset @@ -166,15 +175,16 @@ class RecordBatchWriter : public ArrayVisitor { return Status::OK(); } - Status Write(io::OutputStream* dst, int32_t* metadata_length, int64_t* body_length) { - RETURN_NOT_OK(Assemble(body_length)); + Status Write(const RecordBatch& batch, io::OutputStream* dst, int32_t* metadata_length, + int64_t* body_length) { + RETURN_NOT_OK(Assemble(batch, body_length)); #ifndef NDEBUG int64_t start_position, current_position; RETURN_NOT_OK(dst->Tell(&start_position)); #endif - RETURN_NOT_OK(WriteMetadata(*body_length, dst, metadata_length)); + RETURN_NOT_OK(WriteMetadata(batch.num_rows(), *body_length, dst, metadata_length)); #ifndef NDEBUG RETURN_NOT_OK(dst->Tell(¤t_position)); @@ -206,17 +216,17 @@ class RecordBatchWriter : public ArrayVisitor { return Status::OK(); } - Status GetTotalSize(int64_t* size) { + Status GetTotalSize(const RecordBatch& batch, int64_t* size) { // emulates the behavior of Write without actually writing int32_t metadata_length = 0; int64_t body_length = 0; MockOutputStream dst; - RETURN_NOT_OK(Write(&dst, &metadata_length, &body_length)); + RETURN_NOT_OK(Write(batch, &dst, &metadata_length, &body_length)); *size = dst.GetExtentBytesWritten(); return Status::OK(); } - private: + protected: Status Visit(const NullArray& array) override { return Status::NotImplemented("null"); } template @@ -468,15 +478,12 @@ class RecordBatchWriter : public ArrayVisitor { } Status Visit(const DictionaryArray& array) override { - // Dictionary written out separately - const auto& indices = static_cast(*array.indices().get()); - buffers_.push_back(indices.data()); - return Status::OK(); + // Dictionary written out separately. Slice offset contained in the indices + return array.indices()->Accept(this); } // In some cases, intermediate buffers may need to be allocated (with sliced arrays) MemoryPool* pool_; - const RecordBatch& batch_; std::vector field_nodes_; std::vector buffer_meta_; @@ -486,17 +493,51 @@ class RecordBatchWriter : public ArrayVisitor { int64_t buffer_start_offset_; }; +class DictionaryWriter : public RecordBatchWriter { + public: + using RecordBatchWriter::RecordBatchWriter; + + Status WriteMetadataMessage( + int32_t num_rows, int64_t body_length, std::shared_ptr* out) override { + return WriteDictionaryMessage( + dictionary_id_, num_rows, body_length, field_nodes_, buffer_meta_, out); + } + + Status Write(int64_t dictionary_id, const std::shared_ptr& dictionary, + io::OutputStream* dst, int32_t* metadata_length, int64_t* body_length) { + dictionary_id_ = dictionary_id; + + // Make a dummy record batch. A bit tedious as we have to make a schema + std::vector> fields = { + arrow::field("dictionary", dictionary->type())}; + auto schema = std::make_shared(fields); + RecordBatch batch(schema, dictionary->length(), {dictionary}); + + return RecordBatchWriter::Write(batch, dst, metadata_length, body_length); + } + + private: + // TODO(wesm): Setting this in Write is a bit unclean, but it works + int64_t dictionary_id_; +}; + Status WriteRecordBatch(const RecordBatch& batch, int64_t buffer_start_offset, io::OutputStream* dst, int32_t* metadata_length, int64_t* body_length, MemoryPool* pool, int max_recursion_depth) { - DCHECK_GT(max_recursion_depth, 0); - RecordBatchWriter serializer(pool, batch, buffer_start_offset, max_recursion_depth); - return serializer.Write(dst, metadata_length, body_length); + RecordBatchWriter writer(pool, buffer_start_offset, max_recursion_depth); + return writer.Write(batch, dst, metadata_length, body_length); +} + +Status WriteDictionary(int64_t dictionary_id, const std::shared_ptr& dictionary, + int64_t buffer_start_offset, io::OutputStream* dst, int32_t* metadata_length, + int64_t* body_length, MemoryPool* pool) { + DictionaryWriter writer(pool, buffer_start_offset, kMaxIpcRecursionDepth); + return writer.Write(dictionary_id, dictionary, dst, metadata_length, body_length); } Status GetRecordBatchSize(const RecordBatch& batch, int64_t* size) { - RecordBatchWriter serializer(default_memory_pool(), batch, 0, kMaxIpcRecursionDepth); - RETURN_NOT_OK(serializer.GetTotalSize(size)); + RecordBatchWriter writer(default_memory_pool(), 0, kMaxIpcRecursionDepth); + RETURN_NOT_OK(writer.GetTotalSize(batch, size)); return Status::OK(); } @@ -580,10 +621,9 @@ class ArrayLoader : public TypeVisitor { Status LoadPrimitive(const DataType& type) { FieldMetadata field_meta; - std::shared_ptr null_bitmap; - RETURN_NOT_OK(LoadCommon(&field_meta, &null_bitmap)); + std::shared_ptr null_bitmap, data; - std::shared_ptr data; + RETURN_NOT_OK(LoadCommon(&field_meta, &null_bitmap)); if (field_meta.length > 0) { RETURN_NOT_OK(GetBuffer(context_->buffer_index++, &data)); } else { @@ -597,11 +637,9 @@ class ArrayLoader : public TypeVisitor { template Status LoadBinary() { FieldMetadata field_meta; - std::shared_ptr null_bitmap; - RETURN_NOT_OK(LoadCommon(&field_meta, &null_bitmap)); + std::shared_ptr null_bitmap, offsets, values; - std::shared_ptr offsets; - std::shared_ptr values; + RETURN_NOT_OK(LoadCommon(&field_meta, &null_bitmap)); if (field_meta.length > 0) { RETURN_NOT_OK(GetBuffer(context_->buffer_index++, &offsets)); RETURN_NOT_OK(GetBuffer(context_->buffer_index++, &values)); @@ -661,11 +699,9 @@ class ArrayLoader : public TypeVisitor { Status Visit(const ListType& type) override { FieldMetadata field_meta; - std::shared_ptr null_bitmap; + std::shared_ptr null_bitmap, offsets; RETURN_NOT_OK(LoadCommon(&field_meta, &null_bitmap)); - - std::shared_ptr offsets; if (field_meta.length > 0) { RETURN_NOT_OK(GetBuffer(context_->buffer_index, &offsets)); } else { @@ -715,12 +751,9 @@ class ArrayLoader : public TypeVisitor { Status Visit(const UnionType& type) override { FieldMetadata field_meta; - std::shared_ptr null_bitmap; - RETURN_NOT_OK(LoadCommon(&field_meta, &null_bitmap)); - - std::shared_ptr type_ids = nullptr; - std::shared_ptr offsets = nullptr; + std::shared_ptr null_bitmap, type_ids, offsets; + RETURN_NOT_OK(LoadCommon(&field_meta, &null_bitmap)); if (field_meta.length > 0) { RETURN_NOT_OK(GetBuffer(context_->buffer_index, &type_ids)); if (type.mode == UnionMode::DENSE) { @@ -738,13 +771,23 @@ class ArrayLoader : public TypeVisitor { } Status Visit(const DictionaryType& type) override { - return Status::NotImplemented("dictionary"); + FieldMetadata field_meta; + std::shared_ptr null_bitmap, indices_data; + RETURN_NOT_OK(LoadCommon(&field_meta, &null_bitmap)); + RETURN_NOT_OK(GetBuffer(context_->buffer_index++, &indices_data)); + + std::shared_ptr indices; + RETURN_NOT_OK(MakePrimitiveArray(type.index_type(), field_meta.length, indices_data, + null_bitmap, field_meta.null_count, 0, &indices)); + + result_ = std::make_shared(field_.type, indices); + return Status::OK(); }; }; class RecordBatchReader { public: - RecordBatchReader(const std::shared_ptr& metadata, + RecordBatchReader(const RecordBatchMetadata& metadata, const std::shared_ptr& schema, int max_recursion_depth, io::ReadableFileInterface* file) : metadata_(metadata), @@ -758,7 +801,7 @@ class RecordBatchReader { // The field_index and buffer_index are incremented in the ArrayLoader // based on how much of the batch is "consumed" (through nested data // reconstruction, for example) - context_.metadata = metadata_.get(); + context_.metadata = &metadata_; context_.field_index = 0; context_.buffer_index = 0; context_.max_recursion_depth = max_recursion_depth_; @@ -768,50 +811,58 @@ class RecordBatchReader { RETURN_NOT_OK(loader.Load(&arrays[i])); } - *out = std::make_shared(schema_, metadata_->length(), arrays); + *out = std::make_shared(schema_, metadata_.length(), arrays); return Status::OK(); } private: RecordBatchContext context_; - std::shared_ptr metadata_; + const RecordBatchMetadata& metadata_; std::shared_ptr schema_; int max_recursion_depth_; io::ReadableFileInterface* file_; }; -Status ReadRecordBatchMetadata(int64_t offset, int32_t metadata_length, - io::ReadableFileInterface* file, std::shared_ptr* metadata) { - std::shared_ptr buffer; - RETURN_NOT_OK(file->ReadAt(offset, metadata_length, &buffer)); - - int32_t flatbuffer_size = *reinterpret_cast(buffer->data()); - - if (flatbuffer_size + static_cast(sizeof(int32_t)) > metadata_length) { - std::stringstream ss; - ss << "flatbuffer size " << metadata_length << " invalid. File offset: " << offset - << ", metadata length: " << metadata_length; - return Status::Invalid(ss.str()); - } - - std::shared_ptr message; - RETURN_NOT_OK(Message::Open(buffer, 4, &message)); - *metadata = std::make_shared(message); - return Status::OK(); -} - -Status ReadRecordBatch(const std::shared_ptr& metadata, +Status ReadRecordBatch(const RecordBatchMetadata& metadata, const std::shared_ptr& schema, io::ReadableFileInterface* file, std::shared_ptr* out) { return ReadRecordBatch(metadata, schema, kMaxIpcRecursionDepth, file, out); } -Status ReadRecordBatch(const std::shared_ptr& metadata, +Status ReadRecordBatch(const RecordBatchMetadata& metadata, const std::shared_ptr& schema, int max_recursion_depth, io::ReadableFileInterface* file, std::shared_ptr* out) { RecordBatchReader reader(metadata, schema, max_recursion_depth, file); return reader.Read(out); } +Status ReadDictionary(const DictionaryBatchMetadata& metadata, + const DictionaryTypeMap& dictionary_types, io::ReadableFileInterface* file, + std::shared_ptr* out) { + int64_t id = metadata.id(); + auto it = dictionary_types.find(id); + if (it == dictionary_types.end()) { + std::stringstream ss; + ss << "Do not have type metadata for dictionary with id: " << id; + return Status::KeyError(ss.str()); + } + + std::vector> fields = {it->second}; + + // We need a schema for the record batch + auto dummy_schema = std::make_shared(fields); + + // The dictionary is embedded in a record batch with a single column + std::shared_ptr batch; + RETURN_NOT_OK(ReadRecordBatch(metadata.record_batch(), dummy_schema, file, &batch)); + + if (batch->num_columns() != 1) { + return Status::Invalid("Dictionary record batch must only contain one field"); + } + + *out = batch->column(0); + return Status::OK(); +} + } // namespace ipc } // namespace arrow diff --git a/cpp/src/arrow/ipc/adapter.h b/cpp/src/arrow/ipc/adapter.h index 83542d0b066..b7d8fa93d36 100644 --- a/cpp/src/arrow/ipc/adapter.h +++ b/cpp/src/arrow/ipc/adapter.h @@ -25,6 +25,7 @@ #include #include +#include "arrow/ipc/metadata.h" #include "arrow/util/visibility.h" namespace arrow { @@ -44,8 +45,6 @@ class OutputStream; namespace ipc { -class RecordBatchMetadata; - // ---------------------------------------------------------------------- // Write path // We have trouble decoding flatbuffers if the size i > 70, so 64 is a nice round number @@ -72,34 +71,35 @@ constexpr int kMaxIpcRecursionDepth = 64; // // @param(out) body_length: the size of the contiguous buffer block plus // padding bytes -Status ARROW_EXPORT WriteRecordBatch(const RecordBatch& batch, +Status WriteRecordBatch(const RecordBatch& batch, int64_t buffer_start_offset, + io::OutputStream* dst, int32_t* metadata_length, int64_t* body_length, + MemoryPool* pool, int max_recursion_depth = kMaxIpcRecursionDepth); + +// Write Array as a DictionaryBatch message +Status WriteDictionary(int64_t dictionary_id, const std::shared_ptr& dictionary, int64_t buffer_start_offset, io::OutputStream* dst, int32_t* metadata_length, - int64_t* body_length, MemoryPool* pool, - int max_recursion_depth = kMaxIpcRecursionDepth); + int64_t* body_length, MemoryPool* pool); // Compute the precise number of bytes needed in a contiguous memory segment to // write the record batch. This involves generating the complete serialized // Flatbuffers metadata. -Status ARROW_EXPORT GetRecordBatchSize(const RecordBatch& batch, int64_t* size); +Status GetRecordBatchSize(const RecordBatch& batch, int64_t* size); // ---------------------------------------------------------------------- // "Read" path; does not copy data if the input supports zero copy reads -// Read the record batch flatbuffer metadata starting at the indicated file offset -// -// The flatbuffer is expected to be length-prefixed, so the metadata_length -// includes at least the length prefix and the flatbuffer -Status ARROW_EXPORT ReadRecordBatchMetadata(int64_t offset, int32_t metadata_length, - io::ReadableFileInterface* file, std::shared_ptr* metadata); - -Status ARROW_EXPORT ReadRecordBatch(const std::shared_ptr& metadata, +Status ReadRecordBatch(const RecordBatchMetadata& metadata, const std::shared_ptr& schema, io::ReadableFileInterface* file, std::shared_ptr* out); -Status ARROW_EXPORT ReadRecordBatch(const std::shared_ptr& metadata, +Status ReadRecordBatch(const RecordBatchMetadata& metadata, const std::shared_ptr& schema, int max_recursion_depth, io::ReadableFileInterface* file, std::shared_ptr* out); +Status ReadDictionary(const DictionaryBatchMetadata& metadata, + const DictionaryTypeMap& dictionary_types, io::ReadableFileInterface* file, + std::shared_ptr* out); + } // namespace ipc } // namespace arrow diff --git a/cpp/src/arrow/ipc/file.cc b/cpp/src/arrow/ipc/file.cc index 3b183261102..c1d483f1fbb 100644 --- a/cpp/src/arrow/ipc/file.cc +++ b/cpp/src/arrow/ipc/file.cc @@ -36,8 +36,6 @@ namespace arrow { namespace ipc { static constexpr const char* kArrowMagicBytes = "ARROW1"; -// ---------------------------------------------------------------------- -// File footer static flatbuffers::Offset> FileBlocksToFlatbuffer(FBB& fbb, const std::vector& blocks) { @@ -51,11 +49,12 @@ FileBlocksToFlatbuffer(FBB& fbb, const std::vector& blocks) { } Status WriteFileFooter(const Schema& schema, const std::vector& dictionaries, - const std::vector& record_batches, io::OutputStream* out) { + const std::vector& record_batches, DictionaryMemo* dictionary_memo, + io::OutputStream* out) { FBB fbb; flatbuffers::Offset fb_schema; - RETURN_NOT_OK(SchemaToFlatbuffer(fbb, schema, &fb_schema)); + RETURN_NOT_OK(SchemaToFlatbuffer(fbb, schema, dictionary_memo, &fb_schema)); auto fb_dictionaries = FileBlocksToFlatbuffer(fbb, dictionaries); auto fb_record_batches = FileBlocksToFlatbuffer(fbb, record_batches); @@ -74,87 +73,6 @@ static inline FileBlock FileBlockFromFlatbuffer(const flatbuf::Block* block) { return FileBlock(block->offset(), block->metaDataLength(), block->bodyLength()); } -class FileFooter::FileFooterImpl { - public: - FileFooterImpl(const std::shared_ptr& buffer, const flatbuf::Footer* footer) - : buffer_(buffer), footer_(footer) {} - - int num_dictionaries() const { return footer_->dictionaries()->size(); } - - int num_record_batches() const { return footer_->recordBatches()->size(); } - - MetadataVersion::type version() const { - switch (footer_->version()) { - case flatbuf::MetadataVersion_V1: - return MetadataVersion::V1; - case flatbuf::MetadataVersion_V2: - return MetadataVersion::V2; - // Add cases as other versions become available - default: - return MetadataVersion::V2; - } - } - - FileBlock record_batch(int i) const { - return FileBlockFromFlatbuffer(footer_->recordBatches()->Get(i)); - } - - FileBlock dictionary(int i) const { - return FileBlockFromFlatbuffer(footer_->dictionaries()->Get(i)); - } - - Status GetSchema(std::shared_ptr* out) const { - auto schema_msg = std::make_shared(nullptr, footer_->schema()); - return schema_msg->GetSchema(out); - } - - private: - // Retain reference to memory - std::shared_ptr buffer_; - - const flatbuf::Footer* footer_; -}; - -FileFooter::FileFooter() {} - -FileFooter::~FileFooter() {} - -Status FileFooter::Open( - const std::shared_ptr& buffer, std::unique_ptr* out) { - const flatbuf::Footer* footer = flatbuf::GetFooter(buffer->data()); - - *out = std::unique_ptr(new FileFooter()); - - // TODO(wesm): Verify the footer - (*out)->impl_.reset(new FileFooterImpl(buffer, footer)); - - return Status::OK(); -} - -int FileFooter::num_dictionaries() const { - return impl_->num_dictionaries(); -} - -int FileFooter::num_record_batches() const { - return impl_->num_record_batches(); -} - -MetadataVersion::type FileFooter::version() const { - return impl_->version(); -} - -FileBlock FileFooter::record_batch(int i) const { - return impl_->record_batch(i); -} - -FileBlock FileFooter::dictionary(int i) const { - return impl_->dictionary(i); -} - -Status FileFooter::GetSchema(std::shared_ptr* out) const { - return impl_->GetSchema(out); -} - // ---------------------------------------------------------------------- // File writer implementation @@ -171,22 +89,17 @@ Status FileWriter::Open(io::OutputStream* sink, const std::shared_ptr& s Status FileWriter::Start() { RETURN_NOT_OK(WriteAligned( reinterpret_cast(kArrowMagicBytes), strlen(kArrowMagicBytes))); - started_ = true; - return Status::OK(); -} -Status FileWriter::WriteRecordBatch(const RecordBatch& batch) { - // Push an empty FileBlock - // Append metadata, to be written in the footer later - record_batches_.emplace_back(0, 0, 0); - return StreamWriter::WriteRecordBatch( - batch, &record_batches_[record_batches_.size() - 1]); + // We write the schema at the start of the file (and the end). This also + // writes all the dictionaries at the beginning of the file + return StreamWriter::Start(); } Status FileWriter::Close() { // Write metadata int64_t initial_position = position_; - RETURN_NOT_OK(WriteFileFooter(*schema_, dictionaries_, record_batches_, sink_)); + RETURN_NOT_OK(WriteFileFooter( + *schema_, dictionaries_, record_batches_, dictionary_memo_.get(), sink_)); RETURN_NOT_OK(UpdatePosition()); // Write footer length @@ -204,89 +117,180 @@ Status FileWriter::Close() { // ---------------------------------------------------------------------- // Reader implementation -FileReader::FileReader( - const std::shared_ptr& file, int64_t footer_offset) - : file_(file), footer_offset_(footer_offset) {} +class FileReader::FileReaderImpl { + public: + FileReaderImpl() { dictionary_memo_ = std::make_shared(); } -FileReader::~FileReader() {} + Status ReadFooter() { + int magic_size = static_cast(strlen(kArrowMagicBytes)); -Status FileReader::Open(const std::shared_ptr& file, - std::shared_ptr* reader) { - int64_t footer_offset; - RETURN_NOT_OK(file->GetSize(&footer_offset)); - return Open(file, footer_offset, reader); -} + if (footer_offset_ <= magic_size * 2 + 4) { + std::stringstream ss; + ss << "File is too small: " << footer_offset_; + return Status::Invalid(ss.str()); + } -Status FileReader::Open(const std::shared_ptr& file, - int64_t footer_offset, std::shared_ptr* reader) { - *reader = std::shared_ptr(new FileReader(file, footer_offset)); - return (*reader)->ReadFooter(); -} + std::shared_ptr buffer; + int file_end_size = magic_size + sizeof(int32_t); + RETURN_NOT_OK(file_->ReadAt(footer_offset_ - file_end_size, file_end_size, &buffer)); + + if (memcmp(buffer->data() + sizeof(int32_t), kArrowMagicBytes, magic_size)) { + return Status::Invalid("Not an Arrow file"); + } + + int32_t footer_length = *reinterpret_cast(buffer->data()); + + if (footer_length <= 0 || footer_length + magic_size * 2 + 4 > footer_offset_) { + return Status::Invalid("File is smaller than indicated metadata size"); + } -Status FileReader::ReadFooter() { - int magic_size = static_cast(strlen(kArrowMagicBytes)); + // Now read the footer + RETURN_NOT_OK(file_->ReadAt( + footer_offset_ - footer_length - file_end_size, footer_length, &footer_buffer_)); - if (footer_offset_ <= magic_size * 2 + 4) { - std::stringstream ss; - ss << "File is too small: " << footer_offset_; - return Status::Invalid(ss.str()); + // TODO(wesm): Verify the footer + footer_ = flatbuf::GetFooter(footer_buffer_->data()); + schema_metadata_.reset(new SchemaMetadata(nullptr, footer_->schema())); + + return Status::OK(); + } + + int num_dictionaries() const { return footer_->dictionaries()->size(); } + + int num_record_batches() const { return footer_->recordBatches()->size(); } + + MetadataVersion::type version() const { + switch (footer_->version()) { + case flatbuf::MetadataVersion_V1: + return MetadataVersion::V1; + case flatbuf::MetadataVersion_V2: + return MetadataVersion::V2; + // Add cases as other versions become available + default: + return MetadataVersion::V2; + } + } + + FileBlock record_batch(int i) const { + return FileBlockFromFlatbuffer(footer_->recordBatches()->Get(i)); + } + + FileBlock dictionary(int i) const { + return FileBlockFromFlatbuffer(footer_->dictionaries()->Get(i)); } - std::shared_ptr buffer; - int file_end_size = magic_size + sizeof(int32_t); - RETURN_NOT_OK(file_->ReadAt(footer_offset_ - file_end_size, file_end_size, &buffer)); + const SchemaMetadata& schema_metadata() const { return *schema_metadata_; } + + Status GetRecordBatch(int i, std::shared_ptr* batch) { + DCHECK_GE(i, 0); + DCHECK_LT(i, num_record_batches()); + FileBlock block = record_batch(i); + + std::shared_ptr message; + RETURN_NOT_OK( + ReadMessage(block.offset, block.metadata_length, file_.get(), &message)); + auto metadata = std::make_shared(message); - if (memcmp(buffer->data() + sizeof(int32_t), kArrowMagicBytes, magic_size)) { - return Status::Invalid("Not an Arrow file"); + // TODO(wesm): ARROW-388 -- the buffer frame of reference is 0 (see + // ARROW-384). + std::shared_ptr buffer_block; + RETURN_NOT_OK(file_->Read(block.body_length, &buffer_block)); + io::BufferReader reader(buffer_block); + + return ReadRecordBatch(*metadata, schema_, &reader, batch); } - int32_t footer_length = *reinterpret_cast(buffer->data()); + Status ReadSchema() { + RETURN_NOT_OK(schema_metadata_->GetDictionaryTypes(&dictionary_fields_)); + + // Read all the dictionaries + for (int i = 0; i < num_dictionaries(); ++i) { + FileBlock block = dictionary(i); + std::shared_ptr message; + RETURN_NOT_OK( + ReadMessage(block.offset, block.metadata_length, file_.get(), &message)); + + // TODO(wesm): ARROW-577: This code is duplicated, can be fixed with a more + // invasive refactor + DictionaryBatchMetadata metadata(message); + + // TODO(wesm): ARROW-388 -- the buffer frame of reference is 0 (see + // ARROW-384). + std::shared_ptr buffer_block; + RETURN_NOT_OK(file_->Read(block.body_length, &buffer_block)); + io::BufferReader reader(buffer_block); + + std::shared_ptr dictionary; + RETURN_NOT_OK(ReadDictionary(metadata, dictionary_fields_, &reader, &dictionary)); + RETURN_NOT_OK(dictionary_memo_->AddDictionary(metadata.id(), dictionary)); + } - if (footer_length <= 0 || footer_length + magic_size * 2 + 4 > footer_offset_) { - return Status::Invalid("File is smaller than indicated metadata size"); + // Get the schema + return schema_metadata_->GetSchema(*dictionary_memo_, &schema_); } - // Now read the footer - RETURN_NOT_OK(file_->ReadAt( - footer_offset_ - footer_length - file_end_size, footer_length, &buffer)); - RETURN_NOT_OK(FileFooter::Open(buffer, &footer_)); + Status Open( + const std::shared_ptr& file, int64_t footer_offset) { + file_ = file; + footer_offset_ = footer_offset; + RETURN_NOT_OK(ReadFooter()); + return ReadSchema(); + } + + std::shared_ptr schema() const { return schema_; } + + private: + std::shared_ptr file_; - // Get the schema - return footer_->GetSchema(&schema_); + // The location where the Arrow file layout ends. May be the end of the file + // or some other location if embedded in a larger file. + int64_t footer_offset_; + + // Footer metadata + std::shared_ptr footer_buffer_; + const flatbuf::Footer* footer_; + std::unique_ptr schema_metadata_; + + DictionaryTypeMap dictionary_fields_; + std::shared_ptr dictionary_memo_; + + // Reconstructed schema, including any read dictionaries + std::shared_ptr schema_; +}; + +FileReader::FileReader() { + impl_.reset(new FileReaderImpl()); } -std::shared_ptr FileReader::schema() const { - return schema_; +FileReader::~FileReader() {} + +Status FileReader::Open(const std::shared_ptr& file, + std::shared_ptr* reader) { + int64_t footer_offset; + RETURN_NOT_OK(file->GetSize(&footer_offset)); + return Open(file, footer_offset, reader); +} + +Status FileReader::Open(const std::shared_ptr& file, + int64_t footer_offset, std::shared_ptr* reader) { + *reader = std::shared_ptr(new FileReader()); + return (*reader)->impl_->Open(file, footer_offset); } -int FileReader::num_dictionaries() const { - return footer_->num_dictionaries(); +std::shared_ptr FileReader::schema() const { + return impl_->schema(); } int FileReader::num_record_batches() const { - return footer_->num_record_batches(); + return impl_->num_record_batches(); } MetadataVersion::type FileReader::version() const { - return footer_->version(); + return impl_->version(); } Status FileReader::GetRecordBatch(int i, std::shared_ptr* batch) { - DCHECK_GE(i, 0); - DCHECK_LT(i, num_record_batches()); - FileBlock block = footer_->record_batch(i); - - std::shared_ptr metadata; - RETURN_NOT_OK(ReadRecordBatchMetadata( - block.offset, block.metadata_length, file_.get(), &metadata)); - - // TODO(wesm): ARROW-388 -- the buffer frame of reference is 0 (see - // ARROW-384). - std::shared_ptr buffer_block; - RETURN_NOT_OK(file_->Read(block.body_length, &buffer_block)); - io::BufferReader reader(buffer_block); - - return ReadRecordBatch(metadata, schema_, &reader, batch); + return impl_->GetRecordBatch(i, batch); } } // namespace ipc diff --git a/cpp/src/arrow/ipc/file.h b/cpp/src/arrow/ipc/file.h index cf0baab820e..524766ccb33 100644 --- a/cpp/src/arrow/ipc/file.h +++ b/cpp/src/arrow/ipc/file.h @@ -45,45 +45,21 @@ class ReadableFileInterface; namespace ipc { Status WriteFileFooter(const Schema& schema, const std::vector& dictionaries, - const std::vector& record_batches, io::OutputStream* out); - -class ARROW_EXPORT FileFooter { - public: - ~FileFooter(); - - static Status Open( - const std::shared_ptr& buffer, std::unique_ptr* out); - - int num_dictionaries() const; - int num_record_batches() const; - MetadataVersion::type version() const; - - FileBlock record_batch(int i) const; - FileBlock dictionary(int i) const; - - Status GetSchema(std::shared_ptr* out) const; - - private: - FileFooter(); - class FileFooterImpl; - std::unique_ptr impl_; -}; + const std::vector& record_batches, DictionaryMemo* dictionary_memo, + io::OutputStream* out); class ARROW_EXPORT FileWriter : public StreamWriter { public: static Status Open(io::OutputStream* sink, const std::shared_ptr& schema, std::shared_ptr* out); - Status WriteRecordBatch(const RecordBatch& batch) override; + using StreamWriter::WriteRecordBatch; Status Close() override; private: FileWriter(io::OutputStream* sink, const std::shared_ptr& schema); Status Start() override; - - std::vector dictionaries_; - std::vector record_batches_; }; class ARROW_EXPORT FileReader { @@ -108,13 +84,9 @@ class ARROW_EXPORT FileReader { static Status Open(const std::shared_ptr& file, int64_t footer_offset, std::shared_ptr* reader); + /// The schema includes any dictionaries std::shared_ptr schema() const; - // Shared dictionaries for dictionary-encoding cross record batches - // TODO(wesm): Implement dictionary reading when we also have dictionary - // encoding - int num_dictionaries() const; - int num_record_batches() const; MetadataVersion::type version() const; @@ -127,19 +99,10 @@ class ARROW_EXPORT FileReader { Status GetRecordBatch(int i, std::shared_ptr* batch); private: - FileReader( - const std::shared_ptr& file, int64_t footer_offset); - - Status ReadFooter(); - - std::shared_ptr file_; - - // The location where the Arrow file layout ends. May be the end of the file - // or some other location if embedded in a larger file. - int64_t footer_offset_; + FileReader(); - std::unique_ptr footer_; - std::shared_ptr schema_; + class ARROW_NO_EXPORT FileReaderImpl; + std::unique_ptr impl_; }; } // namespace ipc diff --git a/cpp/src/arrow/ipc/ipc-adapter-test.cc b/cpp/src/arrow/ipc/ipc-adapter-test.cc index d11b95b167d..89993638932 100644 --- a/cpp/src/arrow/ipc/ipc-adapter-test.cc +++ b/cpp/src/arrow/ipc/ipc-adapter-test.cc @@ -27,6 +27,7 @@ #include "arrow/io/memory.h" #include "arrow/io/test-common.h" #include "arrow/ipc/adapter.h" +#include "arrow/ipc/metadata.h" #include "arrow/ipc/test-common.h" #include "arrow/ipc/util.h" @@ -40,12 +41,8 @@ namespace arrow { namespace ipc { -class TestWriteRecordBatch : public ::testing::TestWithParam, - public io::MemoryMapFixture { +class IpcTestFixture : public io::MemoryMapFixture { public: - void SetUp() { pool_ = default_memory_pool(); } - void TearDown() { io::MemoryMapFixture::TearDown(); } - Status RoundTripHelper(const RecordBatch& batch, int memory_map_size, std::shared_ptr* batch_result) { std::string path = "test-write-row-batch"; @@ -59,8 +56,9 @@ class TestWriteRecordBatch : public ::testing::TestWithParam, RETURN_NOT_OK(WriteRecordBatch( batch, buffer_offset, mmap_.get(), &metadata_length, &body_length, pool_)); - std::shared_ptr metadata; - RETURN_NOT_OK(ReadRecordBatchMetadata(0, metadata_length, mmap_.get(), &metadata)); + std::shared_ptr message; + RETURN_NOT_OK(ReadMessage(0, metadata_length, mmap_.get(), &message)); + auto metadata = std::make_shared(message); // The buffer offsets start at 0, so we must construct a // ReadableFileInterface according to that frame of reference @@ -68,7 +66,7 @@ class TestWriteRecordBatch : public ::testing::TestWithParam, RETURN_NOT_OK(mmap_->ReadAt(metadata_length, body_length, &buffer_payload)); io::BufferReader buffer_reader(buffer_payload); - return ReadRecordBatch(metadata, batch.schema(), &buffer_reader, batch_result); + return ReadRecordBatch(*metadata, batch.schema(), &buffer_reader, batch_result); } void CheckRoundtrip(const RecordBatch& batch, int64_t buffer_size) { @@ -112,14 +110,29 @@ class TestWriteRecordBatch : public ::testing::TestWithParam, MemoryPool* pool_; }; -TEST_P(TestWriteRecordBatch, RoundTrip) { +class TestWriteRecordBatch : public ::testing::Test, public IpcTestFixture { + public: + void SetUp() { pool_ = default_memory_pool(); } + void TearDown() { io::MemoryMapFixture::TearDown(); } +}; + +class TestRecordBatchParam : public ::testing::TestWithParam, + public IpcTestFixture { + public: + void SetUp() { pool_ = default_memory_pool(); } + void TearDown() { io::MemoryMapFixture::TearDown(); } + using IpcTestFixture::RoundTripHelper; + using IpcTestFixture::CheckRoundtrip; +}; + +TEST_P(TestRecordBatchParam, RoundTrip) { std::shared_ptr batch; ASSERT_OK((*GetParam())(&batch)); // NOLINT clang-tidy gtest issue CheckRoundtrip(*batch, 1 << 20); } -TEST_P(TestWriteRecordBatch, SliceRoundTrip) { +TEST_P(TestRecordBatchParam, SliceRoundTrip) { std::shared_ptr batch; ASSERT_OK((*GetParam())(&batch)); // NOLINT clang-tidy gtest issue @@ -130,7 +143,7 @@ TEST_P(TestWriteRecordBatch, SliceRoundTrip) { CheckRoundtrip(*sliced_batch, 1 << 20); } -TEST_P(TestWriteRecordBatch, ZeroLengthArrays) { +TEST_P(TestRecordBatchParam, ZeroLengthArrays) { std::shared_ptr batch; ASSERT_OK((*GetParam())(&batch)); // NOLINT clang-tidy gtest issue @@ -159,10 +172,10 @@ TEST_P(TestWriteRecordBatch, ZeroLengthArrays) { } INSTANTIATE_TEST_CASE_P( - RoundTripTests, TestWriteRecordBatch, + RoundTripTests, TestRecordBatchParam, ::testing::Values(&MakeIntRecordBatch, &MakeStringTypesRecordBatch, &MakeNonNullRecordBatch, &MakeZeroLengthRecordBatch, &MakeListRecordBatch, - &MakeDeeplyNestedList, &MakeStruct, &MakeUnion)); + &MakeDeeplyNestedList, &MakeStruct, &MakeUnion, &MakeDictionary)); void TestGetRecordBatchSize(std::shared_ptr batch) { ipc::MockOutputStream mock; @@ -251,8 +264,9 @@ TEST_F(RecursionLimits, ReadLimit) { std::shared_ptr schema; ASSERT_OK(WriteToMmap(64, true, &metadata_length, &body_length, &schema)); - std::shared_ptr metadata; - ASSERT_OK(ReadRecordBatchMetadata(0, metadata_length, mmap_.get(), &metadata)); + std::shared_ptr message; + ASSERT_OK(ReadMessage(0, metadata_length, mmap_.get(), &message)); + auto metadata = std::make_shared(message); std::shared_ptr payload; ASSERT_OK(mmap_->ReadAt(metadata_length, body_length, &payload)); @@ -260,7 +274,7 @@ TEST_F(RecursionLimits, ReadLimit) { io::BufferReader reader(payload); std::shared_ptr batch; - ASSERT_RAISES(Invalid, ReadRecordBatch(metadata, schema, &reader, &batch)); + ASSERT_RAISES(Invalid, ReadRecordBatch(*metadata, schema, &reader, &batch)); } } // namespace ipc diff --git a/cpp/src/arrow/ipc/ipc-file-test.cc b/cpp/src/arrow/ipc/ipc-file-test.cc index 7cd8054679e..4b82aab0e39 100644 --- a/cpp/src/arrow/ipc/ipc-file-test.cc +++ b/cpp/src/arrow/ipc/ipc-file-test.cc @@ -180,72 +180,44 @@ TEST_P(TestStreamFormat, RoundTrip) { #define BATCH_CASES() \ ::testing::Values(&MakeIntRecordBatch, &MakeListRecordBatch, &MakeNonNullRecordBatch, \ &MakeZeroLengthRecordBatch, &MakeDeeplyNestedList, &MakeStringTypesRecordBatch, \ - &MakeStruct); + &MakeStruct, &MakeDictionary); INSTANTIATE_TEST_CASE_P(FileRoundTripTests, TestFileFormat, BATCH_CASES()); INSTANTIATE_TEST_CASE_P(StreamRoundTripTests, TestStreamFormat, BATCH_CASES()); -class TestFileFooter : public ::testing::Test { - public: - void SetUp() {} - - void CheckRoundtrip(const Schema& schema, const std::vector& dictionaries, - const std::vector& record_batches) { - auto buffer = std::make_shared(); - io::BufferOutputStream stream(buffer); - - ASSERT_OK(WriteFileFooter(schema, dictionaries, record_batches, &stream)); - - std::unique_ptr footer; - ASSERT_OK(FileFooter::Open(buffer, &footer)); - - ASSERT_EQ(MetadataVersion::V2, footer->version()); +void CheckBatchDictionaries(const RecordBatch& batch) { + // Check that dictionaries that should be the same are the same + auto schema = batch.schema(); - // Check schema - std::shared_ptr schema2; - ASSERT_OK(footer->GetSchema(&schema2)); - AssertSchemaEqual(schema, *schema2); + const auto& t0 = static_cast(*schema->field(0)->type); + const auto& t1 = static_cast(*schema->field(1)->type); - // Check blocks - ASSERT_EQ(dictionaries.size(), footer->num_dictionaries()); - ASSERT_EQ(record_batches.size(), footer->num_record_batches()); + ASSERT_EQ(t0.dictionary().get(), t1.dictionary().get()); - for (int i = 0; i < footer->num_dictionaries(); ++i) { - CheckBlocks(dictionaries[i], footer->dictionary(i)); - } - - for (int i = 0; i < footer->num_record_batches(); ++i) { - CheckBlocks(record_batches[i], footer->record_batch(i)); - } - } + // Same dictionary used for list values + const auto& t3 = static_cast(*schema->field(3)->type); + const auto& t3_value = static_cast(*t3.value_type()); + ASSERT_EQ(t0.dictionary().get(), t3_value.dictionary().get()); +} - void CheckBlocks(const FileBlock& left, const FileBlock& right) { - ASSERT_EQ(left.offset, right.offset); - ASSERT_EQ(left.metadata_length, right.metadata_length); - ASSERT_EQ(left.body_length, right.body_length); - } +TEST_F(TestStreamFormat, DictionaryRoundTrip) { + std::shared_ptr batch; + ASSERT_OK(MakeDictionary(&batch)); - private: - std::shared_ptr example_schema_; -}; + std::vector> out_batches; + ASSERT_OK(RoundTripHelper(*batch, &out_batches)); -TEST_F(TestFileFooter, Basics) { - auto f0 = std::make_shared("f0", std::make_shared()); - auto f1 = std::make_shared("f1", std::make_shared()); - Schema schema({f0, f1}); + CheckBatchDictionaries(*out_batches[0]); +} - std::vector dictionaries; - dictionaries.emplace_back(8, 92, 900); - dictionaries.emplace_back(1000, 100, 1900); - dictionaries.emplace_back(3000, 100, 2900); +TEST_F(TestFileFormat, DictionaryRoundTrip) { + std::shared_ptr batch; + ASSERT_OK(MakeDictionary(&batch)); - std::vector record_batches; - record_batches.emplace_back(6000, 100, 900); - record_batches.emplace_back(7000, 100, 1900); - record_batches.emplace_back(9000, 100, 2900); - record_batches.emplace_back(12000, 100, 3900); + std::vector> out_batches; + ASSERT_OK(RoundTripHelper({batch}, &out_batches)); - CheckRoundtrip(schema, dictionaries, record_batches); + CheckBatchDictionaries(*out_batches[0]); } } // namespace ipc diff --git a/cpp/src/arrow/ipc/ipc-metadata-test.cc b/cpp/src/arrow/ipc/ipc-metadata-test.cc index 098f996d292..4fb3204a5b6 100644 --- a/cpp/src/arrow/ipc/ipc-metadata-test.cc +++ b/cpp/src/arrow/ipc/ipc-metadata-test.cc @@ -22,6 +22,7 @@ #include "gtest/gtest.h" #include "arrow/io/memory.h" +#include "arrow/ipc/metadata-internal.h" #include "arrow/ipc/metadata.h" #include "arrow/ipc/test-common.h" #include "arrow/schema.h" @@ -39,9 +40,9 @@ class TestSchemaMetadata : public ::testing::Test { public: void SetUp() {} - void CheckRoundtrip(const Schema& schema) { + void CheckRoundtrip(const Schema& schema, DictionaryMemo* memo) { std::shared_ptr buffer; - ASSERT_OK(WriteSchema(schema, &buffer)); + ASSERT_OK(WriteSchemaMessage(schema, memo, &buffer)); std::shared_ptr message; ASSERT_OK(Message::Open(buffer, 0, &message)); @@ -51,8 +52,10 @@ class TestSchemaMetadata : public ::testing::Test { auto schema_msg = std::make_shared(message); ASSERT_EQ(schema.num_fields(), schema_msg->num_fields()); + DictionaryMemo empty_memo; + std::shared_ptr schema2; - ASSERT_OK(schema_msg->GetSchema(&schema2)); + ASSERT_OK(schema_msg->GetSchema(empty_memo, &schema2)); AssertSchemaEqual(schema, *schema2); } @@ -74,7 +77,9 @@ TEST_F(TestSchemaMetadata, PrimitiveFields) { auto f10 = std::make_shared("f10", std::make_shared()); Schema schema({f0, f1, f2, f3, f4, f5, f6, f7, f8, f9, f10}); - CheckRoundtrip(schema); + DictionaryMemo memo; + + CheckRoundtrip(schema, &memo); } TEST_F(TestSchemaMetadata, NestedFields) { @@ -86,7 +91,9 @@ TEST_F(TestSchemaMetadata, NestedFields) { auto f1 = std::make_shared("f1", type2); Schema schema({f0, f1}); - CheckRoundtrip(schema); + DictionaryMemo memo; + + CheckRoundtrip(schema, &memo); } } // namespace ipc diff --git a/cpp/src/arrow/ipc/metadata-internal.cc b/cpp/src/arrow/ipc/metadata-internal.cc index cd7722056a3..7c8ddb93c09 100644 --- a/cpp/src/arrow/ipc/metadata-internal.cc +++ b/cpp/src/arrow/ipc/metadata-internal.cc @@ -25,6 +25,7 @@ #include "flatbuffers/flatbuffers.h" +#include "arrow/array.h" #include "arrow/buffer.h" #include "arrow/ipc/Message_generated.h" #include "arrow/schema.h" @@ -115,8 +116,8 @@ static Status TypeFromFlatbuffer(flatbuf::Type type, const void* type_data, } // Forward declaration -static Status FieldToFlatbuffer( - FBB& fbb, const std::shared_ptr& field, FieldOffset* offset); +static Status FieldToFlatbuffer(FBB& fbb, const std::shared_ptr& field, + DictionaryMemo* dictionary_memo, FieldOffset* offset); static Offset IntToFlatbuffer(FBB& fbb, int bitWidth, bool is_signed) { return flatbuf::CreateInt(fbb, bitWidth, is_signed).Union(); @@ -126,34 +127,73 @@ static Offset FloatToFlatbuffer(FBB& fbb, flatbuf::Precision precision) { return flatbuf::CreateFloatingPoint(fbb, precision).Union(); } -static Status ListToFlatbuffer(FBB& fbb, const std::shared_ptr& type, - std::vector* out_children, Offset* offset) { +static Status AppendChildFields(FBB& fbb, const std::shared_ptr& type, + std::vector* out_children, DictionaryMemo* dictionary_memo) { FieldOffset field; - RETURN_NOT_OK(FieldToFlatbuffer(fbb, type->child(0), &field)); - out_children->push_back(field); + for (int i = 0; i < type->num_children(); ++i) { + RETURN_NOT_OK(FieldToFlatbuffer(fbb, type->child(i), dictionary_memo, &field)); + out_children->push_back(field); + } + return Status::OK(); +} + +static Status ListToFlatbuffer(FBB& fbb, const std::shared_ptr& type, + std::vector* out_children, DictionaryMemo* dictionary_memo, + Offset* offset) { + RETURN_NOT_OK(AppendChildFields(fbb, type, out_children, dictionary_memo)); *offset = flatbuf::CreateList(fbb).Union(); return Status::OK(); } static Status StructToFlatbuffer(FBB& fbb, const std::shared_ptr& type, - std::vector* out_children, Offset* offset) { - FieldOffset field; - for (int i = 0; i < type->num_children(); ++i) { - RETURN_NOT_OK(FieldToFlatbuffer(fbb, type->child(i), &field)); - out_children->push_back(field); - } + std::vector* out_children, DictionaryMemo* dictionary_memo, + Offset* offset) { + RETURN_NOT_OK(AppendChildFields(fbb, type, out_children, dictionary_memo)); *offset = flatbuf::CreateStruct_(fbb).Union(); return Status::OK(); } +static Status UnionToFlatBuffer(FBB& fbb, const std::shared_ptr& type, + std::vector* out_children, DictionaryMemo* dictionary_memo, + Offset* offset) { + RETURN_NOT_OK(AppendChildFields(fbb, type, out_children, dictionary_memo)); + + const auto& union_type = static_cast(*type); + + flatbuf::UnionMode mode = union_type.mode == UnionMode::SPARSE + ? flatbuf::UnionMode_Sparse + : flatbuf::UnionMode_Dense; + + std::vector type_ids; + type_ids.reserve(union_type.type_codes.size()); + for (uint8_t code : union_type.type_codes) { + type_ids.push_back(code); + } + + auto fb_type_ids = fbb.CreateVector(type_ids); + + *offset = flatbuf::CreateUnion(fbb, mode, fb_type_ids).Union(); + return Status::OK(); +} + #define INT_TO_FB_CASE(BIT_WIDTH, IS_SIGNED) \ *out_type = flatbuf::Type_Int; \ *offset = IntToFlatbuffer(fbb, BIT_WIDTH, IS_SIGNED); \ break; +// TODO(wesm): Convert this to visitor pattern static Status TypeToFlatbuffer(FBB& fbb, const std::shared_ptr& type, std::vector* children, std::vector* layout, - flatbuf::Type* out_type, Offset* offset) { + flatbuf::Type* out_type, DictionaryMemo* dictionary_memo, Offset* offset) { + if (type->type == Type::DICTIONARY) { + // In this library, the dictionary "type" is a logical construct. Here we + // pass through to the value type, as we've already captured the index + // type in the DictionaryEncoding metadata in the parent field + const auto& dict_type = static_cast(*type); + return TypeToFlatbuffer(fbb, dict_type.dictionary()->type(), children, layout, + out_type, dictionary_memo, offset); + } + std::vector buffer_layout = type->GetBufferLayout(); for (const BufferDescr& descr : buffer_layout) { flatbuf::VectorType vector_type; @@ -217,10 +257,13 @@ static Status TypeToFlatbuffer(FBB& fbb, const std::shared_ptr& type, break; case Type::LIST: *out_type = flatbuf::Type_List; - return ListToFlatbuffer(fbb, type, children, offset); + return ListToFlatbuffer(fbb, type, children, dictionary_memo, offset); case Type::STRUCT: *out_type = flatbuf::Type_Struct_; - return StructToFlatbuffer(fbb, type, children, offset); + return StructToFlatbuffer(fbb, type, children, dictionary_memo, offset); + case Type::UNION: + *out_type = flatbuf::Type_Union; + return UnionToFlatBuffer(fbb, type, children, dictionary_memo, offset); default: *out_type = flatbuf::Type_NONE; // Make clang-tidy happy std::stringstream ss; @@ -230,35 +273,63 @@ static Status TypeToFlatbuffer(FBB& fbb, const std::shared_ptr& type, return Status::OK(); } -static Status FieldToFlatbuffer( - FBB& fbb, const std::shared_ptr& field, FieldOffset* offset) { +using DictionaryOffset = flatbuffers::Offset; + +static DictionaryOffset GetDictionaryEncoding( + FBB& fbb, const DictionaryType& type, DictionaryMemo* memo) { + int64_t dictionary_id = memo->GetId(type.dictionary()); + + // We assume that the dictionary index type (as an integer) has already been + // validated elsewhere, and can safely assume we are dealing with signed + // integers + const auto& fw_index_type = static_cast(*type.index_type()); + + auto index_type_offset = flatbuf::CreateInt(fbb, fw_index_type.bit_width(), true); + + // TODO(wesm): ordered dictionaries + return flatbuf::CreateDictionaryEncoding(fbb, dictionary_id, index_type_offset); +} + +static Status FieldToFlatbuffer(FBB& fbb, const std::shared_ptr& field, + DictionaryMemo* dictionary_memo, FieldOffset* offset) { auto fb_name = fbb.CreateString(field->name); flatbuf::Type type_enum; - Offset type_data; + Offset type_offset; Offset type_layout; std::vector children; std::vector layout; - RETURN_NOT_OK( - TypeToFlatbuffer(fbb, field->type, &children, &layout, &type_enum, &type_data)); + RETURN_NOT_OK(TypeToFlatbuffer( + fbb, field->type, &children, &layout, &type_enum, dictionary_memo, &type_offset)); auto fb_children = fbb.CreateVector(children); auto fb_layout = fbb.CreateVector(layout); + DictionaryOffset dictionary = 0; + if (field->type->type == Type::DICTIONARY) { + dictionary = GetDictionaryEncoding( + fbb, static_cast(*field->type), dictionary_memo); + } + // TODO: produce the list of VectorTypes - *offset = flatbuf::CreateField(fbb, fb_name, field->nullable, type_enum, type_data, - field->dictionary, fb_children, fb_layout); + *offset = flatbuf::CreateField(fbb, fb_name, field->nullable, type_enum, type_offset, + dictionary, fb_children, fb_layout); return Status::OK(); } -Status FieldFromFlatbuffer(const flatbuf::Field* field, std::shared_ptr* out) { - std::shared_ptr type; +Status FieldFromFlatbufferDictionary( + const flatbuf::Field* field, std::shared_ptr* out) { + // Need an empty memo to pass down for constructing children + DictionaryMemo dummy_memo; + + // Any DictionaryEncoding set is ignored here + std::shared_ptr type; auto children = field->children(); std::vector> child_fields(children->size()); for (size_t i = 0; i < children->size(); ++i) { - RETURN_NOT_OK(FieldFromFlatbuffer(children->Get(i), &child_fields[i])); + RETURN_NOT_OK(FieldFromFlatbuffer(children->Get(i), dummy_memo, &child_fields[i])); } RETURN_NOT_OK( @@ -268,6 +339,39 @@ Status FieldFromFlatbuffer(const flatbuf::Field* field, std::shared_ptr* return Status::OK(); } +Status FieldFromFlatbuffer(const flatbuf::Field* field, + const DictionaryMemo& dictionary_memo, std::shared_ptr* out) { + std::shared_ptr type; + + const flatbuf::DictionaryEncoding* encoding = field->dictionary(); + + if (encoding == nullptr) { + // The field is not dictionary encoded. We must potentially visit its + // children to fully reconstruct the data type + auto children = field->children(); + std::vector> child_fields(children->size()); + for (size_t i = 0; i < children->size(); ++i) { + RETURN_NOT_OK( + FieldFromFlatbuffer(children->Get(i), dictionary_memo, &child_fields[i])); + } + RETURN_NOT_OK( + TypeFromFlatbuffer(field->type_type(), field->type(), child_fields, &type)); + } else { + // The field is dictionary encoded. The type of the dictionary values has + // been determined elsewhere, and is stored in the DictionaryMemo. Here we + // construct the logical DictionaryType object + + std::shared_ptr dictionary; + RETURN_NOT_OK(dictionary_memo.GetDictionary(encoding->id(), &dictionary)); + + std::shared_ptr index_type; + RETURN_NOT_OK(IntFromFlatbuffer(encoding->indexType(), &index_type)); + type = std::make_shared(index_type, dictionary); + } + *out = std::make_shared(field->name()->str(), type, field->nullable()); + return Status::OK(); +} + // Implement MessageBuilder // will return the endianness of the system we are running on @@ -281,13 +385,13 @@ flatbuf::Endianness endianness() { return bint.c[0] == 1 ? flatbuf::Endianness_Big : flatbuf::Endianness_Little; } -Status SchemaToFlatbuffer( - FBB& fbb, const Schema& schema, flatbuffers::Offset* out) { +Status SchemaToFlatbuffer(FBB& fbb, const Schema& schema, DictionaryMemo* dictionary_memo, + flatbuffers::Offset* out) { std::vector field_offsets; for (int i = 0; i < schema.num_fields(); ++i) { std::shared_ptr field = schema.field(i); FieldOffset offset; - RETURN_NOT_OK(FieldToFlatbuffer(fbb, field, &offset)); + RETURN_NOT_OK(FieldToFlatbuffer(fbb, field, dictionary_memo, &offset)); field_offsets.push_back(offset); } @@ -295,29 +399,63 @@ Status SchemaToFlatbuffer( return Status::OK(); } -Status MessageBuilder::SetSchema(const Schema& schema) { - flatbuffers::Offset fb_schema; - RETURN_NOT_OK(SchemaToFlatbuffer(fbb_, schema, &fb_schema)); +class MessageBuilder { + public: + Status SetSchema(const Schema& schema, DictionaryMemo* dictionary_memo) { + flatbuffers::Offset fb_schema; + RETURN_NOT_OK(SchemaToFlatbuffer(fbb_, schema, dictionary_memo, &fb_schema)); - header_type_ = flatbuf::MessageHeader_Schema; - header_ = fb_schema.Union(); - body_length_ = 0; - return Status::OK(); -} + header_type_ = flatbuf::MessageHeader_Schema; + header_ = fb_schema.Union(); + body_length_ = 0; + return Status::OK(); + } -Status MessageBuilder::SetRecordBatch(int32_t length, int64_t body_length, - const std::vector& nodes, - const std::vector& buffers) { - header_type_ = flatbuf::MessageHeader_RecordBatch; - header_ = flatbuf::CreateRecordBatch(fbb_, length, fbb_.CreateVectorOfStructs(nodes), - fbb_.CreateVectorOfStructs(buffers)) - .Union(); - body_length_ = body_length; + Status SetRecordBatch(int32_t length, int64_t body_length, + const std::vector& nodes, + const std::vector& buffers) { + header_type_ = flatbuf::MessageHeader_RecordBatch; + header_ = flatbuf::CreateRecordBatch(fbb_, length, fbb_.CreateVectorOfStructs(nodes), + fbb_.CreateVectorOfStructs(buffers)) + .Union(); + body_length_ = body_length; - return Status::OK(); + return Status::OK(); + } + + Status SetDictionary(int64_t id, int32_t length, int64_t body_length, + const std::vector& nodes, + const std::vector& buffers) { + header_type_ = flatbuf::MessageHeader_DictionaryBatch; + + auto record_batch = flatbuf::CreateRecordBatch(fbb_, length, + fbb_.CreateVectorOfStructs(nodes), fbb_.CreateVectorOfStructs(buffers)); + + header_ = flatbuf::CreateDictionaryBatch(fbb_, id, record_batch).Union(); + body_length_ = body_length; + return Status::OK(); + } + + Status Finish(); + + Status GetBuffer(std::shared_ptr* out); + + private: + flatbuf::MessageHeader header_type_; + flatbuffers::Offset header_; + int64_t body_length_; + flatbuffers::FlatBufferBuilder fbb_; +}; + +Status WriteSchemaMessage( + const Schema& schema, DictionaryMemo* dictionary_memo, std::shared_ptr* out) { + MessageBuilder message; + RETURN_NOT_OK(message.SetSchema(schema, dictionary_memo)); + RETURN_NOT_OK(message.Finish()); + return message.GetBuffer(out); } -Status WriteRecordBatchMetadata(int32_t length, int64_t body_length, +Status WriteRecordBatchMessage(int32_t length, int64_t body_length, const std::vector& nodes, const std::vector& buffers, std::shared_ptr* out) { MessageBuilder builder; @@ -326,6 +464,15 @@ Status WriteRecordBatchMetadata(int32_t length, int64_t body_length, return builder.GetBuffer(out); } +Status WriteDictionaryMessage(int64_t id, int32_t length, int64_t body_length, + const std::vector& nodes, + const std::vector& buffers, std::shared_ptr* out) { + MessageBuilder builder; + RETURN_NOT_OK(builder.SetDictionary(id, length, body_length, nodes, buffers)); + RETURN_NOT_OK(builder.Finish()); + return builder.GetBuffer(out); +} + Status MessageBuilder::Finish() { auto message = flatbuf::CreateMessage(fbb_, kMetadataVersion, header_type_, header_, body_length_); diff --git a/cpp/src/arrow/ipc/metadata-internal.h b/cpp/src/arrow/ipc/metadata-internal.h index d94a8abc99a..59afecbcbd2 100644 --- a/cpp/src/arrow/ipc/metadata-internal.h +++ b/cpp/src/arrow/ipc/metadata-internal.h @@ -46,31 +46,34 @@ using Offset = flatbuffers::Offset; static constexpr flatbuf::MetadataVersion kMetadataVersion = flatbuf::MetadataVersion_V2; -Status FieldFromFlatbuffer(const flatbuf::Field* field, std::shared_ptr* out); +// Construct a field with type for a dictionary-encoded field. None of its +// children or children's descendents can be dictionary encoded +Status FieldFromFlatbufferDictionary( + const flatbuf::Field* field, std::shared_ptr* out); -Status SchemaToFlatbuffer( - FBB& fbb, const Schema& schema, flatbuffers::Offset* out); +// Construct a field for a non-dictionary-encoded field. Its children may be +// dictionary encoded +Status FieldFromFlatbuffer(const flatbuf::Field* field, + const DictionaryMemo& dictionary_memo, std::shared_ptr* out); -class MessageBuilder { - public: - Status SetSchema(const Schema& schema); +Status SchemaToFlatbuffer(FBB& fbb, const Schema& schema, DictionaryMemo* dictionary_memo, + flatbuffers::Offset* out); - Status SetRecordBatch(int32_t length, int64_t body_length, - const std::vector& nodes, - const std::vector& buffers); - - Status Finish(); - - Status GetBuffer(std::shared_ptr* out); - - private: - flatbuf::MessageHeader header_type_; - flatbuffers::Offset header_; - int64_t body_length_; - flatbuffers::FlatBufferBuilder fbb_; -}; +// Serialize arrow::Schema as a Flatbuffer +// +// \param[in] schema a Schema instance +// \param[inout] dictionary_memo class for tracking dictionaries and assigning +// dictionary ids +// \param[out] out the serialized arrow::Buffer +// \return Status outcome +Status WriteSchemaMessage( + const Schema& schema, DictionaryMemo* dictionary_memo, std::shared_ptr* out); + +Status WriteRecordBatchMessage(int32_t length, int64_t body_length, + const std::vector& nodes, + const std::vector& buffers, std::shared_ptr* out); -Status WriteRecordBatchMetadata(int32_t length, int64_t body_length, +Status WriteDictionaryMessage(int64_t id, int32_t length, int64_t body_length, const std::vector& nodes, const std::vector& buffers, std::shared_ptr* out); diff --git a/cpp/src/arrow/ipc/metadata.cc b/cpp/src/arrow/ipc/metadata.cc index a97965c40d6..2ba44ac618c 100644 --- a/cpp/src/arrow/ipc/metadata.cc +++ b/cpp/src/arrow/ipc/metadata.cc @@ -19,6 +19,7 @@ #include #include +#include #include #include "flatbuffers/flatbuffers.h" @@ -38,11 +39,60 @@ namespace flatbuf = org::apache::arrow::flatbuf; namespace ipc { -Status WriteSchema(const Schema& schema, std::shared_ptr* out) { - MessageBuilder message; - RETURN_NOT_OK(message.SetSchema(schema)); - RETURN_NOT_OK(message.Finish()); - return message.GetBuffer(out); +// ---------------------------------------------------------------------- +// Memoization data structure for handling shared dictionaries + +DictionaryMemo::DictionaryMemo() {} + +// Returns KeyError if dictionary not found +Status DictionaryMemo::GetDictionary( + int64_t id, std::shared_ptr* dictionary) const { + auto it = id_to_dictionary_.find(id); + if (it == id_to_dictionary_.end()) { + std::stringstream ss; + ss << "Dictionary with id " << id << " not found"; + return Status::KeyError(ss.str()); + } + *dictionary = it->second; + return Status::OK(); +} + +int64_t DictionaryMemo::GetId(const std::shared_ptr& dictionary) { + intptr_t address = reinterpret_cast(dictionary.get()); + auto it = dictionary_to_id_.find(address); + if (it != dictionary_to_id_.end()) { + // Dictionary already observed, return the id + return it->second; + } else { + int64_t new_id = static_cast(dictionary_to_id_.size()) + 1; + dictionary_to_id_[address] = new_id; + id_to_dictionary_[new_id] = dictionary; + return new_id; + } +} + +bool DictionaryMemo::HasDictionary(const std::shared_ptr& dictionary) const { + intptr_t address = reinterpret_cast(dictionary.get()); + auto it = dictionary_to_id_.find(address); + return it != dictionary_to_id_.end(); +} + +bool DictionaryMemo::HasDictionaryId(int64_t id) const { + auto it = id_to_dictionary_.find(id); + return it != id_to_dictionary_.end(); +} + +Status DictionaryMemo::AddDictionary( + int64_t id, const std::shared_ptr& dictionary) { + if (HasDictionaryId(id)) { + std::stringstream ss; + ss << "Dictionary with id " << id << " already exists"; + return Status::KeyError(ss.str()); + } + intptr_t address = reinterpret_cast(dictionary.get()); + id_to_dictionary_[id] = dictionary; + dictionary_to_id_[address] = id; + return Status::OK(); } //---------------------------------------------------------------------- @@ -113,10 +163,35 @@ class SchemaMetadata::SchemaMetadataImpl { explicit SchemaMetadataImpl(const void* schema) : schema_(static_cast(schema)) {} - const flatbuf::Field* field(int i) const { return schema_->fields()->Get(i); } + const flatbuf::Field* get_field(int i) const { return schema_->fields()->Get(i); } int num_fields() const { return schema_->fields()->size(); } + Status VisitField(const flatbuf::Field* field, DictionaryTypeMap* id_to_field) const { + const flatbuf::DictionaryEncoding* dict_metadata = field->dictionary(); + if (dict_metadata == nullptr) { + // Field is not dictionary encoded. Visit children + auto children = field->children(); + for (flatbuffers::uoffset_t i = 0; i < children->size(); ++i) { + RETURN_NOT_OK(VisitField(children->Get(i), id_to_field)); + } + } else { + // Field is dictionary encoded. Construct the data type for the + // dictionary (no descendents can be dictionary encoded) + std::shared_ptr dictionary_field; + RETURN_NOT_OK(FieldFromFlatbufferDictionary(field, &dictionary_field)); + (*id_to_field)[dict_metadata->id()] = dictionary_field; + } + return Status::OK(); + } + + Status GetDictionaryTypes(DictionaryTypeMap* id_to_field) const { + for (int i = 0; i < num_fields(); ++i) { + RETURN_NOT_OK(VisitField(get_field(i), id_to_field)); + } + return Status::OK(); + } + private: const flatbuf::Schema* schema_; }; @@ -138,15 +213,16 @@ int SchemaMetadata::num_fields() const { return impl_->num_fields(); } -Status SchemaMetadata::GetField(int i, std::shared_ptr* out) const { - const flatbuf::Field* field = impl_->field(i); - return FieldFromFlatbuffer(field, out); +Status SchemaMetadata::GetDictionaryTypes(DictionaryTypeMap* id_to_field) const { + return impl_->GetDictionaryTypes(id_to_field); } -Status SchemaMetadata::GetSchema(std::shared_ptr* out) const { +Status SchemaMetadata::GetSchema( + const DictionaryMemo& dictionary_memo, std::shared_ptr* out) const { std::vector> fields(num_fields()); for (int i = 0; i < this->num_fields(); ++i) { - RETURN_NOT_OK(GetField(i, &fields[i])); + const flatbuf::Field* field = impl_->get_field(i); + RETURN_NOT_OK(FieldFromFlatbuffer(field, dictionary_memo, &fields[i])); } *out = std::make_shared(fields); return Status::OK(); @@ -173,28 +249,34 @@ class RecordBatchMetadata::RecordBatchMetadataImpl { int num_fields() const { return batch_->nodes()->size(); } + void set_message(const std::shared_ptr& message) { message_ = message; } + + void set_buffer(const std::shared_ptr& buffer) { buffer_ = buffer; } + private: const flatbuf::RecordBatch* batch_; const flatbuffers::Vector* nodes_; const flatbuffers::Vector* buffers_; + + // Possible parents, owns the flatbuffer data + std::shared_ptr message_; + std::shared_ptr buffer_; }; RecordBatchMetadata::RecordBatchMetadata(const std::shared_ptr& message) { - message_ = message; impl_.reset(new RecordBatchMetadataImpl(message->impl_->header())); + impl_->set_message(message); } -RecordBatchMetadata::RecordBatchMetadata( - const std::shared_ptr& buffer, int64_t offset) { - message_ = nullptr; - buffer_ = buffer; - - const flatbuf::RecordBatch* metadata = - flatbuffers::GetRoot(buffer->data() + offset); - - // TODO(wesm): validate table +RecordBatchMetadata::RecordBatchMetadata(const void* header) { + impl_.reset(new RecordBatchMetadataImpl(header)); +} - impl_.reset(new RecordBatchMetadataImpl(metadata)); +RecordBatchMetadata::RecordBatchMetadata( + const std::shared_ptr& buffer, int64_t offset) + : RecordBatchMetadata(buffer->data() + offset) { + // Preserve ownership + impl_->set_buffer(buffer); } RecordBatchMetadata::~RecordBatchMetadata() {} @@ -232,5 +314,64 @@ int RecordBatchMetadata::num_fields() const { return impl_->num_fields(); } +// ---------------------------------------------------------------------- +// DictionaryBatchMetadata + +class DictionaryBatchMetadata::DictionaryBatchMetadataImpl { + public: + explicit DictionaryBatchMetadataImpl(const void* dictionary) + : metadata_(static_cast(dictionary)) { + record_batch_.reset(new RecordBatchMetadata(metadata_->data())); + } + + int64_t id() const { return metadata_->id(); } + const RecordBatchMetadata& record_batch() const { return *record_batch_; } + + void set_message(const std::shared_ptr& message) { message_ = message; } + + private: + const flatbuf::DictionaryBatch* metadata_; + + std::unique_ptr record_batch_; + + // Parent, owns the flatbuffer data + std::shared_ptr message_; +}; + +DictionaryBatchMetadata::DictionaryBatchMetadata( + const std::shared_ptr& message) { + impl_.reset(new DictionaryBatchMetadataImpl(message->impl_->header())); + impl_->set_message(message); +} + +DictionaryBatchMetadata::~DictionaryBatchMetadata() {} + +int64_t DictionaryBatchMetadata::id() const { + return impl_->id(); +} + +const RecordBatchMetadata& DictionaryBatchMetadata::record_batch() const { + return impl_->record_batch(); +} + +// ---------------------------------------------------------------------- +// Conveniences + +Status ReadMessage(int64_t offset, int32_t metadata_length, + io::ReadableFileInterface* file, std::shared_ptr* message) { + std::shared_ptr buffer; + RETURN_NOT_OK(file->ReadAt(offset, metadata_length, &buffer)); + + int32_t flatbuffer_size = *reinterpret_cast(buffer->data()); + + if (flatbuffer_size + static_cast(sizeof(int32_t)) > metadata_length) { + std::stringstream ss; + ss << "flatbuffer size " << metadata_length << " invalid. File offset: " << offset + << ", metadata length: " << metadata_length; + return Status::Invalid(ss.str()); + } + return Message::Open(buffer, 4, message); +} + } // namespace ipc } // namespace arrow diff --git a/cpp/src/arrow/ipc/metadata.h b/cpp/src/arrow/ipc/metadata.h index 81e3dbdf6c4..0091067c322 100644 --- a/cpp/src/arrow/ipc/metadata.h +++ b/cpp/src/arrow/ipc/metadata.h @@ -22,13 +22,17 @@ #include #include +#include #include +#include "arrow/util/macros.h" #include "arrow/util/visibility.h" namespace arrow { +class Array; class Buffer; +struct DataType; struct Field; class Schema; class Status; @@ -36,6 +40,7 @@ class Status; namespace io { class OutputStream; +class ReadableFileInterface; } // namespace io @@ -47,9 +52,38 @@ struct MetadataVersion { //---------------------------------------------------------------------- -// Serialize arrow::Schema as a Flatbuffer -ARROW_EXPORT -Status WriteSchema(const Schema& schema, std::shared_ptr* out); +using DictionaryMap = std::unordered_map>; +using DictionaryTypeMap = std::unordered_map>; + +// Memoization data structure for handling shared dictionaries +class DictionaryMemo { + public: + DictionaryMemo(); + + // Returns KeyError if dictionary not found + Status GetDictionary(int64_t id, std::shared_ptr* dictionary) const; + + int64_t GetId(const std::shared_ptr& dictionary); + + bool HasDictionary(const std::shared_ptr& dictionary) const; + bool HasDictionaryId(int64_t id) const; + + // Add a dictionary to the memo with a particular id. Returns KeyError if + // that dictionary already exists + Status AddDictionary(int64_t id, const std::shared_ptr& dictionary); + + const DictionaryMap& id_to_dictionary() const { return id_to_dictionary_; } + + private: + // Dictionary memory addresses, to track whether a dictionary has been seen + // before + std::unordered_map dictionary_to_id_; + + // Map of dictionary id to dictionary array + DictionaryMap id_to_dictionary_; + + DISALLOW_COPY_AND_ASSIGN(DictionaryMemo); +}; // Read interface classes. We do not fully deserialize the flatbuffers so that // individual fields metadata can be retrieved from very large schema without @@ -69,12 +103,15 @@ class ARROW_EXPORT SchemaMetadata { int num_fields() const; - // Construct an arrow::Field for the i-th value in the metadata - Status GetField(int i, std::shared_ptr* out) const; + // Retrieve a list of all the dictionary ids and types required by the schema for + // reconstruction. The presumption is that these will be loaded either from + // the stream or file (or they may already be somewhere else in memory) + Status GetDictionaryTypes(DictionaryTypeMap* id_to_field) const; // Construct a complete Schema from the message. May be expensive for very // large schemas if you are only interested in a few fields - Status GetSchema(std::shared_ptr* out) const; + Status GetSchema( + const DictionaryMemo& dictionary_memo, std::shared_ptr* out) const; private: // Parent, owns the flatbuffer data @@ -82,6 +119,8 @@ class ARROW_EXPORT SchemaMetadata { class SchemaMetadataImpl; std::unique_ptr impl_; + + DISALLOW_COPY_AND_ASSIGN(SchemaMetadata); }; // Field metadata @@ -99,8 +138,10 @@ struct ARROW_EXPORT BufferMetadata { // Container for serialized record batch metadata contained in an IPC message class ARROW_EXPORT RecordBatchMetadata { public: + // Instantiate from opaque pointer. Memory ownership must be preserved + // elsewhere (e.g. in a dictionary batch) + explicit RecordBatchMetadata(const void* header); explicit RecordBatchMetadata(const std::shared_ptr& message); - RecordBatchMetadata(const std::shared_ptr& message, int64_t offset); ~RecordBatchMetadata(); @@ -113,18 +154,25 @@ class ARROW_EXPORT RecordBatchMetadata { int num_fields() const; private: - // Parent, owns the flatbuffer data - std::shared_ptr message_; - std::shared_ptr buffer_; - class RecordBatchMetadataImpl; std::unique_ptr impl_; + + DISALLOW_COPY_AND_ASSIGN(RecordBatchMetadata); }; class ARROW_EXPORT DictionaryBatchMetadata { public: + explicit DictionaryBatchMetadata(const std::shared_ptr& message); + ~DictionaryBatchMetadata(); + int64_t id() const; - std::unique_ptr data() const; + const RecordBatchMetadata& record_batch() const; + + private: + class DictionaryBatchMetadataImpl; + std::unique_ptr impl_; + + DISALLOW_COPY_AND_ASSIGN(DictionaryBatchMetadata); }; class ARROW_EXPORT Message { @@ -141,24 +189,31 @@ class ARROW_EXPORT Message { private: Message(const std::shared_ptr& buffer, int64_t offset); + friend class DictionaryBatchMetadata; friend class RecordBatchMetadata; friend class SchemaMetadata; // Hide serialization details from user API class MessageImpl; std::unique_ptr impl_; -}; -struct ARROW_EXPORT FileBlock { - FileBlock() {} - FileBlock(int64_t offset, int32_t metadata_length, int64_t body_length) - : offset(offset), metadata_length(metadata_length), body_length(body_length) {} - - int64_t offset; - int32_t metadata_length; - int64_t body_length; + DISALLOW_COPY_AND_ASSIGN(Message); }; +/// Read a length-prefixed message flatbuffer starting at the indicated file +/// offset +/// +/// The metadata_length includes at least the length prefix and the flatbuffer +/// +/// \param[in] offset the position in the file where the message starts. The +/// first 4 bytes after the offset are the message length +/// \param[in] metadata_length the total number of bytes to read from file +/// \param[in] file the seekable file interface to read from +/// \param[out] message the message read +/// \return Status success or failure +Status ReadMessage(int64_t offset, int32_t metadata_length, + io::ReadableFileInterface* file, std::shared_ptr* message); + } // namespace ipc } // namespace arrow diff --git a/cpp/src/arrow/ipc/stream.cc b/cpp/src/arrow/ipc/stream.cc index 72eb13465af..7f5c9932330 100644 --- a/cpp/src/arrow/ipc/stream.cc +++ b/cpp/src/arrow/ipc/stream.cc @@ -20,17 +20,20 @@ #include #include #include +#include #include #include "arrow/buffer.h" #include "arrow/io/interfaces.h" #include "arrow/io/memory.h" #include "arrow/ipc/adapter.h" +#include "arrow/ipc/metadata-internal.h" #include "arrow/ipc/metadata.h" #include "arrow/ipc/util.h" #include "arrow/memory_pool.h" #include "arrow/schema.h" #include "arrow/status.h" +#include "arrow/table.h" #include "arrow/util/logging.h" namespace arrow { @@ -39,11 +42,10 @@ namespace ipc { // ---------------------------------------------------------------------- // Stream writer implementation -StreamWriter::~StreamWriter() {} - StreamWriter::StreamWriter(io::OutputStream* sink, const std::shared_ptr& schema) : sink_(sink), schema_(schema), + dictionary_memo_(std::make_shared()), pool_(default_memory_pool()), position_(-1), started_(false) {} @@ -107,7 +109,7 @@ Status StreamWriter::Open(io::OutputStream* sink, const std::shared_ptr& Status StreamWriter::Start() { std::shared_ptr schema_fb; - RETURN_NOT_OK(WriteSchema(*schema_, &schema_fb)); + RETURN_NOT_OK(WriteSchemaMessage(*schema_, dictionary_memo_.get(), &schema_fb)); int32_t flatbuffer_size = schema_fb->size(); RETURN_NOT_OK( @@ -115,14 +117,41 @@ Status StreamWriter::Start() { // Write the flatbuffer RETURN_NOT_OK(Write(schema_fb->data(), flatbuffer_size)); + + // If there are any dictionaries, write them as the next messages + RETURN_NOT_OK(WriteDictionaries()); + started_ = true; return Status::OK(); } Status StreamWriter::WriteRecordBatch(const RecordBatch& batch) { - // Pass FileBlock, but results not used - FileBlock dummy_block; - return WriteRecordBatch(batch, &dummy_block); + // Push an empty FileBlock. Can be written in the footer later + record_batches_.emplace_back(0, 0, 0); + return WriteRecordBatch(batch, &record_batches_[record_batches_.size() - 1]); +} + +Status StreamWriter::WriteDictionaries() { + const DictionaryMap& id_to_dictionary = dictionary_memo_->id_to_dictionary(); + + dictionaries_.resize(id_to_dictionary.size()); + + // TODO(wesm): does sorting by id yield any benefit? + int dict_index = 0; + for (const auto& entry : id_to_dictionary) { + FileBlock* block = &dictionaries_[dict_index++]; + + block->offset = position_; + + // Frame of reference in file format is 0, see ARROW-384 + const int64_t buffer_start_offset = 0; + RETURN_NOT_OK(WriteDictionary(entry.first, entry.second, buffer_start_offset, sink_, + &block->metadata_length, &block->body_length, pool_)); + RETURN_NOT_OK(UpdatePosition()); + DCHECK(position_ % 8 == 0) << "WriteDictionary did not perform aligned writes"; + } + + return Status::OK(); } Status StreamWriter::Close() { @@ -134,81 +163,147 @@ Status StreamWriter::Close() { // ---------------------------------------------------------------------- // StreamReader implementation -StreamReader::StreamReader(const std::shared_ptr& stream) - : stream_(stream), schema_(nullptr) {} - -StreamReader::~StreamReader() {} - -Status StreamReader::Open(const std::shared_ptr& stream, - std::shared_ptr* reader) { - // Private ctor - *reader = std::shared_ptr(new StreamReader(stream)); - return (*reader)->ReadSchema(); +static inline std::string message_type_name(Message::Type type) { + switch (type) { + case Message::SCHEMA: + return "schema"; + case Message::RECORD_BATCH: + return "record batch"; + case Message::DICTIONARY_BATCH: + return "dictionary"; + default: + break; + } + return "unknown"; } -Status StreamReader::ReadSchema() { - std::shared_ptr message; - RETURN_NOT_OK(ReadNextMessage(&message)); +class StreamReader::StreamReaderImpl { + public: + StreamReaderImpl() {} + ~StreamReaderImpl() {} - if (message->type() != Message::SCHEMA) { - return Status::IOError("First message was not schema type"); + Status Open(const std::shared_ptr& stream) { + stream_ = stream; + return ReadSchema(); } - SchemaMetadata schema_meta(message); + Status ReadNextMessage(Message::Type expected_type, std::shared_ptr* message) { + std::shared_ptr buffer; + RETURN_NOT_OK(stream_->Read(sizeof(int32_t), &buffer)); - // TODO(wesm): If the schema contains dictionaries, we must read all the - // dictionaries from the stream before constructing the final Schema - return schema_meta.GetSchema(&schema_); -} + if (buffer->size() != sizeof(int32_t)) { + *message = nullptr; + return Status::OK(); + } + + int32_t message_length = *reinterpret_cast(buffer->data()); + + RETURN_NOT_OK(stream_->Read(message_length, &buffer)); + if (buffer->size() != message_length) { + return Status::IOError("Unexpected end of stream trying to read message"); + } -Status StreamReader::ReadNextMessage(std::shared_ptr* message) { - std::shared_ptr buffer; - RETURN_NOT_OK(stream_->Read(sizeof(int32_t), &buffer)); + RETURN_NOT_OK(Message::Open(buffer, 0, message)); - if (buffer->size() != sizeof(int32_t)) { - *message = nullptr; + if ((*message)->type() != expected_type) { + std::stringstream ss; + ss << "Message not expected type: " << message_type_name(expected_type) + << ", was: " << (*message)->type(); + return Status::IOError(ss.str()); + } return Status::OK(); } - int32_t message_length = *reinterpret_cast(buffer->data()); + Status ReadExact(int64_t size, std::shared_ptr* buffer) { + RETURN_NOT_OK(stream_->Read(size, buffer)); - RETURN_NOT_OK(stream_->Read(message_length, &buffer)); - if (buffer->size() != message_length) { - return Status::IOError("Unexpected end of stream trying to read message"); + if ((*buffer)->size() < size) { + return Status::IOError("Unexpected EOS when reading buffer"); + } + return Status::OK(); } - return Message::Open(buffer, 0, message); -} -std::shared_ptr StreamReader::schema() const { - return schema_; -} + Status ReadNextDictionary() { + std::shared_ptr message; + RETURN_NOT_OK(ReadNextMessage(Message::DICTIONARY_BATCH, &message)); -Status StreamReader::GetNextRecordBatch(std::shared_ptr* batch) { - std::shared_ptr message; - RETURN_NOT_OK(ReadNextMessage(&message)); + DictionaryBatchMetadata metadata(message); - if (message == nullptr) { - // End of stream - *batch = nullptr; - return Status::OK(); + std::shared_ptr batch_body; + RETURN_NOT_OK(ReadExact(message->body_length(), &batch_body)) + io::BufferReader reader(batch_body); + + std::shared_ptr dictionary; + RETURN_NOT_OK(ReadDictionary(metadata, dictionary_types_, &reader, &dictionary)); + return dictionary_memo_.AddDictionary(metadata.id(), dictionary); } - if (message->type() != Message::RECORD_BATCH) { - return Status::IOError("Metadata not record batch"); + Status ReadSchema() { + std::shared_ptr message; + RETURN_NOT_OK(ReadNextMessage(Message::SCHEMA, &message)); + + SchemaMetadata schema_meta(message); + RETURN_NOT_OK(schema_meta.GetDictionaryTypes(&dictionary_types_)); + + // TODO(wesm): In future, we may want to reconcile the ids in the stream with + // those found in the schema + int num_dictionaries = static_cast(dictionary_types_.size()); + for (int i = 0; i < num_dictionaries; ++i) { + RETURN_NOT_OK(ReadNextDictionary()); + } + + return schema_meta.GetSchema(dictionary_memo_, &schema_); } - auto batch_metadata = std::make_shared(message); + Status GetNextRecordBatch(std::shared_ptr* batch) { + std::shared_ptr message; + RETURN_NOT_OK(ReadNextMessage(Message::RECORD_BATCH, &message)); + + if (message == nullptr) { + // End of stream + *batch = nullptr; + return Status::OK(); + } - std::shared_ptr batch_body; - RETURN_NOT_OK(stream_->Read(message->body_length(), &batch_body)); + RecordBatchMetadata batch_metadata(message); - if (batch_body->size() < message->body_length()) { - return Status::IOError("Unexpected EOS when reading message body"); + std::shared_ptr batch_body; + RETURN_NOT_OK(ReadExact(message->body_length(), &batch_body)); + io::BufferReader reader(batch_body); + return ReadRecordBatch(batch_metadata, schema_, &reader, batch); } - io::BufferReader reader(batch_body); + std::shared_ptr schema() const { return schema_; } + + private: + // dictionary_id -> type + DictionaryTypeMap dictionary_types_; + + DictionaryMemo dictionary_memo_; + + std::shared_ptr stream_; + std::shared_ptr schema_; +}; + +StreamReader::StreamReader() { + impl_.reset(new StreamReaderImpl()); +} + +StreamReader::~StreamReader() {} + +Status StreamReader::Open(const std::shared_ptr& stream, + std::shared_ptr* reader) { + // Private ctor + *reader = std::shared_ptr(new StreamReader()); + return (*reader)->impl_->Open(stream); +} + +std::shared_ptr StreamReader::schema() const { + return impl_->schema(); +} - return ReadRecordBatch(batch_metadata, schema_, &reader, batch); +Status StreamReader::GetNextRecordBatch(std::shared_ptr* batch) { + return impl_->GetNextRecordBatch(batch); } } // namespace ipc diff --git a/cpp/src/arrow/ipc/stream.h b/cpp/src/arrow/ipc/stream.h index 12414fa2ca0..1c3f65e49af 100644 --- a/cpp/src/arrow/ipc/stream.h +++ b/cpp/src/arrow/ipc/stream.h @@ -22,7 +22,9 @@ #include #include +#include +#include "arrow/ipc/metadata.h" #include "arrow/util/visibility.h" namespace arrow { @@ -44,12 +46,19 @@ class OutputStream; namespace ipc { -struct FileBlock; -class Message; +struct ARROW_EXPORT FileBlock { + FileBlock() {} + FileBlock(int64_t offset, int32_t metadata_length, int64_t body_length) + : offset(offset), metadata_length(metadata_length), body_length(body_length) {} + + int64_t offset; + int32_t metadata_length; + int64_t body_length; +}; class ARROW_EXPORT StreamWriter { public: - virtual ~StreamWriter(); + virtual ~StreamWriter() = default; static Status Open(io::OutputStream* sink, const std::shared_ptr& schema, std::shared_ptr* out); @@ -72,6 +81,8 @@ class ARROW_EXPORT StreamWriter { Status CheckStarted(); Status UpdatePosition(); + Status WriteDictionaries(); + Status WriteRecordBatch(const RecordBatch& batch, FileBlock* block); // Adds padding bytes if necessary to ensure all memory blocks are written on @@ -87,10 +98,17 @@ class ARROW_EXPORT StreamWriter { io::OutputStream* sink_; std::shared_ptr schema_; + // When writing out the schema, we keep track of all the dictionaries we + // encounter, as they must be written out first in the stream + std::shared_ptr dictionary_memo_; + MemoryPool* pool_; int64_t position_; bool started_; + + std::vector dictionaries_; + std::vector record_batches_; }; class ARROW_EXPORT StreamReader { @@ -107,14 +125,10 @@ class ARROW_EXPORT StreamReader { Status GetNextRecordBatch(std::shared_ptr* batch); private: - explicit StreamReader(const std::shared_ptr& stream); - - Status ReadSchema(); + StreamReader(); - Status ReadNextMessage(std::shared_ptr* message); - - std::shared_ptr stream_; - std::shared_ptr schema_; + class ARROW_NO_EXPORT StreamReaderImpl; + std::unique_ptr impl_; }; } // namespace ipc diff --git a/cpp/src/arrow/ipc/test-common.h b/cpp/src/arrow/ipc/test-common.h index b4930c4555d..07f786c4d1d 100644 --- a/cpp/src/arrow/ipc/test-common.h +++ b/cpp/src/arrow/ipc/test-common.h @@ -345,6 +345,86 @@ Status MakeUnion(std::shared_ptr* out) { return Status::OK(); } +Status MakeDictionary(std::shared_ptr* out) { + const int32_t length = 6; + + std::vector is_valid = {true, true, false, true, true, true}; + std::shared_ptr dict1, dict2; + + std::vector dict1_values = {"foo", "bar", "baz"}; + std::vector dict2_values = {"foo", "bar", "baz", "qux"}; + + ArrayFromVector(dict1_values, &dict1); + ArrayFromVector(dict2_values, &dict2); + + auto f0_type = arrow::dictionary(arrow::int32(), dict1); + auto f1_type = arrow::dictionary(arrow::int8(), dict1); + auto f2_type = arrow::dictionary(arrow::int32(), dict2); + + std::shared_ptr indices0, indices1, indices2; + std::vector indices0_values = {1, 2, -1, 0, 2, 0}; + std::vector indices1_values = {0, 0, 2, 2, 1, 1}; + std::vector indices2_values = {3, 0, 2, 1, 0, 2}; + + ArrayFromVector(is_valid, indices0_values, &indices0); + ArrayFromVector(is_valid, indices1_values, &indices1); + ArrayFromVector(is_valid, indices2_values, &indices2); + + auto a0 = std::make_shared(f0_type, indices0); + auto a1 = std::make_shared(f1_type, indices1); + auto a2 = std::make_shared(f2_type, indices2); + + // List of dictionary-encoded string + auto f3_type = list(f1_type); + + std::vector list_offsets = {0, 0, 2, 2, 5, 6, 9}; + std::shared_ptr offsets, indices3; + ArrayFromVector( + std::vector(list_offsets.size(), true), list_offsets, &offsets); + + std::vector indices3_values = {0, 1, 2, 0, 1, 2, 0, 1, 2}; + std::vector is_valid3(9, true); + ArrayFromVector(is_valid3, indices3_values, &indices3); + + std::shared_ptr null_bitmap; + RETURN_NOT_OK(test::GetBitmapFromBoolVector(is_valid, &null_bitmap)); + + std::shared_ptr a3 = std::make_shared(f3_type, length, + std::static_pointer_cast(offsets)->data(), + std::make_shared(f1_type, indices3), null_bitmap, 1); + + // Dictionary-encoded list of integer + auto f4_value_type = list(int8()); + + std::shared_ptr offsets4, values4, indices4; + + std::vector list_offsets4 = {0, 2, 2, 3}; + ArrayFromVector( + std::vector(4, true), list_offsets4, &offsets4); + + std::vector list_values4 = {0, 1, 2}; + ArrayFromVector(std::vector(3, true), list_values4, &values4); + + auto dict3 = std::make_shared(f4_value_type, 3, + std::static_pointer_cast(offsets4)->data(), values4); + + std::vector indices4_values = {0, 1, 2, 0, 1, 2}; + ArrayFromVector(is_valid, indices4_values, &indices4); + + auto f4_type = dictionary(int8(), dict3); + auto a4 = std::make_shared(f4_type, indices4); + + // construct batch + std::shared_ptr schema(new Schema({field("dict1", f0_type), + field("sparse", f1_type), field("dense", f2_type), + field("list of encoded string", f3_type), field("encoded list", f4_type)})); + + std::vector> arrays = {a0, a1, a2, a3, a4}; + + out->reset(new RecordBatch(schema, length, arrays)); + return Status::OK(); +} + } // namespace ipc } // namespace arrow diff --git a/cpp/src/arrow/type.cc b/cpp/src/arrow/type.cc index a1c2b79950d..b97b4657c36 100644 --- a/cpp/src/arrow/type.cc +++ b/cpp/src/arrow/type.cc @@ -29,7 +29,7 @@ namespace arrow { bool Field::Equals(const Field& other) const { return (this == &other) || (this->name == other.name && this->nullable == other.nullable && - this->dictionary == dictionary && this->type->Equals(*other.type.get())); + this->type->Equals(*other.type.get())); } bool Field::Equals(const std::shared_ptr& other) const { @@ -234,8 +234,8 @@ std::shared_ptr dictionary(const std::shared_ptr& index_type } std::shared_ptr field( - const std::string& name, const TypePtr& type, bool nullable, int64_t dictionary) { - return std::make_shared(name, type, nullable, dictionary); + const std::string& name, const TypePtr& type, bool nullable) { + return std::make_shared(name, type, nullable); } static const BufferDescr kValidityBuffer(BufferType::VALIDITY, 1); diff --git a/cpp/src/arrow/type.h b/cpp/src/arrow/type.h index 927b8a44fe1..b15aa277af2 100644 --- a/cpp/src/arrow/type.h +++ b/cpp/src/arrow/type.h @@ -114,6 +114,8 @@ class BufferDescr { class TypeVisitor { public: + virtual ~TypeVisitor() = default; + virtual Status Visit(const NullType& type) = 0; virtual Status Visit(const BooleanType& type) = 0; virtual Status Visit(const Int8Type& type) = 0; @@ -205,13 +207,9 @@ struct ARROW_EXPORT Field { // Fields can be nullable bool nullable; - // optional dictionary id if the field is dictionary encoded - // 0 means it's not dictionary encoded - int64_t dictionary; - Field(const std::string& name, const std::shared_ptr& type, - bool nullable = true, int64_t dictionary = 0) - : name(name), type(type), nullable(nullable), dictionary(dictionary) {} + bool nullable = true) + : name(name), type(type), nullable(nullable) {} bool operator==(const Field& other) const { return this->Equals(other); } bool operator!=(const Field& other) const { return !this->Equals(other); } @@ -556,8 +554,8 @@ std::shared_ptr ARROW_EXPORT union_( std::shared_ptr ARROW_EXPORT dictionary( const std::shared_ptr& index_type, const std::shared_ptr& values); -std::shared_ptr ARROW_EXPORT field(const std::string& name, - const std::shared_ptr& type, bool nullable = true, int64_t dictionary = 0); +std::shared_ptr ARROW_EXPORT field( + const std::string& name, const std::shared_ptr& type, bool nullable = true); // ---------------------------------------------------------------------- // diff --git a/python/pyarrow/includes/libarrow_ipc.pxd b/python/pyarrow/includes/libarrow_ipc.pxd index 5ab98152add..afc7dbd36e5 100644 --- a/python/pyarrow/includes/libarrow_ipc.pxd +++ b/python/pyarrow/includes/libarrow_ipc.pxd @@ -63,7 +63,6 @@ cdef extern from "arrow/ipc/file.h" namespace "arrow::ipc" nogil: shared_ptr[CSchema] schema() - int num_dictionaries() int num_record_batches() CStatus GetRecordBatch(int i, shared_ptr[CRecordBatch]* batch) diff --git a/python/pyarrow/io.pyx b/python/pyarrow/io.pyx index 89ce6e785c0..4acef212b4d 100644 --- a/python/pyarrow/io.pyx +++ b/python/pyarrow/io.pyx @@ -995,11 +995,6 @@ cdef class _FileReader: else: check_status(CFileReader.Open(reader, &self.reader)) - property num_dictionaries: - - def __get__(self): - return self.reader.get().num_dictionaries() - property num_record_batches: def __get__(self):