Skip to content
Closed
117 changes: 117 additions & 0 deletions cpp/src/arrow/compute/kernels/hash_aggregate.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1962,6 +1962,97 @@ struct GroupedAllImpl : public GroupedBooleanAggregator<GroupedAllImpl> {
num_groups, /*out_offset=*/0, no_nulls);
}
};

// ----------------------------------------------------------------------
// CountDistinct/Distinct implementation

struct GroupedCountDistinctImpl : public GroupedAggregator {
Status Init(ExecContext* ctx, const FunctionOptions* options) override {
ctx_ = ctx;
pool_ = ctx->memory_pool();
return Status::OK();
}

Status Resize(int64_t new_num_groups) override {
num_groups_ = new_num_groups;
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@michalursa what do you think of adding Grouper::Reserve(additional_capacity_hint)?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The keys here are pairs of [value, group_id] so just having the number of groups doesn't give much of a capacity hint.

return Status::OK();
}

Status Consume(const ExecBatch& batch) override {
return grouper_->Consume(batch).status();
}

Status Merge(GroupedAggregator&& raw_other,
const ArrayData& group_id_mapping) override {
auto other = checked_cast<GroupedCountDistinctImpl*>(&raw_other);

// Get (value, group_id) pairs, then translate the group IDs and consume them
// ourselves
ARROW_ASSIGN_OR_RAISE(auto uniques, other->grouper_->GetUniques());
ARROW_ASSIGN_OR_RAISE(auto remapped_g,
AllocateBuffer(uniques.length * sizeof(uint32_t), pool_));

const auto* g_mapping = group_id_mapping.GetValues<uint32_t>(1);
const auto* other_g = uniques[1].array()->GetValues<uint32_t>(1);
auto* g = reinterpret_cast<uint32_t*>(remapped_g->mutable_data());

for (int64_t i = 0; i < uniques.length; i++) {
g[i] = g_mapping[other_g[i]];
}
uniques.values[1] =
ArrayData::Make(uint32(), uniques.length, {nullptr, std::move(remapped_g)});

return Consume(std::move(uniques));
}

Result<Datum> Finalize() override {
ARROW_ASSIGN_OR_RAISE(std::shared_ptr<Buffer> values,
AllocateBuffer(num_groups_ * sizeof(int64_t), pool_));
int64_t* counts = reinterpret_cast<int64_t*>(values->mutable_data());
std::fill(counts, counts + num_groups_, 0);

ARROW_ASSIGN_OR_RAISE(auto uniques, grouper_->GetUniques());
auto* g = uniques[1].array()->GetValues<uint32_t>(1);
for (int64_t i = 0; i < uniques.length; i++) {
counts[g[i]]++;
}

return ArrayData::Make(int64(), num_groups_, {nullptr, std::move(values)},
/*null_count=*/0);
}

std::shared_ptr<DataType> out_type() const override { return int64(); }

ExecContext* ctx_;
MemoryPool* pool_;
int64_t num_groups_;
std::unique_ptr<Grouper> grouper_;
std::shared_ptr<DataType> out_type_;
};

struct GroupedDistinctImpl : public GroupedCountDistinctImpl {
Result<Datum> Finalize() override {
ARROW_ASSIGN_OR_RAISE(auto uniques, grouper_->GetUniques());
ARROW_ASSIGN_OR_RAISE(auto groupings, grouper_->MakeGroupings(
*uniques[1].array_as<UInt32Array>(),
static_cast<uint32_t>(num_groups_), ctx_));
return grouper_->ApplyGroupings(*groupings, *uniques[0].make_array(), ctx_);
}

std::shared_ptr<DataType> out_type() const override { return list(out_type_); }
};

template <typename Impl>
Result<std::unique_ptr<KernelState>> GroupedDistinctInit(KernelContext* ctx,
const KernelInitArgs& args) {
ARROW_ASSIGN_OR_RAISE(auto impl, HashAggregateInit<Impl>(ctx, args));
auto instance = static_cast<Impl*>(impl.get());
instance->out_type_ = args.inputs[0].type;
ARROW_ASSIGN_OR_RAISE(instance->grouper_,
Grouper::Make(args.inputs, ctx->exec_context()));
return std::move(impl);
}

} // namespace

Result<std::vector<const HashAggregateKernel*>> GetKernels(
Expand Down Expand Up @@ -2289,6 +2380,16 @@ const FunctionDoc hash_all_doc{"Test whether all elements evaluate to true",
("Null values are ignored."),
{"array", "group_id_array"},
"ScalarAggregateOptions"};

const FunctionDoc hash_count_distinct_doc{
"Count the distinct values in each group",
("Nulls are counted. NaNs and signed zeroes are not normalized."),
{"array", "group_id_array"}};

const FunctionDoc hash_distinct_doc{
"Keep the distinct values in each group",
("Nulls are kept. NaNs and signed zeroes are not normalized."),
{"array", "group_id_array"}};
} // namespace

