From 050f36f61f185ec95bfcc5cedd1e21174d212225 Mon Sep 17 00:00:00 2001
From: David Li
Date: Tue, 14 Sep 2021 09:20:12 -0400
Subject: [PATCH] ARROW-13849: [C++] Add min/max scalar aggregates
---
.../arrow/compute/kernels/aggregate_basic.cc | 56 +++++++++++
.../compute/kernels/aggregate_internal.h | 4 +
.../arrow/compute/kernels/aggregate_test.cc | 25 ++++-
.../arrow/compute/kernels/hash_aggregate.cc | 96 ++++++++++++++-----
.../compute/kernels/hash_aggregate_test.cc | 53 ++++++++++
docs/source/cpp/compute.rst | 8 ++
6 files changed, 215 insertions(+), 27 deletions(-)
diff --git a/cpp/src/arrow/compute/kernels/aggregate_basic.cc b/cpp/src/arrow/compute/kernels/aggregate_basic.cc
index 323d1e7ca9e..1851ba604ff 100644
--- a/cpp/src/arrow/compute/kernels/aggregate_basic.cc
+++ b/cpp/src/arrow/compute/kernels/aggregate_basic.cc
@@ -282,6 +282,43 @@ Result> MinMaxInit(KernelContext* ctx,
return visitor.Create();
}
+// For "min" and "max" functions: override finalize and return the actual value
+template
+void AddMinOrMaxAggKernel(ScalarAggregateFunction* func,
+ ScalarAggregateFunction* min_max_func) {
+ auto sig = KernelSignature::Make(
+ {InputType(ValueDescr::ANY)},
+ OutputType([](KernelContext*,
+ const std::vector& descrs) -> Result {
+ // any[T] -> scalar[T]
+ return ValueDescr::Scalar(descrs.front().type);
+ }));
+
+ auto init = [min_max_func](
+ KernelContext* ctx,
+ const KernelInitArgs& args) -> Result> {
+ std::vector inputs = args.inputs;
+ ARROW_ASSIGN_OR_RAISE(auto kernel, min_max_func->DispatchBest(&inputs));
+ KernelInitArgs new_args{kernel, inputs, args.options};
+ return kernel->init(ctx, new_args);
+ };
+
+ auto finalize = [](KernelContext* ctx, Datum* out) -> Status {
+ Datum temp;
+ RETURN_NOT_OK(checked_cast(ctx->state())->Finalize(ctx, &temp));
+ const auto& result = temp.scalar_as();
+ DCHECK(result.is_valid);
+ *out = result.value[static_cast(min_or_max)];
+ return Status::OK();
+ };
+
+ // Note SIMD level is always NONE, but the convenience kernel will
+ // dispatch to an appropriate implementation
+ ScalarAggregateKernel kernel(std::move(sig), std::move(init), AggregateConsume,
+ AggregateMerge, std::move(finalize));
+ DCHECK_OK(func->AddKernel(kernel));
+}
+
// ----------------------------------------------------------------------
// Any implementation
@@ -663,6 +700,13 @@ const FunctionDoc min_max_doc{"Compute the minimum and maximum values of a numer
{"array"},
"ScalarAggregateOptions"};
+const FunctionDoc min_or_max_doc{
+ "Compute the minimum or maximum values of a numeric array",
+ ("Null values are ignored by default.\n"
+ "This can be changed through ScalarAggregateOptions."),
+ {"array"},
+ "ScalarAggregateOptions"};
+
const FunctionDoc any_doc{"Test whether any element in a boolean array evaluates to true",
("Null values are ignored by default.\n"
"If null values are taken into account by setting "
@@ -781,6 +825,18 @@ void RegisterScalarAggregateBasic(FunctionRegistry* registry) {
}
#endif
+ auto min_max_func = func.get();
+ DCHECK_OK(registry->AddFunction(std::move(func)));
+
+ // Add min/max as convenience functions
+ func = std::make_shared("min", Arity::Unary(), &min_or_max_doc,
+ &default_scalar_aggregate_options);
+ aggregate::AddMinOrMaxAggKernel(func.get(), min_max_func);
+ DCHECK_OK(registry->AddFunction(std::move(func)));
+
+ func = std::make_shared("max", Arity::Unary(), &min_or_max_doc,
+ &default_scalar_aggregate_options);
+ aggregate::AddMinOrMaxAggKernel(func.get(), min_max_func);
DCHECK_OK(registry->AddFunction(std::move(func)));
func = std::make_shared(
diff --git a/cpp/src/arrow/compute/kernels/aggregate_internal.h b/cpp/src/arrow/compute/kernels/aggregate_internal.h
index 13687d22820..b0aced3e346 100644
--- a/cpp/src/arrow/compute/kernels/aggregate_internal.h
+++ b/cpp/src/arrow/compute/kernels/aggregate_internal.h
@@ -98,6 +98,10 @@ struct ScalarAggregator : public KernelState {
// kernel implementations together
enum class VarOrStd : bool { Var, Std };
+// Helper to differentiate between min/max calculation so we can fold
+// kernel implementations together
+enum class MinOrMax : uint8_t { Min = 0, Max };
+
void AddAggKernel(std::shared_ptr sig, KernelInit init,
ScalarAggregateFunction* func,
SimdLevel::type simd_level = SimdLevel::NONE);
diff --git a/cpp/src/arrow/compute/kernels/aggregate_test.cc b/cpp/src/arrow/compute/kernels/aggregate_test.cc
index 54405b229a9..f3d470c42de 100644
--- a/cpp/src/arrow/compute/kernels/aggregate_test.cc
+++ b/cpp/src/arrow/compute/kernels/aggregate_test.cc
@@ -1177,11 +1177,23 @@ class TestPrimitiveMinMaxKernel : public ::testing::Test {
ASSERT_OK_AND_ASSIGN(Datum out, MinMax(array, options));
const StructScalar& value = out.scalar_as();
- const auto& out_min = checked_cast(*value.value[0]);
- ASSERT_EQ(expected_min, out_min.value);
+ {
+ const auto& out_min = checked_cast(*value.value[0]);
+ ASSERT_EQ(expected_min, out_min.value);
- const auto& out_max = checked_cast(*value.value[1]);
- ASSERT_EQ(expected_max, out_max.value);
+ const auto& out_max = checked_cast(*value.value[1]);
+ ASSERT_EQ(expected_max, out_max.value);
+ }
+
+ {
+ ASSERT_OK_AND_ASSIGN(out, CallFunction("min", {array}, &options));
+ const auto& out_min = out.scalar_as();
+ ASSERT_EQ(expected_min, out_min.value);
+
+ ASSERT_OK_AND_ASSIGN(out, CallFunction("max", {array}, &options));
+ const auto& out_max = out.scalar_as();
+ ASSERT_EQ(expected_max, out_max.value);
+ }
}
void AssertMinMaxIs(const std::string& json, c_type expected_min, c_type expected_max,
@@ -1427,6 +1439,11 @@ TEST(TestDecimalMinMaxKernel, Decimals) {
EXPECT_THAT(MinMax(chunked_input3, options),
ResultWith(ScalarFromJSON(ty, R"({"min": "1.01", "max": "9.42"})")));
+ EXPECT_THAT(CallFunction("min", {chunked_input1}, &options),
+ ResultWith(ScalarFromJSON(item_ty, R"("1.01")")));
+ EXPECT_THAT(CallFunction("max", {chunked_input1}, &options),
+ ResultWith(ScalarFromJSON(item_ty, R"("9.42")")));
+
EXPECT_THAT(MinMax(ScalarFromJSON(item_ty, "null"), options),
ResultWith(ScalarFromJSON(ty, R"({"min": null, "max": null})")));
EXPECT_THAT(MinMax(ScalarFromJSON(item_ty, R"("1.00")"), options),
diff --git a/cpp/src/arrow/compute/kernels/hash_aggregate.cc b/cpp/src/arrow/compute/kernels/hash_aggregate.cc
index 85389f95abe..2be424751d8 100644
--- a/cpp/src/arrow/compute/kernels/hash_aggregate.cc
+++ b/cpp/src/arrow/compute/kernels/hash_aggregate.cc
@@ -770,38 +770,34 @@ Result> HashAggregateInit(KernelContext* ctx,
return std::move(impl);
}
+Status HashAggregateResize(KernelContext* ctx, int64_t num_groups) {
+ return checked_cast(ctx->state())->Resize(num_groups);
+}
+Status HashAggregateConsume(KernelContext* ctx, const ExecBatch& batch) {
+ return checked_cast(ctx->state())->Consume(batch);
+}
+Status HashAggregateMerge(KernelContext* ctx, KernelState&& other,
+ const ArrayData& group_id_mapping) {
+ return checked_cast(ctx->state())
+ ->Merge(checked_cast(other), group_id_mapping);
+}
+Status HashAggregateFinalize(KernelContext* ctx, Datum* out) {
+ return checked_cast(ctx->state())->Finalize().Value(out);
+}
+
HashAggregateKernel MakeKernel(InputType argument_type, KernelInit init) {
HashAggregateKernel kernel;
-
kernel.init = std::move(init);
-
kernel.signature = KernelSignature::Make(
{std::move(argument_type), InputType::Array(Type::UINT32)},
OutputType(
[](KernelContext* ctx, const std::vector&) -> Result {
return checked_cast(ctx->state())->out_type();
}));
-
- kernel.resize = [](KernelContext* ctx, int64_t num_groups) {
- return checked_cast(ctx->state())->Resize(num_groups);
- };
-
- kernel.consume = [](KernelContext* ctx, const ExecBatch& batch) {
- return checked_cast(ctx->state())->Consume(batch);
- };
-
- kernel.merge = [](KernelContext* ctx, KernelState&& other,
- const ArrayData& group_id_mapping) {
- return checked_cast(ctx->state())
- ->Merge(checked_cast(other), group_id_mapping);
- };
-
- kernel.finalize = [](KernelContext* ctx, Datum* out) {
- ARROW_ASSIGN_OR_RAISE(*out,
- checked_cast(ctx->state())->Finalize());
- return Status::OK();
- };
-
+ kernel.resize = HashAggregateResize;
+ kernel.consume = HashAggregateConsume;
+ kernel.merge = HashAggregateMerge;
+ kernel.finalize = HashAggregateFinalize;
return kernel;
}
@@ -1929,6 +1925,35 @@ Result> MinMaxInit(KernelContext* ctx,
return std::move(impl);
}
+template
+HashAggregateKernel MakeMinOrMaxKernel(HashAggregateFunction* min_max_func) {
+ HashAggregateKernel kernel;
+ kernel.init = [min_max_func](
+ KernelContext* ctx,
+ const KernelInitArgs& args) -> Result> {
+ std::vector inputs = args.inputs;
+ ARROW_ASSIGN_OR_RAISE(auto kernel, min_max_func->DispatchBest(&inputs));
+ KernelInitArgs new_args{kernel, inputs, args.options};
+ return kernel->init(ctx, new_args);
+ };
+ kernel.signature = KernelSignature::Make(
+ {InputType(ValueDescr::ANY), InputType::Array(Type::UINT32)},
+ OutputType([](KernelContext* ctx,
+ const std::vector& descrs) -> Result {
+ return ValueDescr::Array(descrs[0].type);
+ }));
+ kernel.resize = HashAggregateResize;
+ kernel.consume = HashAggregateConsume;
+ kernel.merge = HashAggregateMerge;
+ kernel.finalize = [](KernelContext* ctx, Datum* out) {
+ ARROW_ASSIGN_OR_RAISE(Datum temp,
+ checked_cast(ctx->state())->Finalize());
+ *out = temp.array_as()->field(static_cast(min_or_max));
+ return Status::OK();
+ };
+ return kernel;
+}
+
struct GroupedMinMaxFactory {
template
enable_if_physical_integer Visit(const T&) {
@@ -2628,6 +2653,13 @@ const FunctionDoc hash_min_max_doc{
{"array", "group_id_array"},
"ScalarAggregateOptions"};
+const FunctionDoc hash_min_or_max_doc{
+ "Compute the minimum or maximum values of a numeric array",
+ ("Null values are ignored by default.\n"
+ "This can be changed through ScalarAggregateOptions."),
+ {"array", "group_id_array"},
+ "ScalarAggregateOptions"};
+
const FunctionDoc hash_any_doc{"Test whether any element evaluates to true",
("Null values are ignored."),
{"array", "group_id_array"},
@@ -2750,6 +2782,7 @@ void RegisterHashAggregateBasic(FunctionRegistry* registry) {
DCHECK_OK(registry->AddFunction(std::move(func)));
}
+ HashAggregateFunction* min_max_func = nullptr;
{
auto func = std::make_shared(
"hash_min_max", Arity::Binary(), &hash_min_max_doc,
@@ -2760,6 +2793,23 @@ void RegisterHashAggregateBasic(FunctionRegistry* registry) {
DCHECK_OK(AddHashAggKernels(
{null(), boolean(), decimal128(1, 1), decimal256(1, 1), month_interval()},
GroupedMinMaxFactory::Make, func.get()));
+ min_max_func = func.get();
+ DCHECK_OK(registry->AddFunction(std::move(func)));
+ }
+
+ {
+ auto func = std::make_shared(
+ "hash_min", Arity::Binary(), &hash_min_or_max_doc,
+ &default_scalar_aggregate_options);
+ DCHECK_OK(func->AddKernel(MakeMinOrMaxKernel(min_max_func)));
+ DCHECK_OK(registry->AddFunction(std::move(func)));
+ }
+
+ {
+ auto func = std::make_shared(
+ "hash_max", Arity::Binary(), &hash_min_or_max_doc,
+ &default_scalar_aggregate_options);
+ DCHECK_OK(func->AddKernel(MakeMinOrMaxKernel(min_max_func)));
DCHECK_OK(registry->AddFunction(std::move(func)));
}
diff --git a/cpp/src/arrow/compute/kernels/hash_aggregate_test.cc b/cpp/src/arrow/compute/kernels/hash_aggregate_test.cc
index 2a9ceeb7e70..c8894e530ef 100644
--- a/cpp/src/arrow/compute/kernels/hash_aggregate_test.cc
+++ b/cpp/src/arrow/compute/kernels/hash_aggregate_test.cc
@@ -1499,6 +1499,59 @@ TEST(GroupBy, MinMaxDecimal) {
}
}
+TEST(GroupBy, MinOrMax) {
+ auto table =
+ TableFromJSON(schema({field("argument", float64()), field("key", int64())}), {R"([
+ [1.0, 1],
+ [null, 1]
+])",
+ R"([
+ [0.0, 2],
+ [null, 3],
+ [4.0, null],
+ [3.25, 1],
+ [0.125, 2]
+])",
+ R"([
+ [-0.25, 2],
+ [0.75, null],
+ [null, 3]
+])",
+ R"([
+ [NaN, 4],
+ [null, 4],
+ [Inf, 4],
+ [-Inf, 4],
+ [0.0, 4]
+])"});
+
+ ASSERT_OK_AND_ASSIGN(Datum aggregated_and_grouped,
+ GroupByTest({table->GetColumnByName("argument"),
+ table->GetColumnByName("argument")},
+ {table->GetColumnByName("key")},
+ {
+ {"hash_min", nullptr},
+ {"hash_max", nullptr},
+ },
+ /*use_threads=*/true, /*use_exec_plan=*/true));
+ SortBy({"key_0"}, &aggregated_and_grouped);
+
+ AssertDatumsEqual(ArrayFromJSON(struct_({
+ field("hash_min", float64()),
+ field("hash_max", float64()),
+ field("key_0", int64()),
+ }),
+ R"([
+ [1.0, 3.25, 1],
+ [-0.25, 0.125, 2],
+ [null, null, 3],
+ [-Inf, Inf, 4],
+ [0.75, 4.0, null]
+ ])"),
+ aggregated_and_grouped,
+ /*verbose=*/true);
+}
+
TEST(GroupBy, MinMaxScalar) {
BatchesWithSchema input;
input.batches = {
diff --git a/docs/source/cpp/compute.rst b/docs/source/cpp/compute.rst
index 662776b86d0..9db9c5bc563 100644
--- a/docs/source/cpp/compute.rst
+++ b/docs/source/cpp/compute.rst
@@ -199,8 +199,12 @@ the input to a single output value.
+---------------+-------+------------------+------------------------+----------------------------------+-------+
| index | Unary | Any | Scalar Int64 | :struct:`IndexOptions` | |
+---------------+-------+------------------+------------------------+----------------------------------+-------+
+| max | Unary | Non-nested types | Scalar Input type | :struct:`ScalarAggregateOptions` | |
++---------------+-------+------------------+------------------------+----------------------------------+-------+
| mean | Unary | Numeric | Scalar Decimal/Float64 | :struct:`ScalarAggregateOptions` | |
+---------------+-------+------------------+------------------------+----------------------------------+-------+
+| min | Unary | Non-nested types | Scalar Input type | :struct:`ScalarAggregateOptions` | |
++---------------+-------+------------------+------------------------+----------------------------------+-------+
| min_max | Unary | Non-nested types | Scalar Struct | :struct:`ScalarAggregateOptions` | \(3) |
+---------------+-------+------------------+------------------------+----------------------------------+-------+
| mode | Unary | Numeric | Struct | :struct:`ModeOptions` | \(4) |
@@ -307,8 +311,12 @@ equivalents above and reflects how they are implemented internally.
+---------------------+-------+------------------------------------+-----------------+----------------------------------+-------+
| hash_distinct | Unary | Any | Input type | :struct:`CountOptions` | \(2) |
+---------------------+-------+------------------------------------+-----------------+----------------------------------+-------+
+| hash_max | Unary | Non-nested, non-binary/string-like | Input type | :struct:`ScalarAggregateOptions` | |
++---------------------+-------+------------------------------------+-----------------+----------------------------------+-------+
| hash_mean | Unary | Numeric | Decimal/Float64 | :struct:`ScalarAggregateOptions` | |
+---------------------+-------+------------------------------------+-----------------+----------------------------------+-------+
+| hash_min | Unary | Non-nested, non-binary/string-like | Input type | :struct:`ScalarAggregateOptions` | |
++---------------------+-------+------------------------------------+-----------------+----------------------------------+-------+
| hash_min_max | Unary | Non-nested, non-binary/string-like | Struct | :struct:`ScalarAggregateOptions` | \(3) |
+---------------------+-------+------------------------------------+-----------------+----------------------------------+-------+
| hash_product | Unary | Numeric | Numeric | :struct:`ScalarAggregateOptions` | \(4) |