Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
141 changes: 141 additions & 0 deletions cpp/src/arrow/compute/kernels/hash_aggregate.cc
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@
#include "arrow/util/int128_internal.h"
#include "arrow/util/make_unique.h"
#include "arrow/util/task_group.h"
#include "arrow/util/tdigest.h"
#include "arrow/util/thread_pool.h"
#include "arrow/visitor_inline.h"

Expand Down Expand Up @@ -1311,6 +1312,126 @@ struct GroupedVarStdFactory {
InputType argument_type;
};

// ----------------------------------------------------------------------
// TDigest implementation

using arrow::internal::TDigest;

template <typename Type>
struct GroupedTDigestImpl : public GroupedAggregator {
using CType = typename Type::c_type;

Status Init(ExecContext* ctx, const FunctionOptions* options) override {
options_ = *checked_cast<const TDigestOptions*>(options);
ctx_ = ctx;
pool_ = ctx->memory_pool();
return Status::OK();
}

Status Resize(int64_t new_num_groups) override {
const int64_t added_groups = new_num_groups - tdigests_.size();
tdigests_.reserve(new_num_groups);
for (int64_t i = 0; i < added_groups; i++) {
tdigests_.emplace_back(options_.delta, options_.buffer_size);
}
return Status::OK();
}

Status Consume(const ExecBatch& batch) override {
auto g = batch[1].array()->GetValues<uint32_t>(1);
VisitArrayDataInline<Type>(
*batch[0].array(),
[&](typename TypeTraits<Type>::CType value) {
this->tdigests_[*g].NanAdd(value);
++g;
},
[&] { ++g; });
return Status::OK();
}

Status Merge(GroupedAggregator&& raw_other,
const ArrayData& group_id_mapping) override {
auto other = checked_cast<GroupedTDigestImpl*>(&raw_other);

auto g = group_id_mapping.GetValues<uint32_t>(1);
std::vector<TDigest> other_tdigest(1);
for (int64_t other_g = 0; other_g < group_id_mapping.length; ++other_g, ++g) {
other_tdigest[0] = std::move(other->tdigests_[other_g]);
tdigests_[*g].Merge(&other_tdigest);
Comment on lines +1359 to +1360
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't remember why but this api does looks awkward :(

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It might be worth considering whether Merge should take an iterator pair (and whether TDigest::Merge should also take an iterator pair or vector of pointers instead of a vector of structs).

}

return Status::OK();
}

Result<Datum> Finalize() override {
std::shared_ptr<Buffer> null_bitmap;
ARROW_ASSIGN_OR_RAISE(
std::shared_ptr<Buffer> values,
AllocateBuffer(tdigests_.size() * options_.q.size() * sizeof(double), pool_));
int64_t null_count = 0;
const int64_t slot_length = options_.q.size();

double* results = reinterpret_cast<double*>(values->mutable_data());
for (int64_t i = 0; static_cast<size_t>(i) < tdigests_.size(); ++i) {
if (!tdigests_[i].is_empty()) {
for (int64_t j = 0; j < slot_length; j++) {
results[i * slot_length + j] = tdigests_[i].Quantile(options_.q[j]);
}
continue;
}

if (!null_bitmap) {
ARROW_ASSIGN_OR_RAISE(null_bitmap, AllocateBitmap(tdigests_.size(), pool_));
BitUtil::SetBitsTo(null_bitmap->mutable_data(), 0, tdigests_.size(), true);
}
null_count++;
BitUtil::SetBitTo(null_bitmap->mutable_data(), i, false);
std::fill(&results[i * slot_length], &results[(i + 1) * slot_length], 0.0);
}

auto child = ArrayData::Make(float64(), tdigests_.size() * options_.q.size(),
{nullptr, std::move(values)}, /*null_count=*/0);
return ArrayData::Make(out_type(), tdigests_.size(), {std::move(null_bitmap)},
{std::move(child)}, null_count);
}

std::shared_ptr<DataType> out_type() const override {
return fixed_size_list(float64(), static_cast<int32_t>(options_.q.size()));
}

TDigestOptions options_;
std::vector<TDigest> tdigests_;
ExecContext* ctx_;
MemoryPool* pool_;
};

