diff --git a/cpp/src/arrow/compute/kernels/hash_aggregate.cc b/cpp/src/arrow/compute/kernels/hash_aggregate.cc index 79213b93b37..3e4b401bae9 100644 --- a/cpp/src/arrow/compute/kernels/hash_aggregate.cc +++ b/cpp/src/arrow/compute/kernels/hash_aggregate.cc @@ -1149,6 +1149,123 @@ struct GroupedMinMaxFactory { InputType argument_type; }; +// ---------------------------------------------------------------------- +// Any/All implementation + +struct GroupedAnyImpl : public GroupedAggregator { + Status Init(ExecContext* ctx, const FunctionOptions*) override { + seen_ = TypedBufferBuilder(ctx->memory_pool()); + return Status::OK(); + } + + Status Resize(int64_t new_num_groups) override { + auto added_groups = new_num_groups - num_groups_; + num_groups_ = new_num_groups; + return seen_.Append(added_groups, false); + } + + Status Merge(GroupedAggregator&& raw_other, + const ArrayData& group_id_mapping) override { + auto other = checked_cast(&raw_other); + + auto seen = seen_.mutable_data(); + auto other_seen = other->seen_.data(); + + auto g = group_id_mapping.GetValues(1); + for (int64_t other_g = 0; other_g < group_id_mapping.length; ++other_g, ++g) { + if (BitUtil::GetBit(other_seen, other_g)) BitUtil::SetBitTo(seen, *g, true); + } + return Status::OK(); + } + + Status Consume(const ExecBatch& batch) override { + auto seen = seen_.mutable_data(); + + const auto& input = *batch[0].array(); + + auto g = batch[1].array()->GetValues(1); + arrow::internal::VisitTwoBitBlocksVoid( + input.buffers[0], input.offset, input.buffers[1], input.offset, input.length, + [&](int64_t) { BitUtil::SetBitTo(seen, *g++, true); }, [&]() { g++; }); + return Status::OK(); + } + + Result Finalize() override { + ARROW_ASSIGN_OR_RAISE(auto seen, seen_.Finish()); + return std::make_shared(num_groups_, std::move(seen)); + } + + std::shared_ptr out_type() const override { return boolean(); } + + int64_t num_groups_ = 0; + ScalarAggregateOptions options_; + TypedBufferBuilder seen_; +}; + +struct GroupedAllImpl : public GroupedAggregator { + Status Init(ExecContext* ctx, const FunctionOptions*) override { + seen_ = TypedBufferBuilder(ctx->memory_pool()); + return Status::OK(); + } + + Status Resize(int64_t new_num_groups) override { + auto added_groups = new_num_groups - num_groups_; + num_groups_ = new_num_groups; + return seen_.Append(added_groups, true); + } + + Status Merge(GroupedAggregator&& raw_other, + const ArrayData& group_id_mapping) override { + auto other = checked_cast(&raw_other); + + auto seen = seen_.mutable_data(); + auto other_seen = other->seen_.data(); + + auto g = group_id_mapping.GetValues(1); + for (int64_t other_g = 0; other_g < group_id_mapping.length; ++other_g, ++g) { + BitUtil::SetBitTo( + seen, *g, BitUtil::GetBit(seen, *g) && BitUtil::GetBit(other_seen, other_g)); + } + return Status::OK(); + } + + Status Consume(const ExecBatch& batch) override { + auto seen = seen_.mutable_data(); + + const auto& input = *batch[0].array(); + + auto g = batch[1].array()->GetValues(1); + if (input.MayHaveNulls()) { + const uint8_t* bitmap = input.buffers[1]->data(); + arrow::internal::VisitBitBlocksVoid( + input.buffers[0], input.offset, input.length, + [&](int64_t position) { + BitUtil::SetBitTo(seen, *g, + BitUtil::GetBit(seen, *g) && + BitUtil::GetBit(bitmap, input.offset + position)); + g++; + }, + [&]() { g++; }); + } else { + arrow::internal::VisitBitBlocksVoid( + input.buffers[1], input.offset, input.length, [&](int64_t) { g++; }, + [&]() { BitUtil::SetBitTo(seen, *g++, false); }); + } + return Status::OK(); + } + + Result Finalize() override { + ARROW_ASSIGN_OR_RAISE(auto seen, seen_.Finish()); + return std::make_shared(num_groups_, std::move(seen)); + } + + std::shared_ptr out_type() const override { return boolean(); } + + int64_t num_groups_ = 0; + ScalarAggregateOptions options_; + TypedBufferBuilder seen_; +}; + } // namespace Result> GetKernels( @@ -1426,6 +1543,14 @@ const FunctionDoc hash_min_max_doc{ "This can be changed through ScalarAggregateOptions."), {"array", "group_id_array"}, "ScalarAggregateOptions"}; + +const FunctionDoc hash_any_doc{"Test whether any element evaluates to true", + ("Null values are ignored."), + {"array", "group_id_array"}}; + +const FunctionDoc hash_all_doc{"Test whether all elements evaluate to true", + ("Null values are ignored."), + {"array", "group_id_array"}}; } // namespace void RegisterHashAggregateBasic(FunctionRegistry* registry) { @@ -1460,6 +1585,20 @@ void RegisterHashAggregateBasic(FunctionRegistry* registry) { DCHECK_OK(AddHashAggKernels(NumericTypes(), GroupedMinMaxFactory::Make, func.get())); DCHECK_OK(registry->AddFunction(std::move(func))); } + + { + auto func = std::make_shared("hash_any", Arity::Binary(), + &hash_any_doc); + DCHECK_OK(func->AddKernel(MakeKernel(boolean(), HashAggregateInit))); + DCHECK_OK(registry->AddFunction(std::move(func))); + } + + { + auto func = std::make_shared("hash_all", Arity::Binary(), + &hash_all_doc); + DCHECK_OK(func->AddKernel(MakeKernel(boolean(), HashAggregateInit))); + DCHECK_OK(registry->AddFunction(std::move(func))); + } } } // namespace internal diff --git a/cpp/src/arrow/compute/kernels/hash_aggregate_test.cc b/cpp/src/arrow/compute/kernels/hash_aggregate_test.cc index b0327c7aa81..46c7716abce 100644 --- a/cpp/src/arrow/compute/kernels/hash_aggregate_test.cc +++ b/cpp/src/arrow/compute/kernels/hash_aggregate_test.cc @@ -705,6 +705,55 @@ TEST(GroupBy, MinMaxOnly) { } } +TEST(GroupBy, AnyAndAll) { + for (bool use_threads : {true, false}) { + SCOPED_TRACE(use_threads ? "parallel/merged" : "serial"); + + auto table = + TableFromJSON(schema({field("argument", boolean()), field("key", int64())}), {R"([ + [true, 1], + [null, 1] + ])", + R"([ + [false, 2], + [null, 3], + [false, null], + [true, 1], + [true, 2] + ])", + R"([ + [true, 2], + [false, null], + [null, 3] + ])"}); + + ASSERT_OK_AND_ASSIGN(Datum aggregated_and_grouped, + internal::GroupBy({table->GetColumnByName("argument"), + table->GetColumnByName("argument")}, + {table->GetColumnByName("key")}, + { + {"hash_any", nullptr}, + {"hash_all", nullptr}, + }, + use_threads)); + SortBy({"key_0"}, &aggregated_and_grouped); + + AssertDatumsEqual(ArrayFromJSON(struct_({ + field("hash_any", boolean()), + field("hash_all", boolean()), + field("key_0", int64()), + }), + R"([ + [true, true, 1], + [true, false, 2], + [false, true, 3], + [false, false, null] + ])"), + aggregated_and_grouped, + /*verbose=*/true); + } +} + TEST(GroupBy, CountAndSum) { auto batch = RecordBatchFromJSON( schema({field("argument", float64()), field("key", int64())}), R"([