Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
94 changes: 72 additions & 22 deletions cpp/src/arrow/compute/kernels/hash_aggregate.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1718,11 +1718,11 @@ struct GroupedMinMaxFactory {
}

Status Visit(const HalfFloatType& type) {
return Status::NotImplemented("Summing data of type ", type);
return Status::NotImplemented("Computing min/max of data of type ", type);
}

Status Visit(const DataType& type) {
return Status::NotImplemented("Summing data of type ", type);
return Status::NotImplemented("Computing min/max of data of type ", type);
}

static Result<HashAggregateKernel> Make(const std::shared_ptr<DataType>& type) {
Expand All @@ -1740,15 +1740,18 @@ struct GroupedMinMaxFactory {
// Any/All implementation

struct GroupedAnyImpl : public GroupedAggregator {
Status Init(ExecContext* ctx, const FunctionOptions*) override {
Status Init(ExecContext* ctx, const FunctionOptions* options) override {
options_ = *checked_cast<const ScalarAggregateOptions*>(options);
seen_ = TypedBufferBuilder<bool>(ctx->memory_pool());
has_nulls_ = TypedBufferBuilder<bool>(ctx->memory_pool());
return Status::OK();
}

Status Resize(int64_t new_num_groups) override {
auto added_groups = new_num_groups - num_groups_;
num_groups_ = new_num_groups;
return seen_.Append(added_groups, false);
RETURN_NOT_OK(seen_.Append(added_groups, false));
return has_nulls_.Append(added_groups, false);
}

Status Merge(GroupedAggregator&& raw_other,
Expand All @@ -1757,48 +1760,74 @@ struct GroupedAnyImpl : public GroupedAggregator {

auto seen = seen_.mutable_data();
auto other_seen = other->seen_.data();
auto has_nulls = has_nulls_.mutable_data();
auto other_has_nulls = other->has_nulls_.data();

auto g = group_id_mapping.GetValues<uint32_t>(1);
for (int64_t other_g = 0; other_g < group_id_mapping.length; ++other_g, ++g) {
if (BitUtil::GetBit(other_seen, other_g)) BitUtil::SetBitTo(seen, *g, true);
if (BitUtil::GetBit(other_has_nulls, other_g)) {
BitUtil::SetBitTo(has_nulls, *g, true);
}
}
return Status::OK();
}

Status Consume(const ExecBatch& batch) override {
auto seen = seen_.mutable_data();
auto has_nulls = has_nulls_.mutable_data();

const auto& input = *batch[0].array();

auto g = batch[1].array()->GetValues<uint32_t>(1);
arrow::internal::VisitTwoBitBlocksVoid(
input.buffers[0], input.offset, input.buffers[1], input.offset, input.length,
[&](int64_t) { BitUtil::SetBitTo(seen, *g++, true); }, [&]() { g++; });
auto values = input.buffers[1]->data();
arrow::internal::VisitBitBlocksVoid(
input.buffers[0], input.offset, input.length,
[&](int64_t offset) {
BitUtil::SetBitTo(seen, *g,
BitUtil::GetBit(seen, *g) ||
BitUtil::GetBit(values, input.offset + offset));
g++;
},
[&]() { BitUtil::SetBitTo(has_nulls, *g++, true); });
return Status::OK();
}

Result<Datum> Finalize() override {
ARROW_ASSIGN_OR_RAISE(auto seen, seen_.Finish());
return std::make_shared<BooleanArray>(num_groups_, std::move(seen));
if (options_.skip_nulls) {
return std::make_shared<BooleanArray>(num_groups_, std::move(seen));
}
ARROW_ASSIGN_OR_RAISE(auto bitmap, has_nulls_.Finish());
// null if (~seen & has_nulls) -> not null if (seen | ~has_nulls)
::arrow::internal::BitmapOrNot(seen->data(), /*left_offset=*/0, bitmap->data(),
/*right_offset=*/0, num_groups_, /*out_offset=*/0,
bitmap->mutable_data());
return std::make_shared<BooleanArray>(num_groups_, std::move(seen),
std::move(bitmap));
}

std::shared_ptr<DataType> out_type() const override { return boolean(); }

int64_t num_groups_ = 0;
ScalarAggregateOptions options_;
TypedBufferBuilder<bool> seen_;
TypedBufferBuilder<bool> has_nulls_;
};

struct GroupedAllImpl : public GroupedAggregator {
Status Init(ExecContext* ctx, const FunctionOptions*) override {
Status Init(ExecContext* ctx, const FunctionOptions* options) override {
options_ = *checked_cast<const ScalarAggregateOptions*>(options);
seen_ = TypedBufferBuilder<bool>(ctx->memory_pool());
has_nulls_ = TypedBufferBuilder<bool>(ctx->memory_pool());
return Status::OK();
}

Status Resize(int64_t new_num_groups) override {
auto added_groups = new_num_groups - num_groups_;
num_groups_ = new_num_groups;
return seen_.Append(added_groups, true);
RETURN_NOT_OK(seen_.Append(added_groups, true));
return has_nulls_.Append(added_groups, false);
}

Status Merge(GroupedAggregator&& raw_other,
Expand All @@ -1807,17 +1836,23 @@ struct GroupedAllImpl : public GroupedAggregator {

auto seen = seen_.mutable_data();
auto other_seen = other->seen_.data();
auto has_nulls = has_nulls_.mutable_data();
auto other_has_nulls = other->has_nulls_.data();

auto g = group_id_mapping.GetValues<uint32_t>(1);
for (int64_t other_g = 0; other_g < group_id_mapping.length; ++other_g, ++g) {
BitUtil::SetBitTo(
seen, *g, BitUtil::GetBit(seen, *g) && BitUtil::GetBit(other_seen, other_g));
if (BitUtil::GetBit(other_has_nulls, other_g)) {
BitUtil::SetBitTo(has_nulls, *g, true);
}
}
return Status::OK();
}

Status Consume(const ExecBatch& batch) override {
auto seen = seen_.mutable_data();
auto has_nulls = has_nulls_.mutable_data();

const auto& input = *batch[0].array();

Expand All @@ -1832,7 +1867,7 @@ struct GroupedAllImpl : public GroupedAggregator {
BitUtil::GetBit(bitmap, input.offset + position));
g++;
},
[&]() { g++; });
[&]() { BitUtil::SetBitTo(has_nulls, *g++, true); });
} else {
arrow::internal::VisitBitBlocksVoid(
input.buffers[1], input.offset, input.length, [&](int64_t) { g++; },
Expand All @@ -1843,14 +1878,26 @@ struct GroupedAllImpl : public GroupedAggregator {

Result<Datum> Finalize() override {
ARROW_ASSIGN_OR_RAISE(auto seen, seen_.Finish());
return std::make_shared<BooleanArray>(num_groups_, std::move(seen));
if (options_.skip_nulls) {
return std::make_shared<BooleanArray>(num_groups_, std::move(seen));
}
ARROW_ASSIGN_OR_RAISE(auto bitmap, has_nulls_.Finish());
// null if (seen & has_nulls)
::arrow::internal::BitmapAnd(seen->data(), /*left_offset=*/0, bitmap->data(),
/*right_offset=*/0, num_groups_, /*out_offset=*/0,
bitmap->mutable_data());
::arrow::internal::InvertBitmap(bitmap->data(), /*offset=*/0, num_groups_,
bitmap->mutable_data(), /*dest_offset=*/0);
return std::make_shared<BooleanArray>(num_groups_, std::move(seen),
std::move(bitmap));
}

std::shared_ptr<DataType> out_type() const override { return boolean(); }

int64_t num_groups_ = 0;
ScalarAggregateOptions options_;
TypedBufferBuilder<bool> seen_;
TypedBufferBuilder<bool> has_nulls_;
};

} // namespace
Expand Down Expand Up @@ -2122,7 +2169,8 @@ const FunctionDoc hash_count_doc{"Count the number of null / non-null values",

const FunctionDoc hash_sum_doc{"Sum values of a numeric array",
("Null values are ignored."),
{"array", "group_id_array"}};
{"array", "group_id_array"},
"ScalarAggregateOptions"};

const FunctionDoc hash_product_doc{
"Compute product of values of a numeric array",
Expand All @@ -2132,7 +2180,8 @@ const FunctionDoc hash_product_doc{

const FunctionDoc hash_mean_doc{"Average values of a numeric array",
("Null values are ignored."),
{"array", "group_id_array"}};
{"array", "group_id_array"},
"ScalarAggregateOptions"};

const FunctionDoc hash_stddev_doc{
"Calculate the standard deviation of a numeric array",
Expand All @@ -2155,7 +2204,8 @@ const FunctionDoc hash_tdigest_doc{
("By default, the 0.5 quantile (median) is returned.\n"
"Nulls and NaNs are ignored.\n"
"A null array is returned if there are no valid data points."),
{"array", "group_id_array"}};
{"array", "group_id_array"},
"TDigestOptions"};

const FunctionDoc hash_min_max_doc{
"Compute the minimum and maximum values of a numeric array",
Expand All @@ -2175,6 +2225,9 @@ const FunctionDoc hash_all_doc{"Test whether all elements evaluate to true",

void RegisterHashAggregateBasic(FunctionRegistry* registry) {
static auto default_scalar_aggregate_options = ScalarAggregateOptions::Defaults();
static auto default_tdigest_options = TDigestOptions::Defaults();
static auto default_variance_options = VarianceOptions::Defaults();

{
static auto default_count_options = CountOptions::Defaults();
auto func = std::make_shared<HashAggregateFunction>(
Expand Down Expand Up @@ -2222,7 +2275,6 @@ void RegisterHashAggregateBasic(FunctionRegistry* registry) {
DCHECK_OK(registry->AddFunction(std::move(func)));
}

static auto default_variance_options = VarianceOptions::Defaults();
{
auto func = std::make_shared<HashAggregateFunction>(
"hash_stddev", Arity::Binary(), &hash_stddev_doc, &default_variance_options);
Expand All @@ -2247,7 +2299,6 @@ void RegisterHashAggregateBasic(FunctionRegistry* registry) {
DCHECK_OK(registry->AddFunction(std::move(func)));
}

static auto default_tdigest_options = TDigestOptions::Defaults();
{
auto func = std::make_shared<HashAggregateFunction>(
"hash_tdigest", Arity::Binary(), &hash_tdigest_doc, &default_tdigest_options);
Expand All @@ -2264,7 +2315,6 @@ void RegisterHashAggregateBasic(FunctionRegistry* registry) {
auto func = std::make_shared<HashAggregateFunction>(
"hash_min_max", Arity::Binary(), &hash_min_max_doc,
&default_scalar_aggregate_options);
DCHECK_OK(AddHashAggKernels({boolean()}, GroupedSumFactory::Make, func.get()));
DCHECK_OK(AddHashAggKernels(NumericTypes(), GroupedMinMaxFactory::Make, func.get()));
// Type parameters are ignored
DCHECK_OK(AddHashAggKernels({decimal128(1, 1), decimal256(1, 1)},
Expand All @@ -2273,15 +2323,15 @@ void RegisterHashAggregateBasic(FunctionRegistry* registry) {
}

{
auto func = std::make_shared<HashAggregateFunction>("hash_any", Arity::Binary(),
&hash_any_doc);
auto func = std::make_shared<HashAggregateFunction>(
"hash_any", Arity::Binary(), &hash_any_doc, &default_scalar_aggregate_options);
DCHECK_OK(func->AddKernel(MakeKernel(boolean(), HashAggregateInit<GroupedAnyImpl>)));
DCHECK_OK(registry->AddFunction(std::move(func)));
}

{
auto func = std::make_shared<HashAggregateFunction>("hash_all", Arity::Binary(),
&hash_all_doc);
auto func = std::make_shared<HashAggregateFunction>(
"hash_all", Arity::Binary(), &hash_all_doc, &default_scalar_aggregate_options);
DCHECK_OK(func->AddKernel(MakeKernel(boolean(), HashAggregateInit<GroupedAllImpl>)));
DCHECK_OK(registry->AddFunction(std::move(func)));
}
Expand Down
47 changes: 34 additions & 13 deletions cpp/src/arrow/compute/kernels/hash_aggregate_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1083,6 +1083,7 @@ TEST(GroupBy, MinMaxDecimal) {
}

TEST(GroupBy, AnyAndAll) {
ScalarAggregateOptions options(/*skip_nulls=*/false);
for (bool use_threads : {true, false}) {
SCOPED_TRACE(use_threads ? "parallel/merged" : "serial");

Expand All @@ -1094,37 +1095,57 @@ TEST(GroupBy, AnyAndAll) {
R"([
[false, 2],
[null, 3],
[null, 4],
[false, 4],
[true, 5],
[false, null],
[true, 1],
[true, 2]
])",
R"([
[true, 2],
[false, 2],
[false, null],
[null, 3]
])"});

ASSERT_OK_AND_ASSIGN(Datum aggregated_and_grouped,
internal::GroupBy({table->GetColumnByName("argument"),
table->GetColumnByName("argument")},
{table->GetColumnByName("key")},
{
{"hash_any", nullptr},
{"hash_all", nullptr},
},
use_threads));
internal::GroupBy(
{
table->GetColumnByName("argument"),
table->GetColumnByName("argument"),
table->GetColumnByName("argument"),
table->GetColumnByName("argument"),
},
{table->GetColumnByName("key")},
{
{"hash_any", nullptr},
{"hash_all", nullptr},
{"hash_any", &options},
{"hash_all", &options},
},
use_threads));
SortBy({"key_0"}, &aggregated_and_grouped);

// Group 1: trues and nulls
// Group 2: trues and falses
// Group 3: nulls
// Group 4: falses and nulls
// Group 5: trues
// Group null: falses
AssertDatumsEqual(ArrayFromJSON(struct_({
field("hash_any", boolean()),
field("hash_all", boolean()),
field("hash_any", boolean()),
field("hash_all", boolean()),
field("key_0", int64()),
}),
R"([
[true, true, 1],
[true, false, 2],
[false, true, 3],
[false, false, null]
[true, true, true, null, 1],
[true, false, true, false, 2],
[false, true, null, null, 3],
[false, false, null, false, 4],
[true, true, true, true, 5],
[false, false, false, false, null]
])"),
aggregated_and_grouped,
/*verbose=*/true);
Expand Down
Loading