void RegisterHashAggregateBasic(FunctionRegistry* registry) {
Expand Down Expand Up @@ -2412,6 +2513,22 @@ void RegisterHashAggregateBasic(FunctionRegistry* registry) {
DCHECK_OK(func->AddKernel(MakeKernel(boolean(), HashAggregateInit<GroupedAllImpl>)));
DCHECK_OK(registry->AddFunction(std::move(func)));
}

{
auto func = std::make_shared<HashAggregateFunction>(
"hash_count_distinct", Arity::Binary(), &hash_count_distinct_doc);
DCHECK_OK(func->AddKernel(
MakeKernel(ValueDescr::ARRAY, GroupedDistinctInit<GroupedCountDistinctImpl>)));
DCHECK_OK(registry->AddFunction(std::move(func)));
}

{
auto func = std::make_shared<HashAggregateFunction>("hash_distinct", Arity::Binary(),
&hash_distinct_doc);
DCHECK_OK(func->AddKernel(
MakeKernel(ValueDescr::ARRAY, GroupedDistinctInit<GroupedDistinctImpl>)));
DCHECK_OK(registry->AddFunction(std::move(func)));
}
}

} // namespace internal
Expand Down
178 changes: 178 additions & 0 deletions cpp/src/arrow/compute/kernels/hash_aggregate_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1239,6 +1239,184 @@ TEST(GroupBy, AnyAndAll) {
}
}

TEST(GroupBy, CountDistinct) {
for (bool use_threads : {true, false}) {
SCOPED_TRACE(use_threads ? "parallel/merged" : "serial");

auto table =
TableFromJSON(schema({field("argument", float64()), field("key", int64())}), {R"([
[1, 1],
[1, 1]
])",
R"([
[0, 2],
[null, 3]
])",
R"([
[4, null],
[1, 3]
])",
R"([
[0, 2],
[-1, 2]
])",
R"([
[1, null],
[NaN, 3]
])",
R"([
[2, null],
[3, null]
])"});

ASSERT_OK_AND_ASSIGN(Datum aggregated_and_grouped,
internal::GroupBy(
{
table->GetColumnByName("argument"),
},
{
table->GetColumnByName("key"),
},
{
{"hash_count_distinct", nullptr},
},
use_threads));
SortBy({"key_0"}, &aggregated_and_grouped);
ValidateOutput(aggregated_and_grouped);

AssertDatumsEqual(ArrayFromJSON(struct_({
field("hash_count_distinct", int64()),
field("key_0", int64()),
}),
R"([
[1, 1],
[2, 2],
[3, 3],
[4, null]
])"),
aggregated_and_grouped,
/*verbose=*/true);

table =
TableFromJSON(schema({field("argument", utf8()), field("key", int64())}), {R"([
["foo", 1],
["foo", 1]
])",
R"([
["bar", 2],
[null, 3]
])",
R"([
["baz", null],
["foo", 3]
])",
R"([
["bar", 2],
["spam", 2]
])",
R"([
["eggs", null],
["ham", 3]
])",
R"([
["a", null],
["b", null]
])"});

ASSERT_OK_AND_ASSIGN(aggregated_and_grouped,
internal::GroupBy(
{
table->GetColumnByName("argument"),
},
{
table->GetColumnByName("key"),
},
{
{"hash_count_distinct", nullptr},
},
use_threads));
ValidateOutput(aggregated_and_grouped);
SortBy({"key_0"}, &aggregated_and_grouped);

AssertDatumsEqual(ArrayFromJSON(struct_({
field("hash_count_distinct", int64()),
field("key_0", int64()),
}),
R"([
[1, 1],
[2, 2],
[3, 3],
[4, null]
])"),
aggregated_and_grouped,
/*verbose=*/true);
}
}

TEST(GroupBy, Distinct) {
for (bool use_threads : {true, false}) {
SCOPED_TRACE(use_threads ? "parallel/merged" : "serial");

auto table =
TableFromJSON(schema({field("argument", utf8()), field("key", int64())}), {R"([
["foo", 1],
["foo", 1]
])",
R"([
["bar", 2],
[null, 3]
])",
R"([
["baz", null],
["foo", 3]
])",
R"([
["bar", 2],
["spam", 2]
])",
R"([
["eggs", null],
["ham", 3]
])",
R"([
["a", null],
["b", null]
])"});

ASSERT_OK_AND_ASSIGN(auto aggregated_and_grouped,
internal::GroupBy(
{
table->GetColumnByName("argument"),
},
{
table->GetColumnByName("key"),
},
{
{"hash_distinct", nullptr},
},
use_threads));
ValidateOutput(aggregated_and_grouped);
SortBy({"key_0"}, &aggregated_and_grouped);

// Order of sub-arrays is not stable
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That's annoying.

auto struct_arr = aggregated_and_grouped.array_as<StructArray>();
auto distinct_arr = checked_pointer_cast<ListArray>(struct_arr->field(0));
auto sort = [](const Array& arr) -> std::shared_ptr<Array> {
EXPECT_OK_AND_ASSIGN(auto indices, SortIndices(arr));
EXPECT_OK_AND_ASSIGN(auto sorted, Take(arr, indices));
return sorted.make_array();
};
AssertDatumsEqual(ArrayFromJSON(utf8(), R"(["foo"])"),
sort(*distinct_arr->value_slice(0)), /*verbose=*/true);
AssertDatumsEqual(ArrayFromJSON(utf8(), R"(["bar", "spam"])"),
sort(*distinct_arr->value_slice(1)), /*verbose=*/true);
AssertDatumsEqual(ArrayFromJSON(utf8(), R"(["foo", "ham", null])"),
sort(*distinct_arr->value_slice(2)), /*verbose=*/true);
AssertDatumsEqual(ArrayFromJSON(utf8(), R"(["a", "b", "baz", "eggs"])"),
sort(*distinct_arr->value_slice(3)), /*verbose=*/true);
}
}

TEST(GroupBy, CountAndSum) {
auto batch = RecordBatchFromJSON(
schema({field("argument", float64()), field("key", int64())}), R"([
Expand Down