-
Notifications
You must be signed in to change notification settings - Fork 4k
ARROW-13849: [C++] Wrap min_max with min/max functions #11152
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -770,38 +770,34 @@ Result<std::unique_ptr<KernelState>> HashAggregateInit(KernelContext* ctx, | |
| return std::move(impl); | ||
| } | ||
|
|
||
| Status HashAggregateResize(KernelContext* ctx, int64_t num_groups) { | ||
| return checked_cast<GroupedAggregator*>(ctx->state())->Resize(num_groups); | ||
| } | ||
| Status HashAggregateConsume(KernelContext* ctx, const ExecBatch& batch) { | ||
| return checked_cast<GroupedAggregator*>(ctx->state())->Consume(batch); | ||
| } | ||
| Status HashAggregateMerge(KernelContext* ctx, KernelState&& other, | ||
| const ArrayData& group_id_mapping) { | ||
| return checked_cast<GroupedAggregator*>(ctx->state()) | ||
| ->Merge(checked_cast<GroupedAggregator&&>(other), group_id_mapping); | ||
| } | ||
| Status HashAggregateFinalize(KernelContext* ctx, Datum* out) { | ||
| return checked_cast<GroupedAggregator*>(ctx->state())->Finalize().Value(out); | ||
| } | ||
|
|
||
| HashAggregateKernel MakeKernel(InputType argument_type, KernelInit init) { | ||
| HashAggregateKernel kernel; | ||
|
|
||
| kernel.init = std::move(init); | ||
|
|
||
| kernel.signature = KernelSignature::Make( | ||
| {std::move(argument_type), InputType::Array(Type::UINT32)}, | ||
| OutputType( | ||
| [](KernelContext* ctx, const std::vector<ValueDescr>&) -> Result<ValueDescr> { | ||
| return checked_cast<GroupedAggregator*>(ctx->state())->out_type(); | ||
| })); | ||
|
|
||
| kernel.resize = [](KernelContext* ctx, int64_t num_groups) { | ||
| return checked_cast<GroupedAggregator*>(ctx->state())->Resize(num_groups); | ||
| }; | ||
|
|
||
| kernel.consume = [](KernelContext* ctx, const ExecBatch& batch) { | ||
| return checked_cast<GroupedAggregator*>(ctx->state())->Consume(batch); | ||
| }; | ||
|
|
||
| kernel.merge = [](KernelContext* ctx, KernelState&& other, | ||
| const ArrayData& group_id_mapping) { | ||
| return checked_cast<GroupedAggregator*>(ctx->state()) | ||
| ->Merge(checked_cast<GroupedAggregator&&>(other), group_id_mapping); | ||
| }; | ||
|
|
||
| kernel.finalize = [](KernelContext* ctx, Datum* out) { | ||
| ARROW_ASSIGN_OR_RAISE(*out, | ||
| checked_cast<GroupedAggregator*>(ctx->state())->Finalize()); | ||
| return Status::OK(); | ||
| }; | ||
|
|
||
| kernel.resize = HashAggregateResize; | ||
| kernel.consume = HashAggregateConsume; | ||
| kernel.merge = HashAggregateMerge; | ||
| kernel.finalize = HashAggregateFinalize; | ||
| return kernel; | ||
| } | ||
|
|
||
|
|
@@ -1929,6 +1925,35 @@ Result<std::unique_ptr<KernelState>> MinMaxInit(KernelContext* ctx, | |
| return std::move(impl); | ||
| } | ||
|
|
||
| template <MinOrMax min_or_max> | ||
| HashAggregateKernel MakeMinOrMaxKernel(HashAggregateFunction* min_max_func) { | ||
| HashAggregateKernel kernel; | ||
| kernel.init = [min_max_func]( | ||
| KernelContext* ctx, | ||
| const KernelInitArgs& args) -> Result<std::unique_ptr<KernelState>> { | ||
| std::vector<ValueDescr> inputs = args.inputs; | ||
| ARROW_ASSIGN_OR_RAISE(auto kernel, min_max_func->DispatchBest(&inputs)); | ||
| KernelInitArgs new_args{kernel, inputs, args.options}; | ||
| return kernel->init(ctx, new_args); | ||
| }; | ||
| kernel.signature = KernelSignature::Make( | ||
| {InputType(ValueDescr::ANY), InputType::Array(Type::UINT32)}, | ||
| OutputType([](KernelContext* ctx, | ||
| const std::vector<ValueDescr>& descrs) -> Result<ValueDescr> { | ||
| return ValueDescr::Array(descrs[0].type); | ||
| })); | ||
|
||
| kernel.resize = HashAggregateResize; | ||
| kernel.consume = HashAggregateConsume; | ||
| kernel.merge = HashAggregateMerge; | ||
| kernel.finalize = [](KernelContext* ctx, Datum* out) { | ||
| ARROW_ASSIGN_OR_RAISE(Datum temp, | ||
| checked_cast<GroupedAggregator*>(ctx->state())->Finalize()); | ||
| *out = temp.array_as<StructArray>()->field(static_cast<uint8_t>(min_or_max)); | ||
| return Status::OK(); | ||
| }; | ||
| return kernel; | ||
| } | ||
|
|
||
| struct GroupedMinMaxFactory { | ||
| template <typename T> | ||
| enable_if_physical_integer<T, Status> Visit(const T&) { | ||
|
|
@@ -2628,6 +2653,13 @@ const FunctionDoc hash_min_max_doc{ | |
| {"array", "group_id_array"}, | ||
| "ScalarAggregateOptions"}; | ||
|
|
||
| const FunctionDoc hash_min_or_max_doc{ | ||
| "Compute the minimum or maximum values of a numeric array", | ||
| ("Null values are ignored by default.\n" | ||
| "This can be changed through ScalarAggregateOptions."), | ||
| {"array", "group_id_array"}, | ||
| "ScalarAggregateOptions"}; | ||
|
|
||
| const FunctionDoc hash_any_doc{"Test whether any element evaluates to true", | ||
| ("Null values are ignored."), | ||
| {"array", "group_id_array"}, | ||
|
|
@@ -2750,6 +2782,7 @@ void RegisterHashAggregateBasic(FunctionRegistry* registry) { | |
| DCHECK_OK(registry->AddFunction(std::move(func))); | ||
| } | ||
|
|
||
| HashAggregateFunction* min_max_func = nullptr; | ||
| { | ||
| auto func = std::make_shared<HashAggregateFunction>( | ||
| "hash_min_max", Arity::Binary(), &hash_min_max_doc, | ||
|
|
@@ -2760,6 +2793,23 @@ void RegisterHashAggregateBasic(FunctionRegistry* registry) { | |
| DCHECK_OK(AddHashAggKernels( | ||
| {null(), boolean(), decimal128(1, 1), decimal256(1, 1), month_interval()}, | ||
| GroupedMinMaxFactory::Make, func.get())); | ||
| min_max_func = func.get(); | ||
| DCHECK_OK(registry->AddFunction(std::move(func))); | ||
| } | ||
|
|
||
| { | ||
| auto func = std::make_shared<HashAggregateFunction>( | ||
| "hash_min", Arity::Binary(), &hash_min_or_max_doc, | ||
| &default_scalar_aggregate_options); | ||
| DCHECK_OK(func->AddKernel(MakeMinOrMaxKernel<MinOrMax::Min>(min_max_func))); | ||
| DCHECK_OK(registry->AddFunction(std::move(func))); | ||
| } | ||
|
|
||
| { | ||
| auto func = std::make_shared<HashAggregateFunction>( | ||
| "hash_max", Arity::Binary(), &hash_min_or_max_doc, | ||
| &default_scalar_aggregate_options); | ||
| DCHECK_OK(func->AddKernel(MakeMinOrMaxKernel<MinOrMax::Max>(min_max_func))); | ||
| DCHECK_OK(registry->AddFunction(std::move(func))); | ||
| } | ||
|
|
||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -1499,6 +1499,59 @@ TEST(GroupBy, MinMaxDecimal) { | |
| } | ||
| } | ||
|
|
||
| TEST(GroupBy, MinOrMax) { | ||
| auto table = | ||
|
||
| TableFromJSON(schema({field("argument", float64()), field("key", int64())}), {R"([ | ||
| [1.0, 1], | ||
| [null, 1] | ||
| ])", | ||
| R"([ | ||
| [0.0, 2], | ||
| [null, 3], | ||
| [4.0, null], | ||
| [3.25, 1], | ||
|
||
| [0.125, 2] | ||
| ])", | ||
| R"([ | ||
| [-0.25, 2], | ||
| [0.75, null], | ||
| [null, 3] | ||
| ])", | ||
| R"([ | ||
| [NaN, 4], | ||
| [null, 4], | ||
| [Inf, 4], | ||
| [-Inf, 4], | ||
| [0.0, 4] | ||
| ])"}); | ||
|
|
||
| ASSERT_OK_AND_ASSIGN(Datum aggregated_and_grouped, | ||
| GroupByTest({table->GetColumnByName("argument"), | ||
| table->GetColumnByName("argument")}, | ||
| {table->GetColumnByName("key")}, | ||
| { | ||
| {"hash_min", nullptr}, | ||
| {"hash_max", nullptr}, | ||
| }, | ||
| /*use_threads=*/true, /*use_exec_plan=*/true)); | ||
| SortBy({"key_0"}, &aggregated_and_grouped); | ||
|
|
||
| AssertDatumsEqual(ArrayFromJSON(struct_({ | ||
| field("hash_min", float64()), | ||
| field("hash_max", float64()), | ||
| field("key_0", int64()), | ||
| }), | ||
| R"([ | ||
| [1.0, 3.25, 1], | ||
| [-0.25, 0.125, 2], | ||
| [null, null, 3], | ||
| [-Inf, Inf, 4], | ||
| [0.75, 4.0, null] | ||
| ])"), | ||
| aggregated_and_grouped, | ||
| /*verbose=*/true); | ||
| } | ||
|
|
||
| TEST(GroupBy, MinMaxScalar) { | ||
| BatchesWithSchema input; | ||
| input.batches = { | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If the comment of using enum is accepted, then use enum as template parameter.