struct GroupedTDigestFactory {
template <typename T>
enable_if_number<T, Status> Visit(const T&) {
kernel =
MakeKernel(std::move(argument_type), HashAggregateInit<GroupedTDigestImpl<T>>);
return Status::OK();
}

Status Visit(const HalfFloatType& type) {
return Status::NotImplemented("Computing t-digest of data of type ", type);
}

Status Visit(const DataType& type) {
return Status::NotImplemented("Computing t-digest of data of type ", type);
}

static Result<HashAggregateKernel> Make(const std::shared_ptr<DataType>& type) {
GroupedTDigestFactory factory;
factory.argument_type = InputType::Array(type);
RETURN_NOT_OK(VisitTypeInline(*type, &factory));
return std::move(factory.kernel);
}

HashAggregateKernel kernel;
InputType argument_type;
};

// ----------------------------------------------------------------------
// MinMax implementation

Expand Down Expand Up @@ -1863,6 +1984,13 @@ const FunctionDoc hash_variance_doc{
"to satisfy `ddof`, null is returned."),
{"array", "group_id_array"}};

const FunctionDoc hash_tdigest_doc{
"Calculate approximate quantiles of a numeric array with the T-Digest algorithm",
("By default, the 0.5 quantile (median) is returned.\n"
"Nulls and NaNs are ignored.\n"
"A null array is returned if there are no valid data points."),
{"array", "group_id_array"}};

const FunctionDoc hash_min_max_doc{
"Compute the minimum and maximum values of a numeric array",
("Null values are ignored by default.\n"
Expand Down Expand Up @@ -1939,6 +2067,19 @@ void RegisterHashAggregateBasic(FunctionRegistry* registry) {
DCHECK_OK(registry->AddFunction(std::move(func)));
}

static auto default_tdigest_options = TDigestOptions::Defaults();
{
auto func = std::make_shared<HashAggregateFunction>(
"hash_tdigest", Arity::Binary(), &hash_tdigest_doc, &default_tdigest_options);
DCHECK_OK(
AddHashAggKernels(SignedIntTypes(), GroupedTDigestFactory::Make, func.get()));
DCHECK_OK(
AddHashAggKernels(UnsignedIntTypes(), GroupedTDigestFactory::Make, func.get()));
DCHECK_OK(
AddHashAggKernels(FloatingPointTypes(), GroupedTDigestFactory::Make, func.get()));
DCHECK_OK(registry->AddFunction(std::move(func)));
}

{
static auto default_scalar_aggregate_options = ScalarAggregateOptions::Defaults();
auto func = std::make_shared<HashAggregateFunction>(
Expand Down
51 changes: 51 additions & 0 deletions cpp/src/arrow/compute/kernels/hash_aggregate_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -814,6 +814,57 @@ TEST(GroupBy, VarianceAndStddev) {
/*verbose=*/true);
}

TEST(GroupBy, TDigest) {
auto batch = RecordBatchFromJSON(
schema({field("argument", float64()), field("key", int64())}), R"([
[1, 1],
[null, 1],
[0, 2],
[null, 3],
[4, null],
[3, 1],
[0, 2],
[-1, 2],
[1, null],
[NaN, 3]
])");

TDigestOptions options1(std::vector<double>{0.5, 0.9, 0.99});
TDigestOptions options2(std::vector<double>{0.5, 0.9, 0.99}, /*delta=*/50,
/*buffer_size=*/1024);
ASSERT_OK_AND_ASSIGN(Datum aggregated_and_grouped,
internal::GroupBy(
{
batch->GetColumnByName("argument"),
batch->GetColumnByName("argument"),
batch->GetColumnByName("argument"),
},
{
batch->GetColumnByName("key"),
},
{
{"hash_tdigest", nullptr},
{"hash_tdigest", &options1},
{"hash_tdigest", &options2},
}));

AssertDatumsApproxEqual(
ArrayFromJSON(struct_({
field("hash_tdigest", fixed_size_list(float64(), 1)),
field("hash_tdigest", fixed_size_list(float64(), 3)),
field("hash_tdigest", fixed_size_list(float64(), 3)),
field("key_0", int64()),
}),
R"([
[[1.0], [1.0, 3.0, 3.0], [1.0, 3.0, 3.0], 1],
[[0.0], [0.0, 0.0, 0.0], [0.0, 0.0, 0.0], 2],
[null, null, null, 3],
[[1.0], [1.0, 4.0, 4.0], [1.0, 4.0, 4.0], null]
])"),
aggregated_and_grouped,
/*verbose=*/true);
}

TEST(GroupBy, MinMaxOnly) {
for (bool use_exec_plan : {false, true}) {
for (bool use_threads : {true, false}) {
Expand Down