From a5fcdcaa9564d96fa875ab1faa71af8e75f6727a Mon Sep 17 00:00:00 2001
From: David Li
Date: Fri, 9 Jul 2021 16:24:01 -0400
Subject: [PATCH] ARROW-13298: [C++] Implement any/all hash aggregate kernels
---
.../arrow/compute/kernels/hash_aggregate.cc | 139 ++++++++++++++++++
.../compute/kernels/hash_aggregate_test.cc | 49 ++++++
2 files changed, 188 insertions(+)
diff --git a/cpp/src/arrow/compute/kernels/hash_aggregate.cc b/cpp/src/arrow/compute/kernels/hash_aggregate.cc
index 79213b93b37..3e4b401bae9 100644
--- a/cpp/src/arrow/compute/kernels/hash_aggregate.cc
+++ b/cpp/src/arrow/compute/kernels/hash_aggregate.cc
@@ -1149,6 +1149,123 @@ struct GroupedMinMaxFactory {
InputType argument_type;
};
+// ----------------------------------------------------------------------
+// Any/All implementation
+
+struct GroupedAnyImpl : public GroupedAggregator {
+ Status Init(ExecContext* ctx, const FunctionOptions*) override {
+ seen_ = TypedBufferBuilder(ctx->memory_pool());
+ return Status::OK();
+ }
+
+ Status Resize(int64_t new_num_groups) override {
+ auto added_groups = new_num_groups - num_groups_;
+ num_groups_ = new_num_groups;
+ return seen_.Append(added_groups, false);
+ }
+
+ Status Merge(GroupedAggregator&& raw_other,
+ const ArrayData& group_id_mapping) override {
+ auto other = checked_cast(&raw_other);
+
+ auto seen = seen_.mutable_data();
+ auto other_seen = other->seen_.data();
+
+ auto g = group_id_mapping.GetValues(1);
+ for (int64_t other_g = 0; other_g < group_id_mapping.length; ++other_g, ++g) {
+ if (BitUtil::GetBit(other_seen, other_g)) BitUtil::SetBitTo(seen, *g, true);
+ }
+ return Status::OK();
+ }
+
+ Status Consume(const ExecBatch& batch) override {
+ auto seen = seen_.mutable_data();
+
+ const auto& input = *batch[0].array();
+
+ auto g = batch[1].array()->GetValues(1);
+ arrow::internal::VisitTwoBitBlocksVoid(
+ input.buffers[0], input.offset, input.buffers[1], input.offset, input.length,
+ [&](int64_t) { BitUtil::SetBitTo(seen, *g++, true); }, [&]() { g++; });
+ return Status::OK();
+ }
+
+ Result Finalize() override {
+ ARROW_ASSIGN_OR_RAISE(auto seen, seen_.Finish());
+ return std::make_shared(num_groups_, std::move(seen));
+ }
+
+ std::shared_ptr out_type() const override { return boolean(); }
+
+ int64_t num_groups_ = 0;
+ ScalarAggregateOptions options_;
+ TypedBufferBuilder seen_;
+};
+
+struct GroupedAllImpl : public GroupedAggregator {
+ Status Init(ExecContext* ctx, const FunctionOptions*) override {
+ seen_ = TypedBufferBuilder(ctx->memory_pool());
+ return Status::OK();
+ }
+
+ Status Resize(int64_t new_num_groups) override {
+ auto added_groups = new_num_groups - num_groups_;
+ num_groups_ = new_num_groups;
+ return seen_.Append(added_groups, true);
+ }
+
+ Status Merge(GroupedAggregator&& raw_other,
+ const ArrayData& group_id_mapping) override {
+ auto other = checked_cast(&raw_other);
+
+ auto seen = seen_.mutable_data();
+ auto other_seen = other->seen_.data();
+
+ auto g = group_id_mapping.GetValues(1);
+ for (int64_t other_g = 0; other_g < group_id_mapping.length; ++other_g, ++g) {
+ BitUtil::SetBitTo(
+ seen, *g, BitUtil::GetBit(seen, *g) && BitUtil::GetBit(other_seen, other_g));
+ }
+ return Status::OK();
+ }
+
+ Status Consume(const ExecBatch& batch) override {
+ auto seen = seen_.mutable_data();
+
+ const auto& input = *batch[0].array();
+
+ auto g = batch[1].array()->GetValues(1);
+ if (input.MayHaveNulls()) {
+ const uint8_t* bitmap = input.buffers[1]->data();
+ arrow::internal::VisitBitBlocksVoid(
+ input.buffers[0], input.offset, input.length,
+ [&](int64_t position) {
+ BitUtil::SetBitTo(seen, *g,
+ BitUtil::GetBit(seen, *g) &&
+ BitUtil::GetBit(bitmap, input.offset + position));
+ g++;
+ },
+ [&]() { g++; });
+ } else {
+ arrow::internal::VisitBitBlocksVoid(
+ input.buffers[1], input.offset, input.length, [&](int64_t) { g++; },
+ [&]() { BitUtil::SetBitTo(seen, *g++, false); });
+ }
+ return Status::OK();
+ }
+
+ Result Finalize() override {
+ ARROW_ASSIGN_OR_RAISE(auto seen, seen_.Finish());
+ return std::make_shared(num_groups_, std::move(seen));
+ }
+
+ std::shared_ptr out_type() const override { return boolean(); }
+
+ int64_t num_groups_ = 0;
+ ScalarAggregateOptions options_;
+ TypedBufferBuilder seen_;
+};
+
} // namespace
Result> GetKernels(
@@ -1426,6 +1543,14 @@ const FunctionDoc hash_min_max_doc{
"This can be changed through ScalarAggregateOptions."),
{"array", "group_id_array"},
"ScalarAggregateOptions"};
+
+const FunctionDoc hash_any_doc{"Test whether any element evaluates to true",
+ ("Null values are ignored."),
+ {"array", "group_id_array"}};
+
+const FunctionDoc hash_all_doc{"Test whether all elements evaluate to true",
+ ("Null values are ignored."),
+ {"array", "group_id_array"}};
} // namespace
void RegisterHashAggregateBasic(FunctionRegistry* registry) {
@@ -1460,6 +1585,20 @@ void RegisterHashAggregateBasic(FunctionRegistry* registry) {
DCHECK_OK(AddHashAggKernels(NumericTypes(), GroupedMinMaxFactory::Make, func.get()));
DCHECK_OK(registry->AddFunction(std::move(func)));
}
+
+ {
+ auto func = std::make_shared("hash_any", Arity::Binary(),
+ &hash_any_doc);
+ DCHECK_OK(func->AddKernel(MakeKernel(boolean(), HashAggregateInit)));
+ DCHECK_OK(registry->AddFunction(std::move(func)));
+ }
+
+ {
+ auto func = std::make_shared("hash_all", Arity::Binary(),
+ &hash_all_doc);
+ DCHECK_OK(func->AddKernel(MakeKernel(boolean(), HashAggregateInit)));
+ DCHECK_OK(registry->AddFunction(std::move(func)));
+ }
}
} // namespace internal
diff --git a/cpp/src/arrow/compute/kernels/hash_aggregate_test.cc b/cpp/src/arrow/compute/kernels/hash_aggregate_test.cc
index b0327c7aa81..46c7716abce 100644
--- a/cpp/src/arrow/compute/kernels/hash_aggregate_test.cc
+++ b/cpp/src/arrow/compute/kernels/hash_aggregate_test.cc
@@ -705,6 +705,55 @@ TEST(GroupBy, MinMaxOnly) {
}
}
+TEST(GroupBy, AnyAndAll) {
+ for (bool use_threads : {true, false}) {
+ SCOPED_TRACE(use_threads ? "parallel/merged" : "serial");
+
+ auto table =
+ TableFromJSON(schema({field("argument", boolean()), field("key", int64())}), {R"([
+ [true, 1],
+ [null, 1]
+ ])",
+ R"([
+ [false, 2],
+ [null, 3],
+ [false, null],
+ [true, 1],
+ [true, 2]
+ ])",
+ R"([
+ [true, 2],
+ [false, null],
+ [null, 3]
+ ])"});
+
+ ASSERT_OK_AND_ASSIGN(Datum aggregated_and_grouped,
+ internal::GroupBy({table->GetColumnByName("argument"),
+ table->GetColumnByName("argument")},
+ {table->GetColumnByName("key")},
+ {
+ {"hash_any", nullptr},
+ {"hash_all", nullptr},
+ },
+ use_threads));
+ SortBy({"key_0"}, &aggregated_and_grouped);
+
+ AssertDatumsEqual(ArrayFromJSON(struct_({
+ field("hash_any", boolean()),
+ field("hash_all", boolean()),
+ field("key_0", int64()),
+ }),
+ R"([
+ [true, true, 1],
+ [true, false, 2],
+ [false, true, 3],
+ [false, false, null]
+ ])"),
+ aggregated_and_grouped,
+ /*verbose=*/true);
+ }
+}
+
TEST(GroupBy, CountAndSum) {
auto batch = RecordBatchFromJSON(
schema({field("argument", float64()), field("key", int64())}), R"([