From ae86c091f2783fda4dfe0f201ad7f92c554e6e44 Mon Sep 17 00:00:00 2001 From: David Li Date: Wed, 29 Sep 2021 13:12:29 -0400 Subject: [PATCH 1/8] ARROW-14167: [C++] Directly support dictionaries in coalesce --- .../arrow/compute/kernels/scalar_if_else.cc | 48 +++++- .../compute/kernels/scalar_if_else_test.cc | 154 +++++++++++++++++- 2 files changed, 192 insertions(+), 10 deletions(-) diff --git a/cpp/src/arrow/compute/kernels/scalar_if_else.cc b/cpp/src/arrow/compute/kernels/scalar_if_else.cc index 6195d1381a0..dda42171de2 100644 --- a/cpp/src/arrow/compute/kernels/scalar_if_else.cc +++ b/cpp/src/arrow/compute/kernels/scalar_if_else.cc @@ -1439,13 +1439,30 @@ struct CaseWhenFunction : ScalarFunction { } } - if (auto kernel = DispatchExactImpl(this, *values)) return kernel; + // TODO(ARROW-14105): also apply casts to dictionary indices/values + if (is_dictionary((*values)[1].type->id()) && + std::all_of(values->begin() + 2, values->end(), [&](const ValueDescr& descr) { + return descr.type->Equals(*(*values)[1].type); + })) { + auto kernel = DispatchExactImpl(this, *values); + DCHECK(kernel); + return kernel; + } EnsureDictionaryDecoded(values); - if (auto type = CommonNumeric(values->data() + 1, values->size() - 1)) { - for (auto it = values->begin() + 1; it != values->end(); it++) { - it->type = type; - } + ValueDescr* first_arg = &(*values)[1]; + const size_t num_args = values->size() - 1; + if (auto type = CommonNumeric(first_arg, num_args)) { + ReplaceTypes(type, first_arg, num_args); + } + if (auto type = CommonBinary(first_arg, num_args)) { + ReplaceTypes(type, first_arg, num_args); + } + if (auto type = CommonTemporal(first_arg, num_args)) { + ReplaceTypes(type, first_arg, num_args); + } + if (HasDecimal(*values)) { + RETURN_NOT_OK(CastDecimalArgs(first_arg, num_args)); } if (auto kernel = DispatchExactImpl(this, *values)) return kernel; return arrow::compute::detail::NoMatchingKernel(this, *values); @@ -1934,9 +1951,20 @@ struct CoalesceFunction : ScalarFunction { Result DispatchBest(std::vector* values) const override { RETURN_NOT_OK(CheckArity(*values)); using arrow::compute::detail::DispatchExactImpl; + + // TODO(ARROW-14105): also apply casts to dictionary indices/values + if (is_dictionary((*values)[0].type->id()) && + std::all_of(values->begin() + 1, values->end(), [&](const ValueDescr& descr) { + return descr.type->Equals(*(*values)[0].type); + })) { + auto kernel = DispatchExactImpl(this, *values); + DCHECK(kernel); + return kernel; + } + // Do not DispatchExact here since we want to rescale decimals if necessary EnsureDictionaryDecoded(values); - if (auto type = CommonNumeric(*values)) { + if (auto type = CommonNumeric(values->data(), values->size())) { ReplaceTypes(type, values); } if (auto type = CommonBinary(values->data(), values->size())) { @@ -2244,7 +2272,7 @@ static Status ExecVarWidthCoalesceImpl(KernelContext* ctx, const ExecBatch& batc } ArrayData* output = out->mutable_array(); std::unique_ptr raw_builder; - RETURN_NOT_OK(MakeBuilder(ctx->memory_pool(), out->type(), &raw_builder)); + RETURN_NOT_OK(MakeBuilderExactIndex(ctx->memory_pool(), out->type(), &raw_builder)); RETURN_NOT_OK(raw_builder->Reserve(batch.length)); RETURN_NOT_OK(reserve_data(raw_builder.get())); @@ -2388,7 +2416,8 @@ struct CoalesceFunctor> { template struct CoalesceFunctor< - Type, enable_if_t::value && !is_union_type::value>> { + Type, enable_if_t<(is_nested_type::value || is_dictionary_type::value) && + !is_union_type::value>> { static Status Exec(KernelContext* ctx, const ExecBatch& batch, Datum* out) { RETURN_NOT_OK(CheckIdenticalTypes(&batch.values[0], batch.values.size())); for (const auto& datum : batch.values) { @@ -2422,7 +2451,7 @@ struct CoalesceFunctor> { static Status ExecArray(KernelContext* ctx, const ExecBatch& batch, Datum* out) { ArrayData* output = out->mutable_array(); std::unique_ptr raw_builder; - RETURN_NOT_OK(MakeBuilder(ctx->memory_pool(), out->type(), &raw_builder)); + RETURN_NOT_OK(MakeBuilderExactIndex(ctx->memory_pool(), out->type(), &raw_builder)); RETURN_NOT_OK(raw_builder->Reserve(batch.length)); const UnionType& type = checked_cast(*out->type()); @@ -2858,6 +2887,7 @@ void RegisterScalarIfElse(FunctionRegistry* registry) { AddCoalesceKernel(func, Type::STRUCT, CoalesceFunctor::Exec); AddCoalesceKernel(func, Type::DENSE_UNION, CoalesceFunctor::Exec); AddCoalesceKernel(func, Type::SPARSE_UNION, CoalesceFunctor::Exec); + AddCoalesceKernel(func, Type::DICTIONARY, CoalesceFunctor::Exec); DCHECK_OK(registry->AddFunction(std::move(func))); } { diff --git a/cpp/src/arrow/compute/kernels/scalar_if_else_test.cc b/cpp/src/arrow/compute/kernels/scalar_if_else_test.cc index 7bcbb814ada..8493ccd2e62 100644 --- a/cpp/src/arrow/compute/kernels/scalar_if_else_test.cc +++ b/cpp/src/arrow/compute/kernels/scalar_if_else_test.cc @@ -1094,7 +1094,8 @@ TYPED_TEST(TestCaseWhenDict, Simple) { } TYPED_TEST(TestCaseWhenDict, Mixed) { - auto type = dictionary(default_type_instance(), utf8()); + auto index_type = default_type_instance(); + auto type = dictionary(index_type, utf8()); auto cond1 = ArrayFromJSON(boolean(), "[true, true, null, null]"); auto cond2 = ArrayFromJSON(boolean(), "[true, false, true, null]"); auto dict = R"(["a", null, "bc", "def"])"; @@ -1119,6 +1120,17 @@ TYPED_TEST(TestCaseWhenDict, Mixed) { "case_when", {MakeStruct({cond1, cond2}), values_null, values2_dict, values1_decoded}, /*result_is_encoded=*/false); + + // If we have mismatched dictionary types, we decode (for now) + auto values3_dict = + DictArrayFromJSON(dictionary(index_type, binary()), "[2, 1, null, 0]", dict); + auto values4_dict = DictArrayFromJSON( + dictionary(index_type->id() == Type::UINT8 ? int8() : uint8(), utf8()), + "[2, 1, null, 0]", dict); + CheckDictionary("case_when", {MakeStruct({cond1, cond2}), values1_dict, values3_dict}, + /*result_is_encoded=*/false); + CheckDictionary("case_when", {MakeStruct({cond1, cond2}), values1_dict, values4_dict}, + /*result_is_encoded=*/false); } TYPED_TEST(TestCaseWhenDict, NestedSimple) { @@ -2088,6 +2100,17 @@ TEST(TestCaseWhen, UnionBoolString) { TEST(TestCaseWhen, DispatchBest) { CheckDispatchBest("case_when", {struct_({field("", boolean())}), int64(), int32()}, {struct_({field("", boolean())}), int64(), int64()}); + CheckDispatchBest("case_when", + {struct_({field("", boolean())}), binary(), large_utf8()}, + {struct_({field("", boolean())}), large_binary(), large_binary()}); + CheckDispatchBest( + "case_when", + {struct_({field("", boolean())}), timestamp(TimeUnit::SECOND), date32()}, + {struct_({field("", boolean())}), timestamp(TimeUnit::SECOND), + timestamp(TimeUnit::SECOND)}); + CheckDispatchBest( + "case_when", {struct_({field("", boolean())}), decimal128(38, 0), decimal128(1, 1)}, + {struct_({field("", boolean())}), decimal256(39, 1), decimal256(39, 1)}); ASSERT_RAISES(Invalid, CallFunction("case_when", {})); // Too many/too few conditions @@ -2360,6 +2383,132 @@ TYPED_TEST(TestCoalesceList, Errors) { })); } +template +class TestCoalesceDict : public ::testing::Test {}; + +TYPED_TEST_SUITE(TestCoalesceDict, IntegralArrowTypes); + +TYPED_TEST(TestCoalesceDict, Simple) { + for (const auto& dict : + {JsonDict{utf8(), R"(["a", null, "bc", "def"])"}, + JsonDict{int64(), "[1, null, 2, 3]"}, + JsonDict{decimal256(3, 2), R"(["1.23", null, "3.45", "6.78"])"}}) { + auto type = dictionary(default_type_instance(), dict.type); + auto values_null = DictArrayFromJSON(type, "[null, null, null, null]", dict.value); + auto values1 = DictArrayFromJSON(type, "[0, null, 3, null]", dict.value); + auto values2 = DictArrayFromJSON(type, "[2, 1, null, null]", dict.value); + auto scalar = DictScalarFromJSON(type, "2", dict.value); + + // Easy case: all arguments have the same dictionary + CheckDictionary("coalesce", {values1, values2}); + CheckDictionary("coalesce", {values1, values2, values1}); + CheckDictionary("coalesce", {values_null, values1}); + CheckDictionary("coalesce", {values1, values_null}); + CheckDictionary("coalesce", {values1, scalar}); + CheckDictionary("coalesce", {values_null, scalar}); + CheckDictionary("coalesce", {scalar, values1}); + } +} + +TYPED_TEST(TestCoalesceDict, Mixed) { + auto index_type = default_type_instance(); + auto type = dictionary(index_type, utf8()); + auto dict = R"(["a", null, "bc", "def"])"; + auto values_null = DictArrayFromJSON(type, "[null, null, null, null]", dict); + auto values1_dict = DictArrayFromJSON(type, "[0, null, 3, 1]", dict); + auto values1_decoded = ArrayFromJSON(utf8(), R"(["a", null, "def", null])"); + auto values2_dict = DictArrayFromJSON(type, "[2, 1, null, 0]", dict); + auto values2_decoded = ArrayFromJSON(utf8(), R"(["bc", null, null, "a"])"); + auto scalar = ScalarFromJSON(utf8(), R"("bc")"); + + // If we have mixed dictionary/non-dictionary arguments, we decode dictionaries + CheckDictionary("coalesce", {values1_dict, values2_decoded}, + /*result_is_encoded=*/false); + CheckDictionary("coalesce", {values1_decoded, values2_dict}, + /*result_is_encoded=*/false); + CheckDictionary("coalesce", {values1_dict, values2_dict, values1_decoded}, + /*result_is_encoded=*/false); + CheckDictionary("coalesce", {values_null, values2_dict, values1_decoded}, + /*result_is_encoded=*/false); + CheckDictionary("coalesce", {values_null, scalar}, /*result_is_encoded=*/false); + CheckDictionary("coalesce", {scalar, values_null}, /*result_is_encoded=*/false); + CheckDictionary("coalesce", {values1_dict, scalar}, /*result_is_encoded=*/false); + CheckDictionary("coalesce", {scalar, values2_dict}, /*result_is_encoded=*/false); + + // If we have mismatched dictionary types, we decode (for now) + auto values3_dict = + DictArrayFromJSON(dictionary(index_type, binary()), "[2, 1, null, 0]", dict); + auto values4_dict = DictArrayFromJSON( + dictionary(index_type->id() == Type::UINT8 ? int8() : uint8(), utf8()), + "[2, 1, null, 0]", dict); + CheckDictionary("coalesce", {values1_dict, values3_dict}, /*result_is_encoded=*/false); + CheckDictionary("coalesce", {values1_dict, values4_dict}, /*result_is_encoded=*/false); +} + +TYPED_TEST(TestCoalesceDict, NestedSimple) { + auto index_type = default_type_instance(); + auto inner_type = dictionary(index_type, utf8()); + auto type = list(inner_type); + auto dict = R"(["a", null, "bc", "def"])"; + auto values_null = MakeListOfDict(ArrayFromJSON(int32(), "[null, null, null, null, 0]"), + DictArrayFromJSON(inner_type, "[]", dict)); + auto values1_backing = DictArrayFromJSON(inner_type, "[0, null, 3, 1]", dict); + auto values2_backing = DictArrayFromJSON(inner_type, "[2, 1, null, 0]", dict); + auto values1 = + MakeListOfDict(ArrayFromJSON(int32(), "[0, 2, 2, 3, 4]"), values1_backing); + auto values2 = + MakeListOfDict(ArrayFromJSON(int32(), "[0, 1, null, 2, 4]"), values2_backing); + auto scalar = + Datum(std::make_shared(DictArrayFromJSON(inner_type, "[0, 1]", dict))); + + CheckDictionary("coalesce", {values1, values2}, /*result_is_encoded=*/false); + CheckDictionary("coalesce", {values1, scalar}, /*result_is_encoded=*/false); + CheckDictionary("coalesce", {scalar, values2}, /*result_is_encoded=*/false); + CheckDictionary("coalesce", {values_null, values2}, /*result_is_encoded=*/false); + CheckDictionary("coalesce", {values1, values_null}, /*result_is_encoded=*/false); +} + +TYPED_TEST(TestCoalesceDict, DifferentDictionaries) { + auto type = dictionary(default_type_instance(), utf8()); + auto dict1 = R"(["a", "", "bc", "def"])"; + auto dict2 = R"(["bc", "foo", "", "a"])"; + auto values1_null = DictArrayFromJSON(type, "[null, null, null, null]", dict1); + auto values2_null = DictArrayFromJSON(type, "[null, null, null, null]", dict2); + auto values1 = DictArrayFromJSON(type, "[null, 0, 3, 1]", dict1); + auto values2 = DictArrayFromJSON(type, "[2, 1, 0, null]", dict2); + auto scalar1 = DictScalarFromJSON(type, "0", dict1); + auto scalar2 = DictScalarFromJSON(type, "0", dict2); + + CheckDictionary("coalesce", {values1, values2}); + CheckDictionary("coalesce", {values1, scalar2}); + CheckDictionary("coalesce", {scalar1, values2}); + CheckDictionary("coalesce", {values1, scalar2}); + CheckDictionary("coalesce", {values1_null, values2}); + CheckDictionary("coalesce", {values1, values2_null}); + + // Test dictionaries with nulls (where decoding before/after calling coalesce changes + // the results) + dict1 = R"(["a", null, "bc", "def"])"; + dict2 = R"(["bc", "foo", null, "a"])"; + values1 = DictArrayFromJSON(type, "[null, 0, 3, 1]", dict1); + values2 = DictArrayFromJSON(type, "[2, 1, 0, null]", dict2); + scalar1 = DictScalarFromJSON(type, "0", dict1); + + // Note this is sensitive to the implementation. Nulls are emitted here + // because a non-null index mapped to a null dictionary value and was emitted + // as a null (instead of encoding null in the dictionary) + CheckScalarNonRecursive( + "coalesce", {values1, values2}, + DictArrayFromJSON(type, "[null, 0, 1, null]", R"(["a", "def"])")); + CheckScalarNonRecursive("coalesce", {values1, scalar1}, + DictArrayFromJSON(type, "[0, 0, 1, null]", R"(["a", "def"])")); + // The dictionary gets preserved since a leading non-null scalar just gets + // broadcasted and returned without going through the rest of the kernel + // implementation + CheckScalarNonRecursive("coalesce", {scalar1, values1}, + DictArrayFromJSON(type, "[0, 0, 0, 0]", dict1)); +} + TEST(TestCoalesce, Null) { auto type = null(); auto scalar_null = ScalarFromJSON(type, "null"); @@ -2716,6 +2865,9 @@ TEST(TestCoalesce, DispatchBest) { sparse_union({field("a", boolean())}), dense_union({field("a", boolean())}), }); + CheckDispatchBest("coalesce", + {dictionary(int8(), binary()), dictionary(int16(), large_utf8())}, + {large_binary(), large_binary()}); } template From 781e7a2b1554c527909d2f19d12904a993f85513 Mon Sep 17 00:00:00 2001 From: Ian Cook Date: Tue, 9 Nov 2021 15:30:16 -0500 Subject: [PATCH 2/8] Add dictionary case to R coalesce() tests --- r/tests/testthat/test-dplyr-funcs-conditional.R | 16 ++++++++++++++++ 1 file changed, 16 insertions(+) diff --git a/r/tests/testthat/test-dplyr-funcs-conditional.R b/r/tests/testthat/test-dplyr-funcs-conditional.R index 4f270079580..c83dd21e79e 100644 --- a/r/tests/testthat/test-dplyr-funcs-conditional.R +++ b/r/tests/testthat/test-dplyr-funcs-conditional.R @@ -300,6 +300,22 @@ test_that("coalesce()", { collect(), df ) + + # factor + df_fct <- df %>% + mutate(across(everything(), ~ factor(.x))) + compare_dplyr_binding( + .input %>% + mutate( + cw = coalesce(w), + cz = coalesce(z), + cwx = coalesce(w, x), + cwxy = coalesce(w, x, y), + cwxyz = coalesce(w, x, y, z) + ) %>% + collect(), + df_fct + ) # integer df <- tibble( From 616f651bb28fdbf4f9a3d07851f8a22f9adbee33 Mon Sep 17 00:00:00 2001 From: Ian Cook Date: Tue, 9 Nov 2021 15:41:26 -0500 Subject: [PATCH 3/8] Remove trailing whitespace --- r/tests/testthat/test-dplyr-funcs-conditional.R | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/r/tests/testthat/test-dplyr-funcs-conditional.R b/r/tests/testthat/test-dplyr-funcs-conditional.R index c83dd21e79e..09e687d2986 100644 --- a/r/tests/testthat/test-dplyr-funcs-conditional.R +++ b/r/tests/testthat/test-dplyr-funcs-conditional.R @@ -300,7 +300,7 @@ test_that("coalesce()", { collect(), df ) - + # factor df_fct <- df %>% mutate(across(everything(), ~ factor(.x))) From f6d5db68c468eff0b334dc7bcc24566bb03c4a9f Mon Sep 17 00:00:00 2001 From: Ian Cook Date: Tue, 9 Nov 2021 15:56:28 -0500 Subject: [PATCH 4/8] Fix failing R test --- r/tests/testthat/test-dplyr-funcs-conditional.R | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/r/tests/testthat/test-dplyr-funcs-conditional.R b/r/tests/testthat/test-dplyr-funcs-conditional.R index 09e687d2986..234f82293f8 100644 --- a/r/tests/testthat/test-dplyr-funcs-conditional.R +++ b/r/tests/testthat/test-dplyr-funcs-conditional.R @@ -303,7 +303,7 @@ test_that("coalesce()", { # factor df_fct <- df %>% - mutate(across(everything(), ~ factor(.x))) + transmute(across(everything(), ~ factor(.x, levels = c("a", "b", "c")))) compare_dplyr_binding( .input %>% mutate( From d27e65e48898099e6af413197d1dfcaa062c13c4 Mon Sep 17 00:00:00 2001 From: Ian Cook Date: Tue, 9 Nov 2021 16:23:27 -0500 Subject: [PATCH 5/8] Try again to fix failing R test --- r/tests/testthat/test-dplyr-funcs-conditional.R | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/r/tests/testthat/test-dplyr-funcs-conditional.R b/r/tests/testthat/test-dplyr-funcs-conditional.R index 234f82293f8..5a1771cce67 100644 --- a/r/tests/testthat/test-dplyr-funcs-conditional.R +++ b/r/tests/testthat/test-dplyr-funcs-conditional.R @@ -313,7 +313,10 @@ test_that("coalesce()", { cwxy = coalesce(w, x, y), cwxyz = coalesce(w, x, y, z) ) %>% - collect(), + collect() %>% + # Arrow case_when() kernel does not preserve factor levels + # so reset the levels of all the factor columns + transmute(across(where(is.factor), ~ factor(.x, levels = c("a", "b", "c")))), df_fct ) From bea6e693928a2b121da681858c7c7f43b64105a3 Mon Sep 17 00:00:00 2001 From: Ian Cook Date: Tue, 9 Nov 2021 19:05:30 -0500 Subject: [PATCH 6/8] Remove R warning --- r/R/dplyr-functions.R | 6 ------ 1 file changed, 6 deletions(-) diff --git a/r/R/dplyr-functions.R b/r/R/dplyr-functions.R index 0ef30b22255..170ab4b8ffb 100644 --- a/r/R/dplyr-functions.R +++ b/r/R/dplyr-functions.R @@ -77,12 +77,6 @@ nse_funcs$coalesce <- function(...) { arg <- Expression$scalar(arg) } - # coalesce doesn't yet support factors/dictionaries - # TODO: remove this after ARROW-14167 is merged - if (nse_funcs$is.factor(arg)) { - warning("Dictionaries (in R: factors) are currently converted to strings (characters) in coalesce", call. = FALSE) - } - if (last_arg && arg$type_id() %in% TYPES_WITH_NAN) { # store the NA_real_ in the same type as arg to avoid avoid casting # smaller float types to larger float types From ad587ebe8540b0015af141b2d86dd03bcbf02a1a Mon Sep 17 00:00:00 2001 From: Ian Cook Date: Tue, 9 Nov 2021 19:10:26 -0500 Subject: [PATCH 7/8] Remove outdated test --- r/tests/testthat/test-dplyr-funcs-conditional.R | 17 ----------------- 1 file changed, 17 deletions(-) diff --git a/r/tests/testthat/test-dplyr-funcs-conditional.R b/r/tests/testthat/test-dplyr-funcs-conditional.R index 5a1771cce67..1f8ac693305 100644 --- a/r/tests/testthat/test-dplyr-funcs-conditional.R +++ b/r/tests/testthat/test-dplyr-funcs-conditional.R @@ -402,23 +402,6 @@ test_that("coalesce()", { df ) - # factors - # TODO: remove the mutate + warning after ARROW-14167 is merged and Arrow - # supports factors in coalesce - df <- tibble( - x = factor("a", levels = c("a", "z")), - y = factor("b", levels = c("a", "b", "c")) - ) - compare_dplyr_binding( - .input %>% - mutate(c = coalesce(x, y)) %>% - collect() %>% - # This is a no-op on the Arrow side, but necessary to make the results equal - mutate(c = as.character(c)), - df, - warning = "Dictionaries .* are currently converted to strings .* in coalesce" - ) - # no arguments expect_error( nse_funcs$coalesce(), From 67bbf7964e92e881f6b996120f1e58287ee264a3 Mon Sep 17 00:00:00 2001 From: Ian Cook Date: Tue, 9 Nov 2021 19:33:05 -0500 Subject: [PATCH 8/8] Fix comment and include Jira ID --- r/tests/testthat/test-dplyr-funcs-conditional.R | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/r/tests/testthat/test-dplyr-funcs-conditional.R b/r/tests/testthat/test-dplyr-funcs-conditional.R index 1f8ac693305..4e4346a7592 100644 --- a/r/tests/testthat/test-dplyr-funcs-conditional.R +++ b/r/tests/testthat/test-dplyr-funcs-conditional.R @@ -314,8 +314,9 @@ test_that("coalesce()", { cwxyz = coalesce(w, x, y, z) ) %>% collect() %>% - # Arrow case_when() kernel does not preserve factor levels - # so reset the levels of all the factor columns + # Arrow coalesce() kernel does not preserve unused factor levels, + # so reset the levels of all the factor columns to make the test pass + # (ARROW-14649) transmute(across(where(is.factor), ~ factor(.x, levels = c("a", "b", "c")))), df_fct )