diff --git a/cpp/src/arrow/compute/kernels/hash_aggregate.cc b/cpp/src/arrow/compute/kernels/hash_aggregate.cc index 20dcd8ef331..0a567e385e7 100644 --- a/cpp/src/arrow/compute/kernels/hash_aggregate.cc +++ b/cpp/src/arrow/compute/kernels/hash_aggregate.cc @@ -1962,6 +1962,97 @@ struct GroupedAllImpl : public GroupedBooleanAggregator { num_groups, /*out_offset=*/0, no_nulls); } }; + +// ---------------------------------------------------------------------- +// CountDistinct/Distinct implementation + +struct GroupedCountDistinctImpl : public GroupedAggregator { + Status Init(ExecContext* ctx, const FunctionOptions* options) override { + ctx_ = ctx; + pool_ = ctx->memory_pool(); + return Status::OK(); + } + + Status Resize(int64_t new_num_groups) override { + num_groups_ = new_num_groups; + return Status::OK(); + } + + Status Consume(const ExecBatch& batch) override { + return grouper_->Consume(batch).status(); + } + + Status Merge(GroupedAggregator&& raw_other, + const ArrayData& group_id_mapping) override { + auto other = checked_cast(&raw_other); + + // Get (value, group_id) pairs, then translate the group IDs and consume them + // ourselves + ARROW_ASSIGN_OR_RAISE(auto uniques, other->grouper_->GetUniques()); + ARROW_ASSIGN_OR_RAISE(auto remapped_g, + AllocateBuffer(uniques.length * sizeof(uint32_t), pool_)); + + const auto* g_mapping = group_id_mapping.GetValues(1); + const auto* other_g = uniques[1].array()->GetValues(1); + auto* g = reinterpret_cast(remapped_g->mutable_data()); + + for (int64_t i = 0; i < uniques.length; i++) { + g[i] = g_mapping[other_g[i]]; + } + uniques.values[1] = + ArrayData::Make(uint32(), uniques.length, {nullptr, std::move(remapped_g)}); + + return Consume(std::move(uniques)); + } + + Result Finalize() override { + ARROW_ASSIGN_OR_RAISE(std::shared_ptr values, + AllocateBuffer(num_groups_ * sizeof(int64_t), pool_)); + int64_t* counts = reinterpret_cast(values->mutable_data()); + std::fill(counts, counts + num_groups_, 0); + + ARROW_ASSIGN_OR_RAISE(auto uniques, grouper_->GetUniques()); + auto* g = uniques[1].array()->GetValues(1); + for (int64_t i = 0; i < uniques.length; i++) { + counts[g[i]]++; + } + + return ArrayData::Make(int64(), num_groups_, {nullptr, std::move(values)}, + /*null_count=*/0); + } + + std::shared_ptr out_type() const override { return int64(); } + + ExecContext* ctx_; + MemoryPool* pool_; + int64_t num_groups_; + std::unique_ptr grouper_; + std::shared_ptr out_type_; +}; + +struct GroupedDistinctImpl : public GroupedCountDistinctImpl { + Result Finalize() override { + ARROW_ASSIGN_OR_RAISE(auto uniques, grouper_->GetUniques()); + ARROW_ASSIGN_OR_RAISE(auto groupings, grouper_->MakeGroupings( + *uniques[1].array_as(), + static_cast(num_groups_), ctx_)); + return grouper_->ApplyGroupings(*groupings, *uniques[0].make_array(), ctx_); + } + + std::shared_ptr out_type() const override { return list(out_type_); } +}; + +template +Result> GroupedDistinctInit(KernelContext* ctx, + const KernelInitArgs& args) { + ARROW_ASSIGN_OR_RAISE(auto impl, HashAggregateInit(ctx, args)); + auto instance = static_cast(impl.get()); + instance->out_type_ = args.inputs[0].type; + ARROW_ASSIGN_OR_RAISE(instance->grouper_, + Grouper::Make(args.inputs, ctx->exec_context())); + return std::move(impl); +} + } // namespace Result> GetKernels( @@ -2289,6 +2380,16 @@ const FunctionDoc hash_all_doc{"Test whether all elements evaluate to true", ("Null values are ignored."), {"array", "group_id_array"}, "ScalarAggregateOptions"}; + +const FunctionDoc hash_count_distinct_doc{ + "Count the distinct values in each group", + ("Nulls are counted. NaNs and signed zeroes are not normalized."), + {"array", "group_id_array"}}; + +const FunctionDoc hash_distinct_doc{ + "Keep the distinct values in each group", + ("Nulls are kept. NaNs and signed zeroes are not normalized."), + {"array", "group_id_array"}}; } // namespace void RegisterHashAggregateBasic(FunctionRegistry* registry) { @@ -2412,6 +2513,22 @@ void RegisterHashAggregateBasic(FunctionRegistry* registry) { DCHECK_OK(func->AddKernel(MakeKernel(boolean(), HashAggregateInit))); DCHECK_OK(registry->AddFunction(std::move(func))); } + + { + auto func = std::make_shared( + "hash_count_distinct", Arity::Binary(), &hash_count_distinct_doc); + DCHECK_OK(func->AddKernel( + MakeKernel(ValueDescr::ARRAY, GroupedDistinctInit))); + DCHECK_OK(registry->AddFunction(std::move(func))); + } + + { + auto func = std::make_shared("hash_distinct", Arity::Binary(), + &hash_distinct_doc); + DCHECK_OK(func->AddKernel( + MakeKernel(ValueDescr::ARRAY, GroupedDistinctInit))); + DCHECK_OK(registry->AddFunction(std::move(func))); + } } } // namespace internal diff --git a/cpp/src/arrow/compute/kernels/hash_aggregate_test.cc b/cpp/src/arrow/compute/kernels/hash_aggregate_test.cc index a160461b5dc..df2222a4eef 100644 --- a/cpp/src/arrow/compute/kernels/hash_aggregate_test.cc +++ b/cpp/src/arrow/compute/kernels/hash_aggregate_test.cc @@ -1239,6 +1239,184 @@ TEST(GroupBy, AnyAndAll) { } } +TEST(GroupBy, CountDistinct) { + for (bool use_threads : {true, false}) { + SCOPED_TRACE(use_threads ? "parallel/merged" : "serial"); + + auto table = + TableFromJSON(schema({field("argument", float64()), field("key", int64())}), {R"([ + [1, 1], + [1, 1] +])", + R"([ + [0, 2], + [null, 3] +])", + R"([ + [4, null], + [1, 3] +])", + R"([ + [0, 2], + [-1, 2] +])", + R"([ + [1, null], + [NaN, 3] + ])", + R"([ + [2, null], + [3, null] + ])"}); + + ASSERT_OK_AND_ASSIGN(Datum aggregated_and_grouped, + internal::GroupBy( + { + table->GetColumnByName("argument"), + }, + { + table->GetColumnByName("key"), + }, + { + {"hash_count_distinct", nullptr}, + }, + use_threads)); + SortBy({"key_0"}, &aggregated_and_grouped); + ValidateOutput(aggregated_and_grouped); + + AssertDatumsEqual(ArrayFromJSON(struct_({ + field("hash_count_distinct", int64()), + field("key_0", int64()), + }), + R"([ + [1, 1], + [2, 2], + [3, 3], + [4, null] + ])"), + aggregated_and_grouped, + /*verbose=*/true); + + table = + TableFromJSON(schema({field("argument", utf8()), field("key", int64())}), {R"([ + ["foo", 1], + ["foo", 1] +])", + R"([ + ["bar", 2], + [null, 3] +])", + R"([ + ["baz", null], + ["foo", 3] +])", + R"([ + ["bar", 2], + ["spam", 2] +])", + R"([ + ["eggs", null], + ["ham", 3] + ])", + R"([ + ["a", null], + ["b", null] + ])"}); + + ASSERT_OK_AND_ASSIGN(aggregated_and_grouped, + internal::GroupBy( + { + table->GetColumnByName("argument"), + }, + { + table->GetColumnByName("key"), + }, + { + {"hash_count_distinct", nullptr}, + }, + use_threads)); + ValidateOutput(aggregated_and_grouped); + SortBy({"key_0"}, &aggregated_and_grouped); + + AssertDatumsEqual(ArrayFromJSON(struct_({ + field("hash_count_distinct", int64()), + field("key_0", int64()), + }), + R"([ + [1, 1], + [2, 2], + [3, 3], + [4, null] + ])"), + aggregated_and_grouped, + /*verbose=*/true); + } +} + +TEST(GroupBy, Distinct) { + for (bool use_threads : {true, false}) { + SCOPED_TRACE(use_threads ? "parallel/merged" : "serial"); + + auto table = + TableFromJSON(schema({field("argument", utf8()), field("key", int64())}), {R"([ + ["foo", 1], + ["foo", 1] +])", + R"([ + ["bar", 2], + [null, 3] +])", + R"([ + ["baz", null], + ["foo", 3] +])", + R"([ + ["bar", 2], + ["spam", 2] +])", + R"([ + ["eggs", null], + ["ham", 3] + ])", + R"([ + ["a", null], + ["b", null] + ])"}); + + ASSERT_OK_AND_ASSIGN(auto aggregated_and_grouped, + internal::GroupBy( + { + table->GetColumnByName("argument"), + }, + { + table->GetColumnByName("key"), + }, + { + {"hash_distinct", nullptr}, + }, + use_threads)); + ValidateOutput(aggregated_and_grouped); + SortBy({"key_0"}, &aggregated_and_grouped); + + // Order of sub-arrays is not stable + auto struct_arr = aggregated_and_grouped.array_as(); + auto distinct_arr = checked_pointer_cast(struct_arr->field(0)); + auto sort = [](const Array& arr) -> std::shared_ptr { + EXPECT_OK_AND_ASSIGN(auto indices, SortIndices(arr)); + EXPECT_OK_AND_ASSIGN(auto sorted, Take(arr, indices)); + return sorted.make_array(); + }; + AssertDatumsEqual(ArrayFromJSON(utf8(), R"(["foo"])"), + sort(*distinct_arr->value_slice(0)), /*verbose=*/true); + AssertDatumsEqual(ArrayFromJSON(utf8(), R"(["bar", "spam"])"), + sort(*distinct_arr->value_slice(1)), /*verbose=*/true); + AssertDatumsEqual(ArrayFromJSON(utf8(), R"(["foo", "ham", null])"), + sort(*distinct_arr->value_slice(2)), /*verbose=*/true); + AssertDatumsEqual(ArrayFromJSON(utf8(), R"(["a", "b", "baz", "eggs"])"), + sort(*distinct_arr->value_slice(3)), /*verbose=*/true); + } +} + TEST(GroupBy, CountAndSum) { auto batch = RecordBatchFromJSON( schema({field("argument", float64()), field("key", int64())}), R"([