Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
139 changes: 139 additions & 0 deletions cpp/src/arrow/compute/kernels/hash_aggregate.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1149,6 +1149,123 @@ struct GroupedMinMaxFactory {
InputType argument_type;
};

// ----------------------------------------------------------------------
// Any/All implementation

struct GroupedAnyImpl : public GroupedAggregator {
Status Init(ExecContext* ctx, const FunctionOptions*) override {
seen_ = TypedBufferBuilder<bool>(ctx->memory_pool());
return Status::OK();
}

Status Resize(int64_t new_num_groups) override {
auto added_groups = new_num_groups - num_groups_;
num_groups_ = new_num_groups;
return seen_.Append(added_groups, false);
}

Status Merge(GroupedAggregator&& raw_other,
const ArrayData& group_id_mapping) override {
auto other = checked_cast<GroupedAnyImpl*>(&raw_other);

auto seen = seen_.mutable_data();
auto other_seen = other->seen_.data();

auto g = group_id_mapping.GetValues<uint32_t>(1);
for (int64_t other_g = 0; other_g < group_id_mapping.length; ++other_g, ++g) {
if (BitUtil::GetBit(other_seen, other_g)) BitUtil::SetBitTo(seen, *g, true);
}
return Status::OK();
}

Status Consume(const ExecBatch& batch) override {
auto seen = seen_.mutable_data();

const auto& input = *batch[0].array();

auto g = batch[1].array()->GetValues<uint32_t>(1);
arrow::internal::VisitTwoBitBlocksVoid(
input.buffers[0], input.offset, input.buffers[1], input.offset, input.length,
[&](int64_t) { BitUtil::SetBitTo(seen, *g++, true); }, [&]() { g++; });
return Status::OK();
}

Result<Datum> Finalize() override {
ARROW_ASSIGN_OR_RAISE(auto seen, seen_.Finish());
return std::make_shared<BooleanArray>(num_groups_, std::move(seen));
}

std::shared_ptr<DataType> out_type() const override { return boolean(); }

int64_t num_groups_ = 0;
ScalarAggregateOptions options_;
TypedBufferBuilder<bool> seen_;
};

struct GroupedAllImpl : public GroupedAggregator {
Status Init(ExecContext* ctx, const FunctionOptions*) override {
seen_ = TypedBufferBuilder<bool>(ctx->memory_pool());
return Status::OK();
}

Status Resize(int64_t new_num_groups) override {
auto added_groups = new_num_groups - num_groups_;
num_groups_ = new_num_groups;
return seen_.Append(added_groups, true);
}

Status Merge(GroupedAggregator&& raw_other,
const ArrayData& group_id_mapping) override {
auto other = checked_cast<GroupedAllImpl*>(&raw_other);

auto seen = seen_.mutable_data();
auto other_seen = other->seen_.data();

auto g = group_id_mapping.GetValues<uint32_t>(1);
for (int64_t other_g = 0; other_g < group_id_mapping.length; ++other_g, ++g) {
BitUtil::SetBitTo(
seen, *g, BitUtil::GetBit(seen, *g) && BitUtil::GetBit(other_seen, other_g));
}
return Status::OK();
}

Status Consume(const ExecBatch& batch) override {
auto seen = seen_.mutable_data();

const auto& input = *batch[0].array();

auto g = batch[1].array()->GetValues<uint32_t>(1);
if (input.MayHaveNulls()) {
const uint8_t* bitmap = input.buffers[1]->data();
arrow::internal::VisitBitBlocksVoid(
input.buffers[0], input.offset, input.length,
[&](int64_t position) {
BitUtil::SetBitTo(seen, *g,
BitUtil::GetBit(seen, *g) &&
BitUtil::GetBit(bitmap, input.offset + position));
g++;
},
[&]() { g++; });
} else {
arrow::internal::VisitBitBlocksVoid(
input.buffers[1], input.offset, input.length, [&](int64_t) { g++; },
[&]() { BitUtil::SetBitTo(seen, *g++, false); });
}
return Status::OK();
}

Result<Datum> Finalize() override {
ARROW_ASSIGN_OR_RAISE(auto seen, seen_.Finish());
return std::make_shared<BooleanArray>(num_groups_, std::move(seen));
}

std::shared_ptr<DataType> out_type() const override { return boolean(); }

int64_t num_groups_ = 0;
ScalarAggregateOptions options_;
TypedBufferBuilder<bool> seen_;
};

} // namespace

Result<std::vector<const HashAggregateKernel*>> GetKernels(
Expand Down Expand Up @@ -1426,6 +1543,14 @@ const FunctionDoc hash_min_max_doc{
"This can be changed through ScalarAggregateOptions."),
{"array", "group_id_array"},
"ScalarAggregateOptions"};

const FunctionDoc hash_any_doc{"Test whether any element evaluates to true",
("Null values are ignored."),
{"array", "group_id_array"}};

const FunctionDoc hash_all_doc{"Test whether all elements evaluate to true",
("Null values are ignored."),
{"array", "group_id_array"}};
} // namespace

void RegisterHashAggregateBasic(FunctionRegistry* registry) {
Expand Down Expand Up @@ -1460,6 +1585,20 @@ void RegisterHashAggregateBasic(FunctionRegistry* registry) {
DCHECK_OK(AddHashAggKernels(NumericTypes(), GroupedMinMaxFactory::Make, func.get()));
DCHECK_OK(registry->AddFunction(std::move(func)));
}

{
auto func = std::make_shared<HashAggregateFunction>("hash_any", Arity::Binary(),
&hash_any_doc);
DCHECK_OK(func->AddKernel(MakeKernel(boolean(), HashAggregateInit<GroupedAnyImpl>)));
DCHECK_OK(registry->AddFunction(std::move(func)));
}

{
auto func = std::make_shared<HashAggregateFunction>("hash_all", Arity::Binary(),
&hash_all_doc);
DCHECK_OK(func->AddKernel(MakeKernel(boolean(), HashAggregateInit<GroupedAllImpl>)));
DCHECK_OK(registry->AddFunction(std::move(func)));
}
}

} // namespace internal
Expand Down
49 changes: 49 additions & 0 deletions cpp/src/arrow/compute/kernels/hash_aggregate_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -705,6 +705,55 @@ TEST(GroupBy, MinMaxOnly) {
}
}

TEST(GroupBy, AnyAndAll) {
for (bool use_threads : {true, false}) {
SCOPED_TRACE(use_threads ? "parallel/merged" : "serial");

auto table =
TableFromJSON(schema({field("argument", boolean()), field("key", int64())}), {R"([
[true, 1],
[null, 1]
])",
R"([
[false, 2],
[null, 3],
[false, null],
[true, 1],
[true, 2]
])",
R"([
[true, 2],
[false, null],
[null, 3]
])"});

ASSERT_OK_AND_ASSIGN(Datum aggregated_and_grouped,
internal::GroupBy({table->GetColumnByName("argument"),
table->GetColumnByName("argument")},
{table->GetColumnByName("key")},
{
{"hash_any", nullptr},
{"hash_all", nullptr},
},
use_threads));
SortBy({"key_0"}, &aggregated_and_grouped);

AssertDatumsEqual(ArrayFromJSON(struct_({
field("hash_any", boolean()),
field("hash_all", boolean()),
field("key_0", int64()),
}),
R"([
[true, true, 1],
[true, false, 2],
[false, true, 3],
[false, false, null]
])"),
aggregated_and_grouped,
/*verbose=*/true);
}
}

TEST(GroupBy, CountAndSum) {
auto batch = RecordBatchFromJSON(
schema({field("argument", float64()), field("key", int64())}), R"([
Expand Down