From 06b428dbe0fa9e03a11af0d467b57145240ddabe Mon Sep 17 00:00:00 2001 From: David Li Date: Thu, 26 Aug 2021 09:33:41 -0400 Subject: [PATCH 01/21] ARROW-13573: [C++] Add DictScalarFromJSON --- cpp/src/arrow/ipc/json_simple.cc | 19 +++++++++++++++++++ cpp/src/arrow/ipc/json_simple.h | 5 +++++ cpp/src/arrow/ipc/json_simple_test.cc | 23 +++++++++++++++++++++++ cpp/src/arrow/testing/gtest_util.cc | 9 +++++++++ cpp/src/arrow/testing/gtest_util.h | 5 +++++ 5 files changed, 61 insertions(+) diff --git a/cpp/src/arrow/ipc/json_simple.cc b/cpp/src/arrow/ipc/json_simple.cc index 34b0f3fba59..8347b871b1f 100644 --- a/cpp/src/arrow/ipc/json_simple.cc +++ b/cpp/src/arrow/ipc/json_simple.cc @@ -969,6 +969,25 @@ Status ScalarFromJSON(const std::shared_ptr& type, return Status::OK(); } +Status DictScalarFromJSON(const std::shared_ptr& type, + util::string_view index_json, util::string_view dictionary_json, + std::shared_ptr* out) { + if (type->id() != Type::DICTIONARY) { + return Status::TypeError("DictScalarFromJSON requires dictionary type, got ", *type); + } + + const auto& dictionary_type = checked_cast(*type); + + std::shared_ptr index; + std::shared_ptr dictionary; + RETURN_NOT_OK(ScalarFromJSON(dictionary_type.index_type(), index_json, &index)); + RETURN_NOT_OK( + ArrayFromJSON(dictionary_type.value_type(), dictionary_json, &dictionary)); + + *out = DictionaryScalar::Make(std::move(index), std::move(dictionary)); + return Status::OK(); +} + } // namespace json } // namespace internal } // namespace ipc diff --git a/cpp/src/arrow/ipc/json_simple.h b/cpp/src/arrow/ipc/json_simple.h index 4dd3a664aa6..8269bd65326 100644 --- a/cpp/src/arrow/ipc/json_simple.h +++ b/cpp/src/arrow/ipc/json_simple.h @@ -55,6 +55,11 @@ ARROW_EXPORT Status ScalarFromJSON(const std::shared_ptr&, util::string_view json, std::shared_ptr* out); +ARROW_EXPORT +Status DictScalarFromJSON(const std::shared_ptr&, util::string_view index_json, + util::string_view dictionary_json, + std::shared_ptr* out); + } // namespace json } // namespace internal } // namespace ipc diff --git a/cpp/src/arrow/ipc/json_simple_test.cc b/cpp/src/arrow/ipc/json_simple_test.cc index ce2c37b7957..372f6bf1d72 100644 --- a/cpp/src/arrow/ipc/json_simple_test.cc +++ b/cpp/src/arrow/ipc/json_simple_test.cc @@ -1385,6 +1385,29 @@ TEST(TestScalarFromJSON, Errors) { ASSERT_RAISES(Invalid, ScalarFromJSON(boolean(), "\"true\"", &scalar)); } +TEST(TestDictScalarFromJSON, Basics) { + auto type = dictionary(int32(), utf8()); + auto dict = R"(["whiskey", "tango", "foxtrot"])"; + auto expected_dictionary = ArrayFromJSON(utf8(), dict); + + for (auto index : {"null", "2", "1", "0"}) { + auto scalar = DictScalarFromJSON(type, index, dict); + auto expected_index = ScalarFromJSON(int32(), index); + AssertScalarsEqual(*DictionaryScalar::Make(expected_index, expected_dictionary), + *scalar); + } +} + +TEST(TestDictScalarFromJSON, Errors) { + auto type = dictionary(int32(), utf8()); + std::shared_ptr scalar; + + ASSERT_RAISES(Invalid, + DictScalarFromJSON(type, "\"not a valid index\"", "[\"\"]", &scalar)); + ASSERT_RAISES(Invalid, DictScalarFromJSON(type, "0", "[1]", + &scalar)); // dict value isn't string +} + } // namespace json } // namespace internal } // namespace ipc diff --git a/cpp/src/arrow/testing/gtest_util.cc b/cpp/src/arrow/testing/gtest_util.cc index 587154c1f30..24f5edcc6cb 100644 --- a/cpp/src/arrow/testing/gtest_util.cc +++ b/cpp/src/arrow/testing/gtest_util.cc @@ -446,6 +446,15 @@ std::shared_ptr ScalarFromJSON(const std::shared_ptr& type, return out; } +std::shared_ptr DictScalarFromJSON(const std::shared_ptr& type, + util::string_view index_json, + util::string_view dictionary_json) { + std::shared_ptr out; + ABORT_NOT_OK( + ipc::internal::json::DictScalarFromJSON(type, index_json, dictionary_json, &out)); + return out; +} + std::shared_ptr TableFromJSON(const std::shared_ptr& schema, const std::vector& json) { std::vector> batches; diff --git a/cpp/src/arrow/testing/gtest_util.h b/cpp/src/arrow/testing/gtest_util.h index f0021e05603..65ab33c5d1f 100644 --- a/cpp/src/arrow/testing/gtest_util.h +++ b/cpp/src/arrow/testing/gtest_util.h @@ -338,6 +338,11 @@ ARROW_TESTING_EXPORT std::shared_ptr ScalarFromJSON(const std::shared_ptr&, util::string_view json); +ARROW_TESTING_EXPORT +std::shared_ptr DictScalarFromJSON(const std::shared_ptr&, + util::string_view index_json, + util::string_view dictionary_json); + ARROW_TESTING_EXPORT std::shared_ptr
TableFromJSON(const std::shared_ptr&, const std::vector& json); From 2ec113219f2d9733d0534de70eb0dd05f6124b75 Mon Sep 17 00:00:00 2001 From: David Li Date: Thu, 26 Aug 2021 10:59:54 -0400 Subject: [PATCH 02/21] ARROW-13573: [C++] Check that dictionary array has dictionary --- cpp/src/arrow/array/validate.cc | 3 +++ 1 file changed, 3 insertions(+) diff --git a/cpp/src/arrow/array/validate.cc b/cpp/src/arrow/array/validate.cc index c66c4f53b9d..1715863014c 100644 --- a/cpp/src/arrow/array/validate.cc +++ b/cpp/src/arrow/array/validate.cc @@ -568,6 +568,9 @@ struct ValidateArrayFullImpl { } Status Visit(const DictionaryType& type) { + if (!data.dictionary) { + return Status::Invalid("Dictionary array has no dictionary"); + } const Status indices_status = CheckBounds(*type.index_type(), 0, data.dictionary->length - 1); if (!indices_status.ok()) { From e3b7f93f892bb01b735f3066bf814779b7385350 Mon Sep 17 00:00:00 2001 From: David Li Date: Thu, 26 Aug 2021 11:00:32 -0400 Subject: [PATCH 03/21] ARROW-13573: [C++] Handle simple dictionary cases --- .../arrow/compute/kernels/scalar_if_else.cc | 85 ++++++++++++++++++- .../compute/kernels/scalar_if_else_test.cc | 81 ++++++++++++++++++ 2 files changed, 164 insertions(+), 2 deletions(-) diff --git a/cpp/src/arrow/compute/kernels/scalar_if_else.cc b/cpp/src/arrow/compute/kernels/scalar_if_else.cc index 4de04da7a81..fc69c0e2fd0 100644 --- a/cpp/src/arrow/compute/kernels/scalar_if_else.cc +++ b/cpp/src/arrow/compute/kernels/scalar_if_else.cc @@ -1092,6 +1092,63 @@ struct CopyFixedWidth> { } }; +template +struct CopyFixedWidth> { + // TODO: how are we going to deal with passing down a mapping? + static void CopyScalar(const Scalar& scalar, const int64_t length, + uint8_t* raw_out_values, const int64_t out_offset) { + const auto& index = *checked_cast(scalar).value.index; + switch (index.type->id()) { + case arrow::Type::INT8: + case arrow::Type::UINT8: + CopyFixedWidth::CopyScalar(index, length, raw_out_values, out_offset); + break; + case arrow::Type::INT16: + case arrow::Type::UINT16: + CopyFixedWidth::CopyScalar(index, length, raw_out_values, out_offset); + break; + case arrow::Type::INT32: + case arrow::Type::UINT32: + CopyFixedWidth::CopyScalar(index, length, raw_out_values, out_offset); + break; + case arrow::Type::INT64: + case arrow::Type::UINT64: + CopyFixedWidth::CopyScalar(index, length, raw_out_values, out_offset); + break; + default: + ARROW_CHECK(false) << "Invalid index type for dictionary: " << *index.type; + } + } + static void CopyArray(const DataType& type, const uint8_t* in_values, + const int64_t in_offset, const int64_t length, + uint8_t* raw_out_values, const int64_t out_offset) { + const auto& index_type = *checked_cast(type).index_type(); + switch (index_type.id()) { + case arrow::Type::INT8: + case arrow::Type::UINT8: + CopyFixedWidth::CopyArray(index_type, in_values, in_offset, length, + raw_out_values, out_offset); + break; + case arrow::Type::INT16: + case arrow::Type::UINT16: + CopyFixedWidth::CopyArray(index_type, in_values, in_offset, length, + raw_out_values, out_offset); + break; + case arrow::Type::INT32: + case arrow::Type::UINT32: + CopyFixedWidth::CopyArray(index_type, in_values, in_offset, length, + raw_out_values, out_offset); + break; + case arrow::Type::INT64: + case arrow::Type::UINT64: + CopyFixedWidth::CopyArray(index_type, in_values, in_offset, length, + raw_out_values, out_offset); + break; + default: + ARROW_CHECK(false) << "Invalid index type for dictionary: " << index_type; + } + } +}; template struct CopyFixedWidth> { static void CopyScalar(const Scalar& values, const int64_t length, @@ -1222,7 +1279,6 @@ struct CaseWhenFunction : ScalarFunction { // The first function is a struct of booleans, where the number of fields in the // struct is either equal to the number of other arguments or is one less. RETURN_NOT_OK(CheckArity(*values)); - EnsureDictionaryDecoded(values); auto first_type = (*values)[0].type; if (first_type->id() != Type::STRUCT) { return Status::TypeError("case_when: first argument must be STRUCT, not ", @@ -1243,6 +1299,9 @@ struct CaseWhenFunction : ScalarFunction { } } + if (auto kernel = DispatchExactImpl(this, *values)) return kernel; + + EnsureDictionaryDecoded(values); if (auto type = CommonNumeric(values->data() + 1, values->size() - 1)) { for (auto it = values->begin() + 1; it != values->end(); it++) { it->type = type; @@ -1275,10 +1334,20 @@ Status ExecScalarCaseWhen(KernelContext* ctx, const ExecBatch& batch, Datum* out } } if (out->is_scalar()) { + // TODO: makenullscalar here should give the expected dictionary *out = result.is_scalar() ? result.scalar() : MakeNullScalar(out->type()); return Status::OK(); } ArrayData* output = out->mutable_array(); + if (is_dictionary_type::value) { + const Datum& dict_from = result.is_value() ? result : batch[1]; + if (dict_from.is_scalar()) { + output->dictionary = checked_cast(*dict_from.scalar()) + .value.dictionary->data(); + } else { + output->dictionary = dict_from.array()->dictionary; + } + } if (!result.is_value()) { // All conditions false, no 'else' argument result = MakeNullScalar(out->type()); @@ -1313,6 +1382,17 @@ Status ExecArrayCaseWhen(KernelContext* ctx, const ExecBatch& batch, Datum* out) BitUtil::SetBitsTo(out_valid, out_offset, batch.length, false); } + if (is_dictionary_type::value) { + // We always use the dictionary of the first argument + const Datum& dict_from = batch[1]; + if (dict_from.is_scalar()) { + output->dictionary = checked_cast(*dict_from.scalar()) + .value.dictionary->data(); + } else { + output->dictionary = dict_from.array()->dictionary; + } + } + // Allocate a temporary bitmap to determine which elements still need setting. ARROW_ASSIGN_OR_RAISE(auto mask_buffer, ctx->AllocateBitmap(batch.length)); uint8_t* mask = mask_buffer->mutable_data(); @@ -2446,7 +2526,7 @@ void RegisterScalarIfElse(FunctionRegistry* registry) { } { auto func = std::make_shared( - "case_when", Arity::VarArgs(/*min_args=*/1), &case_when_doc); + "case_when", Arity::VarArgs(/*min_args=*/2), &case_when_doc); AddPrimitiveCaseWhenKernels(func, NumericTypes()); AddPrimitiveCaseWhenKernels(func, TemporalTypes()); AddPrimitiveCaseWhenKernels(func, IntervalTypes()); @@ -2464,6 +2544,7 @@ void RegisterScalarIfElse(FunctionRegistry* registry) { AddCaseWhenKernel(func, Type::STRUCT, CaseWhenFunctor::Exec); AddCaseWhenKernel(func, Type::DENSE_UNION, CaseWhenFunctor::Exec); AddCaseWhenKernel(func, Type::SPARSE_UNION, CaseWhenFunctor::Exec); + AddCaseWhenKernel(func, Type::DICTIONARY, CaseWhenFunctor::Exec); DCHECK_OK(registry->AddFunction(std::move(func))); } { diff --git a/cpp/src/arrow/compute/kernels/scalar_if_else_test.cc b/cpp/src/arrow/compute/kernels/scalar_if_else_test.cc index b3b0f26cead..ca6b688eb3d 100644 --- a/cpp/src/arrow/compute/kernels/scalar_if_else_test.cc +++ b/cpp/src/arrow/compute/kernels/scalar_if_else_test.cc @@ -624,6 +624,87 @@ TYPED_TEST(TestCaseWhenNumeric, ListOfType) { ArrayFromJSON(type, R"([null, null, null, [6, null]])")); } +template +class TestCaseWhenInteger : public ::testing::Test {}; + +TYPED_TEST_SUITE(TestCaseWhenInteger, IntegralArrowTypes); + +TYPED_TEST(TestCaseWhenInteger, DictionaryEncodingSimple) { + auto type = dictionary(default_type_instance(), utf8()); + auto cond1 = ArrayFromJSON(boolean(), "[true, true, null, null]"); + auto cond2 = ArrayFromJSON(boolean(), "[true, false, true, null]"); + auto dict = R"(["a", null, "bc", "def"])"; + auto values_null = DictArrayFromJSON(type, "[null, null, null, null]", dict); + auto values1 = DictArrayFromJSON(type, "[0, null, 3, 1]", dict); + auto values2 = DictArrayFromJSON(type, "[2, 1, null, 0]", dict); + + // Easy case: all arguments have the same dictionary + // CheckScalar("case_when", {MakeStruct({Datum(false)}), DictScalarFromJSON(type, "1", + // dict)}, + // DictScalarFromJSON(type, "[0, null, null, null]", dict)); + CheckScalar("case_when", {MakeStruct({cond1, cond2}), values1, values2}, + DictArrayFromJSON(type, "[0, null, null, null]", dict)); + CheckScalar("case_when", {MakeStruct({cond1, cond2}), values1, values2, values1}, + DictArrayFromJSON(type, "[0, null, null, 1]", dict)); + CheckScalar("case_when", {MakeStruct({cond1, cond2}), values_null, values2, values1}, + DictArrayFromJSON(type, "[null, null, null, 1]", dict)); +} + +TYPED_TEST(TestCaseWhenInteger, DictionaryEncodingMixed) { + auto type = dictionary(default_type_instance(), utf8()); + auto cond1 = ArrayFromJSON(boolean(), "[true, true, null, null]"); + auto cond2 = ArrayFromJSON(boolean(), "[true, false, true, null]"); + auto dict = R"(["a", null, "bc", "def"])"; + auto values_null = DictArrayFromJSON(type, "[null, null, null, null]", dict); + auto values1_dict = DictArrayFromJSON(type, "[0, null, 3, 1]", dict); + auto values1_decoded = ArrayFromJSON(utf8(), R"(["a", null, "def", null])"); + auto values2_dict = DictArrayFromJSON(type, "[2, 1, null, 0]", dict); + auto values2_decoded = ArrayFromJSON(utf8(), R"(["bc", null, null, "a"])"); + + // If we have mixed dictionary/non-dictionary arguments, we decode dictionaries + CheckScalar("case_when", {MakeStruct({cond1, cond2}), values1_dict, values2_decoded}, + ArrayFromJSON(utf8(), R"(["a", null, null, null])")); + CheckScalar("case_when", {MakeStruct({cond1, cond2}), values1_decoded, values2_dict}, + ArrayFromJSON(utf8(), R"(["a", null, null, null])")); + CheckScalar("case_when", + {MakeStruct({cond1, cond2}), values1_dict, values2_dict, values1_decoded}, + ArrayFromJSON(utf8(), R"(["a", null, null, null])")); + CheckScalar("case_when", + {MakeStruct({cond1, cond2}), values_null, values2_dict, values1_decoded}, + ArrayFromJSON(utf8(), R"([null, null, null, null])")); +} + +TYPED_TEST(TestCaseWhenInteger, DictionaryEncodingDifferentDictionaries) { + auto type = dictionary(default_type_instance(), utf8()); + auto cond1 = ArrayFromJSON(boolean(), "[true, true, null, null]"); + auto cond2 = ArrayFromJSON(boolean(), "[true, false, true, null]"); + auto dict1 = R"(["a", null, "bc", "def"])"; + auto dict2 = R"(["bc", "foo", null, "a"])"; + auto values1_null = DictArrayFromJSON(type, "[null, null, null, null]", dict1); + auto values2_null = DictArrayFromJSON(type, "[null, null, null, null]", dict2); + auto values1 = DictArrayFromJSON(type, "[0, null, 3, 1]", dict1); + auto values2 = DictArrayFromJSON(type, "[2, 1, null, 0]", dict2); + + // For scalar conditions, we borrow the dictionary of the chosen output (or the first + // input when outputting null) + CheckScalar("case_when", {MakeStruct({Datum(true), Datum(false)}), values1, values2}, + values1); + CheckScalar("case_when", {MakeStruct({Datum(false), Datum(true)}), values1, values2}, + values2); + CheckScalar("case_when", {MakeStruct({Datum(false), Datum(false)}), values1, values2}, + values1_null); + CheckScalar("case_when", {MakeStruct({Datum(false), Datum(false)}), values2, values1}, + values2_null); + + // For array conditions, we always borrow the dictionary of the first input + + // When mixing dictionaries, we try to map other dictionaries onto the first one + + // If we can't map values from a dictionary, then raise an error + + // ...or optionally, raise an error +} + TEST(TestCaseWhen, Null) { auto cond_true = ScalarFromJSON(boolean(), "true"); auto cond_false = ScalarFromJSON(boolean(), "false"); From 7a57c91420222185890f1f74ec46f86388da9cd9 Mon Sep 17 00:00:00 2001 From: David Li Date: Fri, 27 Aug 2021 11:41:38 -0400 Subject: [PATCH 04/21] ARROW-13573: [C++] Transpose dictionaries in case_when --- .../arrow/compute/kernels/scalar_if_else.cc | 377 +++++++++++++++++- .../compute/kernels/scalar_if_else_test.cc | 148 ++++++- cpp/src/arrow/compute/kernels/test_util.cc | 14 +- cpp/src/arrow/compute/kernels/test_util.h | 7 + 4 files changed, 498 insertions(+), 48 deletions(-) diff --git a/cpp/src/arrow/compute/kernels/scalar_if_else.cc b/cpp/src/arrow/compute/kernels/scalar_if_else.cc index fc69c0e2fd0..6a2ca01c439 100644 --- a/cpp/src/arrow/compute/kernels/scalar_if_else.cc +++ b/cpp/src/arrow/compute/kernels/scalar_if_else.cc @@ -15,10 +15,13 @@ // specific language governing permissions and limitations // under the License. +#include + #include "arrow/array/builder_nested.h" #include "arrow/array/builder_primitive.h" #include "arrow/array/builder_time.h" #include "arrow/array/builder_union.h" +#include "arrow/array/dict_internal.h" #include "arrow/compute/api.h" #include "arrow/compute/kernels/codegen_internal.h" #include "arrow/compute/util_internal.h" @@ -27,6 +30,7 @@ #include "arrow/util/bitmap.h" #include "arrow/util/bitmap_ops.h" #include "arrow/util/bitmap_reader.h" +#include "arrow/util/int_util.h" namespace arrow { @@ -1058,6 +1062,109 @@ void AddFSBinaryIfElseKernel(const std::shared_ptr& scalar_funct DCHECK_OK(scalar_function->AddKernel(std::move(kernel))); } +// Given a reference dictionary, computes indices to map dictionary values from a +// comparison dictionary to the reference. +class DictionaryRemapper { + public: + virtual ~DictionaryRemapper() = default; + virtual Status Init(const Array& dictionary) = 0; + virtual Result> Remap(const Array& dictionary) = 0; + Result> Remap(const Datum& has_dictionary) { + DCHECK_EQ(has_dictionary.type()->id(), Type::DICTIONARY); + if (has_dictionary.is_scalar()) { + return Remap(*checked_cast(*has_dictionary.scalar()) + .value.dictionary); + } else { + return Remap(*MakeArray(has_dictionary.array()->dictionary)); + } + } +}; + +template +class DictionaryRemapperImpl : public DictionaryRemapper { + public: + using ArrayType = typename TypeTraits::ArrayType; + using MemoTableType = typename arrow::internal::HashTraits::MemoTableType; + + explicit DictionaryRemapperImpl(MemoryPool* pool) : pool_(pool), memo_table_(pool) {} + + Status Init(const Array& dictionary) override { + const ArrayType& values = checked_cast(dictionary); + if (values.length() > std::numeric_limits::max()) { + return Status::CapacityError("Cannot remap dictionary with more than ", + std::numeric_limits::max(), + " elements, have: ", values.length()); + } + for (int32_t i = 0; i < values.length(); ++i) { + if (values.IsNull(i)) { + memo_table_.GetOrInsertNull(); + continue; + } + int32_t unused_memo_index = 0; + RETURN_NOT_OK(memo_table_.GetOrInsert(values.GetView(i), &unused_memo_index)); + } + return Status::OK(); + } + + Result> Remap(const Array& dictionary) override { + const ArrayType& values = checked_cast(dictionary); + std::shared_ptr valid_buffer; + ARROW_ASSIGN_OR_RAISE(auto indices_buffer, + AllocateBuffer(dictionary.length() * sizeof(int32_t), pool_)); + int32_t* indices = reinterpret_cast(indices_buffer->mutable_data()); + int64_t null_count = 0; + for (int64_t i = 0; i < values.length(); ++i) { + int32_t index = -1; + index = + values.IsNull(i) ? memo_table_.GetNull() : memo_table_.Get(values.GetView(i)); + indices[i] = std::max(0, index); + if (index == arrow::internal::kKeyNotFound && !valid_buffer) { + ARROW_ASSIGN_OR_RAISE( + valid_buffer, + AllocateBuffer(BitUtil::BytesForBits(dictionary.length()), pool_)); + std::memset(valid_buffer->mutable_data(), 0xFF, valid_buffer->size()); + } + if (index == arrow::internal::kKeyNotFound) { + BitUtil::ClearBit(valid_buffer->mutable_data(), i); + null_count++; + } + } + return arrow::internal::make_unique(dictionary.length(), + std::move(indices_buffer), + std::move(valid_buffer), null_count); + } + + private: + MemoryPool* pool_; + MemoTableType memo_table_; +}; + +struct MakeRemapper { + template + enable_if_no_memoize Visit(const T& value_type) { + return Status::NotImplemented("Unification of ", value_type, + " dictionaries is not implemented"); + } + + template + enable_if_memoize Visit(const T&) { + result_.reset(new DictionaryRemapperImpl(pool_)); + return Status::OK(); + } + + static Result> Make(MemoryPool* pool, + const Array& dictionary) { + const auto& value_type = *dictionary.type(); + MakeRemapper impl{pool, /*result_=*/nullptr}; + RETURN_NOT_OK(VisitTypeInline(value_type, &impl)); + RETURN_NOT_OK(impl.result_->Init(dictionary)); + return std::move(impl.result_); + } + + MemoryPool* pool_; + std::unique_ptr result_; +}; + // Helper to copy or broadcast fixed-width values between buffers. template struct CopyFixedWidth {}; @@ -1084,17 +1191,33 @@ struct CopyFixedWidth> { const CType value = UnboxScalar::Unbox(scalar); std::fill(out_values + out_offset, out_values + out_offset + length, value); } + static void CopyScalar(const Scalar& scalar, const Int32Array& transpose_map, + const int64_t length, uint8_t* raw_out_values, + const int64_t out_offset) { + CType* out_values = reinterpret_cast(raw_out_values); + const CType value = UnboxScalar::Unbox(scalar); + const CType transposed = static_cast(transpose_map.raw_values()[value]); + std::fill(out_values + out_offset, out_values + out_offset + length, transposed); + } static void CopyArray(const DataType&, const uint8_t* in_values, const int64_t in_offset, const int64_t length, uint8_t* raw_out_values, const int64_t out_offset) { std::memcpy(raw_out_values + out_offset * sizeof(CType), in_values + in_offset * sizeof(CType), length * sizeof(CType)); } + static void CopyArray(const DataType&, const Int32Array& transpose_map, + const uint8_t* in_values, const int64_t in_offset, + const int64_t length, uint8_t* raw_out_values, + const int64_t out_offset) { + arrow::internal::TransposeInts( + reinterpret_cast(in_values) + in_offset, + reinterpret_cast(raw_out_values) + out_offset, length, + transpose_map.raw_values()); + } }; template struct CopyFixedWidth> { - // TODO: how are we going to deal with passing down a mapping? static void CopyScalar(const Scalar& scalar, const int64_t length, uint8_t* raw_out_values, const int64_t out_offset) { const auto& index = *checked_cast(scalar).value.index; @@ -1119,6 +1242,35 @@ struct CopyFixedWidth> { ARROW_CHECK(false) << "Invalid index type for dictionary: " << *index.type; } } + static void CopyScalar(const Scalar& scalar, const Int32Array& transpose_map, + const int64_t length, uint8_t* raw_out_values, + const int64_t out_offset) { + const auto& index = *checked_cast(scalar).value.index; + switch (index.type->id()) { + case arrow::Type::INT8: + case arrow::Type::UINT8: + CopyFixedWidth::CopyScalar(index, transpose_map, length, + raw_out_values, out_offset); + break; + case arrow::Type::INT16: + case arrow::Type::UINT16: + CopyFixedWidth::CopyScalar(index, transpose_map, length, + raw_out_values, out_offset); + break; + case arrow::Type::INT32: + case arrow::Type::UINT32: + CopyFixedWidth::CopyScalar(index, transpose_map, length, + raw_out_values, out_offset); + break; + case arrow::Type::INT64: + case arrow::Type::UINT64: + CopyFixedWidth::CopyScalar(index, transpose_map, length, + raw_out_values, out_offset); + break; + default: + ARROW_CHECK(false) << "Invalid index type for dictionary: " << *index.type; + } + } static void CopyArray(const DataType& type, const uint8_t* in_values, const int64_t in_offset, const int64_t length, uint8_t* raw_out_values, const int64_t out_offset) { @@ -1148,6 +1300,40 @@ struct CopyFixedWidth> { ARROW_CHECK(false) << "Invalid index type for dictionary: " << index_type; } } + static void CopyArray(const DataType& type, const Int32Array& transpose_map, + const uint8_t* in_values, const int64_t in_offset, + const int64_t length, uint8_t* raw_out_values, + const int64_t out_offset) { + const auto& index_type = *checked_cast(type).index_type(); + switch (index_type.id()) { + case arrow::Type::INT8: + case arrow::Type::UINT8: + CopyFixedWidth::CopyArray(index_type, transpose_map, in_values, + in_offset, length, raw_out_values, + out_offset); + break; + case arrow::Type::INT16: + case arrow::Type::UINT16: + CopyFixedWidth::CopyArray(index_type, transpose_map, in_values, + in_offset, length, raw_out_values, + out_offset); + break; + case arrow::Type::INT32: + case arrow::Type::UINT32: + CopyFixedWidth::CopyArray(index_type, transpose_map, in_values, + in_offset, length, raw_out_values, + out_offset); + break; + case arrow::Type::INT64: + case arrow::Type::UINT64: + CopyFixedWidth::CopyArray(index_type, transpose_map, in_values, + in_offset, length, raw_out_values, + out_offset); + break; + default: + ARROW_CHECK(false) << "Invalid index type for dictionary: " << index_type; + } + } }; template struct CopyFixedWidth> { @@ -1203,8 +1389,11 @@ struct CopyFixedWidth> { // Copy fixed-width values from a scalar/array datum into an output values buffer template -void CopyValues(const Datum& in_values, const int64_t in_offset, const int64_t length, - uint8_t* out_valid, uint8_t* out_values, const int64_t out_offset) { +enable_if_t::value> CopyValues( + const Datum& in_values, const int64_t in_offset, const int64_t length, + uint8_t* out_valid, uint8_t* out_values, const int64_t out_offset, + const Int32Array* transpose_map = nullptr) { + DCHECK(!transpose_map); if (in_values.is_scalar()) { const auto& scalar = *in_values.scalar(); if (out_valid) { @@ -1234,6 +1423,123 @@ void CopyValues(const Datum& in_values, const int64_t in_offset, const int64_t l } } +// Copy values, optionally transposing dictionary indices. +template +enable_if_dictionary CopyValues(const Datum& in_values, const int64_t in_offset, + const int64_t length, uint8_t* out_valid, + uint8_t* out_values, const int64_t out_offset, + const Int32Array* transpose_map) { + if (in_values.is_scalar()) { + const auto& scalar = *in_values.scalar(); + if (out_valid) { + BitUtil::SetBitsTo(out_valid, out_offset, length, scalar.is_valid); + } + if (transpose_map) { + CopyFixedWidth::CopyScalar(scalar, *transpose_map, length, out_values, + out_offset); + } else { + CopyFixedWidth::CopyScalar(scalar, length, out_values, out_offset); + } + } else { + const ArrayData& array = *in_values.array(); + if (out_valid) { + if (array.MayHaveNulls()) { + if (length == 1) { + // CopyBitmap is slow for short runs + BitUtil::SetBitTo( + out_valid, out_offset, + BitUtil::GetBit(array.buffers[0]->data(), array.offset + in_offset)); + } else { + arrow::internal::CopyBitmap(array.buffers[0]->data(), array.offset + in_offset, + length, out_valid, out_offset); + } + } else { + BitUtil::SetBitsTo(out_valid, out_offset, length, true); + } + } + if (transpose_map) { + CopyFixedWidth::CopyArray(*array.type, *transpose_map, + array.buffers[1]->data(), array.offset + in_offset, + length, out_values, out_offset); + } else { + CopyFixedWidth::CopyArray(*array.type, array.buffers[1]->data(), + array.offset + in_offset, length, out_values, + out_offset); + } + } +} + +/// Check that we can actually remap dictionary indices from one dictionary to +/// another without losing data. +struct CheckValidTranspositionArrayImpl { + template + enable_if_number Visit(const T&) { + using c_type = typename T::c_type; + const c_type* values = arr.GetValues(1); + // TODO: eventually offer the option to zero out the bitmap instead + return arrow::internal::VisitSetBitRuns( + arr.buffers[0], arr.offset + offset, length, + [&](int64_t position, int64_t length) { + for (int64_t i = 0; i < length; i++) { + const uint64_t idx = static_cast(values[offset + position + i]); + if (!BitUtil::GetBit(transpose_valid, idx)) { + return Status::Invalid("Cannot map dictionary index ", idx, " at position ", + offset + position + i, " to the common dictionary"); + } + } + return Status::OK(); + }); + } + + Status Visit(const DataType& ty) { + return Status::TypeError("Dictionary cannot have index type", ty); + } + + const ArrayData& arr; + const int64_t offset; + const int64_t length; + const uint8_t* transpose_valid; +}; + +struct CheckValidTranspositionScalarImpl { + template + enable_if_number Visit(const T&) { + const uint64_t idx = static_cast(UnboxScalar::Unbox( + *checked_cast(scalar).value.index)); + // TODO: eventually offer the option to zero out the bitmap instead + if (!BitUtil::GetBit(transpose_valid, idx)) { + return Status::Invalid("Cannot map dictionary index ", idx, + " to the common dictionary"); + } + return Status::OK(); + } + + Status Visit(const DataType& ty) { + return Status::TypeError("Dictionary cannot have index type", ty); + } + + const Scalar& scalar; + const uint8_t* transpose_valid; +}; + +Status CheckValidTransposition(const Datum& values, const int64_t offset, + const int64_t length, const Int32Array* transpose_map) { + // Note we assume the transpose map never has an offset + if (!transpose_map || transpose_map->null_count() == 0) return Status::OK(); + DCHECK_EQ(values.type()->id(), Type::DICTIONARY); + if (values.is_scalar()) { + const Scalar& scalar = *values.scalar(); + CheckValidTranspositionScalarImpl impl{scalar, transpose_map->null_bitmap_data()}; + return VisitTypeInline( + *checked_cast(*scalar.type).index_type(), &impl); + } + const ArrayData& arr = *values.array(); + CheckValidTranspositionArrayImpl impl{arr, offset, length, + transpose_map->null_bitmap_data()}; + return VisitTypeInline(*checked_cast(*arr.type).index_type(), + &impl); +} + // Specialized helper to copy a single value from a source array. Allows avoiding // repeatedly calling MayHaveNulls and Buffer::data() which have internal checks that // add up when called in a loop. @@ -1334,7 +1640,6 @@ Status ExecScalarCaseWhen(KernelContext* ctx, const ExecBatch& batch, Datum* out } } if (out->is_scalar()) { - // TODO: makenullscalar here should give the expected dictionary *out = result.is_scalar() ? result.scalar() : MakeNullScalar(out->type()); return Status::OK(); } @@ -1352,9 +1657,9 @@ Status ExecScalarCaseWhen(KernelContext* ctx, const ExecBatch& batch, Datum* out // All conditions false, no 'else' argument result = MakeNullScalar(out->type()); } - CopyValues(result, /*in_offset=*/0, batch.length, - output->GetMutableValues(0, 0), - output->GetMutableValues(1, 0), output->offset); + CopyValues( + result, /*in_offset=*/0, batch.length, output->GetMutableValues(0, 0), + output->GetMutableValues(1, 0), output->offset, /*transpose_map=*/nullptr); return Status::OK(); } @@ -1373,15 +1678,9 @@ Status ExecArrayCaseWhen(KernelContext* ctx, const ExecBatch& batch, Datum* out) static_cast(conds_array.type->num_fields()) < num_value_args; uint8_t* out_valid = output->buffers[0]->mutable_data(); uint8_t* out_values = output->buffers[1]->mutable_data(); - if (have_else_arg) { - // Copy 'else' value into output - CopyValues(batch.values.back(), /*in_offset=*/0, batch.length, out_valid, - out_values, out_offset); - } else { - // There's no 'else' argument, so we should have an all-null validity bitmap - BitUtil::SetBitsTo(out_valid, out_offset, batch.length, false); - } + std::unique_ptr remapper; + std::unique_ptr transpose_map; if (is_dictionary_type::value) { // We always use the dictionary of the first argument const Datum& dict_from = batch[1]; @@ -1391,6 +1690,20 @@ Status ExecArrayCaseWhen(KernelContext* ctx, const ExecBatch& batch, Datum* out) } else { output->dictionary = dict_from.array()->dictionary; } + ARROW_ASSIGN_OR_RAISE( + remapper, MakeRemapper::Make(ctx->memory_pool(), *MakeArray(output->dictionary))); + } + + if (have_else_arg) { + // Copy 'else' value into output + if (is_dictionary_type::value) { + ARROW_ASSIGN_OR_RAISE(transpose_map, remapper->Remap(batch.values.back())); + } + CopyValues(batch.values.back(), /*in_offset=*/0, batch.length, out_valid, + out_values, out_offset, transpose_map.get()); + } else { + // There's no 'else' argument, so we should have an all-null validity bitmap + BitUtil::SetBitsTo(out_valid, out_offset, batch.length, false); } // Allocate a temporary bitmap to determine which elements still need setting. @@ -1406,6 +1719,10 @@ Status ExecArrayCaseWhen(KernelContext* ctx, const ExecBatch& batch, Datum* out) const Datum& values_datum = batch[i + 1]; int64_t offset = 0; + if (is_dictionary_type::value) { + ARROW_ASSIGN_OR_RAISE(transpose_map, remapper->Remap(values_datum)); + } + if (cond_array.GetNullCount() == 0) { // If no valid buffer, visit mask & cond bitmap simultaneously BinaryBitBlockCounter counter(mask, /*start_offset=*/0, cond_values, cond_offset, @@ -1413,15 +1730,19 @@ Status ExecArrayCaseWhen(KernelContext* ctx, const ExecBatch& batch, Datum* out) while (offset < batch.length) { const auto block = counter.NextAndWord(); if (block.AllSet()) { + RETURN_NOT_OK(CheckValidTransposition(values_datum, offset, block.length, + transpose_map.get())); CopyValues(values_datum, offset, block.length, out_valid, out_values, - out_offset + offset); + out_offset + offset, transpose_map.get()); BitUtil::SetBitsTo(mask, offset, block.length, false); } else if (block.popcount) { for (int64_t j = 0; j < block.length; ++j) { if (BitUtil::GetBit(mask, offset + j) && BitUtil::GetBit(cond_values, cond_offset + offset + j)) { + RETURN_NOT_OK(CheckValidTransposition(values_datum, offset + j, + /*length=*/1, transpose_map.get())); CopyValues(values_datum, offset + j, /*length=*/1, out_valid, - out_values, out_offset + offset + j); + out_values, out_offset + offset + j, transpose_map.get()); BitUtil::SetBitTo(mask, offset + j, false); } } @@ -1434,25 +1755,31 @@ Status ExecArrayCaseWhen(KernelContext* ctx, const ExecBatch& batch, Datum* out) Bitmap bitmaps[3] = {{mask, /*offset=*/0, batch.length}, {cond_values, cond_offset, batch.length}, {cond_valid, cond_offset, batch.length}}; + Status valid_transposition = Status::OK(); Bitmap::VisitWords(bitmaps, [&](std::array words) { const uint64_t word = words[0] & words[1] & words[2]; const int64_t block_length = std::min(64, batch.length - offset); if (word == std::numeric_limits::max()) { + valid_transposition &= CheckValidTransposition( + values_datum, offset, block_length, transpose_map.get()); CopyValues(values_datum, offset, block_length, out_valid, out_values, - out_offset + offset); + out_offset + offset, transpose_map.get()); BitUtil::SetBitsTo(mask, offset, block_length, false); } else if (word) { for (int64_t j = 0; j < block_length; ++j) { if (BitUtil::GetBit(mask, offset + j) && BitUtil::GetBit(cond_valid, cond_offset + offset + j) && BitUtil::GetBit(cond_values, cond_offset + offset + j)) { + valid_transposition &= CheckValidTransposition( + values_datum, offset + j, /*length=*/1, transpose_map.get()); CopyValues(values_datum, offset + j, /*length=*/1, out_valid, - out_values, out_offset + offset + j); + out_values, out_offset + offset + j, transpose_map.get()); BitUtil::SetBitTo(mask, offset + j, false); } } } }); + RETURN_NOT_OK(valid_transposition); } } if (!have_else_arg) { @@ -1483,6 +1810,18 @@ Status ExecArrayCaseWhen(KernelContext* ctx, const ExecBatch& batch, Datum* out) } offset += block.length; } + } else if (is_dictionary_type::value) { + // Check that any 'else' slots that were not overwritten are valid transpositions. + arrow::internal::SetBitRunReader reader(mask, /*offset=*/0, batch.length); + if (is_dictionary_type::value) { + ARROW_ASSIGN_OR_RAISE(transpose_map, remapper->Remap(batch.values.back())); + } + while (true) { + const auto run = reader.NextRun(); + if (run.length == 0) break; + RETURN_NOT_OK(CheckValidTransposition(batch.values.back(), run.position, run.length, + transpose_map.get())); + } } return Status::OK(); } diff --git a/cpp/src/arrow/compute/kernels/scalar_if_else_test.cc b/cpp/src/arrow/compute/kernels/scalar_if_else_test.cc index ca6b688eb3d..c8c22f5ae76 100644 --- a/cpp/src/arrow/compute/kernels/scalar_if_else_test.cc +++ b/cpp/src/arrow/compute/kernels/scalar_if_else_test.cc @@ -625,32 +625,38 @@ TYPED_TEST(TestCaseWhenNumeric, ListOfType) { } template -class TestCaseWhenInteger : public ::testing::Test {}; +class TestCaseWhenDict : public ::testing::Test {}; -TYPED_TEST_SUITE(TestCaseWhenInteger, IntegralArrowTypes); +struct JsonDict { + std::shared_ptr type; + std::string value; +}; -TYPED_TEST(TestCaseWhenInteger, DictionaryEncodingSimple) { - auto type = dictionary(default_type_instance(), utf8()); +TYPED_TEST_SUITE(TestCaseWhenDict, IntegralArrowTypes); + +TYPED_TEST(TestCaseWhenDict, Simple) { auto cond1 = ArrayFromJSON(boolean(), "[true, true, null, null]"); auto cond2 = ArrayFromJSON(boolean(), "[true, false, true, null]"); - auto dict = R"(["a", null, "bc", "def"])"; - auto values_null = DictArrayFromJSON(type, "[null, null, null, null]", dict); - auto values1 = DictArrayFromJSON(type, "[0, null, 3, 1]", dict); - auto values2 = DictArrayFromJSON(type, "[2, 1, null, 0]", dict); - - // Easy case: all arguments have the same dictionary - // CheckScalar("case_when", {MakeStruct({Datum(false)}), DictScalarFromJSON(type, "1", - // dict)}, - // DictScalarFromJSON(type, "[0, null, null, null]", dict)); - CheckScalar("case_when", {MakeStruct({cond1, cond2}), values1, values2}, - DictArrayFromJSON(type, "[0, null, null, null]", dict)); - CheckScalar("case_when", {MakeStruct({cond1, cond2}), values1, values2, values1}, - DictArrayFromJSON(type, "[0, null, null, 1]", dict)); - CheckScalar("case_when", {MakeStruct({cond1, cond2}), values_null, values2, values1}, - DictArrayFromJSON(type, "[null, null, null, 1]", dict)); + for (const auto& dict : + {JsonDict{utf8(), R"(["a", null, "bc", "def"])"}, + JsonDict{int64(), "[1, null, 2, 3]"}, + JsonDict{decimal256(3, 2), R"(["1.23", null, "3.45", "6.78"])"}}) { + auto type = dictionary(default_type_instance(), dict.type); + auto values_null = DictArrayFromJSON(type, "[null, null, null, null]", dict.value); + auto values1 = DictArrayFromJSON(type, "[0, null, 3, 1]", dict.value); + auto values2 = DictArrayFromJSON(type, "[2, 1, null, 0]", dict.value); + + // Easy case: all arguments have the same dictionary + CheckScalar("case_when", {MakeStruct({cond1, cond2}), values1, values2}, + DictArrayFromJSON(type, "[0, null, null, null]", dict.value)); + CheckScalar("case_when", {MakeStruct({cond1, cond2}), values1, values2, values1}, + DictArrayFromJSON(type, "[0, null, null, 1]", dict.value)); + CheckScalar("case_when", {MakeStruct({cond1, cond2}), values_null, values2, values1}, + DictArrayFromJSON(type, "[null, null, null, 1]", dict.value)); + } } -TYPED_TEST(TestCaseWhenInteger, DictionaryEncodingMixed) { +TYPED_TEST(TestCaseWhenDict, Mixed) { auto type = dictionary(default_type_instance(), utf8()); auto cond1 = ArrayFromJSON(boolean(), "[true, true, null, null]"); auto cond2 = ArrayFromJSON(boolean(), "[true, false, true, null]"); @@ -674,16 +680,18 @@ TYPED_TEST(TestCaseWhenInteger, DictionaryEncodingMixed) { ArrayFromJSON(utf8(), R"([null, null, null, null])")); } -TYPED_TEST(TestCaseWhenInteger, DictionaryEncodingDifferentDictionaries) { +TYPED_TEST(TestCaseWhenDict, DifferentDictionaries) { auto type = dictionary(default_type_instance(), utf8()); auto cond1 = ArrayFromJSON(boolean(), "[true, true, null, null]"); auto cond2 = ArrayFromJSON(boolean(), "[true, false, true, null]"); auto dict1 = R"(["a", null, "bc", "def"])"; auto dict2 = R"(["bc", "foo", null, "a"])"; + auto dict3 = R"(["def", "a", "a", "bc"])"; auto values1_null = DictArrayFromJSON(type, "[null, null, null, null]", dict1); auto values2_null = DictArrayFromJSON(type, "[null, null, null, null]", dict2); auto values1 = DictArrayFromJSON(type, "[0, null, 3, 1]", dict1); auto values2 = DictArrayFromJSON(type, "[2, 1, null, 0]", dict2); + auto values3 = DictArrayFromJSON(type, "[0, 1, 2, 3]", dict3); // For scalar conditions, we borrow the dictionary of the chosen output (or the first // input when outputting null) @@ -697,12 +705,108 @@ TYPED_TEST(TestCaseWhenInteger, DictionaryEncodingDifferentDictionaries) { values2_null); // For array conditions, we always borrow the dictionary of the first input + CheckScalar("case_when", {MakeStruct({cond1, cond2}), values1, values2}, + DictArrayFromJSON(type, "[0, null, null, null]", dict1)); + CheckScalar("case_when", {MakeStruct({cond1, cond2}), values1, values2, values1}, + DictArrayFromJSON(type, "[0, null, null, 1]", dict1)); // When mixing dictionaries, we try to map other dictionaries onto the first one + // Don't check the scalar cases since we don't remap dictionaries in that case + CheckScalarNonRecursive( + "case_when", + {MakeStruct({ArrayFromJSON(boolean(), "[true, true, false, false]")}), values1, + values2}, + DictArrayFromJSON(type, "[0, null, null, 2]", dict1)); + CheckScalarNonRecursive( + "case_when", + {MakeStruct({ArrayFromJSON(boolean(), "[true, true, false, false]"), + ArrayFromJSON(boolean(), "[true, false, true, false]")}), + values1, values2}, + DictArrayFromJSON(type, "[0, null, null, null]", dict1)); + CheckScalarNonRecursive( + "case_when", + {MakeStruct({ArrayFromJSON(boolean(), "[false, false, false, false]"), + ArrayFromJSON(boolean(), "[true, true, true, true]")}), + values1, values3}, + DictArrayFromJSON(type, "[3, 0, 0, 2]", dict1)); + CheckScalarNonRecursive( + "case_when", + {MakeStruct({ArrayFromJSON(boolean(), "[null, null, null, true]"), + ArrayFromJSON(boolean(), "[true, true, true, true]")}), + values1, values3}, + DictArrayFromJSON(type, "[3, 0, 0, 1]", dict1)); + CheckScalarNonRecursive( + "case_when", + { + MakeStruct({ArrayFromJSON(boolean(), "[true, true, false, false]")}), + DictScalarFromJSON(type, "0", dict1), + DictScalarFromJSON(type, "0", dict2), + }, + DictArrayFromJSON(type, "[0, 0, 2, 2]", dict1)); + CheckScalarNonRecursive( + "case_when", + { + MakeStruct({ArrayFromJSON(boolean(), "[true, true, false, false]"), + ArrayFromJSON(boolean(), "[false, false, true, true]")}), + DictScalarFromJSON(type, "0", dict1), + DictScalarFromJSON(type, "0", dict2), + }, + DictArrayFromJSON(type, "[0, 0, 2, 2]", dict1)); // If we can't map values from a dictionary, then raise an error + // Unmappable value is in the else clause + EXPECT_RAISES_WITH_MESSAGE_THAT( + Invalid, + ::testing::HasSubstr( + "Cannot map dictionary index 1 at position 1 to the common dictionary"), + CallFunction( + "case_when", + {MakeStruct({ArrayFromJSON(boolean(), "[false, false, false, false]")}), + values1, values2})); + EXPECT_RAISES_WITH_MESSAGE_THAT( + Invalid, + ::testing::HasSubstr("Cannot map dictionary index 1 to the common dictionary"), + CallFunction( + "case_when", + {MakeStruct({ArrayFromJSON(boolean(), "[false, false, false, false]")}), + values1, DictScalarFromJSON(type, "1", dict2)})); + // Unmappable value is in a branch (test multiple times to ensure coverage of branches + // in impl) + EXPECT_RAISES_WITH_MESSAGE_THAT( + Invalid, + ::testing::HasSubstr( + "Cannot map dictionary index 1 at position 1 to the common dictionary"), + CallFunction("case_when", + {MakeStruct({Datum(false), + ArrayFromJSON(boolean(), "[true, true, true, true]")}), + values1, values2})); + EXPECT_RAISES_WITH_MESSAGE_THAT( + Invalid, + ::testing::HasSubstr( + "Cannot map dictionary index 1 at position 1 to the common dictionary"), + CallFunction("case_when", + {MakeStruct({Datum(false), + ArrayFromJSON(boolean(), "[false, true, false, false]")}), + values1, values2})); + EXPECT_RAISES_WITH_MESSAGE_THAT( + Invalid, + ::testing::HasSubstr( + "Cannot map dictionary index 1 at position 1 to the common dictionary"), + CallFunction("case_when", + {MakeStruct({Datum(false), + ArrayFromJSON(boolean(), "[null, true, null, null]")}), + values1, values2})); + EXPECT_RAISES_WITH_MESSAGE_THAT( + Invalid, + ::testing::HasSubstr("Cannot map dictionary index 1 to the common dictionary"), + CallFunction("case_when", + {MakeStruct({Datum(false), + ArrayFromJSON(boolean(), "[true, true, true, null]")}), + values1, DictScalarFromJSON(type, "1", dict2)})); + + // ...or optionally, emit null - // ...or optionally, raise an error + // TODO: this is not implemented yet } TEST(TestCaseWhen, Null) { diff --git a/cpp/src/arrow/compute/kernels/test_util.cc b/cpp/src/arrow/compute/kernels/test_util.cc index 4a9215101b1..9a779e49163 100644 --- a/cpp/src/arrow/compute/kernels/test_util.cc +++ b/cpp/src/arrow/compute/kernels/test_util.cc @@ -46,13 +46,6 @@ DatumVector GetDatums(const std::vector& inputs) { return datums; } -void CheckScalarNonRecursive(const std::string& func_name, const DatumVector& inputs, - const Datum& expected, const FunctionOptions* options) { - ASSERT_OK_AND_ASSIGN(Datum out, CallFunction(func_name, inputs, options)); - ValidateOutput(out); - AssertDatumsEqual(expected, out, /*verbose=*/true); -} - template DatumVector SliceArrays(const DatumVector& inputs, SliceArgs... slice_args) { DatumVector sliced; @@ -80,6 +73,13 @@ ScalarVector GetScalars(const DatumVector& inputs, int64_t index) { } // namespace +void CheckScalarNonRecursive(const std::string& func_name, const DatumVector& inputs, + const Datum& expected, const FunctionOptions* options) { + ASSERT_OK_AND_ASSIGN(Datum out, CallFunction(func_name, inputs, options)); + ValidateOutput(out); + AssertDatumsEqual(expected, out, /*verbose=*/true); +} + void CheckScalar(std::string func_name, const ScalarVector& inputs, std::shared_ptr expected, const FunctionOptions* options) { ASSERT_OK_AND_ASSIGN(Datum out, CallFunction(func_name, GetDatums(inputs), options)); diff --git a/cpp/src/arrow/compute/kernels/test_util.h b/cpp/src/arrow/compute/kernels/test_util.h index 79745b05552..0931f3c77bc 100644 --- a/cpp/src/arrow/compute/kernels/test_util.h +++ b/cpp/src/arrow/compute/kernels/test_util.h @@ -67,6 +67,8 @@ inline std::string CompareOperatorToFunctionName(CompareOperator op) { return function_names[op]; } +// Call the function with the given arguments, as well as slices of +// the arguments and scalars extracted from the arguments. void CheckScalar(std::string func_name, const ScalarVector& inputs, std::shared_ptr expected, const FunctionOptions* options = nullptr); @@ -74,6 +76,11 @@ void CheckScalar(std::string func_name, const ScalarVector& inputs, void CheckScalar(std::string func_name, const DatumVector& inputs, Datum expected, const FunctionOptions* options = nullptr); +// Just call the function with the given arguments. +void CheckScalarNonRecursive(const std::string& func_name, const DatumVector& inputs, + const Datum& expected, + const FunctionOptions* options = nullptr); + void CheckScalarUnary(std::string func_name, std::shared_ptr in_ty, std::string json_input, std::shared_ptr out_ty, std::string json_expected, From 17230ee41b6c8080dff25f4a1e125a4120fd4ed0 Mon Sep 17 00:00:00 2001 From: David Li Date: Mon, 30 Aug 2021 15:15:25 -0400 Subject: [PATCH 05/21] ARROW-13573: [C++] Handle nested dictionaries --- cpp/src/arrow/array/array_test.cc | 12 +- cpp/src/arrow/array/builder_base.cc | 10 +- cpp/src/arrow/array/builder_base.h | 13 +- cpp/src/arrow/array/builder_dict.cc | 38 ++- cpp/src/arrow/array/builder_dict.h | 167 +++++++++++ cpp/src/arrow/builder.cc | 269 ++++++++++-------- .../arrow/compute/kernels/scalar_if_else.cc | 2 +- .../compute/kernels/scalar_if_else_test.cc | 59 ++++ 8 files changed, 410 insertions(+), 160 deletions(-) diff --git a/cpp/src/arrow/array/array_test.cc b/cpp/src/arrow/array/array_test.cc index d9617c4e603..2e3d4057094 100644 --- a/cpp/src/arrow/array/array_test.cc +++ b/cpp/src/arrow/array/array_test.cc @@ -456,7 +456,7 @@ TEST_F(TestArray, TestValidateNullCount) { void AssertAppendScalar(MemoryPool* pool, const std::shared_ptr& scalar) { std::unique_ptr builder; auto null_scalar = MakeNullScalar(scalar->type); - ASSERT_OK(MakeBuilder(pool, scalar->type, &builder)); + ASSERT_OK(MakeBuilderExactIndex(pool, scalar->type, &builder)); ASSERT_OK(builder->AppendScalar(*scalar)); ASSERT_OK(builder->AppendScalar(*scalar)); ASSERT_OK(builder->AppendScalar(*null_scalar)); @@ -471,15 +471,18 @@ void AssertAppendScalar(MemoryPool* pool, const std::shared_ptr& scalar) ASSERT_EQ(out->length(), 9); const bool can_check_nulls = internal::HasValidityBitmap(out->type()->id()); + // For a dictionary builder, the output dictionary won't necessarily be the same + const bool can_check_values = !is_dictionary(out->type()->id()); if (can_check_nulls) { ASSERT_EQ(out->null_count(), 4); } + for (const auto index : {0, 1, 3, 5, 6}) { ASSERT_FALSE(out->IsNull(index)); ASSERT_OK_AND_ASSIGN(auto scalar_i, out->GetScalar(index)); ASSERT_OK(scalar_i->ValidateFull()); - AssertScalarsEqual(*scalar, *scalar_i, /*verbose=*/true); + if (can_check_values) AssertScalarsEqual(*scalar, *scalar_i, /*verbose=*/true); } for (const auto index : {2, 4, 7, 8}) { ASSERT_EQ(out->IsNull(index), can_check_nulls); @@ -575,8 +578,6 @@ TEST_F(TestArray, TestMakeArrayFromScalar) { } for (auto scalar : scalars) { - // TODO(ARROW-13197): appending dictionary scalars not implemented - if (is_dictionary(scalar->type->id())) continue; AssertAppendScalar(pool_, scalar); } } @@ -634,9 +635,6 @@ TEST_F(TestArray, TestMakeArrayFromMapScalar) { TEST_F(TestArray, TestAppendArraySlice) { auto scalars = GetScalars(); for (const auto& scalar : scalars) { - // TODO(ARROW-13573): appending dictionary arrays not implemented - if (is_dictionary(scalar->type->id())) continue; - ARROW_SCOPED_TRACE(*scalar->type); ASSERT_OK_AND_ASSIGN(auto array, MakeArrayFromScalar(*scalar, 16)); ASSERT_OK_AND_ASSIGN(auto nulls, MakeArrayOfNull(scalar->type, 16)); diff --git a/cpp/src/arrow/array/builder_base.cc b/cpp/src/arrow/array/builder_base.cc index 2f4e63b546d..117b9d37632 100644 --- a/cpp/src/arrow/array/builder_base.cc +++ b/cpp/src/arrow/array/builder_base.cc @@ -22,6 +22,7 @@ #include #include "arrow/array/array_base.h" +#include "arrow/array/builder_dict.h" #include "arrow/array/data.h" #include "arrow/array/util.h" #include "arrow/buffer.h" @@ -268,15 +269,6 @@ struct AppendScalarImpl { } // namespace -Status ArrayBuilder::AppendScalar(const Scalar& scalar) { - if (!scalar.type->Equals(type())) { - return Status::Invalid("Cannot append scalar of type ", scalar.type->ToString(), - " to builder for type ", type()->ToString()); - } - std::shared_ptr shared{const_cast(&scalar), [](Scalar*) {}}; - return AppendScalarImpl{&shared, &shared + 1, /*n_repeats=*/1, this}.Convert(); -} - Status ArrayBuilder::AppendScalar(const Scalar& scalar, int64_t n_repeats) { if (!scalar.type->Equals(type())) { return Status::Invalid("Cannot append scalar of type ", scalar.type->ToString(), diff --git a/cpp/src/arrow/array/builder_base.h b/cpp/src/arrow/array/builder_base.h index 67203e79071..87e39c3fe9f 100644 --- a/cpp/src/arrow/array/builder_base.h +++ b/cpp/src/arrow/array/builder_base.h @@ -119,9 +119,9 @@ class ARROW_EXPORT ArrayBuilder { virtual Status AppendEmptyValues(int64_t length) = 0; /// \brief Append a value from a scalar - Status AppendScalar(const Scalar& scalar); - Status AppendScalar(const Scalar& scalar, int64_t n_repeats); - Status AppendScalars(const ScalarVector& scalars); + Status AppendScalar(const Scalar& scalar) { return AppendScalar(scalar, 1); } + virtual Status AppendScalar(const Scalar& scalar, int64_t n_repeats); + virtual Status AppendScalars(const ScalarVector& scalars); /// \brief Append a range of values from an array. /// @@ -282,6 +282,13 @@ ARROW_EXPORT Status MakeBuilder(MemoryPool* pool, const std::shared_ptr& type, std::unique_ptr* out); +/// \brief Construct an empty ArrayBuilder corresponding to the data +/// type, where any top-level or nested dictionary builders return the +/// exact index type specified by the type. +ARROW_EXPORT +Status MakeBuilderExactIndex(MemoryPool* pool, const std::shared_ptr& type, + std::unique_ptr* out); + /// \brief Construct an empty DictionaryBuilder initialized optionally /// with a pre-existing dictionary /// \param[in] pool the MemoryPool to use for allocations diff --git a/cpp/src/arrow/array/builder_dict.cc b/cpp/src/arrow/array/builder_dict.cc index b13f6a2db34..b554c1d7099 100644 --- a/cpp/src/arrow/array/builder_dict.cc +++ b/cpp/src/arrow/array/builder_dict.cc @@ -159,23 +159,31 @@ DictionaryMemoTable::DictionaryMemoTable(MemoryPool* pool, DictionaryMemoTable::~DictionaryMemoTable() = default; -#define GET_OR_INSERT(C_TYPE) \ - Status DictionaryMemoTable::GetOrInsert( \ - const typename CTypeTraits::ArrowType*, C_TYPE value, int32_t* out) { \ - return impl_->GetOrInsert::ArrowType>(value, out); \ +#define GET_OR_INSERT(ARROW_TYPE) \ + Status DictionaryMemoTable::GetOrInsert( \ + const ARROW_TYPE*, typename ARROW_TYPE::c_type value, int32_t* out) { \ + return impl_->GetOrInsert(value, out); \ } -GET_OR_INSERT(bool) -GET_OR_INSERT(int8_t) -GET_OR_INSERT(int16_t) -GET_OR_INSERT(int32_t) -GET_OR_INSERT(int64_t) -GET_OR_INSERT(uint8_t) -GET_OR_INSERT(uint16_t) -GET_OR_INSERT(uint32_t) -GET_OR_INSERT(uint64_t) -GET_OR_INSERT(float) -GET_OR_INSERT(double) +GET_OR_INSERT(BooleanType) +GET_OR_INSERT(Int8Type) +GET_OR_INSERT(Int16Type) +GET_OR_INSERT(Int32Type) +GET_OR_INSERT(Int64Type) +GET_OR_INSERT(UInt8Type) +GET_OR_INSERT(UInt16Type) +GET_OR_INSERT(UInt32Type) +GET_OR_INSERT(UInt64Type) +GET_OR_INSERT(FloatType) +GET_OR_INSERT(DoubleType) +GET_OR_INSERT(DurationType); +GET_OR_INSERT(TimestampType); +GET_OR_INSERT(Date32Type); +GET_OR_INSERT(Date64Type); +GET_OR_INSERT(Time32Type); +GET_OR_INSERT(Time64Type); +GET_OR_INSERT(DayTimeIntervalType); +GET_OR_INSERT(MonthIntervalType); #undef GET_OR_INSERT diff --git a/cpp/src/arrow/array/builder_dict.h b/cpp/src/arrow/array/builder_dict.h index 455cb3df7b1..1b97f98f290 100644 --- a/cpp/src/arrow/array/builder_dict.h +++ b/cpp/src/arrow/array/builder_dict.h @@ -37,6 +37,7 @@ #include "arrow/util/decimal.h" #include "arrow/util/macros.h" #include "arrow/util/visibility.h" +#include "arrow/visitor_inline.h" namespace arrow { @@ -97,6 +98,15 @@ class ARROW_EXPORT DictionaryMemoTable { Status GetOrInsert(const UInt16Type*, uint16_t value, int32_t* out); Status GetOrInsert(const UInt32Type*, uint32_t value, int32_t* out); Status GetOrInsert(const UInt64Type*, uint64_t value, int32_t* out); + Status GetOrInsert(const DurationType*, int64_t value, int32_t* out); + Status GetOrInsert(const TimestampType*, int64_t value, int32_t* out); + Status GetOrInsert(const Date32Type*, int32_t value, int32_t* out); + Status GetOrInsert(const Date64Type*, int64_t value, int32_t* out); + Status GetOrInsert(const Time32Type*, int32_t value, int32_t* out); + Status GetOrInsert(const Time64Type*, int64_t value, int32_t* out); + Status GetOrInsert(const DayTimeIntervalType*, + DayTimeIntervalType::DayMilliseconds value, int32_t* out); + Status GetOrInsert(const MonthIntervalType*, int32_t value, int32_t* out); Status GetOrInsert(const FloatType*, float value, int32_t* out); Status GetOrInsert(const DoubleType*, double value, int32_t* out); @@ -282,6 +292,163 @@ class DictionaryBuilderBase : public ArrayBuilder { return indices_builder_.AppendEmptyValues(length); } + Status AppendScalar(const Scalar& scalar, int64_t n_repeats) override { + if (!scalar.type->Equals(type())) { + return Status::Invalid("Cannot append scalar of type ", scalar.type->ToString(), + " to builder for type ", type()->ToString()); + } + if (!scalar.is_valid) return AppendNulls(n_repeats); + + const auto& dict_ty = internal::checked_cast(*scalar.type); + const DictionaryScalar& dict_scalar = + internal::checked_cast(scalar); + const auto& dict = internal::checked_cast::ArrayType&>( + *dict_scalar.value.dictionary); + switch (dict_ty.index_type()->id()) { + case Type::UINT8: { + const auto& value = dict.GetView( + internal::checked_cast(*dict_scalar.value.index).value); + for (int64_t i = 0; i < n_repeats; i++) { + ARROW_RETURN_NOT_OK(Append(value)); + } + break; + } + case Type::INT8: { + const auto& value = dict.GetView( + internal::checked_cast(*dict_scalar.value.index).value); + for (int64_t i = 0; i < n_repeats; i++) { + ARROW_RETURN_NOT_OK(Append(value)); + } + break; + } + case Type::UINT16: { + const auto& value = dict.GetView( + internal::checked_cast(*dict_scalar.value.index).value); + for (int64_t i = 0; i < n_repeats; i++) { + ARROW_RETURN_NOT_OK(Append(value)); + } + break; + } + case Type::INT16: { + const auto& value = dict.GetView( + internal::checked_cast(*dict_scalar.value.index).value); + for (int64_t i = 0; i < n_repeats; i++) { + ARROW_RETURN_NOT_OK(Append(value)); + } + break; + } + case Type::UINT32: { + const auto& value = dict.GetView( + internal::checked_cast(*dict_scalar.value.index).value); + for (int64_t i = 0; i < n_repeats; i++) { + ARROW_RETURN_NOT_OK(Append(value)); + } + break; + } + case Type::INT32: { + const auto& value = dict.GetView( + internal::checked_cast(*dict_scalar.value.index).value); + for (int64_t i = 0; i < n_repeats; i++) { + ARROW_RETURN_NOT_OK(Append(value)); + } + break; + } + case Type::UINT64: { + const auto& value = dict.GetView( + internal::checked_cast(*dict_scalar.value.index).value); + for (int64_t i = 0; i < n_repeats; i++) { + ARROW_RETURN_NOT_OK(Append(value)); + } + break; + } + case Type::INT64: { + const auto& value = dict.GetView( + internal::checked_cast(*dict_scalar.value.index).value); + for (int64_t i = 0; i < n_repeats; i++) { + ARROW_RETURN_NOT_OK(Append(value)); + } + break; + } + default: + return Status::TypeError("Invalid index type: ", dict_ty); + } + return Status::OK(); + } + + Status AppendScalars(const ScalarVector& scalars) override { + for (const auto& scalar : scalars) { + ARROW_RETURN_NOT_OK(AppendScalar(*scalar, /*n_repeats=*/1)); + } + return Status::OK(); + } + + Status AppendArraySlice(const ArrayData& array, int64_t offset, int64_t length) final { + // Visit the indices and insert the unpacked values. + const auto& dict_ty = internal::checked_cast(*array.type); + const typename TypeTraits::ArrayType dict(array.dictionary); + switch (dict_ty.index_type()->id()) { + case Type::UINT8: { + const uint8_t* values = array.GetValues(1) + offset; + return VisitBitBlocks( + array.buffers[0], array.offset + offset, std::min(array.length, length), + [&](int64_t position) { return Append(dict.GetView(values[position])); }, + [&]() { return AppendNull(); }); + } + case Type::INT8: { + const int8_t* values = array.GetValues(1) + offset; + return VisitBitBlocks( + array.buffers[0], array.offset + offset, std::min(array.length, length), + [&](int64_t position) { return Append(dict.GetView(values[position])); }, + [&]() { return AppendNull(); }); + } + case Type::UINT16: { + const uint16_t* values = array.GetValues(1) + offset; + return VisitBitBlocks( + array.buffers[0], array.offset + offset, std::min(array.length, length), + [&](int64_t position) { return Append(dict.GetView(values[position])); }, + [&]() { return AppendNull(); }); + } + case Type::INT16: { + const int16_t* values = array.GetValues(1) + offset; + return VisitBitBlocks( + array.buffers[0], array.offset + offset, std::min(array.length, length), + [&](int64_t position) { return Append(dict.GetView(values[position])); }, + [&]() { return AppendNull(); }); + } + case Type::UINT32: { + const uint32_t* values = array.GetValues(1) + offset; + return VisitBitBlocks( + array.buffers[0], array.offset + offset, std::min(array.length, length), + [&](int64_t position) { return Append(dict.GetView(values[position])); }, + [&]() { return AppendNull(); }); + } + case Type::INT32: { + const int32_t* values = array.GetValues(1) + offset; + return VisitBitBlocks( + array.buffers[0], array.offset + offset, std::min(array.length, length), + [&](int64_t position) { return Append(dict.GetView(values[position])); }, + [&]() { return AppendNull(); }); + } + case Type::UINT64: { + const uint64_t* values = array.GetValues(1) + offset; + return VisitBitBlocks( + array.buffers[0], array.offset + offset, std::min(array.length, length), + [&](int64_t position) { return Append(dict.GetView(values[position])); }, + [&]() { return AppendNull(); }); + } + case Type::INT64: { + const int64_t* values = array.GetValues(1) + offset; + return VisitBitBlocks( + array.buffers[0], array.offset + offset, std::min(array.length, length), + [&](int64_t position) { return Append(dict.GetView(values[position])); }, + [&]() { return AppendNull(); }); + } + default: + return Status::TypeError("Invalid index type: ", dict_ty); + } + return Status::OK(); + } + /// \brief Insert values into the dictionary's memo, but do not append any /// indices. Can be used to initialize a new builder with known dictionary /// values diff --git a/cpp/src/arrow/builder.cc b/cpp/src/arrow/builder.cc index 37cc9e07ad4..115a97e9389 100644 --- a/cpp/src/arrow/builder.cc +++ b/cpp/src/arrow/builder.cc @@ -41,14 +41,10 @@ struct DictionaryBuilderCase { } Status Visit(const NullType&) { return CreateFor(); } - Status Visit(const BinaryType&) { return Create(); } - Status Visit(const StringType&) { return Create(); } - Status Visit(const LargeBinaryType&) { - return Create>(); - } - Status Visit(const LargeStringType&) { - return Create>(); - } + Status Visit(const BinaryType&) { return CreateFor(); } + Status Visit(const StringType&) { return CreateFor(); } + Status Visit(const LargeBinaryType&) { return CreateFor(); } + Status Visit(const LargeStringType&) { return CreateFor(); } Status Visit(const FixedSizeBinaryType&) { return CreateFor(); } Status Visit(const Decimal128Type&) { return CreateFor(); } Status Visit(const Decimal256Type&) { return CreateFor(); } @@ -63,19 +59,50 @@ struct DictionaryBuilderCase { template Status CreateFor() { - return Create>(); - } - - template - Status Create() { - BuilderType* builder; + using AdaptiveBuilderType = DictionaryBuilder; if (dictionary != nullptr) { - builder = new BuilderType(dictionary, pool); + out->reset(new AdaptiveBuilderType(dictionary, pool)); + } else if (exact_index_type) { + switch (index_type->id()) { + case Type::UINT8: + out->reset(new internal::DictionaryBuilderBase( + value_type, pool)); + break; + case Type::INT8: + out->reset(new internal::DictionaryBuilderBase( + value_type, pool)); + break; + case Type::UINT16: + out->reset(new internal::DictionaryBuilderBase( + value_type, pool)); + break; + case Type::INT16: + out->reset(new internal::DictionaryBuilderBase( + value_type, pool)); + break; + case Type::UINT32: + out->reset(new internal::DictionaryBuilderBase( + value_type, pool)); + break; + case Type::INT32: + out->reset(new internal::DictionaryBuilderBase( + value_type, pool)); + break; + case Type::UINT64: + out->reset(new internal::DictionaryBuilderBase( + value_type, pool)); + break; + case Type::INT64: + out->reset(new internal::DictionaryBuilderBase( + value_type, pool)); + break; + default: + return Status::TypeError("MakeBuilder: invalid index type ", *index_type); + } } else { auto start_int_size = internal::GetByteWidth(*index_type); - builder = new BuilderType(start_int_size, value_type, pool); + out->reset(new AdaptiveBuilderType(start_int_size, value_type, pool)); } - out->reset(builder); return Status::OK(); } @@ -85,138 +112,130 @@ struct DictionaryBuilderCase { const std::shared_ptr& index_type; const std::shared_ptr& value_type; const std::shared_ptr& dictionary; + bool exact_index_type; std::unique_ptr* out; }; -#define BUILDER_CASE(TYPE_CLASS) \ - case TYPE_CLASS##Type::type_id: \ - out->reset(new TYPE_CLASS##Builder(type, pool)); \ +struct MakeBuilderImpl { + template + enable_if_not_nested Visit(const T&) { + out.reset(new typename TypeTraits::BuilderType(type, pool)); return Status::OK(); + } -Result>> FieldBuilders(const DataType& type, - MemoryPool* pool) { - std::vector> field_builders; + Status Visit(const DictionaryType& dict_type) { + DictionaryBuilderCase visitor = {pool, + dict_type.index_type(), + dict_type.value_type(), + /*dictionary=*/nullptr, + exact_index_type, + &out}; + return visitor.Make(); + } - for (const auto& field : type.fields()) { - std::unique_ptr builder; - RETURN_NOT_OK(MakeBuilder(pool, field->type(), &builder)); - field_builders.emplace_back(std::move(builder)); + Status Visit(const ListType& list_type) { + std::shared_ptr value_type = list_type.value_type(); + ARROW_ASSIGN_OR_RAISE(auto value_builder, ChildBuilder(value_type)); + out.reset(new ListBuilder(pool, std::move(value_builder), type)); + return Status::OK(); } - return field_builders; -} + Status Visit(const LargeListType& list_type) { + std::shared_ptr value_type = list_type.value_type(); + ARROW_ASSIGN_OR_RAISE(auto value_builder, ChildBuilder(value_type)); + out.reset(new LargeListBuilder(pool, std::move(value_builder), type)); + return Status::OK(); + } -Status MakeBuilder(MemoryPool* pool, const std::shared_ptr& type, - std::unique_ptr* out) { - switch (type->id()) { - case Type::NA: { - out->reset(new NullBuilder(pool)); - return Status::OK(); - } - BUILDER_CASE(UInt8); - BUILDER_CASE(Int8); - BUILDER_CASE(UInt16); - BUILDER_CASE(Int16); - BUILDER_CASE(UInt32); - BUILDER_CASE(Int32); - BUILDER_CASE(UInt64); - BUILDER_CASE(Int64); - BUILDER_CASE(Date32); - BUILDER_CASE(Date64); - BUILDER_CASE(Duration); - BUILDER_CASE(Time32); - BUILDER_CASE(Time64); - BUILDER_CASE(Timestamp); - BUILDER_CASE(MonthInterval); - BUILDER_CASE(DayTimeInterval); - BUILDER_CASE(MonthDayNanoInterval); - BUILDER_CASE(Boolean); - BUILDER_CASE(HalfFloat); - BUILDER_CASE(Float); - BUILDER_CASE(Double); - BUILDER_CASE(String); - BUILDER_CASE(Binary); - BUILDER_CASE(LargeString); - BUILDER_CASE(LargeBinary); - BUILDER_CASE(FixedSizeBinary); - BUILDER_CASE(Decimal128); - BUILDER_CASE(Decimal256); - - case Type::DICTIONARY: { - const auto& dict_type = static_cast(*type); - DictionaryBuilderCase visitor = {pool, dict_type.index_type(), - dict_type.value_type(), nullptr, out}; - return visitor.Make(); - } + Status Visit(const MapType& map_type) { + ARROW_ASSIGN_OR_RAISE(auto key_builder, ChildBuilder(map_type.key_type())); + ARROW_ASSIGN_OR_RAISE(auto item_builder, ChildBuilder(map_type.item_type())); + out.reset( + new MapBuilder(pool, std::move(key_builder), std::move(item_builder), type)); + return Status::OK(); + } - case Type::LIST: { - std::unique_ptr value_builder; - std::shared_ptr value_type = - internal::checked_cast(*type).value_type(); - RETURN_NOT_OK(MakeBuilder(pool, value_type, &value_builder)); - out->reset(new ListBuilder(pool, std::move(value_builder), type)); - return Status::OK(); - } + Status Visit(const FixedSizeListType& list_type) { + auto value_type = list_type.value_type(); + ARROW_ASSIGN_OR_RAISE(auto value_builder, ChildBuilder(value_type)); + out.reset(new FixedSizeListBuilder(pool, std::move(value_builder), type)); + return Status::OK(); + } - case Type::LARGE_LIST: { - std::unique_ptr value_builder; - std::shared_ptr value_type = - internal::checked_cast(*type).value_type(); - RETURN_NOT_OK(MakeBuilder(pool, value_type, &value_builder)); - out->reset(new LargeListBuilder(pool, std::move(value_builder), type)); - return Status::OK(); - } + Status Visit(const StructType& struct_type) { + ARROW_ASSIGN_OR_RAISE(auto field_builders, FieldBuilders(*type, pool)); + out.reset(new StructBuilder(type, pool, std::move(field_builders))); + return Status::OK(); + } - case Type::MAP: { - const auto& map_type = internal::checked_cast(*type); - std::unique_ptr key_builder, item_builder; - RETURN_NOT_OK(MakeBuilder(pool, map_type.key_type(), &key_builder)); - RETURN_NOT_OK(MakeBuilder(pool, map_type.item_type(), &item_builder)); - out->reset( - new MapBuilder(pool, std::move(key_builder), std::move(item_builder), type)); - return Status::OK(); - } + Status Visit(const SparseUnionType&) { + ARROW_ASSIGN_OR_RAISE(auto field_builders, FieldBuilders(*type, pool)); + out.reset(new SparseUnionBuilder(pool, std::move(field_builders), type)); + return Status::OK(); + } - case Type::FIXED_SIZE_LIST: { - const auto& list_type = internal::checked_cast(*type); - std::unique_ptr value_builder; - auto value_type = list_type.value_type(); - RETURN_NOT_OK(MakeBuilder(pool, value_type, &value_builder)); - out->reset(new FixedSizeListBuilder(pool, std::move(value_builder), type)); - return Status::OK(); - } + Status Visit(const DenseUnionType&) { + ARROW_ASSIGN_OR_RAISE(auto field_builders, FieldBuilders(*type, pool)); + out.reset(new DenseUnionBuilder(pool, std::move(field_builders), type)); + return Status::OK(); + } - case Type::STRUCT: { - ARROW_ASSIGN_OR_RAISE(auto field_builders, FieldBuilders(*type, pool)); - out->reset(new StructBuilder(type, pool, std::move(field_builders))); - return Status::OK(); - } + Status Visit(const ExtensionType&) { return NotImplemented(); } + Status Visit(const DataType&) { return NotImplemented(); } - case Type::SPARSE_UNION: { - ARROW_ASSIGN_OR_RAISE(auto field_builders, FieldBuilders(*type, pool)); - out->reset(new SparseUnionBuilder(pool, std::move(field_builders), type)); - return Status::OK(); - } + Status NotImplemented() { + return Status::NotImplemented("MakeBuilder: cannot construct builder for type ", + type->ToString()); + } - case Type::DENSE_UNION: { - ARROW_ASSIGN_OR_RAISE(auto field_builders, FieldBuilders(*type, pool)); - out->reset(new DenseUnionBuilder(pool, std::move(field_builders), type)); - return Status::OK(); - } + Result> ChildBuilder( + const std::shared_ptr& type) { + MakeBuilderImpl impl{pool, type, exact_index_type, /*out=*/nullptr}; + RETURN_NOT_OK(VisitTypeInline(*type, &impl)); + return std::move(impl.out); + } - default: - break; + Result>> FieldBuilders(const DataType& type, + MemoryPool* pool) { + std::vector> field_builders; + for (const auto& field : type.fields()) { + std::unique_ptr builder; + MakeBuilderImpl impl{pool, field->type(), exact_index_type, /*out=*/nullptr}; + RETURN_NOT_OK(VisitTypeInline(*field->type(), &impl)); + field_builders.emplace_back(std::move(impl.out)); + } + return field_builders; } - return Status::NotImplemented("MakeBuilder: cannot construct builder for type ", - type->ToString()); + + MemoryPool* pool; + const std::shared_ptr& type; + bool exact_index_type; + std::unique_ptr out; +}; + +Status MakeBuilder(MemoryPool* pool, const std::shared_ptr& type, + std::unique_ptr* out) { + MakeBuilderImpl impl{pool, type, /*exact_index_type=*/false, /*out=*/nullptr}; + RETURN_NOT_OK(VisitTypeInline(*type, &impl)); + *out = std::move(impl.out); + return Status::OK(); +} + +Status MakeBuilderExactIndex(MemoryPool* pool, const std::shared_ptr& type, + std::unique_ptr* out) { + MakeBuilderImpl impl{pool, type, /*exact_index_type=*/true, /*out=*/nullptr}; + RETURN_NOT_OK(VisitTypeInline(*type, &impl)); + *out = std::move(impl.out); + return Status::OK(); } Status MakeDictionaryBuilder(MemoryPool* pool, const std::shared_ptr& type, const std::shared_ptr& dictionary, std::unique_ptr* out) { const auto& dict_type = static_cast(*type); - DictionaryBuilderCase visitor = {pool, dict_type.index_type(), dict_type.value_type(), - dictionary, out}; + DictionaryBuilderCase visitor = { + pool, dict_type.index_type(), dict_type.value_type(), + dictionary, /*exact_index_type=*/false, out}; return visitor.Make(); } diff --git a/cpp/src/arrow/compute/kernels/scalar_if_else.cc b/cpp/src/arrow/compute/kernels/scalar_if_else.cc index 6a2ca01c439..ff4df63b966 100644 --- a/cpp/src/arrow/compute/kernels/scalar_if_else.cc +++ b/cpp/src/arrow/compute/kernels/scalar_if_else.cc @@ -1891,7 +1891,7 @@ static Status ExecVarWidthArrayCaseWhenImpl( const bool have_else_arg = static_cast(conds_array.type->num_fields()) < (batch.values.size() - 1); std::unique_ptr raw_builder; - RETURN_NOT_OK(MakeBuilder(ctx->memory_pool(), out->type(), &raw_builder)); + RETURN_NOT_OK(MakeBuilderExactIndex(ctx->memory_pool(), out->type(), &raw_builder)); RETURN_NOT_OK(raw_builder->Reserve(batch.length)); RETURN_NOT_OK(reserve_data(raw_builder.get())); diff --git a/cpp/src/arrow/compute/kernels/scalar_if_else_test.cc b/cpp/src/arrow/compute/kernels/scalar_if_else_test.cc index c8c22f5ae76..38be7eb8780 100644 --- a/cpp/src/arrow/compute/kernels/scalar_if_else_test.cc +++ b/cpp/src/arrow/compute/kernels/scalar_if_else_test.cc @@ -680,6 +680,65 @@ TYPED_TEST(TestCaseWhenDict, Mixed) { ArrayFromJSON(utf8(), R"([null, null, null, null])")); } +TYPED_TEST(TestCaseWhenDict, NestedSimple) { + auto make_list = [](const std::shared_ptr& indices, + const std::shared_ptr& backing_array) { + EXPECT_OK_AND_ASSIGN(auto result, ListArray::FromArrays(*indices, *backing_array)); + return result; + }; + auto index_type = default_type_instance(); + auto inner_type = dictionary(index_type, utf8()); + auto type = list(inner_type); + auto cond1 = ArrayFromJSON(boolean(), "[true, true, null, null]"); + auto cond2 = ArrayFromJSON(boolean(), "[true, false, true, null]"); + auto dict = R"(["a", "b", "bc", "def"])"; + auto values_null = make_list(ArrayFromJSON(int32(), "[null, null, null, null, 0]"), + DictArrayFromJSON(inner_type, "[]", dict)); + auto values1_backing = DictArrayFromJSON(inner_type, "[0, null, 3, 1]", dict); + auto values2_backing = DictArrayFromJSON(inner_type, "[2, 1, null, 0]", dict); + auto values1 = make_list(ArrayFromJSON(int32(), "[0, 2, 2, 3, 4]"), values1_backing); + auto values2 = make_list(ArrayFromJSON(int32(), "[0, 1, 2, 2, 4]"), values2_backing); + + CheckScalarNonRecursive( + "case_when", {MakeStruct({cond1, cond2}), values1, values2}, + make_list(ArrayFromJSON(int32(), "[0, 2, 2, null, 2]"), + DictArrayFromJSON(inner_type, "[0, null]", R"(["a"])"))); + CheckScalarNonRecursive( + "case_when", + {MakeStruct({cond1, cond2}), values1, + make_list(ArrayFromJSON(int32(), "[0, 1, null, 2, 4]"), values2_backing)}, + make_list(ArrayFromJSON(int32(), "[0, 2, null, null, 2]"), + DictArrayFromJSON(inner_type, "[0, null]", R"(["a"])"))); + CheckScalarNonRecursive( + "case_when", + {MakeStruct({cond1, cond2}), values1, + make_list(ArrayFromJSON(int32(), "[0, 1, null, 2, 4]"), values2_backing), values1}, + make_list(ArrayFromJSON(int32(), "[0, 2, null, 2, 3]"), + DictArrayFromJSON(inner_type, "[0, null, 1]", R"(["a", "b"])"))); + + CheckScalarNonRecursive( + "case_when", + { + Datum(MakeStruct({cond1, cond2})), + Datum(std::make_shared( + DictArrayFromJSON(inner_type, "[0, 1]", dict))), + Datum(std::make_shared( + DictArrayFromJSON(inner_type, "[2, 3]", dict))), + }, + make_list(ArrayFromJSON(int32(), "[0, 2, 4, null, 6]"), + DictArrayFromJSON(inner_type, "[0, 1, 0, 1, 2, 3]", dict))); + + CheckScalarNonRecursive( + "case_when", {MakeStruct({Datum(true), Datum(false)}), values1, values2}, values1); + CheckScalarNonRecursive( + "case_when", {MakeStruct({Datum(false), Datum(true)}), values1, values2}, values2); + CheckScalarNonRecursive("case_when", {MakeStruct({Datum(false)}), values1, values2}, + values2); + CheckScalarNonRecursive("case_when", + {MakeStruct({Datum(false), Datum(false)}), values1, values2}, + values_null); +} + TYPED_TEST(TestCaseWhenDict, DifferentDictionaries) { auto type = dictionary(default_type_instance(), utf8()); auto cond1 = ArrayFromJSON(boolean(), "[true, true, null, null]"); From 16fe2104113bcfcf003e414887c2b876316ea356 Mon Sep 17 00:00:00 2001 From: David Li Date: Tue, 31 Aug 2021 12:47:14 -0400 Subject: [PATCH 06/21] ARROW-13691: [C++] Rebase --- cpp/src/arrow/array/builder_dict.cc | 1 + cpp/src/arrow/array/builder_dict.h | 2 ++ 2 files changed, 3 insertions(+) diff --git a/cpp/src/arrow/array/builder_dict.cc b/cpp/src/arrow/array/builder_dict.cc index b554c1d7099..d247316999d 100644 --- a/cpp/src/arrow/array/builder_dict.cc +++ b/cpp/src/arrow/array/builder_dict.cc @@ -182,6 +182,7 @@ GET_OR_INSERT(Date32Type); GET_OR_INSERT(Date64Type); GET_OR_INSERT(Time32Type); GET_OR_INSERT(Time64Type); +GET_OR_INSERT(MonthDayNanoIntervalType); GET_OR_INSERT(DayTimeIntervalType); GET_OR_INSERT(MonthIntervalType); diff --git a/cpp/src/arrow/array/builder_dict.h b/cpp/src/arrow/array/builder_dict.h index 1b97f98f290..df128c08c19 100644 --- a/cpp/src/arrow/array/builder_dict.h +++ b/cpp/src/arrow/array/builder_dict.h @@ -104,6 +104,8 @@ class ARROW_EXPORT DictionaryMemoTable { Status GetOrInsert(const Date64Type*, int64_t value, int32_t* out); Status GetOrInsert(const Time32Type*, int32_t value, int32_t* out); Status GetOrInsert(const Time64Type*, int64_t value, int32_t* out); + Status GetOrInsert(const MonthDayNanoIntervalType*, + MonthDayNanoIntervalType::MonthDayNanos value, int32_t* out); Status GetOrInsert(const DayTimeIntervalType*, DayTimeIntervalType::DayMilliseconds value, int32_t* out); Status GetOrInsert(const MonthIntervalType*, int32_t value, int32_t* out); From 8e0e33323725513c1eb35c1923abf300b0b4afb4 Mon Sep 17 00:00:00 2001 From: David Li Date: Tue, 7 Sep 2021 11:04:04 -0400 Subject: [PATCH 07/21] ARROW-13573: [C++] Always unify dictionaries --- .../arrow/compute/kernels/scalar_if_else.cc | 447 ++---------------- .../compute/kernels/scalar_if_else_test.cc | 239 +++------- cpp/src/arrow/compute/kernels/test_util.cc | 69 +++ cpp/src/arrow/compute/kernels/test_util.h | 4 + 4 files changed, 179 insertions(+), 580 deletions(-) diff --git a/cpp/src/arrow/compute/kernels/scalar_if_else.cc b/cpp/src/arrow/compute/kernels/scalar_if_else.cc index ff4df63b966..ab7071af20c 100644 --- a/cpp/src/arrow/compute/kernels/scalar_if_else.cc +++ b/cpp/src/arrow/compute/kernels/scalar_if_else.cc @@ -15,13 +15,10 @@ // specific language governing permissions and limitations // under the License. -#include - #include "arrow/array/builder_nested.h" #include "arrow/array/builder_primitive.h" #include "arrow/array/builder_time.h" #include "arrow/array/builder_union.h" -#include "arrow/array/dict_internal.h" #include "arrow/compute/api.h" #include "arrow/compute/kernels/codegen_internal.h" #include "arrow/compute/util_internal.h" @@ -30,7 +27,6 @@ #include "arrow/util/bitmap.h" #include "arrow/util/bitmap_ops.h" #include "arrow/util/bitmap_reader.h" -#include "arrow/util/int_util.h" namespace arrow { @@ -1062,109 +1058,6 @@ void AddFSBinaryIfElseKernel(const std::shared_ptr& scalar_funct DCHECK_OK(scalar_function->AddKernel(std::move(kernel))); } -// Given a reference dictionary, computes indices to map dictionary values from a -// comparison dictionary to the reference. -class DictionaryRemapper { - public: - virtual ~DictionaryRemapper() = default; - virtual Status Init(const Array& dictionary) = 0; - virtual Result> Remap(const Array& dictionary) = 0; - Result> Remap(const Datum& has_dictionary) { - DCHECK_EQ(has_dictionary.type()->id(), Type::DICTIONARY); - if (has_dictionary.is_scalar()) { - return Remap(*checked_cast(*has_dictionary.scalar()) - .value.dictionary); - } else { - return Remap(*MakeArray(has_dictionary.array()->dictionary)); - } - } -}; - -template -class DictionaryRemapperImpl : public DictionaryRemapper { - public: - using ArrayType = typename TypeTraits::ArrayType; - using MemoTableType = typename arrow::internal::HashTraits::MemoTableType; - - explicit DictionaryRemapperImpl(MemoryPool* pool) : pool_(pool), memo_table_(pool) {} - - Status Init(const Array& dictionary) override { - const ArrayType& values = checked_cast(dictionary); - if (values.length() > std::numeric_limits::max()) { - return Status::CapacityError("Cannot remap dictionary with more than ", - std::numeric_limits::max(), - " elements, have: ", values.length()); - } - for (int32_t i = 0; i < values.length(); ++i) { - if (values.IsNull(i)) { - memo_table_.GetOrInsertNull(); - continue; - } - int32_t unused_memo_index = 0; - RETURN_NOT_OK(memo_table_.GetOrInsert(values.GetView(i), &unused_memo_index)); - } - return Status::OK(); - } - - Result> Remap(const Array& dictionary) override { - const ArrayType& values = checked_cast(dictionary); - std::shared_ptr valid_buffer; - ARROW_ASSIGN_OR_RAISE(auto indices_buffer, - AllocateBuffer(dictionary.length() * sizeof(int32_t), pool_)); - int32_t* indices = reinterpret_cast(indices_buffer->mutable_data()); - int64_t null_count = 0; - for (int64_t i = 0; i < values.length(); ++i) { - int32_t index = -1; - index = - values.IsNull(i) ? memo_table_.GetNull() : memo_table_.Get(values.GetView(i)); - indices[i] = std::max(0, index); - if (index == arrow::internal::kKeyNotFound && !valid_buffer) { - ARROW_ASSIGN_OR_RAISE( - valid_buffer, - AllocateBuffer(BitUtil::BytesForBits(dictionary.length()), pool_)); - std::memset(valid_buffer->mutable_data(), 0xFF, valid_buffer->size()); - } - if (index == arrow::internal::kKeyNotFound) { - BitUtil::ClearBit(valid_buffer->mutable_data(), i); - null_count++; - } - } - return arrow::internal::make_unique(dictionary.length(), - std::move(indices_buffer), - std::move(valid_buffer), null_count); - } - - private: - MemoryPool* pool_; - MemoTableType memo_table_; -}; - -struct MakeRemapper { - template - enable_if_no_memoize Visit(const T& value_type) { - return Status::NotImplemented("Unification of ", value_type, - " dictionaries is not implemented"); - } - - template - enable_if_memoize Visit(const T&) { - result_.reset(new DictionaryRemapperImpl(pool_)); - return Status::OK(); - } - - static Result> Make(MemoryPool* pool, - const Array& dictionary) { - const auto& value_type = *dictionary.type(); - MakeRemapper impl{pool, /*result_=*/nullptr}; - RETURN_NOT_OK(VisitTypeInline(value_type, &impl)); - RETURN_NOT_OK(impl.result_->Init(dictionary)); - return std::move(impl.result_); - } - - MemoryPool* pool_; - std::unique_ptr result_; -}; - // Helper to copy or broadcast fixed-width values between buffers. template struct CopyFixedWidth {}; @@ -1191,150 +1084,14 @@ struct CopyFixedWidth> { const CType value = UnboxScalar::Unbox(scalar); std::fill(out_values + out_offset, out_values + out_offset + length, value); } - static void CopyScalar(const Scalar& scalar, const Int32Array& transpose_map, - const int64_t length, uint8_t* raw_out_values, - const int64_t out_offset) { - CType* out_values = reinterpret_cast(raw_out_values); - const CType value = UnboxScalar::Unbox(scalar); - const CType transposed = static_cast(transpose_map.raw_values()[value]); - std::fill(out_values + out_offset, out_values + out_offset + length, transposed); - } static void CopyArray(const DataType&, const uint8_t* in_values, const int64_t in_offset, const int64_t length, uint8_t* raw_out_values, const int64_t out_offset) { std::memcpy(raw_out_values + out_offset * sizeof(CType), in_values + in_offset * sizeof(CType), length * sizeof(CType)); } - static void CopyArray(const DataType&, const Int32Array& transpose_map, - const uint8_t* in_values, const int64_t in_offset, - const int64_t length, uint8_t* raw_out_values, - const int64_t out_offset) { - arrow::internal::TransposeInts( - reinterpret_cast(in_values) + in_offset, - reinterpret_cast(raw_out_values) + out_offset, length, - transpose_map.raw_values()); - } }; -template -struct CopyFixedWidth> { - static void CopyScalar(const Scalar& scalar, const int64_t length, - uint8_t* raw_out_values, const int64_t out_offset) { - const auto& index = *checked_cast(scalar).value.index; - switch (index.type->id()) { - case arrow::Type::INT8: - case arrow::Type::UINT8: - CopyFixedWidth::CopyScalar(index, length, raw_out_values, out_offset); - break; - case arrow::Type::INT16: - case arrow::Type::UINT16: - CopyFixedWidth::CopyScalar(index, length, raw_out_values, out_offset); - break; - case arrow::Type::INT32: - case arrow::Type::UINT32: - CopyFixedWidth::CopyScalar(index, length, raw_out_values, out_offset); - break; - case arrow::Type::INT64: - case arrow::Type::UINT64: - CopyFixedWidth::CopyScalar(index, length, raw_out_values, out_offset); - break; - default: - ARROW_CHECK(false) << "Invalid index type for dictionary: " << *index.type; - } - } - static void CopyScalar(const Scalar& scalar, const Int32Array& transpose_map, - const int64_t length, uint8_t* raw_out_values, - const int64_t out_offset) { - const auto& index = *checked_cast(scalar).value.index; - switch (index.type->id()) { - case arrow::Type::INT8: - case arrow::Type::UINT8: - CopyFixedWidth::CopyScalar(index, transpose_map, length, - raw_out_values, out_offset); - break; - case arrow::Type::INT16: - case arrow::Type::UINT16: - CopyFixedWidth::CopyScalar(index, transpose_map, length, - raw_out_values, out_offset); - break; - case arrow::Type::INT32: - case arrow::Type::UINT32: - CopyFixedWidth::CopyScalar(index, transpose_map, length, - raw_out_values, out_offset); - break; - case arrow::Type::INT64: - case arrow::Type::UINT64: - CopyFixedWidth::CopyScalar(index, transpose_map, length, - raw_out_values, out_offset); - break; - default: - ARROW_CHECK(false) << "Invalid index type for dictionary: " << *index.type; - } - } - static void CopyArray(const DataType& type, const uint8_t* in_values, - const int64_t in_offset, const int64_t length, - uint8_t* raw_out_values, const int64_t out_offset) { - const auto& index_type = *checked_cast(type).index_type(); - switch (index_type.id()) { - case arrow::Type::INT8: - case arrow::Type::UINT8: - CopyFixedWidth::CopyArray(index_type, in_values, in_offset, length, - raw_out_values, out_offset); - break; - case arrow::Type::INT16: - case arrow::Type::UINT16: - CopyFixedWidth::CopyArray(index_type, in_values, in_offset, length, - raw_out_values, out_offset); - break; - case arrow::Type::INT32: - case arrow::Type::UINT32: - CopyFixedWidth::CopyArray(index_type, in_values, in_offset, length, - raw_out_values, out_offset); - break; - case arrow::Type::INT64: - case arrow::Type::UINT64: - CopyFixedWidth::CopyArray(index_type, in_values, in_offset, length, - raw_out_values, out_offset); - break; - default: - ARROW_CHECK(false) << "Invalid index type for dictionary: " << index_type; - } - } - static void CopyArray(const DataType& type, const Int32Array& transpose_map, - const uint8_t* in_values, const int64_t in_offset, - const int64_t length, uint8_t* raw_out_values, - const int64_t out_offset) { - const auto& index_type = *checked_cast(type).index_type(); - switch (index_type.id()) { - case arrow::Type::INT8: - case arrow::Type::UINT8: - CopyFixedWidth::CopyArray(index_type, transpose_map, in_values, - in_offset, length, raw_out_values, - out_offset); - break; - case arrow::Type::INT16: - case arrow::Type::UINT16: - CopyFixedWidth::CopyArray(index_type, transpose_map, in_values, - in_offset, length, raw_out_values, - out_offset); - break; - case arrow::Type::INT32: - case arrow::Type::UINT32: - CopyFixedWidth::CopyArray(index_type, transpose_map, in_values, - in_offset, length, raw_out_values, - out_offset); - break; - case arrow::Type::INT64: - case arrow::Type::UINT64: - CopyFixedWidth::CopyArray(index_type, transpose_map, in_values, - in_offset, length, raw_out_values, - out_offset); - break; - default: - ARROW_CHECK(false) << "Invalid index type for dictionary: " << index_type; - } - } -}; template struct CopyFixedWidth> { static void CopyScalar(const Scalar& values, const int64_t length, @@ -1389,11 +1146,8 @@ struct CopyFixedWidth> { // Copy fixed-width values from a scalar/array datum into an output values buffer template -enable_if_t::value> CopyValues( - const Datum& in_values, const int64_t in_offset, const int64_t length, - uint8_t* out_valid, uint8_t* out_values, const int64_t out_offset, - const Int32Array* transpose_map = nullptr) { - DCHECK(!transpose_map); +void CopyValues(const Datum& in_values, const int64_t in_offset, const int64_t length, + uint8_t* out_valid, uint8_t* out_values, const int64_t out_offset) { if (in_values.is_scalar()) { const auto& scalar = *in_values.scalar(); if (out_valid) { @@ -1423,123 +1177,6 @@ enable_if_t::value> CopyValues( } } -// Copy values, optionally transposing dictionary indices. -template -enable_if_dictionary CopyValues(const Datum& in_values, const int64_t in_offset, - const int64_t length, uint8_t* out_valid, - uint8_t* out_values, const int64_t out_offset, - const Int32Array* transpose_map) { - if (in_values.is_scalar()) { - const auto& scalar = *in_values.scalar(); - if (out_valid) { - BitUtil::SetBitsTo(out_valid, out_offset, length, scalar.is_valid); - } - if (transpose_map) { - CopyFixedWidth::CopyScalar(scalar, *transpose_map, length, out_values, - out_offset); - } else { - CopyFixedWidth::CopyScalar(scalar, length, out_values, out_offset); - } - } else { - const ArrayData& array = *in_values.array(); - if (out_valid) { - if (array.MayHaveNulls()) { - if (length == 1) { - // CopyBitmap is slow for short runs - BitUtil::SetBitTo( - out_valid, out_offset, - BitUtil::GetBit(array.buffers[0]->data(), array.offset + in_offset)); - } else { - arrow::internal::CopyBitmap(array.buffers[0]->data(), array.offset + in_offset, - length, out_valid, out_offset); - } - } else { - BitUtil::SetBitsTo(out_valid, out_offset, length, true); - } - } - if (transpose_map) { - CopyFixedWidth::CopyArray(*array.type, *transpose_map, - array.buffers[1]->data(), array.offset + in_offset, - length, out_values, out_offset); - } else { - CopyFixedWidth::CopyArray(*array.type, array.buffers[1]->data(), - array.offset + in_offset, length, out_values, - out_offset); - } - } -} - -/// Check that we can actually remap dictionary indices from one dictionary to -/// another without losing data. -struct CheckValidTranspositionArrayImpl { - template - enable_if_number Visit(const T&) { - using c_type = typename T::c_type; - const c_type* values = arr.GetValues(1); - // TODO: eventually offer the option to zero out the bitmap instead - return arrow::internal::VisitSetBitRuns( - arr.buffers[0], arr.offset + offset, length, - [&](int64_t position, int64_t length) { - for (int64_t i = 0; i < length; i++) { - const uint64_t idx = static_cast(values[offset + position + i]); - if (!BitUtil::GetBit(transpose_valid, idx)) { - return Status::Invalid("Cannot map dictionary index ", idx, " at position ", - offset + position + i, " to the common dictionary"); - } - } - return Status::OK(); - }); - } - - Status Visit(const DataType& ty) { - return Status::TypeError("Dictionary cannot have index type", ty); - } - - const ArrayData& arr; - const int64_t offset; - const int64_t length; - const uint8_t* transpose_valid; -}; - -struct CheckValidTranspositionScalarImpl { - template - enable_if_number Visit(const T&) { - const uint64_t idx = static_cast(UnboxScalar::Unbox( - *checked_cast(scalar).value.index)); - // TODO: eventually offer the option to zero out the bitmap instead - if (!BitUtil::GetBit(transpose_valid, idx)) { - return Status::Invalid("Cannot map dictionary index ", idx, - " to the common dictionary"); - } - return Status::OK(); - } - - Status Visit(const DataType& ty) { - return Status::TypeError("Dictionary cannot have index type", ty); - } - - const Scalar& scalar; - const uint8_t* transpose_valid; -}; - -Status CheckValidTransposition(const Datum& values, const int64_t offset, - const int64_t length, const Int32Array* transpose_map) { - // Note we assume the transpose map never has an offset - if (!transpose_map || transpose_map->null_count() == 0) return Status::OK(); - DCHECK_EQ(values.type()->id(), Type::DICTIONARY); - if (values.is_scalar()) { - const Scalar& scalar = *values.scalar(); - CheckValidTranspositionScalarImpl impl{scalar, transpose_map->null_bitmap_data()}; - return VisitTypeInline( - *checked_cast(*scalar.type).index_type(), &impl); - } - const ArrayData& arr = *values.array(); - CheckValidTranspositionArrayImpl impl{arr, offset, length, - transpose_map->null_bitmap_data()}; - return VisitTypeInline(*checked_cast(*arr.type).index_type(), - &impl); -} - // Specialized helper to copy a single value from a source array. Allows avoiding // repeatedly calling MayHaveNulls and Buffer::data() which have internal checks that // add up when called in a loop. @@ -1657,9 +1294,9 @@ Status ExecScalarCaseWhen(KernelContext* ctx, const ExecBatch& batch, Datum* out // All conditions false, no 'else' argument result = MakeNullScalar(out->type()); } - CopyValues( - result, /*in_offset=*/0, batch.length, output->GetMutableValues(0, 0), - output->GetMutableValues(1, 0), output->offset, /*transpose_map=*/nullptr); + CopyValues(result, /*in_offset=*/0, batch.length, + output->GetMutableValues(0, 0), + output->GetMutableValues(1, 0), output->offset); return Status::OK(); } @@ -1679,28 +1316,10 @@ Status ExecArrayCaseWhen(KernelContext* ctx, const ExecBatch& batch, Datum* out) uint8_t* out_valid = output->buffers[0]->mutable_data(); uint8_t* out_values = output->buffers[1]->mutable_data(); - std::unique_ptr remapper; - std::unique_ptr transpose_map; - if (is_dictionary_type::value) { - // We always use the dictionary of the first argument - const Datum& dict_from = batch[1]; - if (dict_from.is_scalar()) { - output->dictionary = checked_cast(*dict_from.scalar()) - .value.dictionary->data(); - } else { - output->dictionary = dict_from.array()->dictionary; - } - ARROW_ASSIGN_OR_RAISE( - remapper, MakeRemapper::Make(ctx->memory_pool(), *MakeArray(output->dictionary))); - } - if (have_else_arg) { // Copy 'else' value into output - if (is_dictionary_type::value) { - ARROW_ASSIGN_OR_RAISE(transpose_map, remapper->Remap(batch.values.back())); - } CopyValues(batch.values.back(), /*in_offset=*/0, batch.length, out_valid, - out_values, out_offset, transpose_map.get()); + out_values, out_offset); } else { // There's no 'else' argument, so we should have an all-null validity bitmap BitUtil::SetBitsTo(out_valid, out_offset, batch.length, false); @@ -1719,10 +1338,6 @@ Status ExecArrayCaseWhen(KernelContext* ctx, const ExecBatch& batch, Datum* out) const Datum& values_datum = batch[i + 1]; int64_t offset = 0; - if (is_dictionary_type::value) { - ARROW_ASSIGN_OR_RAISE(transpose_map, remapper->Remap(values_datum)); - } - if (cond_array.GetNullCount() == 0) { // If no valid buffer, visit mask & cond bitmap simultaneously BinaryBitBlockCounter counter(mask, /*start_offset=*/0, cond_values, cond_offset, @@ -1730,19 +1345,15 @@ Status ExecArrayCaseWhen(KernelContext* ctx, const ExecBatch& batch, Datum* out) while (offset < batch.length) { const auto block = counter.NextAndWord(); if (block.AllSet()) { - RETURN_NOT_OK(CheckValidTransposition(values_datum, offset, block.length, - transpose_map.get())); CopyValues(values_datum, offset, block.length, out_valid, out_values, - out_offset + offset, transpose_map.get()); + out_offset + offset); BitUtil::SetBitsTo(mask, offset, block.length, false); } else if (block.popcount) { for (int64_t j = 0; j < block.length; ++j) { if (BitUtil::GetBit(mask, offset + j) && BitUtil::GetBit(cond_values, cond_offset + offset + j)) { - RETURN_NOT_OK(CheckValidTransposition(values_datum, offset + j, - /*length=*/1, transpose_map.get())); CopyValues(values_datum, offset + j, /*length=*/1, out_valid, - out_values, out_offset + offset + j, transpose_map.get()); + out_values, out_offset + offset + j); BitUtil::SetBitTo(mask, offset + j, false); } } @@ -1755,31 +1366,25 @@ Status ExecArrayCaseWhen(KernelContext* ctx, const ExecBatch& batch, Datum* out) Bitmap bitmaps[3] = {{mask, /*offset=*/0, batch.length}, {cond_values, cond_offset, batch.length}, {cond_valid, cond_offset, batch.length}}; - Status valid_transposition = Status::OK(); Bitmap::VisitWords(bitmaps, [&](std::array words) { const uint64_t word = words[0] & words[1] & words[2]; const int64_t block_length = std::min(64, batch.length - offset); if (word == std::numeric_limits::max()) { - valid_transposition &= CheckValidTransposition( - values_datum, offset, block_length, transpose_map.get()); CopyValues(values_datum, offset, block_length, out_valid, out_values, - out_offset + offset, transpose_map.get()); + out_offset + offset); BitUtil::SetBitsTo(mask, offset, block_length, false); } else if (word) { for (int64_t j = 0; j < block_length; ++j) { if (BitUtil::GetBit(mask, offset + j) && BitUtil::GetBit(cond_valid, cond_offset + offset + j) && BitUtil::GetBit(cond_values, cond_offset + offset + j)) { - valid_transposition &= CheckValidTransposition( - values_datum, offset + j, /*length=*/1, transpose_map.get()); CopyValues(values_datum, offset + j, /*length=*/1, out_valid, - out_values, out_offset + offset + j, transpose_map.get()); + out_values, out_offset + offset + j); BitUtil::SetBitTo(mask, offset + j, false); } } } }); - RETURN_NOT_OK(valid_transposition); } } if (!have_else_arg) { @@ -1810,18 +1415,6 @@ Status ExecArrayCaseWhen(KernelContext* ctx, const ExecBatch& batch, Datum* out) } offset += block.length; } - } else if (is_dictionary_type::value) { - // Check that any 'else' slots that were not overwritten are valid transpositions. - arrow::internal::SetBitRunReader reader(mask, /*offset=*/0, batch.length); - if (is_dictionary_type::value) { - ARROW_ASSIGN_OR_RAISE(transpose_map, remapper->Remap(batch.values.back())); - } - while (true) { - const auto run = reader.NextRun(); - if (run.length == 0) break; - RETURN_NOT_OK(CheckValidTransposition(batch.values.back(), run.position, run.length, - transpose_map.get())); - } } return Status::OK(); } @@ -2120,6 +1713,26 @@ struct CaseWhenFunctor> { } }; +template <> +struct CaseWhenFunctor { + static Status Exec(KernelContext* ctx, const ExecBatch& batch, Datum* out) { + if (batch[0].null_count() > 0) { + return Status::Invalid("cond struct must not have outer nulls"); + } + if (batch[0].is_scalar()) { + return ExecVarWidthScalarCaseWhen(ctx, batch, out); + } + return ExecArray(ctx, batch, out); + } + + static Status ExecArray(KernelContext* ctx, const ExecBatch& batch, Datum* out) { + return ExecVarWidthArrayCaseWhen( + ctx, batch, out, + // ReserveData + [&](ArrayBuilder* raw_builder) { return Status::OK(); }); + } +}; + struct CoalesceFunction : ScalarFunction { using ScalarFunction::ScalarFunction; diff --git a/cpp/src/arrow/compute/kernels/scalar_if_else_test.cc b/cpp/src/arrow/compute/kernels/scalar_if_else_test.cc index 38be7eb8780..a489934a422 100644 --- a/cpp/src/arrow/compute/kernels/scalar_if_else_test.cc +++ b/cpp/src/arrow/compute/kernels/scalar_if_else_test.cc @@ -637,22 +637,20 @@ TYPED_TEST_SUITE(TestCaseWhenDict, IntegralArrowTypes); TYPED_TEST(TestCaseWhenDict, Simple) { auto cond1 = ArrayFromJSON(boolean(), "[true, true, null, null]"); auto cond2 = ArrayFromJSON(boolean(), "[true, false, true, null]"); + // TODO: test dictionaries with nulls for (const auto& dict : - {JsonDict{utf8(), R"(["a", null, "bc", "def"])"}, - JsonDict{int64(), "[1, null, 2, 3]"}, - JsonDict{decimal256(3, 2), R"(["1.23", null, "3.45", "6.78"])"}}) { + {JsonDict{utf8(), R"(["a", "b", "bc", "def"])"}, JsonDict{int64(), "[1, 4, 2, 3]"}, + JsonDict{decimal256(3, 2), R"(["1.23", "2.34", "3.45", "6.78"])"}}) { auto type = dictionary(default_type_instance(), dict.type); auto values_null = DictArrayFromJSON(type, "[null, null, null, null]", dict.value); auto values1 = DictArrayFromJSON(type, "[0, null, 3, 1]", dict.value); auto values2 = DictArrayFromJSON(type, "[2, 1, null, 0]", dict.value); // Easy case: all arguments have the same dictionary - CheckScalar("case_when", {MakeStruct({cond1, cond2}), values1, values2}, - DictArrayFromJSON(type, "[0, null, null, null]", dict.value)); - CheckScalar("case_when", {MakeStruct({cond1, cond2}), values1, values2, values1}, - DictArrayFromJSON(type, "[0, null, null, 1]", dict.value)); - CheckScalar("case_when", {MakeStruct({cond1, cond2}), values_null, values2, values1}, - DictArrayFromJSON(type, "[null, null, null, 1]", dict.value)); + CheckDictionary("case_when", {MakeStruct({cond1, cond2}), values1, values2}); + CheckDictionary("case_when", {MakeStruct({cond1, cond2}), values1, values2, values1}); + CheckDictionary("case_when", + {MakeStruct({cond1, cond2}), values_null, values2, values1}); } } @@ -660,7 +658,7 @@ TYPED_TEST(TestCaseWhenDict, Mixed) { auto type = dictionary(default_type_instance(), utf8()); auto cond1 = ArrayFromJSON(boolean(), "[true, true, null, null]"); auto cond2 = ArrayFromJSON(boolean(), "[true, false, true, null]"); - auto dict = R"(["a", null, "bc", "def"])"; + auto dict = R"(["a", "", "bc", "def"])"; auto values_null = DictArrayFromJSON(type, "[null, null, null, null]", dict); auto values1_dict = DictArrayFromJSON(type, "[0, null, 3, 1]", dict); auto values1_decoded = ArrayFromJSON(utf8(), R"(["a", null, "def", null])"); @@ -668,16 +666,14 @@ TYPED_TEST(TestCaseWhenDict, Mixed) { auto values2_decoded = ArrayFromJSON(utf8(), R"(["bc", null, null, "a"])"); // If we have mixed dictionary/non-dictionary arguments, we decode dictionaries - CheckScalar("case_when", {MakeStruct({cond1, cond2}), values1_dict, values2_decoded}, - ArrayFromJSON(utf8(), R"(["a", null, null, null])")); - CheckScalar("case_when", {MakeStruct({cond1, cond2}), values1_decoded, values2_dict}, - ArrayFromJSON(utf8(), R"(["a", null, null, null])")); - CheckScalar("case_when", - {MakeStruct({cond1, cond2}), values1_dict, values2_dict, values1_decoded}, - ArrayFromJSON(utf8(), R"(["a", null, null, null])")); - CheckScalar("case_when", - {MakeStruct({cond1, cond2}), values_null, values2_dict, values1_decoded}, - ArrayFromJSON(utf8(), R"([null, null, null, null])")); + CheckDictionary("case_when", + {MakeStruct({cond1, cond2}), values1_dict, values2_decoded}); + CheckDictionary("case_when", + {MakeStruct({cond1, cond2}), values1_decoded, values2_dict}); + CheckDictionary("case_when", {MakeStruct({cond1, cond2}), values1_dict, values2_dict, + values1_decoded}); + CheckDictionary("case_when", {MakeStruct({cond1, cond2}), values_null, values2_dict, + values1_decoded}); } TYPED_TEST(TestCaseWhenDict, NestedSimple) { @@ -699,52 +695,38 @@ TYPED_TEST(TestCaseWhenDict, NestedSimple) { auto values1 = make_list(ArrayFromJSON(int32(), "[0, 2, 2, 3, 4]"), values1_backing); auto values2 = make_list(ArrayFromJSON(int32(), "[0, 1, 2, 2, 4]"), values2_backing); - CheckScalarNonRecursive( - "case_when", {MakeStruct({cond1, cond2}), values1, values2}, - make_list(ArrayFromJSON(int32(), "[0, 2, 2, null, 2]"), - DictArrayFromJSON(inner_type, "[0, null]", R"(["a"])"))); - CheckScalarNonRecursive( - "case_when", - {MakeStruct({cond1, cond2}), values1, - make_list(ArrayFromJSON(int32(), "[0, 1, null, 2, 4]"), values2_backing)}, - make_list(ArrayFromJSON(int32(), "[0, 2, null, null, 2]"), - DictArrayFromJSON(inner_type, "[0, null]", R"(["a"])"))); - CheckScalarNonRecursive( - "case_when", - {MakeStruct({cond1, cond2}), values1, - make_list(ArrayFromJSON(int32(), "[0, 1, null, 2, 4]"), values2_backing), values1}, - make_list(ArrayFromJSON(int32(), "[0, 2, null, 2, 3]"), - DictArrayFromJSON(inner_type, "[0, null, 1]", R"(["a", "b"])"))); - - CheckScalarNonRecursive( - "case_when", - { - Datum(MakeStruct({cond1, cond2})), - Datum(std::make_shared( - DictArrayFromJSON(inner_type, "[0, 1]", dict))), - Datum(std::make_shared( - DictArrayFromJSON(inner_type, "[2, 3]", dict))), - }, - make_list(ArrayFromJSON(int32(), "[0, 2, 4, null, 6]"), - DictArrayFromJSON(inner_type, "[0, 1, 0, 1, 2, 3]", dict))); - - CheckScalarNonRecursive( - "case_when", {MakeStruct({Datum(true), Datum(false)}), values1, values2}, values1); - CheckScalarNonRecursive( - "case_when", {MakeStruct({Datum(false), Datum(true)}), values1, values2}, values2); - CheckScalarNonRecursive("case_when", {MakeStruct({Datum(false)}), values1, values2}, - values2); - CheckScalarNonRecursive("case_when", - {MakeStruct({Datum(false), Datum(false)}), values1, values2}, - values_null); + CheckDictionary("case_when", {MakeStruct({cond1, cond2}), values1, values2}); + CheckDictionary("case_when", {MakeStruct({cond1, cond2}), values1, + make_list(ArrayFromJSON(int32(), "[0, 1, null, 2, 4]"), + values2_backing)}); + CheckDictionary("case_when", {MakeStruct({cond1, cond2}), values1, + make_list(ArrayFromJSON(int32(), "[0, 1, null, 2, 4]"), + values2_backing), + values1}); + + CheckDictionary("case_when", { + Datum(MakeStruct({cond1, cond2})), + Datum(std::make_shared( + DictArrayFromJSON(inner_type, "[0, 1]", dict))), + Datum(std::make_shared( + DictArrayFromJSON(inner_type, "[2, 3]", dict))), + }); + + CheckDictionary("case_when", + {MakeStruct({Datum(true), Datum(false)}), values1, values2}); + CheckDictionary("case_when", + {MakeStruct({Datum(false), Datum(true)}), values1, values2}); + CheckDictionary("case_when", {MakeStruct({Datum(false)}), values1, values2}); + CheckDictionary("case_when", + {MakeStruct({Datum(false), Datum(false)}), values1, values2}); } TYPED_TEST(TestCaseWhenDict, DifferentDictionaries) { auto type = dictionary(default_type_instance(), utf8()); auto cond1 = ArrayFromJSON(boolean(), "[true, true, null, null]"); auto cond2 = ArrayFromJSON(boolean(), "[true, false, true, null]"); - auto dict1 = R"(["a", null, "bc", "def"])"; - auto dict2 = R"(["bc", "foo", null, "a"])"; + auto dict1 = R"(["a", "", "bc", "def"])"; + auto dict2 = R"(["bc", "foo", "", "a"])"; auto dict3 = R"(["def", "a", "a", "bc"])"; auto values1_null = DictArrayFromJSON(type, "[null, null, null, null]", dict1); auto values2_null = DictArrayFromJSON(type, "[null, null, null, null]", dict2); @@ -752,120 +734,51 @@ TYPED_TEST(TestCaseWhenDict, DifferentDictionaries) { auto values2 = DictArrayFromJSON(type, "[2, 1, null, 0]", dict2); auto values3 = DictArrayFromJSON(type, "[0, 1, 2, 3]", dict3); - // For scalar conditions, we borrow the dictionary of the chosen output (or the first - // input when outputting null) - CheckScalar("case_when", {MakeStruct({Datum(true), Datum(false)}), values1, values2}, - values1); - CheckScalar("case_when", {MakeStruct({Datum(false), Datum(true)}), values1, values2}, - values2); - CheckScalar("case_when", {MakeStruct({Datum(false), Datum(false)}), values1, values2}, - values1_null); - CheckScalar("case_when", {MakeStruct({Datum(false), Datum(false)}), values2, values1}, - values2_null); - - // For array conditions, we always borrow the dictionary of the first input - CheckScalar("case_when", {MakeStruct({cond1, cond2}), values1, values2}, - DictArrayFromJSON(type, "[0, null, null, null]", dict1)); - CheckScalar("case_when", {MakeStruct({cond1, cond2}), values1, values2, values1}, - DictArrayFromJSON(type, "[0, null, null, 1]", dict1)); - - // When mixing dictionaries, we try to map other dictionaries onto the first one - // Don't check the scalar cases since we don't remap dictionaries in that case - CheckScalarNonRecursive( - "case_when", - {MakeStruct({ArrayFromJSON(boolean(), "[true, true, false, false]")}), values1, - values2}, - DictArrayFromJSON(type, "[0, null, null, 2]", dict1)); - CheckScalarNonRecursive( - "case_when", - {MakeStruct({ArrayFromJSON(boolean(), "[true, true, false, false]"), - ArrayFromJSON(boolean(), "[true, false, true, false]")}), - values1, values2}, - DictArrayFromJSON(type, "[0, null, null, null]", dict1)); - CheckScalarNonRecursive( - "case_when", - {MakeStruct({ArrayFromJSON(boolean(), "[false, false, false, false]"), - ArrayFromJSON(boolean(), "[true, true, true, true]")}), - values1, values3}, - DictArrayFromJSON(type, "[3, 0, 0, 2]", dict1)); - CheckScalarNonRecursive( - "case_when", - {MakeStruct({ArrayFromJSON(boolean(), "[null, null, null, true]"), - ArrayFromJSON(boolean(), "[true, true, true, true]")}), - values1, values3}, - DictArrayFromJSON(type, "[3, 0, 0, 1]", dict1)); - CheckScalarNonRecursive( + CheckDictionary("case_when", + {MakeStruct({Datum(true), Datum(false)}), values1, values2}); + CheckDictionary("case_when", + {MakeStruct({Datum(false), Datum(true)}), values1, values2}); + CheckDictionary("case_when", + {MakeStruct({Datum(false), Datum(false)}), values1, values2}); + CheckDictionary("case_when", + {MakeStruct({Datum(false), Datum(false)}), values2, values1}); + + CheckDictionary("case_when", {MakeStruct({cond1, cond2}), values1, values2}); + CheckDictionary("case_when", {MakeStruct({cond1, cond2}), values1, values2, values1}); + + CheckDictionary("case_when", + {MakeStruct({ArrayFromJSON(boolean(), "[true, true, false, false]")}), + values1, values2}); + CheckDictionary("case_when", + {MakeStruct({ArrayFromJSON(boolean(), "[true, false, false, true]")}), + values1, values2}); + CheckDictionary("case_when", + {MakeStruct({ArrayFromJSON(boolean(), "[true, true, false, false]"), + ArrayFromJSON(boolean(), "[true, false, true, false]")}), + values1, values2}); + CheckDictionary("case_when", + {MakeStruct({ArrayFromJSON(boolean(), "[false, false, false, false]"), + ArrayFromJSON(boolean(), "[true, true, true, true]")}), + values1, values3}); + CheckDictionary("case_when", + {MakeStruct({ArrayFromJSON(boolean(), "[null, null, null, true]"), + ArrayFromJSON(boolean(), "[true, true, true, true]")}), + values1, values3}); + CheckDictionary( "case_when", { MakeStruct({ArrayFromJSON(boolean(), "[true, true, false, false]")}), DictScalarFromJSON(type, "0", dict1), DictScalarFromJSON(type, "0", dict2), - }, - DictArrayFromJSON(type, "[0, 0, 2, 2]", dict1)); - CheckScalarNonRecursive( + }); + CheckDictionary( "case_when", { MakeStruct({ArrayFromJSON(boolean(), "[true, true, false, false]"), ArrayFromJSON(boolean(), "[false, false, true, true]")}), DictScalarFromJSON(type, "0", dict1), DictScalarFromJSON(type, "0", dict2), - }, - DictArrayFromJSON(type, "[0, 0, 2, 2]", dict1)); - - // If we can't map values from a dictionary, then raise an error - // Unmappable value is in the else clause - EXPECT_RAISES_WITH_MESSAGE_THAT( - Invalid, - ::testing::HasSubstr( - "Cannot map dictionary index 1 at position 1 to the common dictionary"), - CallFunction( - "case_when", - {MakeStruct({ArrayFromJSON(boolean(), "[false, false, false, false]")}), - values1, values2})); - EXPECT_RAISES_WITH_MESSAGE_THAT( - Invalid, - ::testing::HasSubstr("Cannot map dictionary index 1 to the common dictionary"), - CallFunction( - "case_when", - {MakeStruct({ArrayFromJSON(boolean(), "[false, false, false, false]")}), - values1, DictScalarFromJSON(type, "1", dict2)})); - // Unmappable value is in a branch (test multiple times to ensure coverage of branches - // in impl) - EXPECT_RAISES_WITH_MESSAGE_THAT( - Invalid, - ::testing::HasSubstr( - "Cannot map dictionary index 1 at position 1 to the common dictionary"), - CallFunction("case_when", - {MakeStruct({Datum(false), - ArrayFromJSON(boolean(), "[true, true, true, true]")}), - values1, values2})); - EXPECT_RAISES_WITH_MESSAGE_THAT( - Invalid, - ::testing::HasSubstr( - "Cannot map dictionary index 1 at position 1 to the common dictionary"), - CallFunction("case_when", - {MakeStruct({Datum(false), - ArrayFromJSON(boolean(), "[false, true, false, false]")}), - values1, values2})); - EXPECT_RAISES_WITH_MESSAGE_THAT( - Invalid, - ::testing::HasSubstr( - "Cannot map dictionary index 1 at position 1 to the common dictionary"), - CallFunction("case_when", - {MakeStruct({Datum(false), - ArrayFromJSON(boolean(), "[null, true, null, null]")}), - values1, values2})); - EXPECT_RAISES_WITH_MESSAGE_THAT( - Invalid, - ::testing::HasSubstr("Cannot map dictionary index 1 to the common dictionary"), - CallFunction("case_when", - {MakeStruct({Datum(false), - ArrayFromJSON(boolean(), "[true, true, true, null]")}), - values1, DictScalarFromJSON(type, "1", dict2)})); - - // ...or optionally, emit null - - // TODO: this is not implemented yet + }); } TEST(TestCaseWhen, Null) { diff --git a/cpp/src/arrow/compute/kernels/test_util.cc b/cpp/src/arrow/compute/kernels/test_util.cc index 9a779e49163..fc210ae9346 100644 --- a/cpp/src/arrow/compute/kernels/test_util.cc +++ b/cpp/src/arrow/compute/kernels/test_util.cc @@ -24,6 +24,7 @@ #include "arrow/array.h" #include "arrow/array/validate.h" #include "arrow/chunked_array.h" +#include "arrow/compute/cast.h" #include "arrow/compute/exec.h" #include "arrow/compute/function.h" #include "arrow/compute/registry.h" @@ -170,6 +171,74 @@ void CheckScalar(std::string func_name, const DatumVector& inputs, Datum expecte } } +Datum CheckDictionaryNonRecursive(const std::string& func_name, const DatumVector& args) { + EXPECT_OK_AND_ASSIGN(Datum actual, CallFunction(func_name, args)); + ValidateOutput(actual); + + DatumVector decoded_args; + decoded_args.reserve(args.size()); + for (const auto& arg : args) { + if (arg.type()->id() == Type::DICTIONARY) { + const auto& to_type = checked_cast(*arg.type()).value_type(); + EXPECT_OK_AND_ASSIGN(auto decoded, Cast(arg, to_type)); + decoded_args.push_back(decoded); + } else { + decoded_args.push_back(arg); + } + } + EXPECT_OK_AND_ASSIGN(Datum expected, CallFunction(func_name, decoded_args)); + + if (actual.type()->id() == Type::DICTIONARY) { + const auto& to_type = + checked_cast(*actual.type()).value_type(); + EXPECT_OK_AND_ASSIGN(auto decoded, Cast(actual, to_type)); + AssertDatumsApproxEqual(expected, decoded, /*verbose=*/true); + } else { + AssertDatumsApproxEqual(expected, actual, /*verbose=*/true); + } + return actual; +} + +void CheckDictionary(const std::string& func_name, const DatumVector& args) { + auto actual = CheckDictionaryNonRecursive(func_name, args); + + if (actual.is_scalar()) return; + ASSERT_TRUE(actual.is_array()); + ASSERT_GE(actual.length(), 0); + + // Check all scalars + for (int64_t i = 0; i < actual.length(); i++) { + CheckDictionaryNonRecursive(func_name, GetDatums(GetScalars(args, i))); + } + + // Check slices of the input + const auto slice_length = actual.length() / 3; + if (slice_length > 0) { + CheckDictionaryNonRecursive(func_name, SliceArrays(args, 0, slice_length)); + CheckDictionaryNonRecursive(func_name, SliceArrays(args, slice_length, slice_length)); + CheckDictionaryNonRecursive(func_name, SliceArrays(args, 2 * slice_length)); + } + + // Check empty slice + CheckDictionaryNonRecursive(func_name, SliceArrays(args, 0, 0)); + + // Check chunked arrays + if (slice_length > 0) { + DatumVector chunked_args; + chunked_args.reserve(args.size()); + for (const auto& arg : args) { + if (arg.is_array()) { + auto arr = arg.make_array(); + ArrayVector chunks{arr->Slice(0, slice_length), arr->Slice(slice_length)}; + chunked_args.push_back(std::make_shared(std::move(chunks))); + } else { + chunked_args.push_back(arg); + } + } + CheckDictionaryNonRecursive(func_name, chunked_args); + } +} + void CheckScalarUnary(std::string func_name, Datum input, Datum expected, const FunctionOptions* options) { std::vector input_vector = {std::move(input)}; diff --git a/cpp/src/arrow/compute/kernels/test_util.h b/cpp/src/arrow/compute/kernels/test_util.h index 0931f3c77bc..89e68950175 100644 --- a/cpp/src/arrow/compute/kernels/test_util.h +++ b/cpp/src/arrow/compute/kernels/test_util.h @@ -76,6 +76,10 @@ void CheckScalar(std::string func_name, const ScalarVector& inputs, void CheckScalar(std::string func_name, const DatumVector& inputs, Datum expected, const FunctionOptions* options = nullptr); +// Like CheckScalar, but gets the expected result by +// dictionary-decoding arguments and calling the function again. +void CheckDictionary(const std::string& func_name, const DatumVector& args); + // Just call the function with the given arguments. void CheckScalarNonRecursive(const std::string& func_name, const DatumVector& inputs, const Datum& expected, From 60ffb02768416b4aba7e1d1714ff7c0d6313b591 Mon Sep 17 00:00:00 2001 From: David Li Date: Tue, 7 Sep 2021 11:44:43 -0400 Subject: [PATCH 08/21] ARROW-13573: [C++] Handle nulls before unifying, refactor --- cpp/src/arrow/array/builder_dict.h | 182 ++++++++++------------------- 1 file changed, 62 insertions(+), 120 deletions(-) diff --git a/cpp/src/arrow/array/builder_dict.h b/cpp/src/arrow/array/builder_dict.h index df128c08c19..db248ff2dac 100644 --- a/cpp/src/arrow/array/builder_dict.h +++ b/cpp/src/arrow/array/builder_dict.h @@ -307,70 +307,22 @@ class DictionaryBuilderBase : public ArrayBuilder { const auto& dict = internal::checked_cast::ArrayType&>( *dict_scalar.value.dictionary); switch (dict_ty.index_type()->id()) { - case Type::UINT8: { - const auto& value = dict.GetView( - internal::checked_cast(*dict_scalar.value.index).value); - for (int64_t i = 0; i < n_repeats; i++) { - ARROW_RETURN_NOT_OK(Append(value)); - } - break; - } - case Type::INT8: { - const auto& value = dict.GetView( - internal::checked_cast(*dict_scalar.value.index).value); - for (int64_t i = 0; i < n_repeats; i++) { - ARROW_RETURN_NOT_OK(Append(value)); - } - break; - } - case Type::UINT16: { - const auto& value = dict.GetView( - internal::checked_cast(*dict_scalar.value.index).value); - for (int64_t i = 0; i < n_repeats; i++) { - ARROW_RETURN_NOT_OK(Append(value)); - } - break; - } - case Type::INT16: { - const auto& value = dict.GetView( - internal::checked_cast(*dict_scalar.value.index).value); - for (int64_t i = 0; i < n_repeats; i++) { - ARROW_RETURN_NOT_OK(Append(value)); - } - break; - } - case Type::UINT32: { - const auto& value = dict.GetView( - internal::checked_cast(*dict_scalar.value.index).value); - for (int64_t i = 0; i < n_repeats; i++) { - ARROW_RETURN_NOT_OK(Append(value)); - } - break; - } - case Type::INT32: { - const auto& value = dict.GetView( - internal::checked_cast(*dict_scalar.value.index).value); - for (int64_t i = 0; i < n_repeats; i++) { - ARROW_RETURN_NOT_OK(Append(value)); - } - break; - } - case Type::UINT64: { - const auto& value = dict.GetView( - internal::checked_cast(*dict_scalar.value.index).value); - for (int64_t i = 0; i < n_repeats; i++) { - ARROW_RETURN_NOT_OK(Append(value)); - } - break; - } - case Type::INT64: { - const auto& value = dict.GetView( - internal::checked_cast(*dict_scalar.value.index).value); - for (int64_t i = 0; i < n_repeats; i++) { - ARROW_RETURN_NOT_OK(Append(value)); - } - break; - } + case Type::UINT8: + return AppendScalarImpl(dict, *dict_scalar.value.index, n_repeats); + case Type::INT8: + return AppendScalarImpl(dict, *dict_scalar.value.index, n_repeats); + case Type::UINT16: + return AppendScalarImpl(dict, *dict_scalar.value.index, n_repeats); + case Type::INT16: + return AppendScalarImpl(dict, *dict_scalar.value.index, n_repeats); + case Type::UINT32: + return AppendScalarImpl(dict, *dict_scalar.value.index, n_repeats); + case Type::INT32: + return AppendScalarImpl(dict, *dict_scalar.value.index, n_repeats); + case Type::UINT64: + return AppendScalarImpl(dict, *dict_scalar.value.index, n_repeats); + case Type::INT64: + return AppendScalarImpl(dict, *dict_scalar.value.index, n_repeats); default: return Status::TypeError("Invalid index type: ", dict_ty); } @@ -389,62 +341,22 @@ class DictionaryBuilderBase : public ArrayBuilder { const auto& dict_ty = internal::checked_cast(*array.type); const typename TypeTraits::ArrayType dict(array.dictionary); switch (dict_ty.index_type()->id()) { - case Type::UINT8: { - const uint8_t* values = array.GetValues(1) + offset; - return VisitBitBlocks( - array.buffers[0], array.offset + offset, std::min(array.length, length), - [&](int64_t position) { return Append(dict.GetView(values[position])); }, - [&]() { return AppendNull(); }); - } - case Type::INT8: { - const int8_t* values = array.GetValues(1) + offset; - return VisitBitBlocks( - array.buffers[0], array.offset + offset, std::min(array.length, length), - [&](int64_t position) { return Append(dict.GetView(values[position])); }, - [&]() { return AppendNull(); }); - } - case Type::UINT16: { - const uint16_t* values = array.GetValues(1) + offset; - return VisitBitBlocks( - array.buffers[0], array.offset + offset, std::min(array.length, length), - [&](int64_t position) { return Append(dict.GetView(values[position])); }, - [&]() { return AppendNull(); }); - } - case Type::INT16: { - const int16_t* values = array.GetValues(1) + offset; - return VisitBitBlocks( - array.buffers[0], array.offset + offset, std::min(array.length, length), - [&](int64_t position) { return Append(dict.GetView(values[position])); }, - [&]() { return AppendNull(); }); - } - case Type::UINT32: { - const uint32_t* values = array.GetValues(1) + offset; - return VisitBitBlocks( - array.buffers[0], array.offset + offset, std::min(array.length, length), - [&](int64_t position) { return Append(dict.GetView(values[position])); }, - [&]() { return AppendNull(); }); - } - case Type::INT32: { - const int32_t* values = array.GetValues(1) + offset; - return VisitBitBlocks( - array.buffers[0], array.offset + offset, std::min(array.length, length), - [&](int64_t position) { return Append(dict.GetView(values[position])); }, - [&]() { return AppendNull(); }); - } - case Type::UINT64: { - const uint64_t* values = array.GetValues(1) + offset; - return VisitBitBlocks( - array.buffers[0], array.offset + offset, std::min(array.length, length), - [&](int64_t position) { return Append(dict.GetView(values[position])); }, - [&]() { return AppendNull(); }); - } - case Type::INT64: { - const int64_t* values = array.GetValues(1) + offset; - return VisitBitBlocks( - array.buffers[0], array.offset + offset, std::min(array.length, length), - [&](int64_t position) { return Append(dict.GetView(values[position])); }, - [&]() { return AppendNull(); }); - } + case Type::UINT8: + return AppendArraySliceImpl(dict, array, offset, length); + case Type::INT8: + return AppendArraySliceImpl(dict, array, offset, length); + case Type::UINT16: + return AppendArraySliceImpl(dict, array, offset, length); + case Type::INT16: + return AppendArraySliceImpl(dict, array, offset, length); + case Type::UINT32: + return AppendArraySliceImpl(dict, array, offset, length); + case Type::INT32: + return AppendArraySliceImpl(dict, array, offset, length); + case Type::UINT64: + return AppendArraySliceImpl(dict, array, offset, length); + case Type::INT64: + return AppendArraySliceImpl(dict, array, offset, length); default: return Status::TypeError("Invalid index type: ", dict_ty); } @@ -545,6 +457,36 @@ class DictionaryBuilderBase : public ArrayBuilder { } protected: + template + Status AppendArraySliceImpl(const typename TypeTraits::ArrayType& dict, + const ArrayData& array, int64_t offset, int64_t length) { + const c_type* values = array.GetValues(1) + offset; + return VisitBitBlocks( + array.buffers[0], array.offset + offset, std::min(array.length, length), + [&](c_type position) { + if (dict.IsValid(values[position])) { + return Append(dict.GetView(values[position])); + } + return AppendNull(); + }, + [&]() { return AppendNull(); }); + } + + template + Status AppendScalarImpl(const typename TypeTraits::ArrayType& dict, + const Scalar& index_scalar, int64_t n_repeats) { + using ScalarType = typename TypeTraits::ScalarType; + const auto index = internal::checked_cast(index_scalar).value; + if (index_scalar.is_valid && dict.IsValid(index)) { + const auto& value = dict.GetView(index); + for (int64_t i = 0; i < n_repeats; i++) { + ARROW_RETURN_NOT_OK(Append(value)); + } + return Status::OK(); + } + return AppendNulls(n_repeats); + } + Status FinishInternal(std::shared_ptr* out) override { std::shared_ptr dictionary; ARROW_RETURN_NOT_OK(FinishWithDictOffset(/*offset=*/0, out, &dictionary)); From a10888cc02659544bc24c4969b812b9ca35289e0 Mon Sep 17 00:00:00 2001 From: David Li Date: Tue, 7 Sep 2021 11:50:15 -0400 Subject: [PATCH 09/21] ARROW-13573: [C++] Test dictionaries with nulls --- .../arrow/compute/kernels/scalar_if_else_test.cc | 16 ++++++++-------- 1 file changed, 8 insertions(+), 8 deletions(-) diff --git a/cpp/src/arrow/compute/kernels/scalar_if_else_test.cc b/cpp/src/arrow/compute/kernels/scalar_if_else_test.cc index a489934a422..266a2298fe3 100644 --- a/cpp/src/arrow/compute/kernels/scalar_if_else_test.cc +++ b/cpp/src/arrow/compute/kernels/scalar_if_else_test.cc @@ -637,10 +637,10 @@ TYPED_TEST_SUITE(TestCaseWhenDict, IntegralArrowTypes); TYPED_TEST(TestCaseWhenDict, Simple) { auto cond1 = ArrayFromJSON(boolean(), "[true, true, null, null]"); auto cond2 = ArrayFromJSON(boolean(), "[true, false, true, null]"); - // TODO: test dictionaries with nulls for (const auto& dict : - {JsonDict{utf8(), R"(["a", "b", "bc", "def"])"}, JsonDict{int64(), "[1, 4, 2, 3]"}, - JsonDict{decimal256(3, 2), R"(["1.23", "2.34", "3.45", "6.78"])"}}) { + {JsonDict{utf8(), R"(["a", null, "bc", "def"])"}, + JsonDict{int64(), "[1, null, 2, 3]"}, + JsonDict{decimal256(3, 2), R"(["1.23", null, "3.45", "6.78"])"}}) { auto type = dictionary(default_type_instance(), dict.type); auto values_null = DictArrayFromJSON(type, "[null, null, null, null]", dict.value); auto values1 = DictArrayFromJSON(type, "[0, null, 3, 1]", dict.value); @@ -658,7 +658,7 @@ TYPED_TEST(TestCaseWhenDict, Mixed) { auto type = dictionary(default_type_instance(), utf8()); auto cond1 = ArrayFromJSON(boolean(), "[true, true, null, null]"); auto cond2 = ArrayFromJSON(boolean(), "[true, false, true, null]"); - auto dict = R"(["a", "", "bc", "def"])"; + auto dict = R"(["a", null, "bc", "def"])"; auto values_null = DictArrayFromJSON(type, "[null, null, null, null]", dict); auto values1_dict = DictArrayFromJSON(type, "[0, null, 3, 1]", dict); auto values1_decoded = ArrayFromJSON(utf8(), R"(["a", null, "def", null])"); @@ -687,7 +687,7 @@ TYPED_TEST(TestCaseWhenDict, NestedSimple) { auto type = list(inner_type); auto cond1 = ArrayFromJSON(boolean(), "[true, true, null, null]"); auto cond2 = ArrayFromJSON(boolean(), "[true, false, true, null]"); - auto dict = R"(["a", "b", "bc", "def"])"; + auto dict = R"(["a", null, "bc", "def"])"; auto values_null = make_list(ArrayFromJSON(int32(), "[null, null, null, null, 0]"), DictArrayFromJSON(inner_type, "[]", dict)); auto values1_backing = DictArrayFromJSON(inner_type, "[0, null, 3, 1]", dict); @@ -725,9 +725,9 @@ TYPED_TEST(TestCaseWhenDict, DifferentDictionaries) { auto type = dictionary(default_type_instance(), utf8()); auto cond1 = ArrayFromJSON(boolean(), "[true, true, null, null]"); auto cond2 = ArrayFromJSON(boolean(), "[true, false, true, null]"); - auto dict1 = R"(["a", "", "bc", "def"])"; - auto dict2 = R"(["bc", "foo", "", "a"])"; - auto dict3 = R"(["def", "a", "a", "bc"])"; + auto dict1 = R"(["a", null, "bc", "def"])"; + auto dict2 = R"(["bc", "foo", null, "a"])"; + auto dict3 = R"(["def", null, "a", "bc"])"; auto values1_null = DictArrayFromJSON(type, "[null, null, null, null]", dict1); auto values2_null = DictArrayFromJSON(type, "[null, null, null, null]", dict2); auto values1 = DictArrayFromJSON(type, "[0, null, 3, 1]", dict1); From 5cbe6d53bf6c513f12b6c5b22f0d97cbdccf57f1 Mon Sep 17 00:00:00 2001 From: David Li Date: Tue, 7 Sep 2021 12:02:26 -0400 Subject: [PATCH 10/21] ARROW-13573: [C++] Address feedback --- cpp/src/arrow/array/builder_dict.h | 8 +++----- cpp/src/arrow/array/validate.cc | 3 --- .../arrow/compute/kernels/scalar_if_else_test.cc | 14 +++++++++++--- cpp/src/arrow/ipc/json_simple_test.cc | 2 +- cpp/src/arrow/scalar.cc | 2 +- 5 files changed, 16 insertions(+), 13 deletions(-) diff --git a/cpp/src/arrow/array/builder_dict.h b/cpp/src/arrow/array/builder_dict.h index db248ff2dac..4d0987e9e1f 100644 --- a/cpp/src/arrow/array/builder_dict.h +++ b/cpp/src/arrow/array/builder_dict.h @@ -295,10 +295,6 @@ class DictionaryBuilderBase : public ArrayBuilder { } Status AppendScalar(const Scalar& scalar, int64_t n_repeats) override { - if (!scalar.type->Equals(type())) { - return Status::Invalid("Cannot append scalar of type ", scalar.type->ToString(), - " to builder for type ", type()->ToString()); - } if (!scalar.is_valid) return AppendNulls(n_repeats); const auto& dict_ty = internal::checked_cast(*scalar.type); @@ -306,6 +302,7 @@ class DictionaryBuilderBase : public ArrayBuilder { internal::checked_cast(scalar); const auto& dict = internal::checked_cast::ArrayType&>( *dict_scalar.value.dictionary); + RETURN_NOT_OK(Reserve(n_repeats)); switch (dict_ty.index_type()->id()) { case Type::UINT8: return AppendScalarImpl(dict, *dict_scalar.value.index, n_repeats); @@ -340,6 +337,7 @@ class DictionaryBuilderBase : public ArrayBuilder { // Visit the indices and insert the unpacked values. const auto& dict_ty = internal::checked_cast(*array.type); const typename TypeTraits::ArrayType dict(array.dictionary); + RETURN_NOT_OK(Reserve(length)); switch (dict_ty.index_type()->id()) { case Type::UINT8: return AppendArraySliceImpl(dict, array, offset, length); @@ -462,7 +460,7 @@ class DictionaryBuilderBase : public ArrayBuilder { const ArrayData& array, int64_t offset, int64_t length) { const c_type* values = array.GetValues(1) + offset; return VisitBitBlocks( - array.buffers[0], array.offset + offset, std::min(array.length, length), + array.buffers[0], array.offset + offset, length, [&](c_type position) { if (dict.IsValid(values[position])) { return Append(dict.GetView(values[position])); diff --git a/cpp/src/arrow/array/validate.cc b/cpp/src/arrow/array/validate.cc index 1715863014c..c66c4f53b9d 100644 --- a/cpp/src/arrow/array/validate.cc +++ b/cpp/src/arrow/array/validate.cc @@ -568,9 +568,6 @@ struct ValidateArrayFullImpl { } Status Visit(const DictionaryType& type) { - if (!data.dictionary) { - return Status::Invalid("Dictionary array has no dictionary"); - } const Status indices_status = CheckBounds(*type.index_type(), 0, data.dictionary->length - 1); if (!indices_status.ok()) { diff --git a/cpp/src/arrow/compute/kernels/scalar_if_else_test.cc b/cpp/src/arrow/compute/kernels/scalar_if_else_test.cc index 266a2298fe3..521876e2a11 100644 --- a/cpp/src/arrow/compute/kernels/scalar_if_else_test.cc +++ b/cpp/src/arrow/compute/kernels/scalar_if_else_test.cc @@ -724,14 +724,14 @@ TYPED_TEST(TestCaseWhenDict, NestedSimple) { TYPED_TEST(TestCaseWhenDict, DifferentDictionaries) { auto type = dictionary(default_type_instance(), utf8()); auto cond1 = ArrayFromJSON(boolean(), "[true, true, null, null]"); - auto cond2 = ArrayFromJSON(boolean(), "[true, false, true, null]"); + auto cond2 = ArrayFromJSON(boolean(), "[true, false, null, true]"); auto dict1 = R"(["a", null, "bc", "def"])"; auto dict2 = R"(["bc", "foo", null, "a"])"; auto dict3 = R"(["def", null, "a", "bc"])"; auto values1_null = DictArrayFromJSON(type, "[null, null, null, null]", dict1); auto values2_null = DictArrayFromJSON(type, "[null, null, null, null]", dict2); - auto values1 = DictArrayFromJSON(type, "[0, null, 3, 1]", dict1); - auto values2 = DictArrayFromJSON(type, "[2, 1, null, 0]", dict2); + auto values1 = DictArrayFromJSON(type, "[null, 0, 3, 1]", dict1); + auto values2 = DictArrayFromJSON(type, "[2, 1, 0, null]", dict2); auto values3 = DictArrayFromJSON(type, "[0, 1, 2, 3]", dict3); CheckDictionary("case_when", @@ -779,6 +779,14 @@ TYPED_TEST(TestCaseWhenDict, DifferentDictionaries) { DictScalarFromJSON(type, "0", dict1), DictScalarFromJSON(type, "0", dict2), }); + CheckDictionary( + "case_when", + { + MakeStruct({ArrayFromJSON(boolean(), "[true, true, false, false]"), + ArrayFromJSON(boolean(), "[false, false, true, true]")}), + DictScalarFromJSON(type, "null", dict1), + DictScalarFromJSON(type, "0", dict2), + }); } TEST(TestCaseWhen, Null) { diff --git a/cpp/src/arrow/ipc/json_simple_test.cc b/cpp/src/arrow/ipc/json_simple_test.cc index 372f6bf1d72..273f980c16c 100644 --- a/cpp/src/arrow/ipc/json_simple_test.cc +++ b/cpp/src/arrow/ipc/json_simple_test.cc @@ -1394,7 +1394,7 @@ TEST(TestDictScalarFromJSON, Basics) { auto scalar = DictScalarFromJSON(type, index, dict); auto expected_index = ScalarFromJSON(int32(), index); AssertScalarsEqual(*DictionaryScalar::Make(expected_index, expected_dictionary), - *scalar); + *scalar, /*verbose=*/true); } } diff --git a/cpp/src/arrow/scalar.cc b/cpp/src/arrow/scalar.cc index 60ba54f82cc..77129035cb1 100644 --- a/cpp/src/arrow/scalar.cc +++ b/cpp/src/arrow/scalar.cc @@ -600,7 +600,7 @@ std::shared_ptr DictionaryScalar::Make(std::shared_ptr std::shared_ptr dict) { auto type = dictionary(index->type, dict->type()); return std::make_shared(ValueType{std::move(index), std::move(dict)}, - std::move(type)); + std::move(type), index->is_valid); } namespace { From 345388fef1ddd44110b27a5c5ba102d714600623 Mon Sep 17 00:00:00 2001 From: David Li Date: Tue, 7 Sep 2021 12:10:29 -0400 Subject: [PATCH 11/21] ARROW-13573: [C++] Add a direct test of dispatch --- cpp/src/arrow/compute/kernels/scalar_if_else_test.cc | 12 ++++++++++++ 1 file changed, 12 insertions(+) diff --git a/cpp/src/arrow/compute/kernels/scalar_if_else_test.cc b/cpp/src/arrow/compute/kernels/scalar_if_else_test.cc index 521876e2a11..7ba0a223556 100644 --- a/cpp/src/arrow/compute/kernels/scalar_if_else_test.cc +++ b/cpp/src/arrow/compute/kernels/scalar_if_else_test.cc @@ -1654,6 +1654,18 @@ TEST(TestCaseWhen, DispatchBest) { CallFunction("case_when", {MakeStruct({ArrayFromJSON(boolean(), "[]")}), ArrayFromJSON(int64(), "[]"), ArrayFromJSON(utf8(), "[]")})); + + // Do not dictionary-decode when we have only dictionary values + CheckDispatchBest("case_when", + {struct_({field("", boolean())}), dictionary(int64(), utf8()), + dictionary(int64(), utf8())}, + {struct_({field("", boolean())}), dictionary(int64(), utf8()), + dictionary(int64(), utf8())}); + + // Dictionary-decode if we have a mix + CheckDispatchBest( + "case_when", {struct_({field("", boolean())}), dictionary(int64(), utf8()), utf8()}, + {struct_({field("", boolean())}), utf8(), utf8()}); } template From 8abb93f53ee0eac7c0f9089f8f922f02129414d3 Mon Sep 17 00:00:00 2001 From: David Li Date: Tue, 7 Sep 2021 13:01:37 -0400 Subject: [PATCH 12/21] ARROW-13573: [C++] Fix mistakes --- cpp/src/arrow/array/builder_dict.h | 11 ++++++----- 1 file changed, 6 insertions(+), 5 deletions(-) diff --git a/cpp/src/arrow/array/builder_dict.h b/cpp/src/arrow/array/builder_dict.h index 4d0987e9e1f..0637c9722a8 100644 --- a/cpp/src/arrow/array/builder_dict.h +++ b/cpp/src/arrow/array/builder_dict.h @@ -302,7 +302,7 @@ class DictionaryBuilderBase : public ArrayBuilder { internal::checked_cast(scalar); const auto& dict = internal::checked_cast::ArrayType&>( *dict_scalar.value.dictionary); - RETURN_NOT_OK(Reserve(n_repeats)); + ARROW_RETURN_NOT_OK(Reserve(n_repeats)); switch (dict_ty.index_type()->id()) { case Type::UINT8: return AppendScalarImpl(dict, *dict_scalar.value.index, n_repeats); @@ -337,7 +337,7 @@ class DictionaryBuilderBase : public ArrayBuilder { // Visit the indices and insert the unpacked values. const auto& dict_ty = internal::checked_cast(*array.type); const typename TypeTraits::ArrayType dict(array.dictionary); - RETURN_NOT_OK(Reserve(length)); + ARROW_RETURN_NOT_OK(Reserve(length)); switch (dict_ty.index_type()->id()) { case Type::UINT8: return AppendArraySliceImpl(dict, array, offset, length); @@ -461,9 +461,10 @@ class DictionaryBuilderBase : public ArrayBuilder { const c_type* values = array.GetValues(1) + offset; return VisitBitBlocks( array.buffers[0], array.offset + offset, length, - [&](c_type position) { - if (dict.IsValid(values[position])) { - return Append(dict.GetView(values[position])); + [&](const int64_t position) { + const int64_t index = static_cast(values[position]); + if (dict.IsValid(index)) { + return Append(dict.GetView(index)); } return AppendNull(); }, From 45563d156fba59282e62111ee04207d30267fb9f Mon Sep 17 00:00:00 2001 From: David Li Date: Tue, 7 Sep 2021 13:08:32 -0400 Subject: [PATCH 13/21] ARROW-13573: [C++] Fix undefined behavior --- cpp/src/arrow/scalar.cc | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/cpp/src/arrow/scalar.cc b/cpp/src/arrow/scalar.cc index 77129035cb1..adfc50182cb 100644 --- a/cpp/src/arrow/scalar.cc +++ b/cpp/src/arrow/scalar.cc @@ -599,8 +599,9 @@ Result> DictionaryScalar::GetEncodedValue() const { std::shared_ptr DictionaryScalar::Make(std::shared_ptr index, std::shared_ptr dict) { auto type = dictionary(index->type, dict->type()); + auto is_valid = index->is_valid; return std::make_shared(ValueType{std::move(index), std::move(dict)}, - std::move(type), index->is_valid); + std::move(type), is_valid); } namespace { From 5fc1a1f7585944979be3299b07c89fcf7525e9a2 Mon Sep 17 00:00:00 2001 From: David Li Date: Fri, 10 Sep 2021 17:50:59 -0400 Subject: [PATCH 14/21] ARROW-13573: [C++] See if turning off unity builds fixes R CI --- ci/scripts/PKGBUILD | 1 - 1 file changed, 1 deletion(-) diff --git a/ci/scripts/PKGBUILD b/ci/scripts/PKGBUILD index 56d70d83daf..173ff0797d5 100644 --- a/ci/scripts/PKGBUILD +++ b/ci/scripts/PKGBUILD @@ -115,7 +115,6 @@ build() { -DARROW_CXXFLAGS="${CPPFLAGS}" \ -DCMAKE_BUILD_TYPE="release" \ -DCMAKE_INSTALL_PREFIX=${MINGW_PREFIX} \ - -DCMAKE_UNITY_BUILD=ON \ -DCMAKE_VERBOSE_MAKEFILE=ON make -j3 From f2a0a9e8df1b6941b80415ca2b39aba017008351 Mon Sep 17 00:00:00 2001 From: David Li Date: Mon, 13 Sep 2021 09:02:17 -0400 Subject: [PATCH 15/21] ARROW-13573: [C++] Try bumping timeout --- .github/workflows/cpp.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/cpp.yml b/.github/workflows/cpp.yml index 086f45d6fee..0f19f7351c3 100644 --- a/.github/workflows/cpp.yml +++ b/.github/workflows/cpp.yml @@ -238,7 +238,7 @@ jobs: name: AMD64 Windows MinGW ${{ matrix.mingw-n-bits }} C++ runs-on: windows-latest if: ${{ !contains(github.event.pull_request.title, 'WIP') }} - timeout-minutes: 45 + timeout-minutes: 60 strategy: fail-fast: false matrix: From a5b6078efe8d820c0faa84239ba5d7c72a90bf9b Mon Sep 17 00:00:00 2001 From: David Li Date: Mon, 13 Sep 2021 11:49:38 -0400 Subject: [PATCH 16/21] ARROW-13573: [C++] Should fix MinGW32 --- cpp/src/arrow/compute/kernels/scalar_if_else.cc | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/cpp/src/arrow/compute/kernels/scalar_if_else.cc b/cpp/src/arrow/compute/kernels/scalar_if_else.cc index ab7071af20c..35bb6248f23 100644 --- a/cpp/src/arrow/compute/kernels/scalar_if_else.cc +++ b/cpp/src/arrow/compute/kernels/scalar_if_else.cc @@ -1726,10 +1726,8 @@ struct CaseWhenFunctor { } static Status ExecArray(KernelContext* ctx, const ExecBatch& batch, Datum* out) { - return ExecVarWidthArrayCaseWhen( - ctx, batch, out, - // ReserveData - [&](ArrayBuilder* raw_builder) { return Status::OK(); }); + std::function reserve_data = ReserveNoData; + return ExecVarWidthArrayCaseWhen(ctx, batch, out, std::move(reserve_data)); } }; From 29a2f87cb894560ee1206dcd577953099eebf3af Mon Sep 17 00:00:00 2001 From: David Li Date: Tue, 14 Sep 2021 14:48:13 -0400 Subject: [PATCH 17/21] ARROW-13573: [C++] Make CMAKE_UNITY_BUILD depend on the rtools version --- ci/scripts/PKGBUILD | 3 +++ 1 file changed, 3 insertions(+) diff --git a/ci/scripts/PKGBUILD b/ci/scripts/PKGBUILD index 173ff0797d5..de3f7a6c1e8 100644 --- a/ci/scripts/PKGBUILD +++ b/ci/scripts/PKGBUILD @@ -80,9 +80,11 @@ build() { export LIBS="-L${MINGW_PREFIX}/libs" export ARROW_S3=OFF export ARROW_WITH_RE2=OFF + export CMAKE_UNITY_BUILD=ON else export ARROW_S3=ON export ARROW_WITH_RE2=ON + export CMAKE_UNITY_BUILD=OFF fi MSYS2_ARG_CONV_EXCL="-DCMAKE_INSTALL_PREFIX=" \ @@ -115,6 +117,7 @@ build() { -DARROW_CXXFLAGS="${CPPFLAGS}" \ -DCMAKE_BUILD_TYPE="release" \ -DCMAKE_INSTALL_PREFIX=${MINGW_PREFIX} \ + -DCMAKE_UNITY_BUILD=${CMAKE_UNITY_BUILD} \ -DCMAKE_VERBOSE_MAKEFILE=ON make -j3 From d81773ddf2b6de8e8726b2d9927468143e0dfa7c Mon Sep 17 00:00:00 2001 From: David Li Date: Tue, 14 Sep 2021 15:50:05 -0400 Subject: [PATCH 18/21] ARROW-13573: [C++] RTools40 build is very slow without unity build --- .github/workflows/r.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/r.yml b/.github/workflows/r.yml index e160ba8128a..3886eafee94 100644 --- a/.github/workflows/r.yml +++ b/.github/workflows/r.yml @@ -53,7 +53,7 @@ jobs: name: AMD64 Ubuntu ${{ matrix.ubuntu }} R ${{ matrix.r }} runs-on: ubuntu-latest if: ${{ !contains(github.event.pull_request.title, 'WIP') }} - timeout-minutes: 60 + timeout-minutes: 75 strategy: fail-fast: false matrix: From 15a64a211f4ee9dfabac7af8a9e3712d2326cf9c Mon Sep 17 00:00:00 2001 From: David Li Date: Tue, 14 Sep 2021 19:13:48 -0400 Subject: [PATCH 19/21] ARROW-13573: [C++] Add clarifying comments --- ci/scripts/PKGBUILD | 2 ++ 1 file changed, 2 insertions(+) diff --git a/ci/scripts/PKGBUILD b/ci/scripts/PKGBUILD index de3f7a6c1e8..a15d1712bc9 100644 --- a/ci/scripts/PKGBUILD +++ b/ci/scripts/PKGBUILD @@ -80,10 +80,12 @@ build() { export LIBS="-L${MINGW_PREFIX}/libs" export ARROW_S3=OFF export ARROW_WITH_RE2=OFF + # Without this, some dataset functionality segfaults export CMAKE_UNITY_BUILD=ON else export ARROW_S3=ON export ARROW_WITH_RE2=ON + # Without this, some compute functionality segfaults export CMAKE_UNITY_BUILD=OFF fi From 26230a4f4a1f3e27b459eb0d9aba3a385241410a Mon Sep 17 00:00:00 2001 From: David Li Date: Mon, 20 Sep 2021 11:15:58 -0400 Subject: [PATCH 20/21] ARROW-13573: [C++] Address feedback --- ci/scripts/PKGBUILD | 2 +- cpp/src/arrow/ipc/json_simple_test.cc | 1 + 2 files changed, 2 insertions(+), 1 deletion(-) diff --git a/ci/scripts/PKGBUILD b/ci/scripts/PKGBUILD index a15d1712bc9..246b679129a 100644 --- a/ci/scripts/PKGBUILD +++ b/ci/scripts/PKGBUILD @@ -85,7 +85,7 @@ build() { else export ARROW_S3=ON export ARROW_WITH_RE2=ON - # Without this, some compute functionality segfaults + # Without this, some compute functionality segfaults in tests export CMAKE_UNITY_BUILD=OFF fi diff --git a/cpp/src/arrow/ipc/json_simple_test.cc b/cpp/src/arrow/ipc/json_simple_test.cc index 273f980c16c..34c300faa95 100644 --- a/cpp/src/arrow/ipc/json_simple_test.cc +++ b/cpp/src/arrow/ipc/json_simple_test.cc @@ -1395,6 +1395,7 @@ TEST(TestDictScalarFromJSON, Basics) { auto expected_index = ScalarFromJSON(int32(), index); AssertScalarsEqual(*DictionaryScalar::Make(expected_index, expected_dictionary), *scalar, /*verbose=*/true); + ASSERT_OK(scalar->ValidateFull()); } } From ba39d83919909e5f5cee231e02a51b9f6f4decd5 Mon Sep 17 00:00:00 2001 From: David Li Date: Mon, 20 Sep 2021 14:05:19 -0400 Subject: [PATCH 21/21] ARROW-13573: [C++] Explicitly indicate when we expect dictionary-encoded results in tests --- .../compute/kernels/scalar_if_else_test.cc | 68 ++++++++++++------- cpp/src/arrow/compute/kernels/test_util.cc | 29 +++++--- cpp/src/arrow/compute/kernels/test_util.h | 6 +- 3 files changed, 66 insertions(+), 37 deletions(-) diff --git a/cpp/src/arrow/compute/kernels/scalar_if_else_test.cc b/cpp/src/arrow/compute/kernels/scalar_if_else_test.cc index 7ba0a223556..8793cac7619 100644 --- a/cpp/src/arrow/compute/kernels/scalar_if_else_test.cc +++ b/cpp/src/arrow/compute/kernels/scalar_if_else_test.cc @@ -667,13 +667,19 @@ TYPED_TEST(TestCaseWhenDict, Mixed) { // If we have mixed dictionary/non-dictionary arguments, we decode dictionaries CheckDictionary("case_when", - {MakeStruct({cond1, cond2}), values1_dict, values2_decoded}); + {MakeStruct({cond1, cond2}), values1_dict, values2_decoded}, + /*result_is_encoded=*/false); CheckDictionary("case_when", - {MakeStruct({cond1, cond2}), values1_decoded, values2_dict}); - CheckDictionary("case_when", {MakeStruct({cond1, cond2}), values1_dict, values2_dict, - values1_decoded}); - CheckDictionary("case_when", {MakeStruct({cond1, cond2}), values_null, values2_dict, - values1_decoded}); + {MakeStruct({cond1, cond2}), values1_decoded, values2_dict}, + /*result_is_encoded=*/false); + CheckDictionary( + "case_when", + {MakeStruct({cond1, cond2}), values1_dict, values2_dict, values1_decoded}, + /*result_is_encoded=*/false); + CheckDictionary( + "case_when", + {MakeStruct({cond1, cond2}), values_null, values2_dict, values1_decoded}, + /*result_is_encoded=*/false); } TYPED_TEST(TestCaseWhenDict, NestedSimple) { @@ -695,30 +701,40 @@ TYPED_TEST(TestCaseWhenDict, NestedSimple) { auto values1 = make_list(ArrayFromJSON(int32(), "[0, 2, 2, 3, 4]"), values1_backing); auto values2 = make_list(ArrayFromJSON(int32(), "[0, 1, 2, 2, 4]"), values2_backing); - CheckDictionary("case_when", {MakeStruct({cond1, cond2}), values1, values2}); - CheckDictionary("case_when", {MakeStruct({cond1, cond2}), values1, - make_list(ArrayFromJSON(int32(), "[0, 1, null, 2, 4]"), - values2_backing)}); - CheckDictionary("case_when", {MakeStruct({cond1, cond2}), values1, - make_list(ArrayFromJSON(int32(), "[0, 1, null, 2, 4]"), - values2_backing), - values1}); - - CheckDictionary("case_when", { - Datum(MakeStruct({cond1, cond2})), - Datum(std::make_shared( - DictArrayFromJSON(inner_type, "[0, 1]", dict))), - Datum(std::make_shared( - DictArrayFromJSON(inner_type, "[2, 3]", dict))), - }); + CheckDictionary("case_when", {MakeStruct({cond1, cond2}), values1, values2}, + /*result_is_encoded=*/false); + CheckDictionary( + "case_when", + {MakeStruct({cond1, cond2}), values1, + make_list(ArrayFromJSON(int32(), "[0, 1, null, 2, 4]"), values2_backing)}, + /*result_is_encoded=*/false); + CheckDictionary( + "case_when", + {MakeStruct({cond1, cond2}), values1, + make_list(ArrayFromJSON(int32(), "[0, 1, null, 2, 4]"), values2_backing), values1}, + /*result_is_encoded=*/false); CheckDictionary("case_when", - {MakeStruct({Datum(true), Datum(false)}), values1, values2}); + { + Datum(MakeStruct({cond1, cond2})), + Datum(std::make_shared( + DictArrayFromJSON(inner_type, "[0, 1]", dict))), + Datum(std::make_shared( + DictArrayFromJSON(inner_type, "[2, 3]", dict))), + }, + /*result_is_encoded=*/false); + CheckDictionary("case_when", - {MakeStruct({Datum(false), Datum(true)}), values1, values2}); - CheckDictionary("case_when", {MakeStruct({Datum(false)}), values1, values2}); + {MakeStruct({Datum(true), Datum(false)}), values1, values2}, + /*result_is_encoded=*/false); CheckDictionary("case_when", - {MakeStruct({Datum(false), Datum(false)}), values1, values2}); + {MakeStruct({Datum(false), Datum(true)}), values1, values2}, + /*result_is_encoded=*/false); + CheckDictionary("case_when", {MakeStruct({Datum(false)}), values1, values2}, + /*result_is_encoded=*/false); + CheckDictionary("case_when", + {MakeStruct({Datum(false), Datum(false)}), values1, values2}, + /*result_is_encoded=*/false); } TYPED_TEST(TestCaseWhenDict, DifferentDictionaries) { diff --git a/cpp/src/arrow/compute/kernels/test_util.cc b/cpp/src/arrow/compute/kernels/test_util.cc index fc210ae9346..cedc03698a1 100644 --- a/cpp/src/arrow/compute/kernels/test_util.cc +++ b/cpp/src/arrow/compute/kernels/test_util.cc @@ -171,7 +171,8 @@ void CheckScalar(std::string func_name, const DatumVector& inputs, Datum expecte } } -Datum CheckDictionaryNonRecursive(const std::string& func_name, const DatumVector& args) { +Datum CheckDictionaryNonRecursive(const std::string& func_name, const DatumVector& args, + bool result_is_encoded) { EXPECT_OK_AND_ASSIGN(Datum actual, CallFunction(func_name, args)); ValidateOutput(actual); @@ -188,7 +189,10 @@ Datum CheckDictionaryNonRecursive(const std::string& func_name, const DatumVecto } EXPECT_OK_AND_ASSIGN(Datum expected, CallFunction(func_name, decoded_args)); - if (actual.type()->id() == Type::DICTIONARY) { + if (result_is_encoded) { + EXPECT_EQ(Type::DICTIONARY, actual.type()->id()) + << "Result should have been dictionary-encoded"; + // Decode before comparison - we care about equivalent not identical results const auto& to_type = checked_cast(*actual.type()).value_type(); EXPECT_OK_AND_ASSIGN(auto decoded, Cast(actual, to_type)); @@ -199,8 +203,9 @@ Datum CheckDictionaryNonRecursive(const std::string& func_name, const DatumVecto return actual; } -void CheckDictionary(const std::string& func_name, const DatumVector& args) { - auto actual = CheckDictionaryNonRecursive(func_name, args); +void CheckDictionary(const std::string& func_name, const DatumVector& args, + bool result_is_encoded) { + auto actual = CheckDictionaryNonRecursive(func_name, args, result_is_encoded); if (actual.is_scalar()) return; ASSERT_TRUE(actual.is_array()); @@ -208,19 +213,23 @@ void CheckDictionary(const std::string& func_name, const DatumVector& args) { // Check all scalars for (int64_t i = 0; i < actual.length(); i++) { - CheckDictionaryNonRecursive(func_name, GetDatums(GetScalars(args, i))); + CheckDictionaryNonRecursive(func_name, GetDatums(GetScalars(args, i)), + result_is_encoded); } // Check slices of the input const auto slice_length = actual.length() / 3; if (slice_length > 0) { - CheckDictionaryNonRecursive(func_name, SliceArrays(args, 0, slice_length)); - CheckDictionaryNonRecursive(func_name, SliceArrays(args, slice_length, slice_length)); - CheckDictionaryNonRecursive(func_name, SliceArrays(args, 2 * slice_length)); + CheckDictionaryNonRecursive(func_name, SliceArrays(args, 0, slice_length), + result_is_encoded); + CheckDictionaryNonRecursive(func_name, SliceArrays(args, slice_length, slice_length), + result_is_encoded); + CheckDictionaryNonRecursive(func_name, SliceArrays(args, 2 * slice_length), + result_is_encoded); } // Check empty slice - CheckDictionaryNonRecursive(func_name, SliceArrays(args, 0, 0)); + CheckDictionaryNonRecursive(func_name, SliceArrays(args, 0, 0), result_is_encoded); // Check chunked arrays if (slice_length > 0) { @@ -235,7 +244,7 @@ void CheckDictionary(const std::string& func_name, const DatumVector& args) { chunked_args.push_back(arg); } } - CheckDictionaryNonRecursive(func_name, chunked_args); + CheckDictionaryNonRecursive(func_name, chunked_args, result_is_encoded); } } diff --git a/cpp/src/arrow/compute/kernels/test_util.h b/cpp/src/arrow/compute/kernels/test_util.h index 89e68950175..25ea577a423 100644 --- a/cpp/src/arrow/compute/kernels/test_util.h +++ b/cpp/src/arrow/compute/kernels/test_util.h @@ -78,7 +78,11 @@ void CheckScalar(std::string func_name, const DatumVector& inputs, Datum expecte // Like CheckScalar, but gets the expected result by // dictionary-decoding arguments and calling the function again. -void CheckDictionary(const std::string& func_name, const DatumVector& args); +// +// result_is_encoded controls whether the result is expected to be a +// dictionary or not. +void CheckDictionary(const std::string& func_name, const DatumVector& args, + bool result_is_encoded = true); // Just call the function with the given arguments. void CheckScalarNonRecursive(const std::string& func_name, const DatumVector& inputs,