From fc70f931a67e0796a9545b6df48a7487dad65132 Mon Sep 17 00:00:00 2001 From: Antoine Pitrou Date: Thu, 28 Mar 2019 14:38:02 +0100 Subject: [PATCH] ARROW-5052: [C++] Add IncompleteDictionaryType This allows passing information about a dictionary type with known index type and value type, but unknown dictionary values. --- cpp/src/arrow/CMakeLists.txt | 2 + cpp/src/arrow/array-dict-test.cc | 31 ++++- cpp/src/arrow/array.cc | 4 + cpp/src/arrow/builder.cc | 41 +++++-- cpp/src/arrow/compare.cc | 20 ++-- cpp/src/arrow/compute/kernels/take.cc | 4 + cpp/src/arrow/ipc/metadata-internal.cc | 56 ++++++--- cpp/src/arrow/ipc/read-write-test.cc | 25 ++++ cpp/src/arrow/ipc/reader.cc | 18 +-- cpp/src/arrow/ipc/reader.h | 1 + cpp/src/arrow/type-test.cc | 137 ++++++++++++++++++++- cpp/src/arrow/type.cc | 158 +++++++++++++++++++++++++ cpp/src/arrow/type.h | 82 ++++++++++++- cpp/src/arrow/type_fwd.h | 2 + cpp/src/arrow/visitor.cc | 1 + cpp/src/arrow/visitor.h | 1 + cpp/src/arrow/visitor_inline.h | 11 +- 17 files changed, 538 insertions(+), 56 deletions(-) diff --git a/cpp/src/arrow/CMakeLists.txt b/cpp/src/arrow/CMakeLists.txt index c04570464c3..1bacb846552 100644 --- a/cpp/src/arrow/CMakeLists.txt +++ b/cpp/src/arrow/CMakeLists.txt @@ -317,6 +317,8 @@ arrow_add_pkg_config("arrow") add_arrow_test(allocator-test) if(WIN32) + # XXX This bogus special case because of MinGW + # see https://github.com/apache/arrow/pull/3693 add_arrow_test(array-test SOURCES array-test.cc diff --git a/cpp/src/arrow/array-dict-test.cc b/cpp/src/arrow/array-dict-test.cc index daa7b343762..df0b2651a1a 100644 --- a/cpp/src/arrow/array-dict-test.cc +++ b/cpp/src/arrow/array-dict-test.cc @@ -98,27 +98,52 @@ TYPED_TEST(TestDictionaryBuilder, ArrayInit) { } TYPED_TEST(TestDictionaryBuilder, MakeBuilder) { + // Explicit dictionary values are provided auto dict_array = ArrayFromJSON(std::make_shared(), "[1, 2]"); auto dict_type = dictionary(int8(), dict_array); std::unique_ptr boxed_builder; ASSERT_OK(MakeBuilder(default_memory_pool(), dict_type, &boxed_builder)); auto& builder = checked_cast&>(*boxed_builder); - ASSERT_OK(builder.Append(static_cast(1))); + ASSERT_OK(builder.Append(static_cast(2))); ASSERT_OK(builder.Append(static_cast(2))); ASSERT_OK(builder.Append(static_cast(1))); ASSERT_OK(builder.AppendNull()); ASSERT_EQ(builder.length(), 4); ASSERT_EQ(builder.null_count(), 1); + std::shared_ptr result; + ASSERT_OK(builder.Finish(&result)); // Build expected data + auto int_array = ArrayFromJSON(int8(), "[1, 1, 0, null]"); + DictionaryArray expected(dict_type, int_array); + + AssertArraysEqual(expected, *result); +} + +TYPED_TEST(TestDictionaryBuilder, MakeBuilderFromIncompleteDictType) { + // Dictionary values are inferred as an IncompleteDictionaryType is passed + auto value_type = std::make_shared(); + auto dict_type = incomplete_dictionary(int8(), value_type); + std::unique_ptr boxed_builder; + ASSERT_OK(MakeBuilder(default_memory_pool(), dict_type, &boxed_builder)); + auto& builder = checked_cast&>(*boxed_builder); + + ASSERT_OK(builder.Append(static_cast(2))); + ASSERT_OK(builder.Append(static_cast(2))); + ASSERT_OK(builder.Append(static_cast(1))); + ASSERT_OK(builder.AppendNull()); + ASSERT_EQ(builder.length(), 4); + ASSERT_EQ(builder.null_count(), 1); std::shared_ptr result; ASSERT_OK(builder.Finish(&result)); - auto int_array = ArrayFromJSON(int8(), "[0, 1, 0, null]"); - DictionaryArray expected(dict_type, int_array); + // Build expected data + auto int_array = ArrayFromJSON(int8(), "[0, 0, 1, null]"); + auto actual_dict_type = dictionary(int8(), ArrayFromJSON(value_type, "[2, 1]")); + DictionaryArray expected(actual_dict_type, int_array); AssertArraysEqual(expected, *result); } diff --git a/cpp/src/arrow/array.cc b/cpp/src/arrow/array.cc index 5956dd29caa..58c9759e807 100644 --- a/cpp/src/arrow/array.cc +++ b/cpp/src/arrow/array.cc @@ -938,6 +938,10 @@ class ArrayDataWrapper { return Status::OK(); } + Status Visit(const IncompleteDictionaryType& type) { + return Status::TypeError("Cannot create array of type '", type.ToString(), "'"); + } + Status Visit(const ExtensionType& type) { *out_ = type.MakeArray(data_); return Status::OK(); diff --git a/cpp/src/arrow/builder.cc b/cpp/src/arrow/builder.cc index 9669c08e0cd..b435ba73681 100644 --- a/cpp/src/arrow/builder.cc +++ b/cpp/src/arrow/builder.cc @@ -25,7 +25,6 @@ #include "arrow/status.h" #include "arrow/type.h" #include "arrow/util/checked_cast.h" -#include "arrow/util/hashing.h" #include "arrow/visitor_inline.h" namespace arrow { @@ -35,11 +34,21 @@ class MemoryPool; // ---------------------------------------------------------------------- // Helper functions -#define BUILDER_CASE(ENUM, BuilderType) \ - case Type::ENUM: \ - out->reset(new BuilderType(type, pool)); \ - return Status::OK(); +template +static Status CreateDictBuilder(const DictionaryType& dict_type, MemoryPool* pool, + std::unique_ptr* out) { + out->reset(new BuilderType(dict_type.dictionary(), pool)); + return Status::OK(); +} + +template +static Status CreateDictBuilder(const IncompleteDictionaryType& dict_type, + MemoryPool* pool, std::unique_ptr* out) { + out->reset(new BuilderType(dict_type.value_type(), pool)); + return Status::OK(); +} +template struct DictionaryBuilderCase { template Status Visit(const ValueType&, typename ValueType::c_type* = nullptr) { @@ -65,12 +74,11 @@ struct DictionaryBuilderCase { template Status Create() { - out->reset(new BuilderType(dict_type.dictionary(), pool)); - return Status::OK(); + return CreateDictBuilder(dict_type, pool, out); } MemoryPool* pool; - const DictionaryType& dict_type; + const DICT_TYPE& dict_type; std::unique_ptr* out; }; @@ -80,6 +88,11 @@ struct DictionaryBuilderCase { // TODO(wesm): come up with a less monolithic strategy Status MakeBuilder(MemoryPool* pool, const std::shared_ptr& type, std::unique_ptr* out) { +#define BUILDER_CASE(ENUM, BuilderType) \ + case Type::ENUM: \ + out->reset(new BuilderType(type, pool)); \ + return Status::OK(); + switch (type->id()) { case Type::NA: { out->reset(new NullBuilder(pool)); @@ -107,10 +120,16 @@ Status MakeBuilder(MemoryPool* pool, const std::shared_ptr& type, BUILDER_CASE(FIXED_SIZE_BINARY, FixedSizeBinaryBuilder); BUILDER_CASE(DECIMAL, Decimal128Builder); case Type::DICTIONARY: { - const auto& dict_type = static_cast(*type); - DictionaryBuilderCase visitor = {pool, dict_type, out}; + const auto& dict_type = internal::checked_cast(*type); + DictionaryBuilderCase visitor = {pool, dict_type, out}; return VisitTypeInline(*dict_type.dictionary()->type(), &visitor); } + case Type::INCOMPLETE_DICTIONARY: { + const auto& dict_type = + internal::checked_cast(*type); + DictionaryBuilderCase visitor = {pool, dict_type, out}; + return VisitTypeInline(*dict_type.value_type(), &visitor); + } case Type::LIST: { std::unique_ptr value_builder; std::shared_ptr value_type = @@ -138,6 +157,8 @@ Status MakeBuilder(MemoryPool* pool, const std::shared_ptr& type, type->ToString()); } } + +#undef BUILDER_CASE } } // namespace arrow diff --git a/cpp/src/arrow/compare.cc b/cpp/src/arrow/compare.cc index 5eeefdfbc48..3301547698e 100644 --- a/cpp/src/arrow/compare.cc +++ b/cpp/src/arrow/compare.cc @@ -775,6 +775,15 @@ class TypeEqualsVisitor { return Status::OK(); } + Status Visit(const IncompleteDictionaryType& left) { + const auto& right = checked_cast(right_); + result_ = left.index_type()->Equals(right.index_type()) && + left.value_type()->Equals(right.value_type()) && + (left.dictionary_id() == right.dictionary_id()) && + (left.ordered() == right.ordered()); + return Status::OK(); + } + Status Visit(const ExtensionType& left) { result_ = left.ExtensionEquals(static_cast(right_)); return Status::OK(); @@ -843,15 +852,8 @@ class ScalarEqualsVisitor { return Status::OK(); } - Status Visit(const UnionScalar& left) { return Status::NotImplemented("union"); } - - Status Visit(const DictionaryScalar& left) { - return Status::NotImplemented("dictionary"); - } - - Status Visit(const ExtensionScalar& left) { - return Status::NotImplemented("extension"); - } + // Default case + Status Visit(const Scalar& left) { return Status::NotImplemented(left.type->name()); } bool result() const { return result_; } diff --git a/cpp/src/arrow/compute/kernels/take.cc b/cpp/src/arrow/compute/kernels/take.cc index 1dd34a92449..d0e16f49490 100644 --- a/cpp/src/arrow/compute/kernels/take.cc +++ b/cpp/src/arrow/compute/kernels/take.cc @@ -177,6 +177,10 @@ struct UnpackValues { return Status::NotImplemented("gathering values of type ", t); } + Status Visit(const IncompleteDictionaryType& t) { + return Status::NotImplemented("gathering values of type ", t); + } + const TakeParameters& params_; }; diff --git a/cpp/src/arrow/ipc/metadata-internal.cc b/cpp/src/arrow/ipc/metadata-internal.cc index dedeee3a632..d1bae25161c 100644 --- a/cpp/src/arrow/ipc/metadata-internal.cc +++ b/cpp/src/arrow/ipc/metadata-internal.cc @@ -588,6 +588,10 @@ class FieldToFlatbufferVisitor { return Status::OK(); } + Status Visit(const IncompleteDictionaryType& type) { + return Status::NotImplemented("incomplete dictionary"); + } + Status GetResult(const Field& field, FieldOffset* offset) { auto fb_name = fbb_.CreateString(field.name()); RETURN_NOT_OK(VisitType(*field.type())); @@ -644,6 +648,24 @@ static Status GetFieldMetadata(const flatbuf::Field* field, return Status::OK(); } +static Status FieldFromFlatbuffer(const flatbuf::Field* field, + const DictionaryMemo& dictionary_memo, + std::shared_ptr* out); + +// Reconstruct the data type of a flatbuffer-encoded field +static Status ReconstructFieldType(const flatbuf::Field* field, + const KeyValueMetadata* metadata, + const DictionaryMemo& dictionary_memo, + std::shared_ptr* out) { + auto children = field->children(); + std::vector> child_fields(children->size()); + for (int i = 0; i < static_cast(children->size()); ++i) { + RETURN_NOT_OK( + FieldFromFlatbuffer(children->Get(i), dictionary_memo, &child_fields[i])); + } + return TypeFromFlatbuffer(field, child_fields, metadata, out); +} + static Status FieldFromFlatbuffer(const flatbuf::Field* field, const DictionaryMemo& dictionary_memo, std::shared_ptr* out) { @@ -657,24 +679,25 @@ static Status FieldFromFlatbuffer(const flatbuf::Field* field, if (encoding == nullptr) { // The field is not dictionary encoded. We must potentially visit its // children to fully reconstruct the data type - auto children = field->children(); - std::vector> child_fields(children->size()); - for (int i = 0; i < static_cast(children->size()); ++i) { - RETURN_NOT_OK( - FieldFromFlatbuffer(children->Get(i), dictionary_memo, &child_fields[i])); - } - RETURN_NOT_OK(TypeFromFlatbuffer(field, child_fields, metadata.get(), &type)); + RETURN_NOT_OK(ReconstructFieldType(field, metadata.get(), dictionary_memo, &type)); } else { - // The field is dictionary encoded. The type of the dictionary values has - // been determined elsewhere, and is stored in the DictionaryMemo. Here we - // construct the logical DictionaryType object - - std::shared_ptr dictionary; - RETURN_NOT_OK(dictionary_memo.GetDictionary(encoding->id(), &dictionary)); + // The field is dictionary encoded. Here we + // construct the logical dictionary type. std::shared_ptr index_type; RETURN_NOT_OK(IntFromFlatbuffer(encoding->indexType(), &index_type)); - type = ::arrow::dictionary(index_type, dictionary, encoding->isOrdered()); + std::shared_ptr dictionary; + auto status = dictionary_memo.GetDictionary(encoding->id(), &dictionary); + + if (status.IsKeyError()) { + // Dictionary array not found, need to reconstruct value type + RETURN_NOT_OK(ReconstructFieldType(field, metadata.get(), dictionary_memo, &type)); + type = ::arrow::incomplete_dictionary(index_type, type, encoding->isOrdered(), + encoding->id()); + } else { + RETURN_NOT_OK(status); // Handle other errors + type = ::arrow::dictionary(index_type, dictionary, encoding->isOrdered()); + } } *out = std::make_shared(field->name()->str(), type, field->nullable(), metadata); @@ -684,16 +707,13 @@ static Status FieldFromFlatbuffer(const flatbuf::Field* field, static Status FieldFromFlatbufferDictionary(const flatbuf::Field* field, std::shared_ptr* out) { - // Need an empty memo to pass down for constructing children - DictionaryMemo dummy_memo; - // Any DictionaryEncoding set is ignored here std::shared_ptr type; auto children = field->children(); std::vector> child_fields(children->size()); for (int i = 0; i < static_cast(children->size()); ++i) { - RETURN_NOT_OK(FieldFromFlatbuffer(children->Get(i), dummy_memo, &child_fields[i])); + RETURN_NOT_OK(FieldFromFlatbufferDictionary(children->Get(i), &child_fields[i])); } std::shared_ptr metadata; diff --git a/cpp/src/arrow/ipc/read-write-test.cc b/cpp/src/arrow/ipc/read-write-test.cc index 0408a1712fe..354e9dc2db2 100644 --- a/cpp/src/arrow/ipc/read-write-test.cc +++ b/cpp/src/arrow/ipc/read-write-test.cc @@ -198,6 +198,31 @@ TEST_F(TestSchemaMetadata, DictionaryFields) { } } +TEST_F(TestSchemaMetadata, IncompleteDictionaryFields) { + auto dict_type = dictionary(int8(), ArrayFromJSON(int32(), "[6, 5, 4]")); + auto f0 = field("f0", dict_type); + auto f1 = field("f1", list(dict_type)); + Schema schema({f0, f1}); + + std::shared_ptr buffer; + ASSERT_OK(SerializeSchema(schema, default_memory_pool(), &buffer)); + + // Only read one message. It will not contain the schema dictionaries. + std::unique_ptr message; + std::shared_ptr result; + io::BufferReader reader(buffer); + ASSERT_OK(ReadMessage(&reader, &message)); + ASSERT_OK(ReadSchema(*message, &result)); + + // Decoding the schema should give us incomplete dictionary types. + auto incomplete_dict_type = + incomplete_dictionary(int8(), int32(), false /* ordered */, 0 /* dictionary_id */); + f0 = field("f0", incomplete_dict_type); + f1 = field("f1", list(incomplete_dict_type)); + Schema expected({f0, f1}); + AssertSchemaEqual(expected, *result); +} + TEST_F(TestSchemaMetadata, KeyValueMetadata) { auto field_metadata = key_value_metadata({{"key", "value"}}); auto schema_metadata = key_value_metadata({{"foo", "bar"}, {"bizz", "buzz"}}); diff --git a/cpp/src/arrow/ipc/reader.cc b/cpp/src/arrow/ipc/reader.cc index 85c64004aa6..26c29756dab 100644 --- a/cpp/src/arrow/ipc/reader.cc +++ b/cpp/src/arrow/ipc/reader.cc @@ -232,11 +232,8 @@ class ArrayLoader { } template - typename std::enable_if::value && - !std::is_base_of::value && - !std::is_base_of::value, - Status>::type - Visit(const T& type) { + typename std::enable_if::value, Status>::type Visit( + const T& type) { return LoadPrimitive(); } @@ -292,6 +289,10 @@ class ArrayLoader { return Status::OK(); } + Status Visit(const IncompleteDictionaryType& type) { + return Status::TypeError("Incomplete dictionary encountered, bad IPC stream?"); + } + Status Visit(const ExtensionType& type) { RETURN_NOT_OK(LoadArray(type.storage_type(), context_, out_)); out_->type = type_; @@ -373,9 +374,10 @@ Status ReadRecordBatch(const Buffer& metadata, const std::shared_ptr& sc return ReadRecordBatch(batch, schema, max_recursion_depth, file, out); } -Status ReadDictionary(const Buffer& metadata, const DictionaryTypeMap& dictionary_types, - io::RandomAccessFile* file, int64_t* dictionary_id, - std::shared_ptr* out) { +static Status ReadDictionary(const Buffer& metadata, + const DictionaryTypeMap& dictionary_types, + io::RandomAccessFile* file, int64_t* dictionary_id, + std::shared_ptr* out) { auto message = flatbuf::GetMessage(metadata.data()); auto dictionary_batch = reinterpret_cast(message->header()); diff --git a/cpp/src/arrow/ipc/reader.h b/cpp/src/arrow/ipc/reader.h index 8fe310f5b77..78616a0916f 100644 --- a/cpp/src/arrow/ipc/reader.h +++ b/cpp/src/arrow/ipc/reader.h @@ -21,6 +21,7 @@ #define ARROW_IPC_READER_H #include +#include #include #include "arrow/ipc/message.h" diff --git a/cpp/src/arrow/type-test.cc b/cpp/src/arrow/type-test.cc index b869aec1b91..e150dc8f6c7 100644 --- a/cpp/src/arrow/type-test.cc +++ b/cpp/src/arrow/type-test.cc @@ -282,12 +282,48 @@ TEST_F(TestSchema, TestRemoveMetadata) { auto f1 = field("f1", uint8(), false); auto f2 = field("f2", utf8()); std::vector> fields = {f0, f1, f2}; - KeyValueMetadata metadata({"foo", "bar"}, {"bizz", "buzz"}); - auto schema = std::make_shared(fields); + auto metadata = std::shared_ptr( + new KeyValueMetadata({"foo", "bar"}, {"bizz", "buzz"})); + auto schema = std::make_shared(fields, metadata); std::shared_ptr new_schema = schema->RemoveMetadata(); ASSERT_TRUE(new_schema->metadata() == nullptr); } +TEST_F(TestSchema, SetDictionary) { + auto id1 = incomplete_dictionary(int32(), utf8(), false, 12 /* dictionary_id */); + auto id2 = incomplete_dictionary(int16(), int64(), false, 42 /* dictionary_id */); + + auto f0 = field("f0", int32()); + auto f1 = field("f1", id1); + auto f2 = field("f2", list(id2)); + auto f3 = field("f3", struct_({field("s1", int8()), field("s2", id1)})); + + auto schema = ::arrow::schema({f0, f1, f2, f3}); + + auto v1 = ArrayFromJSON(utf8(), "[\"foo\", \"bar\"]"); + auto v2 = ArrayFromJSON(int64(), "[1000, 2000, 3000]"); + std::shared_ptr actual, expected; + + ASSERT_OK(schema->SetDictionary(12, v1, &actual)); + auto d1 = dictionary(int32(), v1); + f1 = field("f1", d1); + f3 = field("f3", struct_({field("s1", int8()), field("s2", d1)})); + expected = ::arrow::schema({f0, f1, f2, f3}); + ASSERT_TRUE(actual->Equals(*expected)); + + ASSERT_OK(actual->SetDictionary(42, v2, &actual)); + auto d2 = dictionary(int16(), v2); + f2 = field("f2", list(d2)); + expected = ::arrow::schema({f0, f1, f2, f3}); + ASSERT_TRUE(actual->Equals(*expected)); + + // Non-existent dictionary id + ASSERT_RAISES(Invalid, schema->SetDictionary(0, v1, &actual)); + + // Mismatching dictionary value type + ASSERT_RAISES(TypeError, schema->SetDictionary(12, v2, &actual)); +} + #define PRIMITIVE_TEST(KLASS, CTYPE, ENUM, NAME) \ TEST(TypesTest, ARROW_CONCAT(TestPrimitive_, ENUM)) { \ KLASS tp; \ @@ -369,6 +405,18 @@ TEST(TestListType, Basics) { ASSERT_EQ("list>", lt2.ToString()); } +TEST(TestListType, SetChild) { + std::shared_ptr t1, t2, expected; + + t1 = list(int8()); + ASSERT_OK(t1->SetChild(0, utf8(), &t2)); + expected = list(utf8()); + ASSERT_TRUE(t2->Equals(expected)); + + // Bad child index + ASSERT_RAISES(Invalid, t1->SetChild(1, utf8(), &t2)); +} + TEST(TestDateTypes, Attrs) { auto t1 = date32(); auto t2 = date64(); @@ -558,6 +606,91 @@ TEST(TestStructType, GetFieldDuplicates) { ASSERT_EQ(results.size(), 0); } +TEST(TestStructType, SetChild) { + auto f0 = field("f0", int32()); + auto f1 = field("f1", int64()); + auto f2 = field("f1", utf8(), false /* nullable */); + StructType struct_type({f0, f1, f2}); + + std::shared_ptr actual, expected; + ASSERT_OK(struct_type.SetChild(2, list(binary()), &actual)); + + auto f3 = field("f1", list(binary()), false /* nullable */); + expected = struct_({f0, f1, f3}); + ASSERT_TRUE(actual->Equals(*expected)); + + // Bad child index + ASSERT_RAISES(Invalid, struct_type.SetChild(3, list(binary()), &actual)); +} + +TEST(TestUnionType, Basics) { + auto f0 = field("f0", int32()); + auto f1 = field("f1", float32()); + + auto t1 = union_({f0, f1}, {42, 43}); + auto t2 = union_({f0, f1}, {42, 43}); + auto t3 = union_({f1, f0}, {42, 43}); + auto t4 = union_({f0, f1}, {42, 43}, UnionMode::DENSE); + auto t5 = union_({f0, f1}, {43, 44}); + + ASSERT_EQ(t1->ToString(), "union[sparse]"); + ASSERT_EQ(t4->ToString(), "union[dense]"); +} + +TEST(TestUnionType, SetChild) { + auto f0 = field("f0", int32()); + auto f1 = field("f1", float32()); + + auto t1 = union_({f0, f1}, {5, 6}, UnionMode::DENSE); + + std::shared_ptr actual, expected; + ASSERT_OK(t1->SetChild(0, float64(), &actual)); + + auto f2 = field("f0", float64()); + expected = union_({f2, f1}, {5, 6}, UnionMode::DENSE); + ASSERT_TRUE(actual->Equals(*expected)); + + // Bad child index + ASSERT_RAISES(Invalid, t1->SetChild(2, float64(), &actual)); +} + +TEST(TestIncompleteDictionaryType, Equals) { + auto t1 = incomplete_dictionary(int8(), int32()); + auto t2 = incomplete_dictionary(int8(), int32()); + auto t3 = incomplete_dictionary(int16(), int32()); + auto t4 = incomplete_dictionary(int8(), uint32()); + auto t5 = incomplete_dictionary(int8(), int32(), true /* ordered */); + auto t6 = incomplete_dictionary(int8(), int32(), false /* ordered */, 0); + auto t7 = incomplete_dictionary(int8(), int32(), false /* ordered */, 1); + auto t8 = incomplete_dictionary(int8(), int32(), false /* ordered */, 1); + + ASSERT_TRUE(t1->Equals(t2)); + // Different index type + ASSERT_FALSE(t1->Equals(t3)); + // Different value type + ASSERT_FALSE(t1->Equals(t4)); + // Different orderedness + ASSERT_FALSE(t1->Equals(t5)); + // Different dictionary id + ASSERT_FALSE(t1->Equals(t6)); + ASSERT_FALSE(t6->Equals(t7)); + ASSERT_TRUE(t7->Equals(t8)); +} + +TEST(TestIncompleteDictionaryType, Complete) { + auto type = incomplete_dictionary(int8(), int32()); + const auto& dict_type = checked_cast(*type); + auto dict_values = ArrayFromJSON(int32(), "[3, 4, 5, 6]"); + + std::shared_ptr completed, expected; + ASSERT_OK(dict_type.Complete(dict_values, &completed)); + expected = dictionary(int8(), dict_values); + ASSERT_TRUE(completed->Equals(expected)); + + auto bad_type_dict_values = ArrayFromJSON(uint32(), "[3, 4, 5, 6]"); + ASSERT_RAISES(TypeError, dict_type.Complete(bad_type_dict_values, &completed)); +} + TEST(TestDictionaryType, Equals) { auto t1 = dictionary(int8(), ArrayFromJSON(int32(), "[3, 4, 5, 6]")); auto t2 = dictionary(int8(), ArrayFromJSON(int32(), "[3, 4, 5, 6]")); diff --git a/cpp/src/arrow/type.cc b/cpp/src/arrow/type.cc index 0e0d9fc431c..52d64419951 100644 --- a/cpp/src/arrow/type.cc +++ b/cpp/src/arrow/type.cc @@ -37,6 +37,9 @@ namespace arrow { using internal::checked_cast; +// ---------------------------------------------------------------------- +// Field + bool Field::HasMetadata() const { return (metadata_ != nullptr) && (metadata_->size() > 0); } @@ -101,6 +104,9 @@ std::string Field::ToString() const { return ss.str(); } +// ---------------------------------------------------------------------- +// DataType + DataType::~DataType() {} bool DataType::Equals(const DataType& other, bool check_metadata) const { @@ -114,6 +120,28 @@ bool DataType::Equals(const std::shared_ptr& other) const { return Equals(*other.get()); } +Status DataType::SetChild(int i, const std::shared_ptr& field, + std::shared_ptr* out) const { + if (i < 0 || i >= num_children()) { + return Status::Invalid("Child number ", i, " invalid for type '", ToString(), "'"); + } + return SetChildInternal(i, field, out); +} + +Status DataType::SetChild(int i, const std::shared_ptr& child_type, + std::shared_ptr* out) const { + if (i < 0 || i >= num_children()) { + return Status::Invalid("Child number ", i, " invalid for type '", ToString(), "'"); + } + auto field = children_[i]->WithType(child_type); + return SetChildInternal(i, field, out); +} + +Status DataType::SetChildInternal(int i, const std::shared_ptr& field, + std::shared_ptr* out) const { + return Status::NotImplemented("Cannot set child on type '", ToString(), "'"); +} + std::string BooleanType::ToString() const { return name(); } FloatingPoint::Precision HalfFloatType::precision() const { return FloatingPoint::HALF; } @@ -130,6 +158,13 @@ std::string ListType::ToString() const { return s.str(); } +Status ListType::SetChildInternal(int i, const std::shared_ptr& child_field, + std::shared_ptr* out) const { + DCHECK_EQ(i, 0); + *out = std::make_shared(child_field); + return Status::OK(); +} + std::string BinaryType::ToString() const { return std::string("binary"); } int FixedSizeBinaryType::bit_width() const { return CHAR_BIT * byte_width(); } @@ -222,6 +257,14 @@ std::string UnionType::ToString() const { return s.str(); } +Status UnionType::SetChildInternal(int i, const std::shared_ptr& child_field, + std::shared_ptr* out) const { + auto new_children = children_; + new_children[i] = child_field; + *out = std::make_shared(new_children, type_codes_, mode_); + return Status::OK(); +} + // ---------------------------------------------------------------------- // Struct type @@ -301,6 +344,14 @@ std::vector> StructType::GetAllFieldsByName( return result; } +Status StructType::SetChildInternal(int i, const std::shared_ptr& child_field, + std::shared_ptr* out) const { + auto new_children = children_; + new_children[i] = child_field; + *out = std::make_shared(new_children); + return Status::OK(); +} + // Deprecated methods std::shared_ptr StructType::GetChildByName(const std::string& name) const { @@ -341,6 +392,10 @@ int DictionaryType::bit_width() const { std::shared_ptr DictionaryType::dictionary() const { return dictionary_; } +std::shared_ptr DictionaryType::value_type() const { + return dictionary_->type(); +} + std::string DictionaryType::ToString() const { std::stringstream ss; ss << "dictionarytype()->ToString() @@ -348,6 +403,46 @@ std::string DictionaryType::ToString() const { return ss.str(); } +// ---------------------------------------------------------------------- +// IncompleteDictionaryType + +IncompleteDictionaryType::IncompleteDictionaryType( + const std::shared_ptr& index_type, + const std::shared_ptr& value_type, bool ordered, int64_t dictionary_id) + : FixedWidthType(Type::INCOMPLETE_DICTIONARY), + index_type_(index_type), + value_type_(value_type), + dictionary_id_(dictionary_id), + ordered_(ordered) { +#ifndef NDEBUG + const auto& int_type = checked_cast(*index_type); + DCHECK_EQ(int_type.is_signed(), true) << "dictionary index type should be signed"; +#endif +} + +int IncompleteDictionaryType::bit_width() const { + return checked_cast(*index_type_).bit_width(); +} + +std::string IncompleteDictionaryType::ToString() const { + std::stringstream ss; + ss << "incomplete-dictionaryToString() + << ", indices=" << index_type_->ToString() << ", ordered=" << ordered_ + << ", dictionary_id=" << dictionary_id_ << ">"; + return ss.str(); +} + +Status IncompleteDictionaryType::Complete(const std::shared_ptr& values, + std::shared_ptr* out_type) const { + if (!values->type()->Equals(*value_type_, false /* check_metadata */)) { + return Status::TypeError("IncompleteDictionaryType with value_type = '", + value_type_->ToString(), "' got values with type = '", + values->type()->ToString(), "'"); + } + *out_type = dictionary(index_type_, values, ordered_); + return Status::OK(); +} + // ---------------------------------------------------------------------- // Null type @@ -445,6 +540,62 @@ Status Schema::SetField(int i, const std::shared_ptr& field, return Status::OK(); } +namespace { + +// Recursive helper for SetDictionary() +Status SetDictionaryInternal(int64_t dict_id, const std::shared_ptr& dict_values, + const std::shared_ptr& field, + std::shared_ptr* out) { + const auto& type = field->type(); + auto new_field = field; + if (type->id() == Type::INCOMPLETE_DICTIONARY) { + const auto& dict_type = checked_cast(*type); + if (dict_type.dictionary_id() == dict_id) { + std::shared_ptr new_dict_type; + RETURN_NOT_OK(dict_type.Complete(dict_values, &new_dict_type)); + new_field = new_field->WithType(new_dict_type); + } + } else { + // Recurse over child fields + auto num_children = type->num_children(); + std::shared_ptr new_type = type; + for (int i = 0; i < num_children; ++i) { + const auto& child = type->child(i); + std::shared_ptr new_child; + RETURN_NOT_OK(SetDictionaryInternal(dict_id, dict_values, child, &new_child)); + if (new_child.get() != child.get()) { + // Child field changed => rebuild parent type + RETURN_NOT_OK(type->SetChild(i, new_child, &new_type)); + } + } + if (new_type.get() != type.get()) { + // Parent type changed => rebuild field + new_field = new_field->WithType(new_type); + } + } + *out = std::move(new_field); + return Status::OK(); +} + +} // namespace + +Status Schema::SetDictionary(int64_t dict_id, const std::shared_ptr& dict_values, + std::shared_ptr* out) const { + std::vector> new_fields(num_fields()); + bool changed = false; + for (int i = 0; i < num_fields(); ++i) { + RETURN_NOT_OK( + SetDictionaryInternal(dict_id, dict_values, fields_[i], &new_fields[i])); + DCHECK(new_fields[i]); + changed = changed || (new_fields[i].get() != fields_[i].get()); + } + if (!changed) { + return Status::Invalid("Dictionary id ", dict_id, " not found in schema"); + } + *out = std::make_shared(new_fields, metadata_); + return Status::OK(); +} + bool Schema::HasMetadata() const { return (metadata_ != nullptr) && (metadata_->size() > 0); } @@ -602,6 +753,13 @@ std::shared_ptr dictionary(const std::shared_ptr& index_type return std::make_shared(index_type, dict_values, ordered); } +std::shared_ptr incomplete_dictionary( + const std::shared_ptr& index_type, + const std::shared_ptr& value_type, bool ordered, int64_t dictionary_id) { + return std::make_shared(index_type, value_type, ordered, + dictionary_id); +} + std::shared_ptr field(const std::string& name, const std::shared_ptr& type, bool nullable, const std::shared_ptr& metadata) { diff --git a/cpp/src/arrow/type.h b/cpp/src/arrow/type.h index 4c3537834cf..3d19af4a512 100644 --- a/cpp/src/arrow/type.h +++ b/cpp/src/arrow/type.h @@ -136,7 +136,10 @@ struct Type { MAP, /// Custom data type, implemented by user - EXTENSION + EXTENSION, + + /// Incomplete dictionary type + INCOMPLETE_DICTIONARY }; }; @@ -171,6 +174,11 @@ class ARROW_EXPORT DataType { Status Accept(TypeVisitor* visitor) const; + Status SetChild(int i, const std::shared_ptr& child_field, + std::shared_ptr* out) const; + Status SetChild(int i, const std::shared_ptr& child_type, + std::shared_ptr* out) const; + /// \brief A string representation of the type, including any children virtual std::string ToString() const = 0; @@ -184,6 +192,9 @@ class ARROW_EXPORT DataType { Type::type id() const { return id_; } protected: + virtual Status SetChildInternal(int i, const std::shared_ptr& child_field, + std::shared_ptr* out) const; + Type::type id_; std::vector> children_; @@ -450,6 +461,10 @@ class ARROW_EXPORT ListType : public NestedType { std::string ToString() const override; std::string name() const override { return "list"; } + + protected: + Status SetChildInternal(int i, const std::shared_ptr& child_field, + std::shared_ptr* out) const override; }; /// \brief Concrete type class for variable-size binary data @@ -527,6 +542,10 @@ class ARROW_EXPORT StructType : public NestedType { ARROW_DEPRECATED("Use GetFieldIndex") int GetChildIndex(const std::string& name) const; + protected: + Status SetChildInternal(int i, const std::shared_ptr& child_field, + std::shared_ptr* out) const override; + private: std::unordered_multimap name_to_index_; }; @@ -578,6 +597,10 @@ class ARROW_EXPORT UnionType : public NestedType { UnionMode::type mode() const { return mode_; } + protected: + Status SetChildInternal(int i, const std::shared_ptr& child_field, + std::shared_ptr* out) const override; + private: UnionMode::type mode_; @@ -749,7 +772,7 @@ class ARROW_EXPORT IntervalType : public FixedWidthType { // DictionaryType (for categorical or dictionary-encoded data) /// Concrete type class for dictionary data -class ARROW_EXPORT DictionaryType : public FixedWidthType { +class ARROW_EXPORT DictionaryType : public FixedWidthType, public ParametricType { public: static constexpr Type::type type_id = Type::DICTIONARY; @@ -760,6 +783,8 @@ class ARROW_EXPORT DictionaryType : public FixedWidthType { std::shared_ptr index_type() const { return index_type_; } + std::shared_ptr value_type() const; + std::shared_ptr dictionary() const; std::string ToString() const override; @@ -791,6 +816,49 @@ class ARROW_EXPORT DictionaryType : public FixedWidthType { bool ordered_; }; +/// Type class representing an incomplete dictionary type, +/// whose index type and value type are known, but whose actual +/// dictionary values are still unknown. +class ARROW_EXPORT IncompleteDictionaryType : public FixedWidthType, + public ParametricType { + public: + static constexpr Type::type type_id = Type::INCOMPLETE_DICTIONARY; + + IncompleteDictionaryType(const std::shared_ptr& index_type, + const std::shared_ptr& value_type, + bool ordered = false, int64_t dictionary_id = -1); + + int bit_width() const override; + + std::shared_ptr index_type() const { return index_type_; } + + std::shared_ptr value_type() const { return value_type_; } + + bool ordered() const { return ordered_; } + + int64_t dictionary_id() const { return dictionary_id_; } + + std::string ToString() const override; + std::string name() const override { return "incomplete-dictionary"; } + + /// \brief Make a dictionary type, providing its values + /// + /// Create a DictionaryType with the same index and value types, + /// and with the given values as dictionary. The value type must + /// be equal to the type of the values array. + /// \param[in] values The array of dictionary values + /// \param[out] out The new DictionaryType instance + Status Complete(const std::shared_ptr& values, + std::shared_ptr* out) const; + + private: + // Must be an integer type (not currently checked) + std::shared_ptr index_type_; + std::shared_ptr value_type_; + int64_t dictionary_id_; + bool ordered_; +}; + // ---------------------------------------------------------------------- // Schema @@ -843,6 +911,10 @@ class ARROW_EXPORT Schema { Status SetField(int i, const std::shared_ptr& field, std::shared_ptr* out) const; + /// \brief EXPERIMENTAL + Status SetDictionary(int64_t dict_id, const std::shared_ptr& dict_values, + std::shared_ptr* out) const; + /// \brief Replace key-value metadata with new metadata /// /// \param[in] metadata new KeyValueMetadata @@ -943,6 +1015,12 @@ std::shared_ptr ARROW_EXPORT dictionary(const std::shared_ptr& index_type, const std::shared_ptr& values, bool ordered = false); +/// \brief Create a IncompleteDictionaryType instance +std::shared_ptr ARROW_EXPORT +incomplete_dictionary(const std::shared_ptr& index_type, + const std::shared_ptr& value_type, bool ordered = false, + int64_t dictionary_id = -1); + /// @} /// \defgroup schema-factories Factory functions for fields and schemas diff --git a/cpp/src/arrow/type_fwd.h b/cpp/src/arrow/type_fwd.h index 9a8d3ef1a54..8716c0a8918 100644 --- a/cpp/src/arrow/type_fwd.h +++ b/cpp/src/arrow/type_fwd.h @@ -47,6 +47,8 @@ class DictionaryType; class DictionaryArray; class DictionaryScalar; +class IncompleteDictionaryType; + class NullType; class NullArray; class NullBuilder; diff --git a/cpp/src/arrow/visitor.cc b/cpp/src/arrow/visitor.cc index 5a601c898bf..149c8462df3 100644 --- a/cpp/src/arrow/visitor.cc +++ b/cpp/src/arrow/visitor.cc @@ -98,6 +98,7 @@ TYPE_VISITOR_DEFAULT(ListType) TYPE_VISITOR_DEFAULT(StructType) TYPE_VISITOR_DEFAULT(UnionType) TYPE_VISITOR_DEFAULT(DictionaryType) +TYPE_VISITOR_DEFAULT(IncompleteDictionaryType) TYPE_VISITOR_DEFAULT(ExtensionType) #undef TYPE_VISITOR_DEFAULT diff --git a/cpp/src/arrow/visitor.h b/cpp/src/arrow/visitor.h index 9806eff5abc..3aca78f36c3 100644 --- a/cpp/src/arrow/visitor.h +++ b/cpp/src/arrow/visitor.h @@ -89,6 +89,7 @@ class ARROW_EXPORT TypeVisitor { virtual Status Visit(const StructType& type); virtual Status Visit(const UnionType& type); virtual Status Visit(const DictionaryType& type); + virtual Status Visit(const IncompleteDictionaryType& type); virtual Status Visit(const ExtensionType& type); }; diff --git a/cpp/src/arrow/visitor_inline.h b/cpp/src/arrow/visitor_inline.h index 5e20e78c64b..7b47692565d 100644 --- a/cpp/src/arrow/visitor_inline.h +++ b/cpp/src/arrow/visitor_inline.h @@ -69,10 +69,13 @@ template inline Status VisitTypeInline(const DataType& type, VISITOR* visitor) { switch (type.id()) { ARROW_GENERATE_FOR_ALL_TYPES(TYPE_VISIT_INLINE); + // IncompleteDictionary is not in ARROW_GENERATE_FOR_ALL_TYPES as it + // only makes sense for visiting types, not data. + TYPE_VISIT_INLINE(IncompleteDictionary); default: break; } - return Status::NotImplemented("Type not implemented"); + return Status::NotImplemented("Type '", type.name(), "' not implemented"); } #undef TYPE_VISIT_INLINE @@ -90,7 +93,7 @@ inline Status VisitArrayInline(const Array& array, VISITOR* visitor) { default: break; } - return Status::NotImplemented("Type not implemented"); + return Status::NotImplemented("Type '", array.type()->name(), "' not implemented"); } // Visit an array's data values, in order, without overhead. @@ -244,8 +247,8 @@ inline Status VisitScalarInline(const Scalar& scalar, VISITOR* visitor) { default: break; } - return Status::NotImplemented("Scalar visitor for type not implemented ", - scalar.type->ToString()); + return Status::NotImplemented("Scalar visitor for type '", scalar.type->name(), + "'not implemented "); } #undef TYPE_VISIT_INLINE