diff --git a/cpp/src/arrow/compute/kernels/hash_aggregate.cc b/cpp/src/arrow/compute/kernels/hash_aggregate.cc index 0a567e385e7..c099ec660b8 100644 --- a/cpp/src/arrow/compute/kernels/hash_aggregate.cc +++ b/cpp/src/arrow/compute/kernels/hash_aggregate.cc @@ -38,6 +38,7 @@ #include "arrow/compute/kernels/aggregate_var_std_internal.h" #include "arrow/compute/kernels/common.h" #include "arrow/compute/kernels/util_internal.h" +#include "arrow/record_batch.h" #include "arrow/util/bit_run_reader.h" #include "arrow/util/bitmap_ops.h" #include "arrow/util/bitmap_writer.h" @@ -1970,6 +1971,7 @@ struct GroupedCountDistinctImpl : public GroupedAggregator { Status Init(ExecContext* ctx, const FunctionOptions* options) override { ctx_ = ctx; pool_ = ctx->memory_pool(); + options_ = checked_cast(*options); return Status::OK(); } @@ -2013,8 +2015,21 @@ struct GroupedCountDistinctImpl : public GroupedAggregator { ARROW_ASSIGN_OR_RAISE(auto uniques, grouper_->GetUniques()); auto* g = uniques[1].array()->GetValues(1); - for (int64_t i = 0; i < uniques.length; i++) { - counts[g[i]]++; + const auto& items = *uniques[0].array(); + const auto* valid = items.GetValues(0, 0); + if (options_.mode == CountOptions::ALL || + (options_.mode == CountOptions::ONLY_VALID && !valid)) { + for (int64_t i = 0; i < uniques.length; i++) { + counts[g[i]]++; + } + } else if (options_.mode == CountOptions::ONLY_VALID) { + for (int64_t i = 0; i < uniques.length; i++) { + counts[g[i]] += BitUtil::GetBit(valid, items.offset + i); + } + } else if (valid) { // ONLY_NULL + for (int64_t i = 0; i < uniques.length; i++) { + counts[g[i]] += !BitUtil::GetBit(valid, items.offset + i); + } } return ArrayData::Make(int64(), num_groups_, {nullptr, std::move(values)}, @@ -2026,6 +2041,7 @@ struct GroupedCountDistinctImpl : public GroupedAggregator { ExecContext* ctx_; MemoryPool* pool_; int64_t num_groups_; + CountOptions options_; std::unique_ptr grouper_; std::shared_ptr out_type_; }; @@ -2036,7 +2052,56 @@ struct GroupedDistinctImpl : public GroupedCountDistinctImpl { ARROW_ASSIGN_OR_RAISE(auto groupings, grouper_->MakeGroupings( *uniques[1].array_as(), static_cast(num_groups_), ctx_)); - return grouper_->ApplyGroupings(*groupings, *uniques[0].make_array(), ctx_); + ARROW_ASSIGN_OR_RAISE( + auto list, grouper_->ApplyGroupings(*groupings, *uniques[0].make_array(), ctx_)); + auto values = list->values(); + DCHECK_EQ(values->offset(), 0); + int32_t* offsets = reinterpret_cast(list->value_offsets()->mutable_data()); + if (options_.mode == CountOptions::ALL || + (options_.mode == CountOptions::ONLY_VALID && values->null_count() == 0)) { + return list; + } else if (options_.mode == CountOptions::ONLY_VALID) { + int32_t prev_offset = offsets[0]; + for (int64_t i = 0; i < list->length(); i++) { + const int32_t slot_length = offsets[i + 1] - prev_offset; + const int64_t null_count = + slot_length - arrow::internal::CountSetBits(values->null_bitmap()->data(), + prev_offset, slot_length); + DCHECK_LE(null_count, 1); + const int32_t offset = null_count > 0 ? slot_length - 1 : slot_length; + prev_offset = offsets[i + 1]; + offsets[i + 1] = offsets[i] + offset; + } + auto filter = + std::make_shared(values->length(), values->null_bitmap()); + ARROW_ASSIGN_OR_RAISE( + auto new_values, + Filter(std::move(values), filter, FilterOptions(FilterOptions::DROP), ctx_)); + return std::make_shared(list->type(), list->length(), + list->value_offsets(), new_values.make_array()); + } + // ONLY_NULL + if (values->null_count() == 0) { + std::fill(offsets + 1, offsets + list->length() + 1, offsets[0]); + } else { + int32_t prev_offset = offsets[0]; + for (int64_t i = 0; i < list->length(); i++) { + const int32_t slot_length = offsets[i + 1] - prev_offset; + const int64_t null_count = + slot_length - arrow::internal::CountSetBits(values->null_bitmap()->data(), + prev_offset, slot_length); + const int32_t offset = null_count > 0 ? 1 : 0; + prev_offset = offsets[i + 1]; + offsets[i + 1] = offsets[i] + offset; + } + } + ARROW_ASSIGN_OR_RAISE( + auto new_values, + MakeArrayOfNull(out_type_, + list->length() > 0 ? offsets[list->length()] - offsets[0] : 0, + pool_)); + return std::make_shared(list->type(), list->length(), + list->value_offsets(), std::move(new_values)); } std::shared_ptr out_type() const override { return list(out_type_); } @@ -2383,22 +2448,26 @@ const FunctionDoc hash_all_doc{"Test whether all elements evaluate to true", const FunctionDoc hash_count_distinct_doc{ "Count the distinct values in each group", - ("Nulls are counted. NaNs and signed zeroes are not normalized."), - {"array", "group_id_array"}}; + ("Whether nulls/values are counted is controlled by CountOptions.\n" + "NaNs and signed zeroes are not normalized."), + {"array", "group_id_array"}, + "CountOptions"}; const FunctionDoc hash_distinct_doc{ "Keep the distinct values in each group", - ("Nulls are kept. NaNs and signed zeroes are not normalized."), - {"array", "group_id_array"}}; + ("Whether nulls/values are kept is controlled by CountOptions.\n" + "NaNs and signed zeroes are not normalized."), + {"array", "group_id_array"}, + "CountOptions"}; } // namespace void RegisterHashAggregateBasic(FunctionRegistry* registry) { + static auto default_count_options = CountOptions::Defaults(); static auto default_scalar_aggregate_options = ScalarAggregateOptions::Defaults(); static auto default_tdigest_options = TDigestOptions::Defaults(); static auto default_variance_options = VarianceOptions::Defaults(); { - static auto default_count_options = CountOptions::Defaults(); auto func = std::make_shared( "hash_count", Arity::Binary(), &hash_count_doc, &default_count_options); @@ -2516,15 +2585,16 @@ void RegisterHashAggregateBasic(FunctionRegistry* registry) { { auto func = std::make_shared( - "hash_count_distinct", Arity::Binary(), &hash_count_distinct_doc); + "hash_count_distinct", Arity::Binary(), &hash_count_distinct_doc, + &default_count_options); DCHECK_OK(func->AddKernel( MakeKernel(ValueDescr::ARRAY, GroupedDistinctInit))); DCHECK_OK(registry->AddFunction(std::move(func))); } { - auto func = std::make_shared("hash_distinct", Arity::Binary(), - &hash_distinct_doc); + auto func = std::make_shared( + "hash_distinct", Arity::Binary(), &hash_distinct_doc, &default_count_options); DCHECK_OK(func->AddKernel( MakeKernel(ValueDescr::ARRAY, GroupedDistinctInit))); DCHECK_OK(registry->AddFunction(std::move(func))); diff --git a/cpp/src/arrow/compute/kernels/hash_aggregate_test.cc b/cpp/src/arrow/compute/kernels/hash_aggregate_test.cc index df2222a4eef..4d6064fa62d 100644 --- a/cpp/src/arrow/compute/kernels/hash_aggregate_test.cc +++ b/cpp/src/arrow/compute/kernels/hash_aggregate_test.cc @@ -1240,6 +1240,9 @@ TEST(GroupBy, AnyAndAll) { } TEST(GroupBy, CountDistinct) { + CountOptions all(CountOptions::ALL); + CountOptions only_valid(CountOptions::ONLY_VALID); + CountOptions only_null(CountOptions::ONLY_NULL); for (bool use_threads : {true, false}) { SCOPED_TRACE(use_threads ? "parallel/merged" : "serial"); @@ -1250,7 +1253,12 @@ TEST(GroupBy, CountDistinct) { ])", R"([ [0, 2], + [null, 3], [null, 3] +])", + R"([ + [null, 4], + [null, 4] ])", R"([ [4, null], @@ -1273,26 +1281,33 @@ TEST(GroupBy, CountDistinct) { internal::GroupBy( { table->GetColumnByName("argument"), + table->GetColumnByName("argument"), + table->GetColumnByName("argument"), }, { table->GetColumnByName("key"), }, { - {"hash_count_distinct", nullptr}, + {"hash_count_distinct", &all}, + {"hash_count_distinct", &only_valid}, + {"hash_count_distinct", &only_null}, }, use_threads)); SortBy({"key_0"}, &aggregated_and_grouped); ValidateOutput(aggregated_and_grouped); AssertDatumsEqual(ArrayFromJSON(struct_({ + field("hash_count_distinct", int64()), + field("hash_count_distinct", int64()), field("hash_count_distinct", int64()), field("key_0", int64()), }), R"([ - [1, 1], - [2, 2], - [3, 3], - [4, null] + [1, 1, 0, 1], + [2, 2, 0, 2], + [3, 2, 1, 3], + [1, 0, 1, 4], + [4, 4, 0, null] ])"), aggregated_and_grouped, /*verbose=*/true); @@ -1304,7 +1319,12 @@ TEST(GroupBy, CountDistinct) { ])", R"([ ["bar", 2], + [null, 3], [null, 3] +])", + R"([ + [null, 4], + [null, 4] ])", R"([ ["baz", null], @@ -1327,26 +1347,76 @@ TEST(GroupBy, CountDistinct) { internal::GroupBy( { table->GetColumnByName("argument"), + table->GetColumnByName("argument"), + table->GetColumnByName("argument"), }, { table->GetColumnByName("key"), }, { - {"hash_count_distinct", nullptr}, + {"hash_count_distinct", &all}, + {"hash_count_distinct", &only_valid}, + {"hash_count_distinct", &only_null}, }, use_threads)); ValidateOutput(aggregated_and_grouped); SortBy({"key_0"}, &aggregated_and_grouped); AssertDatumsEqual(ArrayFromJSON(struct_({ + field("hash_count_distinct", int64()), + field("hash_count_distinct", int64()), field("hash_count_distinct", int64()), field("key_0", int64()), }), R"([ - [1, 1], - [2, 2], - [3, 3], - [4, null] + [1, 1, 0, 1], + [2, 2, 0, 2], + [3, 2, 1, 3], + [1, 0, 1, 4], + [4, 4, 0, null] + ])"), + aggregated_and_grouped, + /*verbose=*/true); + + table = + TableFromJSON(schema({field("argument", utf8()), field("key", int64())}), { + R"([ + ["foo", 1], + ["foo", 1], + ["bar", 2], + ["bar", 2], + ["spam", 2] +])", + }); + + ASSERT_OK_AND_ASSIGN(aggregated_and_grouped, + internal::GroupBy( + { + table->GetColumnByName("argument"), + table->GetColumnByName("argument"), + table->GetColumnByName("argument"), + }, + { + table->GetColumnByName("key"), + }, + { + {"hash_count_distinct", &all}, + {"hash_count_distinct", &only_valid}, + {"hash_count_distinct", &only_null}, + }, + use_threads)); + ValidateOutput(aggregated_and_grouped); + SortBy({"key_0"}, &aggregated_and_grouped); + + AssertDatumsEqual(ArrayFromJSON(struct_({ + field("hash_count_distinct", int64()), + field("hash_count_distinct", int64()), + field("hash_count_distinct", int64()), + field("key_0", int64()), + }), + R"([ + [1, 1, 0, 1], + [2, 2, 0, 2] ])"), aggregated_and_grouped, /*verbose=*/true); @@ -1354,6 +1424,9 @@ TEST(GroupBy, CountDistinct) { } TEST(GroupBy, Distinct) { + CountOptions all(CountOptions::ALL); + CountOptions only_valid(CountOptions::ONLY_VALID); + CountOptions only_null(CountOptions::ONLY_NULL); for (bool use_threads : {true, false}) { SCOPED_TRACE(use_threads ? "parallel/merged" : "serial"); @@ -1364,7 +1437,12 @@ TEST(GroupBy, Distinct) { ])", R"([ ["bar", 2], + [null, 3], [null, 3] +])", + R"([ + [null, 4], + [null, 4] ])", R"([ ["baz", null], @@ -1387,33 +1465,104 @@ TEST(GroupBy, Distinct) { internal::GroupBy( { table->GetColumnByName("argument"), + table->GetColumnByName("argument"), + table->GetColumnByName("argument"), }, { table->GetColumnByName("key"), }, { - {"hash_distinct", nullptr}, + {"hash_distinct", &all}, + {"hash_distinct", &only_valid}, + {"hash_distinct", &only_null}, }, use_threads)); ValidateOutput(aggregated_and_grouped); SortBy({"key_0"}, &aggregated_and_grouped); // Order of sub-arrays is not stable - auto struct_arr = aggregated_and_grouped.array_as(); - auto distinct_arr = checked_pointer_cast(struct_arr->field(0)); auto sort = [](const Array& arr) -> std::shared_ptr { EXPECT_OK_AND_ASSIGN(auto indices, SortIndices(arr)); EXPECT_OK_AND_ASSIGN(auto sorted, Take(arr, indices)); return sorted.make_array(); }; - AssertDatumsEqual(ArrayFromJSON(utf8(), R"(["foo"])"), - sort(*distinct_arr->value_slice(0)), /*verbose=*/true); + + auto struct_arr = aggregated_and_grouped.array_as(); + + auto all_arr = checked_pointer_cast(struct_arr->field(0)); + AssertDatumsEqual(ArrayFromJSON(utf8(), R"(["foo"])"), sort(*all_arr->value_slice(0)), + /*verbose=*/true); AssertDatumsEqual(ArrayFromJSON(utf8(), R"(["bar", "spam"])"), - sort(*distinct_arr->value_slice(1)), /*verbose=*/true); + sort(*all_arr->value_slice(1)), /*verbose=*/true); AssertDatumsEqual(ArrayFromJSON(utf8(), R"(["foo", "ham", null])"), - sort(*distinct_arr->value_slice(2)), /*verbose=*/true); + sort(*all_arr->value_slice(2)), /*verbose=*/true); + AssertDatumsEqual(ArrayFromJSON(utf8(), R"([null])"), sort(*all_arr->value_slice(3)), + /*verbose=*/true); + AssertDatumsEqual(ArrayFromJSON(utf8(), R"(["a", "b", "baz", "eggs"])"), + sort(*all_arr->value_slice(4)), /*verbose=*/true); + + auto valid_arr = checked_pointer_cast(struct_arr->field(1)); + AssertDatumsEqual(ArrayFromJSON(utf8(), R"(["foo"])"), + sort(*valid_arr->value_slice(0)), /*verbose=*/true); + AssertDatumsEqual(ArrayFromJSON(utf8(), R"(["bar", "spam"])"), + sort(*valid_arr->value_slice(1)), /*verbose=*/true); + AssertDatumsEqual(ArrayFromJSON(utf8(), R"(["foo", "ham"])"), + sort(*valid_arr->value_slice(2)), /*verbose=*/true); + AssertDatumsEqual(ArrayFromJSON(utf8(), R"([])"), sort(*valid_arr->value_slice(3)), + /*verbose=*/true); AssertDatumsEqual(ArrayFromJSON(utf8(), R"(["a", "b", "baz", "eggs"])"), - sort(*distinct_arr->value_slice(3)), /*verbose=*/true); + sort(*valid_arr->value_slice(4)), /*verbose=*/true); + + auto null_arr = checked_pointer_cast(struct_arr->field(2)); + AssertDatumsEqual(ArrayFromJSON(utf8(), R"([])"), sort(*null_arr->value_slice(0)), + /*verbose=*/true); + AssertDatumsEqual(ArrayFromJSON(utf8(), R"([])"), sort(*null_arr->value_slice(1)), + /*verbose=*/true); + AssertDatumsEqual(ArrayFromJSON(utf8(), R"([null])"), sort(*null_arr->value_slice(2)), + /*verbose=*/true); + AssertDatumsEqual(ArrayFromJSON(utf8(), R"([null])"), sort(*null_arr->value_slice(3)), + /*verbose=*/true); + AssertDatumsEqual(ArrayFromJSON(utf8(), R"([])"), sort(*null_arr->value_slice(4)), + /*verbose=*/true); + + table = + TableFromJSON(schema({field("argument", utf8()), field("key", int64())}), { + R"([ + ["foo", 1], + ["foo", 1], + ["bar", 2], + ["bar", 2] +])", + }); + ASSERT_OK_AND_ASSIGN(aggregated_and_grouped, + internal::GroupBy( + { + table->GetColumnByName("argument"), + table->GetColumnByName("argument"), + table->GetColumnByName("argument"), + }, + { + table->GetColumnByName("key"), + }, + { + {"hash_distinct", &all}, + {"hash_distinct", &only_valid}, + {"hash_distinct", &only_null}, + }, + use_threads)); + ValidateOutput(aggregated_and_grouped); + SortBy({"key_0"}, &aggregated_and_grouped); + + AssertDatumsEqual( + ArrayFromJSON(struct_({ + field("hash_distinct", list(utf8())), + field("hash_distinct", list(utf8())), + field("hash_distinct", list(utf8())), + field("key_0", int64()), + }), + R"([[["foo"], ["foo"], [], 1], [["bar"], ["bar"], [], 2]])"), + aggregated_and_grouped, + /*verbose=*/true); } } diff --git a/docs/source/cpp/compute.rst b/docs/source/cpp/compute.rst index 2b39e3ca33a..465500e8dae 100644 --- a/docs/source/cpp/compute.rst +++ b/docs/source/cpp/compute.rst @@ -288,36 +288,43 @@ The supported aggregation functions are as follows. All function names are prefixed with ``hash_``, which differentiates them from their scalar equivalents above and reflects how they are implemented internally. -+---------------+-------+-------------+-----------------+----------------------------------+-------+ -| Function name | Arity | Input types | Output type | Options class | Notes | -+===============+=======+=============+=================+==================================+=======+ -| hash_all | Unary | Boolean | Boolean | :struct:`ScalarAggregateOptions` | \(1) | -+---------------+-------+-------------+-----------------+----------------------------------+-------+ -| hash_any | Unary | Boolean | Boolean | :struct:`ScalarAggregateOptions` | \(1) | -+---------------+-------+-------------+-----------------+----------------------------------+-------+ -| hash_count | Unary | Any | Int64 | :struct:`CountOptions` | \(2) | -+---------------+-------+-------------+-----------------+----------------------------------+-------+ -| hash_mean | Unary | Numeric | Decimal/Float64 | :struct:`ScalarAggregateOptions` | | -+---------------+-------+-------------+-----------------+----------------------------------+-------+ -| hash_min_max | Unary | Numeric | Struct | :struct:`ScalarAggregateOptions` | \(3) | -+---------------+-------+-------------+-----------------+----------------------------------+-------+ -| hash_product | Unary | Numeric | Numeric | :struct:`ScalarAggregateOptions` | \(4) | -+---------------+-------+-------------+-----------------+----------------------------------+-------+ -| hash_stddev | Unary | Numeric | Float64 | :struct:`VarianceOptions` | | -+---------------+-------+-------------+-----------------+----------------------------------+-------+ -| hash_sum | Unary | Numeric | Numeric | :struct:`ScalarAggregateOptions` | \(4) | -+---------------+-------+-------------+-----------------+----------------------------------+-------+ -| hash_tdigest | Unary | Numeric | Float64 | :struct:`TDigestOptions` | \(5) | -+---------------+-------+-------------+-----------------+----------------------------------+-------+ -| hash_variance | Unary | Numeric | Float64 | :struct:`VarianceOptions` | | -+---------------+-------+-------------+-----------------+----------------------------------+-------+ ++---------------------+-------+-------------+-----------------+----------------------------------+-------+ +| Function name | Arity | Input types | Output type | Options class | Notes | ++=====================+=======+=============+=================+==================================+=======+ +| hash_all | Unary | Boolean | Boolean | :struct:`ScalarAggregateOptions` | \(1) | ++---------------------+-------+-------------+-----------------+----------------------------------+-------+ +| hash_any | Unary | Boolean | Boolean | :struct:`ScalarAggregateOptions` | \(1) | ++---------------------+-------+-------------+-----------------+----------------------------------+-------+ +| hash_count | Unary | Any | Int64 | :struct:`CountOptions` | \(2) | ++---------------------+-------+-------------+-----------------+----------------------------------+-------+ +| hash_count_distinct | Unary | Any | Int64 | :struct:`CountOptions` | \(2) | ++---------------------+-------+-------------+-----------------+----------------------------------+-------+ +| hash_distinct | Unary | Any | Input type | :struct:`CountOptions` | \(2) | ++---------------------+-------+-------------+-----------------+----------------------------------+-------+ +| hash_mean | Unary | Numeric | Decimal/Float64 | :struct:`ScalarAggregateOptions` | | ++---------------------+-------+-------------+-----------------+----------------------------------+-------+ +| hash_min_max | Unary | Numeric | Struct | :struct:`ScalarAggregateOptions` | \(3) | ++---------------------+-------+-------------+-----------------+----------------------------------+-------+ +| hash_product | Unary | Numeric | Numeric | :struct:`ScalarAggregateOptions` | \(4) | ++---------------------+-------+-------------+-----------------+----------------------------------+-------+ +| hash_stddev | Unary | Numeric | Float64 | :struct:`VarianceOptions` | | ++---------------------+-------+-------------+-----------------+----------------------------------+-------+ +| hash_sum | Unary | Numeric | Numeric | :struct:`ScalarAggregateOptions` | \(4) | ++---------------------+-------+-------------+-----------------+----------------------------------+-------+ +| hash_tdigest | Unary | Numeric | Float64 | :struct:`TDigestOptions` | \(5) | ++---------------------+-------+-------------+-----------------+----------------------------------+-------+ +| hash_variance | Unary | Numeric | Float64 | :struct:`VarianceOptions` | | ++---------------------+-------+-------------+-----------------+----------------------------------+-------+ * \(1) If null values are taken into account, by setting the :member:`ScalarAggregateOptions::skip_nulls` to false, then `Kleene logic`_ logic is applied. The min_count option is not respected. -* \(2) CountMode controls whether only non-null values are counted (the - default), only null values are counted, or all values are counted. +* \(2) CountMode controls whether only non-null values are counted + (the default), only null values are counted, or all values are + counted. For hash_distinct, it instead controls whether null values + are emitted. This never affects the grouping keys, only group values + (i.e. you may get a group where the key is null). * \(3) Output is a ``{"min": input type, "max": input type}`` Struct scalar.