Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
92 changes: 81 additions & 11 deletions cpp/src/arrow/compute/kernels/hash_aggregate.cc
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@
#include "arrow/compute/kernels/aggregate_var_std_internal.h"
#include "arrow/compute/kernels/common.h"
#include "arrow/compute/kernels/util_internal.h"
#include "arrow/record_batch.h"
#include "arrow/util/bit_run_reader.h"
#include "arrow/util/bitmap_ops.h"
#include "arrow/util/bitmap_writer.h"
Expand Down Expand Up @@ -1970,6 +1971,7 @@ struct GroupedCountDistinctImpl : public GroupedAggregator {
Status Init(ExecContext* ctx, const FunctionOptions* options) override {
ctx_ = ctx;
pool_ = ctx->memory_pool();
options_ = checked_cast<const CountOptions&>(*options);
return Status::OK();
}

Expand Down Expand Up @@ -2013,8 +2015,21 @@ struct GroupedCountDistinctImpl : public GroupedAggregator {

ARROW_ASSIGN_OR_RAISE(auto uniques, grouper_->GetUniques());
auto* g = uniques[1].array()->GetValues<uint32_t>(1);
for (int64_t i = 0; i < uniques.length; i++) {
counts[g[i]]++;
const auto& items = *uniques[0].array();
const auto* valid = items.GetValues<uint8_t>(0, 0);
if (options_.mode == CountOptions::ALL ||
(options_.mode == CountOptions::ONLY_VALID && !valid)) {
for (int64_t i = 0; i < uniques.length; i++) {
counts[g[i]]++;
}
} else if (options_.mode == CountOptions::ONLY_VALID) {
for (int64_t i = 0; i < uniques.length; i++) {
counts[g[i]] += BitUtil::GetBit(valid, items.offset + i);
}
} else if (valid) { // ONLY_NULL
for (int64_t i = 0; i < uniques.length; i++) {
counts[g[i]] += !BitUtil::GetBit(valid, items.offset + i);
}
}

return ArrayData::Make(int64(), num_groups_, {nullptr, std::move(values)},
Expand All @@ -2026,6 +2041,7 @@ struct GroupedCountDistinctImpl : public GroupedAggregator {
ExecContext* ctx_;
MemoryPool* pool_;
int64_t num_groups_;
CountOptions options_;
std::unique_ptr<Grouper> grouper_;
std::shared_ptr<DataType> out_type_;
};
Expand All @@ -2036,7 +2052,56 @@ struct GroupedDistinctImpl : public GroupedCountDistinctImpl {
ARROW_ASSIGN_OR_RAISE(auto groupings, grouper_->MakeGroupings(
*uniques[1].array_as<UInt32Array>(),
static_cast<uint32_t>(num_groups_), ctx_));
return grouper_->ApplyGroupings(*groupings, *uniques[0].make_array(), ctx_);
ARROW_ASSIGN_OR_RAISE(
auto list, grouper_->ApplyGroupings(*groupings, *uniques[0].make_array(), ctx_));
auto values = list->values();
DCHECK_EQ(values->offset(), 0);
int32_t* offsets = reinterpret_cast<int32_t*>(list->value_offsets()->mutable_data());
if (options_.mode == CountOptions::ALL ||
(options_.mode == CountOptions::ONLY_VALID && values->null_count() == 0)) {
return list;
} else if (options_.mode == CountOptions::ONLY_VALID) {
int32_t prev_offset = offsets[0];
for (int64_t i = 0; i < list->length(); i++) {
const int32_t slot_length = offsets[i + 1] - prev_offset;
const int64_t null_count =
slot_length - arrow::internal::CountSetBits(values->null_bitmap()->data(),
prev_offset, slot_length);
DCHECK_LE(null_count, 1);
const int32_t offset = null_count > 0 ? slot_length - 1 : slot_length;
prev_offset = offsets[i + 1];
offsets[i + 1] = offsets[i] + offset;
}
auto filter =
std::make_shared<BooleanArray>(values->length(), values->null_bitmap());
ARROW_ASSIGN_OR_RAISE(
auto new_values,
Filter(std::move(values), filter, FilterOptions(FilterOptions::DROP), ctx_));
return std::make_shared<ListArray>(list->type(), list->length(),
list->value_offsets(), new_values.make_array());
}
// ONLY_NULL
if (values->null_count() == 0) {
std::fill(offsets + 1, offsets + list->length() + 1, offsets[0]);
} else {
int32_t prev_offset = offsets[0];
for (int64_t i = 0; i < list->length(); i++) {
const int32_t slot_length = offsets[i + 1] - prev_offset;
const int64_t null_count =
slot_length - arrow::internal::CountSetBits(values->null_bitmap()->data(),
prev_offset, slot_length);
const int32_t offset = null_count > 0 ? 1 : 0;
prev_offset = offsets[i + 1];
offsets[i + 1] = offsets[i] + offset;
}
}
ARROW_ASSIGN_OR_RAISE(
auto new_values,
MakeArrayOfNull(out_type_,
list->length() > 0 ? offsets[list->length()] - offsets[0] : 0,
pool_));
return std::make_shared<ListArray>(list->type(), list->length(),
list->value_offsets(), std::move(new_values));
}

std::shared_ptr<DataType> out_type() const override { return list(out_type_); }
Expand Down Expand Up @@ -2383,22 +2448,26 @@ const FunctionDoc hash_all_doc{"Test whether all elements evaluate to true",

const FunctionDoc hash_count_distinct_doc{
"Count the distinct values in each group",
("Nulls are counted. NaNs and signed zeroes are not normalized."),
{"array", "group_id_array"}};
("Whether nulls/values are counted is controlled by CountOptions.\n"
"NaNs and signed zeroes are not normalized."),
{"array", "group_id_array"},
"CountOptions"};

const FunctionDoc hash_distinct_doc{
"Keep the distinct values in each group",
("Nulls are kept. NaNs and signed zeroes are not normalized."),
{"array", "group_id_array"}};
("Whether nulls/values are kept is controlled by CountOptions.\n"
"NaNs and signed zeroes are not normalized."),
{"array", "group_id_array"},
"CountOptions"};
} // namespace

void RegisterHashAggregateBasic(FunctionRegistry* registry) {
static auto default_count_options = CountOptions::Defaults();
static auto default_scalar_aggregate_options = ScalarAggregateOptions::Defaults();
static auto default_tdigest_options = TDigestOptions::Defaults();
static auto default_variance_options = VarianceOptions::Defaults();

{
static auto default_count_options = CountOptions::Defaults();
auto func = std::make_shared<HashAggregateFunction>(
"hash_count", Arity::Binary(), &hash_count_doc, &default_count_options);

Expand Down Expand Up @@ -2516,15 +2585,16 @@ void RegisterHashAggregateBasic(FunctionRegistry* registry) {

{
auto func = std::make_shared<HashAggregateFunction>(
"hash_count_distinct", Arity::Binary(), &hash_count_distinct_doc);
"hash_count_distinct", Arity::Binary(), &hash_count_distinct_doc,
&default_count_options);
DCHECK_OK(func->AddKernel(
MakeKernel(ValueDescr::ARRAY, GroupedDistinctInit<GroupedCountDistinctImpl>)));
DCHECK_OK(registry->AddFunction(std::move(func)));
}

{
auto func = std::make_shared<HashAggregateFunction>("hash_distinct", Arity::Binary(),
&hash_distinct_doc);
auto func = std::make_shared<HashAggregateFunction>(
"hash_distinct", Arity::Binary(), &hash_distinct_doc, &default_count_options);
DCHECK_OK(func->AddKernel(
MakeKernel(ValueDescr::ARRAY, GroupedDistinctInit<GroupedDistinctImpl>)));
DCHECK_OK(registry->AddFunction(std::move(func)));
Expand Down
Loading