From 050f36f61f185ec95bfcc5cedd1e21174d212225 Mon Sep 17 00:00:00 2001 From: David Li Date: Tue, 14 Sep 2021 09:20:12 -0400 Subject: [PATCH] ARROW-13849: [C++] Add min/max scalar aggregates --- .../arrow/compute/kernels/aggregate_basic.cc | 56 +++++++++++ .../compute/kernels/aggregate_internal.h | 4 + .../arrow/compute/kernels/aggregate_test.cc | 25 ++++- .../arrow/compute/kernels/hash_aggregate.cc | 96 ++++++++++++++----- .../compute/kernels/hash_aggregate_test.cc | 53 ++++++++++ docs/source/cpp/compute.rst | 8 ++ 6 files changed, 215 insertions(+), 27 deletions(-) diff --git a/cpp/src/arrow/compute/kernels/aggregate_basic.cc b/cpp/src/arrow/compute/kernels/aggregate_basic.cc index 323d1e7ca9e..1851ba604ff 100644 --- a/cpp/src/arrow/compute/kernels/aggregate_basic.cc +++ b/cpp/src/arrow/compute/kernels/aggregate_basic.cc @@ -282,6 +282,43 @@ Result> MinMaxInit(KernelContext* ctx, return visitor.Create(); } +// For "min" and "max" functions: override finalize and return the actual value +template +void AddMinOrMaxAggKernel(ScalarAggregateFunction* func, + ScalarAggregateFunction* min_max_func) { + auto sig = KernelSignature::Make( + {InputType(ValueDescr::ANY)}, + OutputType([](KernelContext*, + const std::vector& descrs) -> Result { + // any[T] -> scalar[T] + return ValueDescr::Scalar(descrs.front().type); + })); + + auto init = [min_max_func]( + KernelContext* ctx, + const KernelInitArgs& args) -> Result> { + std::vector inputs = args.inputs; + ARROW_ASSIGN_OR_RAISE(auto kernel, min_max_func->DispatchBest(&inputs)); + KernelInitArgs new_args{kernel, inputs, args.options}; + return kernel->init(ctx, new_args); + }; + + auto finalize = [](KernelContext* ctx, Datum* out) -> Status { + Datum temp; + RETURN_NOT_OK(checked_cast(ctx->state())->Finalize(ctx, &temp)); + const auto& result = temp.scalar_as(); + DCHECK(result.is_valid); + *out = result.value[static_cast(min_or_max)]; + return Status::OK(); + }; + + // Note SIMD level is always NONE, but the convenience kernel will + // dispatch to an appropriate implementation + ScalarAggregateKernel kernel(std::move(sig), std::move(init), AggregateConsume, + AggregateMerge, std::move(finalize)); + DCHECK_OK(func->AddKernel(kernel)); +} + // ---------------------------------------------------------------------- // Any implementation @@ -663,6 +700,13 @@ const FunctionDoc min_max_doc{"Compute the minimum and maximum values of a numer {"array"}, "ScalarAggregateOptions"}; +const FunctionDoc min_or_max_doc{ + "Compute the minimum or maximum values of a numeric array", + ("Null values are ignored by default.\n" + "This can be changed through ScalarAggregateOptions."), + {"array"}, + "ScalarAggregateOptions"}; + const FunctionDoc any_doc{"Test whether any element in a boolean array evaluates to true", ("Null values are ignored by default.\n" "If null values are taken into account by setting " @@ -781,6 +825,18 @@ void RegisterScalarAggregateBasic(FunctionRegistry* registry) { } #endif + auto min_max_func = func.get(); + DCHECK_OK(registry->AddFunction(std::move(func))); + + // Add min/max as convenience functions + func = std::make_shared("min", Arity::Unary(), &min_or_max_doc, + &default_scalar_aggregate_options); + aggregate::AddMinOrMaxAggKernel(func.get(), min_max_func); + DCHECK_OK(registry->AddFunction(std::move(func))); + + func = std::make_shared("max", Arity::Unary(), &min_or_max_doc, + &default_scalar_aggregate_options); + aggregate::AddMinOrMaxAggKernel(func.get(), min_max_func); DCHECK_OK(registry->AddFunction(std::move(func))); func = std::make_shared( diff --git a/cpp/src/arrow/compute/kernels/aggregate_internal.h b/cpp/src/arrow/compute/kernels/aggregate_internal.h index 13687d22820..b0aced3e346 100644 --- a/cpp/src/arrow/compute/kernels/aggregate_internal.h +++ b/cpp/src/arrow/compute/kernels/aggregate_internal.h @@ -98,6 +98,10 @@ struct ScalarAggregator : public KernelState { // kernel implementations together enum class VarOrStd : bool { Var, Std }; +// Helper to differentiate between min/max calculation so we can fold +// kernel implementations together +enum class MinOrMax : uint8_t { Min = 0, Max }; + void AddAggKernel(std::shared_ptr sig, KernelInit init, ScalarAggregateFunction* func, SimdLevel::type simd_level = SimdLevel::NONE); diff --git a/cpp/src/arrow/compute/kernels/aggregate_test.cc b/cpp/src/arrow/compute/kernels/aggregate_test.cc index 54405b229a9..f3d470c42de 100644 --- a/cpp/src/arrow/compute/kernels/aggregate_test.cc +++ b/cpp/src/arrow/compute/kernels/aggregate_test.cc @@ -1177,11 +1177,23 @@ class TestPrimitiveMinMaxKernel : public ::testing::Test { ASSERT_OK_AND_ASSIGN(Datum out, MinMax(array, options)); const StructScalar& value = out.scalar_as(); - const auto& out_min = checked_cast(*value.value[0]); - ASSERT_EQ(expected_min, out_min.value); + { + const auto& out_min = checked_cast(*value.value[0]); + ASSERT_EQ(expected_min, out_min.value); - const auto& out_max = checked_cast(*value.value[1]); - ASSERT_EQ(expected_max, out_max.value); + const auto& out_max = checked_cast(*value.value[1]); + ASSERT_EQ(expected_max, out_max.value); + } + + { + ASSERT_OK_AND_ASSIGN(out, CallFunction("min", {array}, &options)); + const auto& out_min = out.scalar_as(); + ASSERT_EQ(expected_min, out_min.value); + + ASSERT_OK_AND_ASSIGN(out, CallFunction("max", {array}, &options)); + const auto& out_max = out.scalar_as(); + ASSERT_EQ(expected_max, out_max.value); + } } void AssertMinMaxIs(const std::string& json, c_type expected_min, c_type expected_max, @@ -1427,6 +1439,11 @@ TEST(TestDecimalMinMaxKernel, Decimals) { EXPECT_THAT(MinMax(chunked_input3, options), ResultWith(ScalarFromJSON(ty, R"({"min": "1.01", "max": "9.42"})"))); + EXPECT_THAT(CallFunction("min", {chunked_input1}, &options), + ResultWith(ScalarFromJSON(item_ty, R"("1.01")"))); + EXPECT_THAT(CallFunction("max", {chunked_input1}, &options), + ResultWith(ScalarFromJSON(item_ty, R"("9.42")"))); + EXPECT_THAT(MinMax(ScalarFromJSON(item_ty, "null"), options), ResultWith(ScalarFromJSON(ty, R"({"min": null, "max": null})"))); EXPECT_THAT(MinMax(ScalarFromJSON(item_ty, R"("1.00")"), options), diff --git a/cpp/src/arrow/compute/kernels/hash_aggregate.cc b/cpp/src/arrow/compute/kernels/hash_aggregate.cc index 85389f95abe..2be424751d8 100644 --- a/cpp/src/arrow/compute/kernels/hash_aggregate.cc +++ b/cpp/src/arrow/compute/kernels/hash_aggregate.cc @@ -770,38 +770,34 @@ Result> HashAggregateInit(KernelContext* ctx, return std::move(impl); } +Status HashAggregateResize(KernelContext* ctx, int64_t num_groups) { + return checked_cast(ctx->state())->Resize(num_groups); +} +Status HashAggregateConsume(KernelContext* ctx, const ExecBatch& batch) { + return checked_cast(ctx->state())->Consume(batch); +} +Status HashAggregateMerge(KernelContext* ctx, KernelState&& other, + const ArrayData& group_id_mapping) { + return checked_cast(ctx->state()) + ->Merge(checked_cast(other), group_id_mapping); +} +Status HashAggregateFinalize(KernelContext* ctx, Datum* out) { + return checked_cast(ctx->state())->Finalize().Value(out); +} + HashAggregateKernel MakeKernel(InputType argument_type, KernelInit init) { HashAggregateKernel kernel; - kernel.init = std::move(init); - kernel.signature = KernelSignature::Make( {std::move(argument_type), InputType::Array(Type::UINT32)}, OutputType( [](KernelContext* ctx, const std::vector&) -> Result { return checked_cast(ctx->state())->out_type(); })); - - kernel.resize = [](KernelContext* ctx, int64_t num_groups) { - return checked_cast(ctx->state())->Resize(num_groups); - }; - - kernel.consume = [](KernelContext* ctx, const ExecBatch& batch) { - return checked_cast(ctx->state())->Consume(batch); - }; - - kernel.merge = [](KernelContext* ctx, KernelState&& other, - const ArrayData& group_id_mapping) { - return checked_cast(ctx->state()) - ->Merge(checked_cast(other), group_id_mapping); - }; - - kernel.finalize = [](KernelContext* ctx, Datum* out) { - ARROW_ASSIGN_OR_RAISE(*out, - checked_cast(ctx->state())->Finalize()); - return Status::OK(); - }; - + kernel.resize = HashAggregateResize; + kernel.consume = HashAggregateConsume; + kernel.merge = HashAggregateMerge; + kernel.finalize = HashAggregateFinalize; return kernel; } @@ -1929,6 +1925,35 @@ Result> MinMaxInit(KernelContext* ctx, return std::move(impl); } +template +HashAggregateKernel MakeMinOrMaxKernel(HashAggregateFunction* min_max_func) { + HashAggregateKernel kernel; + kernel.init = [min_max_func]( + KernelContext* ctx, + const KernelInitArgs& args) -> Result> { + std::vector inputs = args.inputs; + ARROW_ASSIGN_OR_RAISE(auto kernel, min_max_func->DispatchBest(&inputs)); + KernelInitArgs new_args{kernel, inputs, args.options}; + return kernel->init(ctx, new_args); + }; + kernel.signature = KernelSignature::Make( + {InputType(ValueDescr::ANY), InputType::Array(Type::UINT32)}, + OutputType([](KernelContext* ctx, + const std::vector& descrs) -> Result { + return ValueDescr::Array(descrs[0].type); + })); + kernel.resize = HashAggregateResize; + kernel.consume = HashAggregateConsume; + kernel.merge = HashAggregateMerge; + kernel.finalize = [](KernelContext* ctx, Datum* out) { + ARROW_ASSIGN_OR_RAISE(Datum temp, + checked_cast(ctx->state())->Finalize()); + *out = temp.array_as()->field(static_cast(min_or_max)); + return Status::OK(); + }; + return kernel; +} + struct GroupedMinMaxFactory { template enable_if_physical_integer Visit(const T&) { @@ -2628,6 +2653,13 @@ const FunctionDoc hash_min_max_doc{ {"array", "group_id_array"}, "ScalarAggregateOptions"}; +const FunctionDoc hash_min_or_max_doc{ + "Compute the minimum or maximum values of a numeric array", + ("Null values are ignored by default.\n" + "This can be changed through ScalarAggregateOptions."), + {"array", "group_id_array"}, + "ScalarAggregateOptions"}; + const FunctionDoc hash_any_doc{"Test whether any element evaluates to true", ("Null values are ignored."), {"array", "group_id_array"}, @@ -2750,6 +2782,7 @@ void RegisterHashAggregateBasic(FunctionRegistry* registry) { DCHECK_OK(registry->AddFunction(std::move(func))); } + HashAggregateFunction* min_max_func = nullptr; { auto func = std::make_shared( "hash_min_max", Arity::Binary(), &hash_min_max_doc, @@ -2760,6 +2793,23 @@ void RegisterHashAggregateBasic(FunctionRegistry* registry) { DCHECK_OK(AddHashAggKernels( {null(), boolean(), decimal128(1, 1), decimal256(1, 1), month_interval()}, GroupedMinMaxFactory::Make, func.get())); + min_max_func = func.get(); + DCHECK_OK(registry->AddFunction(std::move(func))); + } + + { + auto func = std::make_shared( + "hash_min", Arity::Binary(), &hash_min_or_max_doc, + &default_scalar_aggregate_options); + DCHECK_OK(func->AddKernel(MakeMinOrMaxKernel(min_max_func))); + DCHECK_OK(registry->AddFunction(std::move(func))); + } + + { + auto func = std::make_shared( + "hash_max", Arity::Binary(), &hash_min_or_max_doc, + &default_scalar_aggregate_options); + DCHECK_OK(func->AddKernel(MakeMinOrMaxKernel(min_max_func))); DCHECK_OK(registry->AddFunction(std::move(func))); } diff --git a/cpp/src/arrow/compute/kernels/hash_aggregate_test.cc b/cpp/src/arrow/compute/kernels/hash_aggregate_test.cc index 2a9ceeb7e70..c8894e530ef 100644 --- a/cpp/src/arrow/compute/kernels/hash_aggregate_test.cc +++ b/cpp/src/arrow/compute/kernels/hash_aggregate_test.cc @@ -1499,6 +1499,59 @@ TEST(GroupBy, MinMaxDecimal) { } } +TEST(GroupBy, MinOrMax) { + auto table = + TableFromJSON(schema({field("argument", float64()), field("key", int64())}), {R"([ + [1.0, 1], + [null, 1] +])", + R"([ + [0.0, 2], + [null, 3], + [4.0, null], + [3.25, 1], + [0.125, 2] +])", + R"([ + [-0.25, 2], + [0.75, null], + [null, 3] +])", + R"([ + [NaN, 4], + [null, 4], + [Inf, 4], + [-Inf, 4], + [0.0, 4] +])"}); + + ASSERT_OK_AND_ASSIGN(Datum aggregated_and_grouped, + GroupByTest({table->GetColumnByName("argument"), + table->GetColumnByName("argument")}, + {table->GetColumnByName("key")}, + { + {"hash_min", nullptr}, + {"hash_max", nullptr}, + }, + /*use_threads=*/true, /*use_exec_plan=*/true)); + SortBy({"key_0"}, &aggregated_and_grouped); + + AssertDatumsEqual(ArrayFromJSON(struct_({ + field("hash_min", float64()), + field("hash_max", float64()), + field("key_0", int64()), + }), + R"([ + [1.0, 3.25, 1], + [-0.25, 0.125, 2], + [null, null, 3], + [-Inf, Inf, 4], + [0.75, 4.0, null] + ])"), + aggregated_and_grouped, + /*verbose=*/true); +} + TEST(GroupBy, MinMaxScalar) { BatchesWithSchema input; input.batches = { diff --git a/docs/source/cpp/compute.rst b/docs/source/cpp/compute.rst index 662776b86d0..9db9c5bc563 100644 --- a/docs/source/cpp/compute.rst +++ b/docs/source/cpp/compute.rst @@ -199,8 +199,12 @@ the input to a single output value. +---------------+-------+------------------+------------------------+----------------------------------+-------+ | index | Unary | Any | Scalar Int64 | :struct:`IndexOptions` | | +---------------+-------+------------------+------------------------+----------------------------------+-------+ +| max | Unary | Non-nested types | Scalar Input type | :struct:`ScalarAggregateOptions` | | ++---------------+-------+------------------+------------------------+----------------------------------+-------+ | mean | Unary | Numeric | Scalar Decimal/Float64 | :struct:`ScalarAggregateOptions` | | +---------------+-------+------------------+------------------------+----------------------------------+-------+ +| min | Unary | Non-nested types | Scalar Input type | :struct:`ScalarAggregateOptions` | | ++---------------+-------+------------------+------------------------+----------------------------------+-------+ | min_max | Unary | Non-nested types | Scalar Struct | :struct:`ScalarAggregateOptions` | \(3) | +---------------+-------+------------------+------------------------+----------------------------------+-------+ | mode | Unary | Numeric | Struct | :struct:`ModeOptions` | \(4) | @@ -307,8 +311,12 @@ equivalents above and reflects how they are implemented internally. +---------------------+-------+------------------------------------+-----------------+----------------------------------+-------+ | hash_distinct | Unary | Any | Input type | :struct:`CountOptions` | \(2) | +---------------------+-------+------------------------------------+-----------------+----------------------------------+-------+ +| hash_max | Unary | Non-nested, non-binary/string-like | Input type | :struct:`ScalarAggregateOptions` | | ++---------------------+-------+------------------------------------+-----------------+----------------------------------+-------+ | hash_mean | Unary | Numeric | Decimal/Float64 | :struct:`ScalarAggregateOptions` | | +---------------------+-------+------------------------------------+-----------------+----------------------------------+-------+ +| hash_min | Unary | Non-nested, non-binary/string-like | Input type | :struct:`ScalarAggregateOptions` | | ++---------------------+-------+------------------------------------+-----------------+----------------------------------+-------+ | hash_min_max | Unary | Non-nested, non-binary/string-like | Struct | :struct:`ScalarAggregateOptions` | \(3) | +---------------------+-------+------------------------------------+-----------------+----------------------------------+-------+ | hash_product | Unary | Numeric | Numeric | :struct:`ScalarAggregateOptions` | \(4) |