diff --git a/cpp/src/arrow/CMakeLists.txt b/cpp/src/arrow/CMakeLists.txt index 24e8eefad15..0550abd3ef8 100644 --- a/cpp/src/arrow/CMakeLists.txt +++ b/cpp/src/arrow/CMakeLists.txt @@ -223,6 +223,7 @@ set(ARROW_SRCS util/debug.cc util/decimal.cc util/delimiting.cc + util/dict_util.cc util/float16.cc util/formatting.cc util/future.cc diff --git a/cpp/src/arrow/array/array_test.cc b/cpp/src/arrow/array/array_test.cc index 46908439ef5..b295e37dfe7 100644 --- a/cpp/src/arrow/array/array_test.cc +++ b/cpp/src/arrow/array/array_test.cc @@ -80,6 +80,22 @@ class TestArray : public ::testing::Test { MemoryPool* pool_; }; +void CheckDictionaryNullCount(const std::shared_ptr& dict_type, + const std::string& input_dictionary_json, + const std::string& input_index_json, + const int64_t& expected_null_count, + const int64_t& expected_logical_null_count, + bool expected_may_have_nulls, + bool expected_may_have_logical_nulls) { + std::shared_ptr arr = + DictArrayFromJSON(dict_type, input_index_json, input_dictionary_json); + + ASSERT_EQ(arr->null_count(), expected_null_count); + ASSERT_EQ(arr->ComputeLogicalNullCount(), expected_logical_null_count); + ASSERT_EQ(arr->data()->MayHaveNulls(), expected_may_have_nulls); + ASSERT_EQ(arr->data()->MayHaveLogicalNulls(), expected_may_have_logical_nulls); +} + TEST_F(TestArray, TestNullCount) { // These are placeholders auto data = std::make_shared(nullptr, 0); @@ -127,6 +143,37 @@ TEST_F(TestArray, TestNullCount) { ASSERT_EQ(0, ree_no_nulls->ComputeLogicalNullCount()); ASSERT_FALSE(ree_no_nulls->data()->MayHaveNulls()); ASSERT_FALSE(ree_no_nulls->data()->MayHaveLogicalNulls()); + + // Dictionary type + std::shared_ptr type; + std::shared_ptr dict_type; + + for (const auto& index_type : all_dictionary_index_types()) { + ARROW_SCOPED_TRACE("index_type = ", index_type->ToString()); + + type = boolean(); + dict_type = dictionary(index_type, type); + // no null value + CheckDictionaryNullCount(dict_type, "[]", "[]", 0, 0, false, false); + CheckDictionaryNullCount(dict_type, "[true, false]", "[0, 1, 0]", 0, 0, false, false); + + // only indices contain null value + CheckDictionaryNullCount(dict_type, "[true, false]", "[null, 0, 1]", 1, 1, true, + true); + CheckDictionaryNullCount(dict_type, "[true, false]", "[null, null]", 2, 2, true, + true); + + // only dictionary contains null value + CheckDictionaryNullCount(dict_type, "[null, true]", "[]", 0, 0, false, true); + CheckDictionaryNullCount(dict_type, "[null, true, false]", "[0, 1, 0]", 0, 2, false, + true); + + // both indices and dictionary contain null value + CheckDictionaryNullCount(dict_type, "[null, true, false]", "[0, 1, 0, null]", 1, 3, + true, true); + CheckDictionaryNullCount(dict_type, "[null, true, null, false]", "[null, 1, 0, 2, 3]", + 1, 3, true, true); + } } TEST_F(TestArray, TestSlicePreservesAllNullCount) { @@ -137,6 +184,16 @@ TEST_F(TestArray, TestSlicePreservesAllNullCount) { Int32Array arr(/*length=*/100, data, null_bitmap, /*null_count*/ 100); EXPECT_EQ(arr.Slice(1, 99)->data()->null_count, arr.Slice(1, 99)->length()); + + // Dictionary type + std::shared_ptr dict_type = dictionary(int64(), boolean()); + std::shared_ptr dict_arr = + DictArrayFromJSON(dict_type, /*indices=*/"[null, 0, 0, 0, 0, 0, 1, 2, 0, 0]", + /*dictionary=*/"[null, true, false]"); + ASSERT_EQ(dict_arr->null_count(), 1); + ASSERT_EQ(dict_arr->ComputeLogicalNullCount(), 8); + ASSERT_EQ(dict_arr->Slice(2, 8)->null_count(), 0); + ASSERT_EQ(dict_arr->Slice(2, 8)->ComputeLogicalNullCount(), 6); } TEST_F(TestArray, TestLength) { diff --git a/cpp/src/arrow/array/data.cc b/cpp/src/arrow/array/data.cc index 186682be300..3959bf6cc18 100644 --- a/cpp/src/arrow/array/data.cc +++ b/cpp/src/arrow/array/data.cc @@ -33,6 +33,7 @@ #include "arrow/type_traits.h" #include "arrow/util/binary_view_util.h" #include "arrow/util/bitmap_ops.h" +#include "arrow/util/dict_util.h" #include "arrow/util/logging.h" #include "arrow/util/macros.h" #include "arrow/util/ree_util.h" @@ -93,6 +94,10 @@ bool RunEndEncodedMayHaveLogicalNulls(const ArrayData& data) { return ArraySpan(data).MayHaveLogicalNulls(); } +bool DictionaryMayHaveLogicalNulls(const ArrayData& data) { + return ArraySpan(data).MayHaveLogicalNulls(); +} + BufferSpan PackVariadicBuffers(util::span> buffers) { return {const_cast(reinterpret_cast(buffers.data())), static_cast(buffers.size() * sizeof(std::shared_ptr))}; @@ -174,7 +179,7 @@ int64_t ArrayData::GetNullCount() const { } int64_t ArrayData::ComputeLogicalNullCount() const { - if (this->buffers[0]) { + if (this->buffers[0] && this->type->id() != Type::DICTIONARY) { return GetNullCount(); } return ArraySpan(*this).ComputeLogicalNullCount(); @@ -520,6 +525,9 @@ int64_t ArraySpan::ComputeLogicalNullCount() const { if (t == Type::RUN_END_ENCODED) { return ree_util::LogicalNullCount(*this); } + if (t == Type::DICTIONARY) { + return dict_util::LogicalNullCount(*this); + } return GetNullCount(); } @@ -617,6 +625,10 @@ bool ArraySpan::RunEndEncodedMayHaveLogicalNulls() const { return ree_util::ValuesArray(*this).MayHaveLogicalNulls(); } +bool ArraySpan::DictionaryMayHaveLogicalNulls() const { + return this->GetNullCount() != 0 || this->dictionary().GetNullCount() != 0; +} + // ---------------------------------------------------------------------- // Implement internal::GetArrayView diff --git a/cpp/src/arrow/array/data.h b/cpp/src/arrow/array/data.h index 40a77640cd1..4c2df838149 100644 --- a/cpp/src/arrow/array/data.h +++ b/cpp/src/arrow/array/data.h @@ -38,7 +38,7 @@ struct ArrayData; namespace internal { // ---------------------------------------------------------------------- -// Null handling for types without a validity bitmap +// Null handling for types without a validity bitmap and the dictionary type ARROW_EXPORT bool IsNullSparseUnion(const ArrayData& data, int64_t i); ARROW_EXPORT bool IsNullDenseUnion(const ArrayData& data, int64_t i); @@ -46,6 +46,7 @@ ARROW_EXPORT bool IsNullRunEndEncoded(const ArrayData& data, int64_t i); ARROW_EXPORT bool UnionMayHaveLogicalNulls(const ArrayData& data); ARROW_EXPORT bool RunEndEncodedMayHaveLogicalNulls(const ArrayData& data); +ARROW_EXPORT bool DictionaryMayHaveLogicalNulls(const ArrayData& data); } // namespace internal // When slicing, we do not know the null count of the sliced range without @@ -280,7 +281,7 @@ struct ARROW_EXPORT ArrayData { /// \brief Return true if the validity bitmap may have 0's in it, or if the /// child arrays (in the case of types without a validity bitmap) may have - /// nulls + /// nulls, or if the dictionary of dictionay array may have nulls. /// /// This is not a drop-in replacement for MayHaveNulls, as historically /// MayHaveNulls() has been used to check for the presence of a validity @@ -325,6 +326,9 @@ struct ARROW_EXPORT ArrayData { if (t == Type::RUN_END_ENCODED) { return internal::RunEndEncodedMayHaveLogicalNulls(*this); } + if (t == Type::DICTIONARY) { + return internal::DictionaryMayHaveLogicalNulls(*this); + } return null_count.load() != 0; } @@ -505,7 +509,7 @@ struct ARROW_EXPORT ArraySpan { /// \brief Return true if the validity bitmap may have 0's in it, or if the /// child arrays (in the case of types without a validity bitmap) may have - /// nulls + /// nulls, or if the dictionary of dictionay array may have nulls. /// /// \see ArrayData::MayHaveLogicalNulls bool MayHaveLogicalNulls() const { @@ -519,6 +523,9 @@ struct ARROW_EXPORT ArraySpan { if (t == Type::RUN_END_ENCODED) { return RunEndEncodedMayHaveLogicalNulls(); } + if (t == Type::DICTIONARY) { + return DictionaryMayHaveLogicalNulls(); + } return null_count != 0; } @@ -560,6 +567,7 @@ struct ARROW_EXPORT ArraySpan { bool UnionMayHaveLogicalNulls() const; bool RunEndEncodedMayHaveLogicalNulls() const; + bool DictionaryMayHaveLogicalNulls() const; }; namespace internal { diff --git a/cpp/src/arrow/util/dict_util.cc b/cpp/src/arrow/util/dict_util.cc new file mode 100644 index 00000000000..feab2324a40 --- /dev/null +++ b/cpp/src/arrow/util/dict_util.cc @@ -0,0 +1,81 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#include "arrow/util/dict_util.h" +#include "arrow/array/array_dict.h" +#include "arrow/util/bit_util.h" +#include "arrow/util/checked_cast.h" + +namespace arrow { +namespace dict_util { + +namespace { + +template +int64_t LogicalNullCount(const ArraySpan& span) { + const auto* indices_null_bit_map = span.buffers[0].data; + const auto& dictionary_span = span.dictionary(); + const auto* dictionary_null_bit_map = dictionary_span.buffers[0].data; + + using CType = typename IndexArrowType::c_type; + const CType* indices_data = span.GetValues(1); + int64_t null_count = 0; + for (int64_t i = 0; i < span.length; i++) { + if (indices_null_bit_map != nullptr && + !bit_util::GetBit(indices_null_bit_map, i + span.offset)) { + null_count++; + continue; + } + + CType current_index = indices_data[i]; + if (!bit_util::GetBit(dictionary_null_bit_map, + current_index + dictionary_span.offset)) { + null_count++; + } + } + return null_count; +} + +} // namespace + +int64_t LogicalNullCount(const ArraySpan& span) { + if (span.dictionary().GetNullCount() == 0 || span.length == 0) { + return span.GetNullCount(); + } + + const auto& dict_array_type = internal::checked_cast(*span.type); + switch (dict_array_type.index_type()->id()) { + case Type::UINT8: + return LogicalNullCount(span); + case Type::INT8: + return LogicalNullCount(span); + case Type::UINT16: + return LogicalNullCount(span); + case Type::INT16: + return LogicalNullCount(span); + case Type::UINT32: + return LogicalNullCount(span); + case Type::INT32: + return LogicalNullCount(span); + case Type::UINT64: + return LogicalNullCount(span); + default: + return LogicalNullCount(span); + } +} +} // namespace dict_util +} // namespace arrow diff --git a/cpp/src/arrow/util/dict_util.h b/cpp/src/arrow/util/dict_util.h new file mode 100644 index 00000000000..a92733ae0f6 --- /dev/null +++ b/cpp/src/arrow/util/dict_util.h @@ -0,0 +1,28 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#pragma once + +#include "arrow/array/data.h" + +namespace arrow { +namespace dict_util { + +int64_t LogicalNullCount(const ArraySpan& span); + +} // namespace dict_util +} // namespace arrow