From 1e7925d2d7e43e1b73937448774433f101237ef3 Mon Sep 17 00:00:00 2001 From: Yaron Gvili Date: Tue, 4 Apr 2023 09:41:14 -0400 Subject: [PATCH 1/2] GH-34786: [C++] Fix output schema calculated by Substrait consumer for AggregateRel --- cpp/src/arrow/acero/aggregate_node.cc | 463 ++++++++++-------- cpp/src/arrow/acero/aggregate_node.h | 76 +++ cpp/src/arrow/acero/plan_test.cc | 2 +- cpp/src/arrow/engine/substrait/options.cc | 36 +- .../engine/substrait/relation_internal.cc | 67 +-- .../engine/substrait/relation_internal.h | 24 +- 6 files changed, 358 insertions(+), 310 deletions(-) create mode 100644 cpp/src/arrow/acero/aggregate_node.h diff --git a/cpp/src/arrow/acero/aggregate_node.cc b/cpp/src/arrow/acero/aggregate_node.cc index 6669d30bcc0..a3ee058ff5b 100644 --- a/cpp/src/arrow/acero/aggregate_node.cc +++ b/cpp/src/arrow/acero/aggregate_node.cc @@ -21,6 +21,7 @@ #include #include +#include "arrow/acero/aggregate_node.h" #include "arrow/acero/exec_plan.h" #include "arrow/acero/options.h" #include "arrow/acero/query_context.h" @@ -85,8 +86,43 @@ std::vector ExtendWithGroupIdType(const std::vector& in_ return aggr_in_types; } -Result GetKernel(ExecContext* ctx, const Aggregate& aggregate, +void DefaultAggregateOptions(Aggregate* aggregate_ptr, + const std::shared_ptr function) { + Aggregate& aggregate = *aggregate_ptr; + if (aggregate.options == nullptr) { + DCHECK(!function->doc().options_required); + const auto* default_options = function->default_options(); + if (default_options) { + aggregate.options = default_options->Copy(); + } + } +} + +using GetKernel = std::function(ExecContext*, Aggregate*, + const std::vector&)>; + +Result GetScalarAggregateKernel(ExecContext* ctx, Aggregate* aggregate_ptr, + const std::vector& in_types) { + Aggregate& aggregate = *aggregate_ptr; + ARROW_ASSIGN_OR_RAISE(auto function, + ctx->func_registry()->GetFunction(aggregate.function)); + if (function->kind() != Function::SCALAR_AGGREGATE) { + if (function->kind() == Function::HASH_AGGREGATE) { + return Status::Invalid("The provided function (", aggregate.function, + ") is a hash aggregate function. Since there are no " + "keys to group by, a scalar aggregate function was " + "expected (normally these do not start with hash_)"); + } + return Status::Invalid("The provided function(", aggregate.function, + ") is not an aggregate function"); + } + DefaultAggregateOptions(&aggregate, function); + return function->DispatchExact(in_types); +} + +Result GetHashAggregateKernel(ExecContext* ctx, Aggregate* aggregate_ptr, const std::vector& in_types) { + Aggregate& aggregate = *aggregate_ptr; const auto aggr_in_types = ExtendWithGroupIdType(in_types); ARROW_ASSIGN_OR_RAISE(auto function, ctx->func_registry()->GetFunction(aggregate.function)); @@ -100,16 +136,16 @@ Result GetKernel(ExecContext* ctx, const Aggregate& return Status::Invalid("The provided function(", aggregate.function, ") is not an aggregate function"); } - ARROW_ASSIGN_OR_RAISE(const Kernel* kernel, function->DispatchExact(aggr_in_types)); - return static_cast(kernel); + DefaultAggregateOptions(&aggregate, function); + return function->DispatchExact(aggr_in_types); } -Result> InitKernel(const HashAggregateKernel* kernel, - ExecContext* ctx, - const Aggregate& aggregate, - const std::vector& in_types) { - const auto aggr_in_types = ExtendWithGroupIdType(in_types); +using InitKernel = std::function>( + const Kernel*, ExecContext*, const Aggregate&, const std::vector&)>; +Result> InitScalarAggregateKernel( + const Kernel* kernel, ExecContext* ctx, const Aggregate& aggregate, + const std::vector& in_types) { KernelContext kernel_ctx{ctx}; const auto* options = arrow::internal::checked_cast(aggregate.options.get()); @@ -122,53 +158,91 @@ Result> InitKernel(const HashAggregateKernel* kerne } ARROW_ASSIGN_OR_RAISE( - auto state, - kernel->init(&kernel_ctx, KernelInitArgs{kernel, aggr_in_types, options})); + auto state, kernel->init(&kernel_ctx, KernelInitArgs{kernel, in_types, options})); return std::move(state); } -Result> GetKernels( - ExecContext* ctx, const std::vector& aggregates, +Result> InitHashAggregateKernel( + const Kernel* kernel, ExecContext* ctx, const Aggregate& aggregate, + const std::vector& in_types) { + const auto aggr_in_types = ExtendWithGroupIdType(in_types); + return InitScalarAggregateKernel(kernel, ctx, aggregate, std::move(aggr_in_types)); +} + +Result> GetKernels( + GetKernel get_kernel, ExecContext* ctx, std::vector* aggregates_ptr, const std::vector>& in_types) { + std::vector& aggregates = *aggregates_ptr; if (aggregates.size() != in_types.size()) { return Status::Invalid(aggregates.size(), " aggregate functions were specified but ", in_types.size(), " arguments were provided."); } - std::vector kernels(in_types.size()); + std::vector kernels(in_types.size()); for (size_t i = 0; i < aggregates.size(); ++i) { - ARROW_ASSIGN_OR_RAISE(kernels[i], GetKernel(ctx, aggregates[i], in_types[i])); + ARROW_ASSIGN_OR_RAISE(kernels[i], get_kernel(ctx, &aggregates[i], in_types[i])); } return kernels; } -Result>> InitKernels( - const std::vector& kernels, ExecContext* ctx, +template +Result>>> InitKernels( + InitKernel init_kernel, const std::vector& kernels, + ExecContext* ctx, size_t num_states_per_kernel, const std::vector& aggregates, const std::vector>& in_types) { - std::vector> states(kernels.size()); + std::vector>> states(kernels.size()); for (size_t i = 0; i < aggregates.size(); ++i) { - ARROW_ASSIGN_OR_RAISE(states[i], - InitKernel(kernels[i], ctx, aggregates[i], in_types[i])); + states[i].resize(num_states_per_kernel); + for (size_t j = 0; j < num_states_per_kernel; j++) { + ARROW_ASSIGN_OR_RAISE(states[i][j], + init_kernel(kernels[i], ctx, aggregates[i], in_types[i])); + } } return std::move(states); } -Result ResolveKernels( - const std::vector& aggregates, - const std::vector& kernels, - const std::vector>& states, ExecContext* ctx, - const std::vector>& types) { +Result> ResolveKernel(const Aggregate& aggregate, + const Kernel* kernel, + const std::unique_ptr& state, + ExecContext* ctx, + const std::vector& types) { + KernelContext kernel_ctx{ctx}; + kernel_ctx.SetState(state.get()); + + ARROW_ASSIGN_OR_RAISE(auto type, + kernel->signature->out_type().Resolve(&kernel_ctx, types)); + return field(aggregate.function, type.GetSharedPtr()); +} + +using ResolveKernels = std::function( + const std::vector&, const std::vector&, + const std::vector>>&, ExecContext*, + const std::vector>&)>; + +Result ResolveScalarAggregateKernels( + const std::vector& aggregates, const std::vector& kernels, + const std::vector>>& states, + ExecContext* ctx, const std::vector>& types) { FieldVector fields(types.size()); for (size_t i = 0; i < kernels.size(); ++i) { - KernelContext kernel_ctx{ctx}; - kernel_ctx.SetState(states[i].get()); + ARROW_ASSIGN_OR_RAISE( + fields[i], ResolveKernel(aggregates[i], kernels[i], states[i][0], ctx, types[i])); + } + return fields; +} + +Result ResolveHashAggregateKernels( + const std::vector& aggregates, const std::vector& kernels, + const std::vector>>& states, + ExecContext* ctx, const std::vector>& types) { + FieldVector fields(types.size()); + for (size_t i = 0; i < kernels.size(); ++i) { const auto aggr_in_types = ExtendWithGroupIdType(types[i]); - ARROW_ASSIGN_OR_RAISE( - auto type, kernels[i]->signature->out_type().Resolve(&kernel_ctx, aggr_in_types)); - fields[i] = field(aggregates[i].function, type.GetSharedPtr()); + ARROW_ASSIGN_OR_RAISE(fields[i], ResolveKernel(aggregates[i], kernels[i], + states[i][0], ctx, aggr_in_types)); } return fields; } @@ -294,89 +368,21 @@ class ScalarAggregateNode : public ExecNode, public TracedNode { const auto& input_schema = *inputs[0]->output_schema(); auto exec_ctx = plan->query_context()->exec_context(); - std::vector segment_field_ids(segment_keys.size()); - std::vector segment_key_types(segment_keys.size()); - for (size_t i = 0; i < segment_keys.size(); i++) { - ARROW_ASSIGN_OR_RAISE(FieldPath match, segment_keys[i].FindOne(input_schema)); - if (match.indices().size() > 1) { - // ARROW-18369: Support nested references as segment ids - return Status::Invalid("Nested references cannot be used as segment ids"); - } - segment_field_ids[i] = match[0]; - segment_key_types[i] = input_schema.field(match[0])->type().get(); - } - - ARROW_ASSIGN_OR_RAISE(auto segmenter, - RowSegmenter::Make(std::move(segment_key_types), - /*nullable_keys=*/false, exec_ctx)); - - std::vector> kernel_intypes(aggregates.size()); - std::vector kernels(aggregates.size()); - std::vector>> states(kernels.size()); - FieldVector fields(kernels.size() + segment_keys.size()); - std::vector> target_fieldsets(kernels.size()); - - for (size_t i = 0; i < kernels.size(); ++i) { - const auto& target_fieldset = aggregate_options.aggregates[i].target; - for (const auto& target : target_fieldset) { - ARROW_ASSIGN_OR_RAISE(auto match, FieldRef(target).FindOne(input_schema)); - target_fieldsets[i].push_back(match[0]); - } - - ARROW_ASSIGN_OR_RAISE( - auto function, exec_ctx->func_registry()->GetFunction(aggregates[i].function)); - - if (function->kind() != Function::SCALAR_AGGREGATE) { - if (function->kind() == Function::HASH_AGGREGATE) { - return Status::Invalid("The provided function (", aggregates[i].function, - ") is a hash aggregate function. Since there are no " - "keys to group by, a scalar aggregate function was " - "expected (normally these do not start with hash_)"); - } - return Status::Invalid("The provided function(", aggregates[i].function, - ") is not an aggregate function"); - } - - std::vector in_types; - for (const auto& target : target_fieldsets[i]) { - in_types.emplace_back(input_schema.field(target)->type().get()); - } - kernel_intypes[i] = in_types; - ARROW_ASSIGN_OR_RAISE(const Kernel* kernel, - function->DispatchExact(kernel_intypes[i])); - kernels[i] = static_cast(kernel); - - if (aggregates[i].options == nullptr) { - DCHECK(!function->doc().options_required); - const auto* default_options = function->default_options(); - if (default_options) { - aggregates[i].options = default_options->Copy(); - } - } - - KernelContext kernel_ctx{exec_ctx}; - states[i].resize(plan->query_context()->max_concurrency()); - RETURN_NOT_OK(Kernel::InitAll( - &kernel_ctx, - KernelInitArgs{kernels[i], kernel_intypes[i], aggregates[i].options.get()}, - &states[i])); - - // pick one to resolve the kernel signature - kernel_ctx.SetState(states[i][0].get()); - ARROW_ASSIGN_OR_RAISE(auto out_type, kernels[i]->signature->out_type().Resolve( - &kernel_ctx, kernel_intypes[i])); + ARROW_ASSIGN_OR_RAISE(auto args, + aggregate::MakeAggregateNodeArgs( + input_schema, keys, segment_keys, aggregates, + plan->query_context()->max_concurrency(), exec_ctx)); - fields[i] = field(aggregate_options.aggregates[i].name, out_type.GetSharedPtr()); - } - for (size_t i = 0; i < segment_keys.size(); ++i) { - ARROW_ASSIGN_OR_RAISE(fields[kernels.size() + i], - segment_keys[i].GetOne(*inputs[0]->output_schema())); + std::vector kernels; + kernels.reserve(args.kernels.size()); + for (auto kernel : args.kernels) { + kernels.push_back(static_cast(kernel)); } - return plan->EmplaceNode( - plan, std::move(inputs), schema(std::move(fields)), std::move(segmenter), - std::move(segment_field_ids), std::move(target_fieldsets), std::move(aggregates), - std::move(kernels), std::move(kernel_intypes), std::move(states)); + plan, std::move(inputs), std::move(args.output_schema), std::move(args.segmenter), + std::move(args.segment_key_field_ids), std::move(args.target_fieldsets), + std::move(args.aggregates), std::move(kernels), std::move(args.kernel_intypes), + std::move(args.states)); } const char* kind_name() const override { return "ScalarAggregateNode"; } @@ -560,114 +566,40 @@ class GroupByNode : public ExecNode, public TracedNode { const ExecNodeOptions& options) { RETURN_NOT_OK(ValidateExecNodeInputs(plan, inputs, 1, "GroupByNode")); - auto input = inputs[0]; const auto& aggregate_options = checked_cast(options); + auto aggregates = aggregate_options.aggregates; const auto& keys = aggregate_options.keys; const auto& segment_keys = aggregate_options.segment_keys; - // Copy (need to modify options pointer below) - auto aggs = aggregate_options.aggregates; if (plan->query_context()->exec_context()->executor()->GetCapacity() > 1 && segment_keys.size() > 0) { return Status::NotImplemented("Segmented aggregation in a multi-threaded plan"); } - // Get input schema - auto input_schema = input->output_schema(); - - // Find input field indices for key fields - std::vector key_field_ids(keys.size()); - for (size_t i = 0; i < keys.size(); ++i) { - ARROW_ASSIGN_OR_RAISE(auto match, keys[i].FindOne(*input_schema)); - key_field_ids[i] = match[0]; - } - - // Find input field indices for segment key fields - std::vector segment_key_field_ids(segment_keys.size()); - for (size_t i = 0; i < segment_keys.size(); ++i) { - ARROW_ASSIGN_OR_RAISE(auto match, segment_keys[i].FindOne(*input_schema)); - segment_key_field_ids[i] = match[0]; - } - - // Check key fields and segment key fields are disjoint - std::unordered_set key_field_id_set(key_field_ids.begin(), key_field_ids.end()); - for (const auto& segment_key_field_id : segment_key_field_ids) { - if (key_field_id_set.find(segment_key_field_id) != key_field_id_set.end()) { - return Status::Invalid("Group-by aggregation with field '", - input_schema->field(segment_key_field_id)->name(), - "' as both key and segment key"); - } - } - - // Find input field indices for aggregates - std::vector> agg_src_fieldsets(aggs.size()); - for (size_t i = 0; i < aggs.size(); ++i) { - const auto& target_fieldset = aggs[i].target; - for (const auto& target : target_fieldset) { - ARROW_ASSIGN_OR_RAISE(auto match, target.FindOne(*input_schema)); - agg_src_fieldsets[i].push_back(match[0]); - } - } - - // Build vector of aggregate source field data types - std::vector> agg_src_types(aggs.size()); - for (size_t i = 0; i < aggs.size(); ++i) { - for (const auto& agg_src_field_id : agg_src_fieldsets[i]) { - agg_src_types[i].push_back(input_schema->field(agg_src_field_id)->type().get()); - } - } - - // Build vector of segment key field data types - std::vector segment_key_types(segment_keys.size()); - for (size_t i = 0; i < segment_keys.size(); ++i) { - auto segment_key_field_id = segment_key_field_ids[i]; - segment_key_types[i] = input_schema->field(segment_key_field_id)->type().get(); - } - - auto ctx = plan->query_context()->exec_context(); - - ARROW_ASSIGN_OR_RAISE(auto segmenter, - RowSegmenter::Make(std::move(segment_key_types), - /*nullable_keys=*/false, ctx)); - - // Construct aggregates - ARROW_ASSIGN_OR_RAISE(auto agg_kernels, GetKernels(ctx, aggs, agg_src_types)); - - ARROW_ASSIGN_OR_RAISE(auto agg_states, - InitKernels(agg_kernels, ctx, aggs, agg_src_types)); - - ARROW_ASSIGN_OR_RAISE( - FieldVector agg_result_fields, - ResolveKernels(aggs, agg_kernels, agg_states, ctx, agg_src_types)); + const auto& input_schema = *inputs[0]->output_schema(); + auto exec_ctx = plan->query_context()->exec_context(); - // Build field vector for output schema - FieldVector output_fields{keys.size() + segment_keys.size() + aggs.size()}; + ARROW_ASSIGN_OR_RAISE(auto args, + aggregate::MakeAggregateNodeArgs( + input_schema, keys, segment_keys, aggregates, + plan->query_context()->max_concurrency(), exec_ctx)); - // Aggregate fields come before key fields to match the behavior of GroupBy function - for (size_t i = 0; i < aggs.size(); ++i) { - output_fields[i] = - agg_result_fields[i]->WithName(aggregate_options.aggregates[i].name); + std::vector kernels; + kernels.reserve(args.kernels.size()); + for (auto kernel : args.kernels) { + kernels.push_back(static_cast(kernel)); } - size_t base = aggs.size(); - for (size_t i = 0; i < keys.size(); ++i) { - int key_field_id = key_field_ids[i]; - output_fields[base + i] = input_schema->field(key_field_id); - } - base += keys.size(); - for (size_t i = 0; i < segment_keys.size(); ++i) { - int segment_key_field_id = segment_key_field_ids[i]; - output_fields[base + i] = input_schema->field(segment_key_field_id); - } - - return input->plan()->EmplaceNode( - input, schema(std::move(output_fields)), std::move(key_field_ids), - std::move(segment_key_field_ids), std::move(segmenter), std::move(agg_src_types), - std::move(agg_src_fieldsets), std::move(aggs), std::move(agg_kernels)); + return inputs[0]->plan()->EmplaceNode( + inputs[0], std::move(args.output_schema), std::move(args.grouping_key_field_ids), + std::move(args.segment_key_field_ids), std::move(args.segmenter), + std::move(args.kernel_intypes), std::move(args.target_fieldsets), + std::move(args.aggregates), std::move(kernels)); } Status ResetKernelStates() { auto ctx = plan()->query_context()->exec_context(); - ARROW_RETURN_NOT_OK(InitKernels(agg_kernels_, ctx, aggs_, agg_src_types_)); + ARROW_RETURN_NOT_OK(InitKernels(InitHashAggregateKernel, agg_kernels_, ctx, + /*num_states_per_kernel=*/1, aggs_, agg_src_types_)); return Status::OK(); } @@ -703,7 +635,7 @@ class GroupByNode : public ExecNode, public TracedNode { {"function.kind", std::string(kind_name()) + "::Consume"}}); auto ctx = plan_->query_context()->exec_context(); KernelContext kernel_ctx{ctx}; - kernel_ctx.SetState(state->agg_states[i].get()); + kernel_ctx.SetState(state->agg_states[i][0].get()); std::vector column_values; for (const int field : agg_src_fieldsets_[i]) { @@ -745,13 +677,13 @@ class GroupByNode : public ExecNode, public TracedNode { auto ctx = plan_->query_context()->exec_context(); KernelContext batch_ctx{ctx}; - DCHECK(state0->agg_states[i]); - batch_ctx.SetState(state0->agg_states[i].get()); + DCHECK(state0->agg_states[i][0]); + batch_ctx.SetState(state0->agg_states[i][0].get()); RETURN_NOT_OK(agg_kernels_[i]->resize(&batch_ctx, state0->grouper->num_groups())); - RETURN_NOT_OK(agg_kernels_[i]->merge(&batch_ctx, std::move(*state->agg_states[i]), - *transposition.array())); - state->agg_states[i].reset(); + RETURN_NOT_OK(agg_kernels_[i]->merge( + &batch_ctx, std::move(*state->agg_states[i][0]), *transposition.array())); + state->agg_states[i][0].reset(); } } return Status::OK(); @@ -779,9 +711,9 @@ class GroupByNode : public ExecNode, public TracedNode { aggs_[i].options ? aggs_[i].options->ToString() : ""}, {"function.kind", std::string(kind_name()) + "::Finalize"}}); KernelContext batch_ctx{plan_->query_context()->exec_context()}; - batch_ctx.SetState(state->agg_states[i].get()); + batch_ctx.SetState(state->agg_states[i][0].get()); RETURN_NOT_OK(agg_kernels_[i]->finalize(&batch_ctx, &out_data.values[i])); - state->agg_states[i].reset(); + state->agg_states[i][0].reset(); } ARROW_ASSIGN_OR_RAISE(ExecBatch out_keys, state->grouper->GetUniques()); @@ -893,7 +825,7 @@ class GroupByNode : public ExecNode, public TracedNode { private: struct ThreadLocalState { std::unique_ptr grouper; - std::vector> agg_states; + std::vector>> agg_states; }; ThreadLocalState* GetLocalState() { @@ -926,10 +858,10 @@ class GroupByNode : public ExecNode, public TracedNode { } } - ARROW_ASSIGN_OR_RAISE( - state->agg_states, - InitKernels(agg_kernels_, plan_->query_context()->exec_context(), aggs_, - agg_src_types)); + ARROW_ASSIGN_OR_RAISE(state->agg_states, + InitKernels(InitHashAggregateKernel, agg_kernels_, + plan_->query_context()->exec_context(), + /*num_states_per_kernel=*/1, aggs_, agg_src_types)); return Status::OK(); } @@ -968,6 +900,109 @@ class GroupByNode : public ExecNode, public TracedNode { } // namespace +namespace aggregate { + +Result MakeAggregateNodeArgs(const Schema& input_schema, + const std::vector& keys, + const std::vector& segment_keys, + const std::vector& aggs, + size_t num_states_per_kernel, + ExecContext* exec_ctx) { + std::vector aggregates(aggs); + GetKernel get_kernel = keys.empty() ? GetScalarAggregateKernel : GetHashAggregateKernel; + InitKernel init_kernel = + keys.empty() ? InitScalarAggregateKernel : InitHashAggregateKernel; + ResolveKernels resolve_kernels = + keys.empty() ? ResolveScalarAggregateKernels : ResolveHashAggregateKernels; + + // Find input field indices for key fields + std::vector key_field_ids(keys.size()); + for (size_t i = 0; i < keys.size(); ++i) { + ARROW_ASSIGN_OR_RAISE(auto match, keys[i].FindOne(input_schema)); + if (match.indices().size() > 1) { + // ARROW-18369: Support nested references as segment ids + return Status::Invalid("Nested references cannot be used as segment ids"); + } + key_field_ids[i] = match[0]; + } + + std::vector segment_field_ids(segment_keys.size()); + std::vector segment_key_types(segment_keys.size()); + for (size_t i = 0; i < segment_keys.size(); i++) { + ARROW_ASSIGN_OR_RAISE(FieldPath match, segment_keys[i].FindOne(input_schema)); + if (match.indices().size() > 1) { + // ARROW-18369: Support nested references as segment ids + return Status::Invalid("Nested references cannot be used as segment ids"); + } + segment_field_ids[i] = match[0]; + segment_key_types[i] = input_schema.field(match[0])->type().get(); + } + + ARROW_ASSIGN_OR_RAISE(auto segmenter, + RowSegmenter::Make(std::move(segment_key_types), + /*nullable_keys=*/false, exec_ctx)); + + std::vector> kernel_intypes(aggregates.size()); + FieldVector fields(aggregates.size() + keys.size() + segment_keys.size()); + std::vector> target_fieldsets(aggregates.size()); + + for (size_t i = 0; i < aggregates.size(); ++i) { + const auto& target_fieldset = aggregates[i].target; + for (const auto& target : target_fieldset) { + ARROW_ASSIGN_OR_RAISE(auto match, FieldRef(target).FindOne(input_schema)); + target_fieldsets[i].push_back(match[0]); + } + + std::vector in_types; + for (const auto& target : target_fieldsets[i]) { + in_types.emplace_back(input_schema.field(target)->type().get()); + } + kernel_intypes[i] = in_types; + } + + ARROW_ASSIGN_OR_RAISE(auto kernels, + GetKernels(get_kernel, exec_ctx, &aggregates, kernel_intypes)); + + ARROW_ASSIGN_OR_RAISE(auto states, + InitKernels(init_kernel, kernels, exec_ctx, num_states_per_kernel, + aggregates, kernel_intypes)); + ARROW_ASSIGN_OR_RAISE(auto resolved_fields, resolve_kernels(aggregates, kernels, states, + exec_ctx, kernel_intypes)); + + for (size_t i = 0; i < aggregates.size(); ++i) { + fields[i] = resolved_fields[i]->WithName(aggregates[i].name); + } + for (size_t i = 0; i < keys.size(); ++i) { + ARROW_ASSIGN_OR_RAISE(fields[kernels.size() + i], keys[i].GetOne(input_schema)); + } + for (size_t i = 0; i < segment_keys.size(); ++i) { + ARROW_ASSIGN_OR_RAISE(fields[kernels.size() + keys.size() + i], + segment_keys[i].GetOne(input_schema)); + } + + return AggregateNodeArgs{schema(std::move(fields)), + std::move(key_field_ids), + std::move(segment_field_ids), + std::move(segmenter), + std::move(target_fieldsets), + std::move(aggregates), + std::move(kernels), + std::move(kernel_intypes), + std::move(states)}; +} + +Result> MakeOutputSchema( + const Schema& input_schema, const std::vector& keys, + const std::vector& segment_keys, const std::vector& aggregates, + ExecContext* exec_ctx) { + ARROW_ASSIGN_OR_RAISE( + auto args, MakeAggregateNodeArgs(input_schema, keys, segment_keys, aggregates, + /*num_states_per_kernel=*/0, exec_ctx)); + return std::move(args.output_schema); +} + +} // namespace aggregate + namespace internal { void RegisterAggregateNode(ExecFactoryRegistry* registry) { diff --git a/cpp/src/arrow/acero/aggregate_node.h b/cpp/src/arrow/acero/aggregate_node.h new file mode 100644 index 00000000000..100a0a2446c --- /dev/null +++ b/cpp/src/arrow/acero/aggregate_node.h @@ -0,0 +1,76 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#include +#include + +#include "arrow/acero/visibility.h" +#include "arrow/compute/api_aggregate.h" +#include "arrow/compute/row/grouper.h" +#include "arrow/compute/type_fwd.h" +#include "arrow/type_fwd.h" + +namespace arrow { +namespace acero { +namespace aggregate { + +using compute::Aggregate; +using compute::default_exec_context; +using compute::ExecContext; +using compute::Kernel; +using compute::KernelState; +using compute::RowSegmenter; + +struct ARROW_ACERO_EXPORT AggregateNodeArgs { + std::shared_ptr output_schema; + std::vector grouping_key_field_ids; + std::vector segment_key_field_ids; + std::unique_ptr segmenter; + std::vector> target_fieldsets; + std::vector aggregates; + std::vector kernels; + std::vector> kernel_intypes; + std::vector>> states; +}; + +/// \brief Make the arguments of an aggregate node +/// +/// \param[in] input_schema the schema of the input to the node +/// \param[in] keys the grouping keys for the aggregation +/// \param[in] segment_keys the segmenting keys for the aggregation +/// \param[in] num_states_per_kernel number of states per kernel for the aggregation +/// \param[in] exec_ctx the execution context for the aggregation +ARROW_ACERO_EXPORT Result MakeAggregateNodeArgs( + const Schema& input_schema, const std::vector& keys, + const std::vector& segment_keys, const std::vector& aggregates, + size_t num_states_per_kernel = 1, ExecContext* exec_ctx = default_exec_context()); + +/// \brief Make the output schema of an aggregate node +/// +/// \param[in] input_schema the schema of the input to the node +/// \param[in] keys the grouping keys for the aggregation +/// \param[in] segment_keys the segmenting keys for the aggregation +/// \param[in] num_states_per_kernel number of states per kernel for the aggregation +/// \param[in] exec_ctx the execution context for the aggregation +ARROW_ACERO_EXPORT Result> MakeOutputSchema( + const Schema& input_schema, const std::vector& keys, + const std::vector& segment_keys, const std::vector& aggregates, + size_t num_states_per_kernel, ExecContext* exec_ctx = default_exec_context()); + +} // namespace aggregate +} // namespace acero +} // namespace arrow diff --git a/cpp/src/arrow/acero/plan_test.cc b/cpp/src/arrow/acero/plan_test.cc index a3ba1946a1a..1718b08d797 100644 --- a/cpp/src/arrow/acero/plan_test.cc +++ b/cpp/src/arrow/acero/plan_test.cc @@ -517,7 +517,7 @@ TEST(ExecPlan, ToString) { custom_sink_label:OrderBySinkNode{by={sort_keys=[FieldRef.Name(sum(multiply(i32, 2))) ASC], null_placement=AtEnd}} :FilterNode{filter=(sum(multiply(i32, 2)) > 10)} :GroupByNode{keys=["bool"], aggregates=[ - hash_sum(multiply(i32, 2)), + hash_sum(multiply(i32, 2), {skip_nulls=true, min_count=1}), hash_count(multiply(i32, 2), {mode=NON_NULL}), hash_count_all(*), ]} diff --git a/cpp/src/arrow/engine/substrait/options.cc b/cpp/src/arrow/engine/substrait/options.cc index 979db875df2..905a65163ab 100644 --- a/cpp/src/arrow/engine/substrait/options.cc +++ b/cpp/src/arrow/engine/substrait/options.cc @@ -20,6 +20,7 @@ #include #include +#include "arrow/acero/aggregate_node.h" #include "arrow/acero/asof_join_node.h" #include "arrow/acero/options.h" #include "arrow/engine/substrait/expression_internal.h" @@ -187,48 +188,37 @@ class DefaultExtensionProvider : public BaseExtensionProvider { auto input_schema = inputs[0].output_schema; - // store key fields to be used when output schema is created - std::vector key_field_ids; std::vector keys; for (auto& ref : seg_agg_rel.grouping_keys()) { ARROW_ASSIGN_OR_RAISE(auto field_ref, DirectReferenceFromProto(&ref, ext_set, conv_opts)); - ARROW_ASSIGN_OR_RAISE(auto match, field_ref.FindOne(*input_schema)); - key_field_ids.emplace_back(std::move(match[0])); keys.emplace_back(std::move(field_ref)); } - // store segment key fields to be used when output schema is created - std::vector segment_key_field_ids; std::vector segment_keys; for (auto& ref : seg_agg_rel.segment_keys()) { ARROW_ASSIGN_OR_RAISE(auto field_ref, DirectReferenceFromProto(&ref, ext_set, conv_opts)); - ARROW_ASSIGN_OR_RAISE(auto match, field_ref.FindOne(*input_schema)); - segment_key_field_ids.emplace_back(std::move(match[0])); segment_keys.emplace_back(std::move(field_ref)); } std::vector aggregates; aggregates.reserve(seg_agg_rel.measures_size()); - std::vector> agg_src_fieldsets; - agg_src_fieldsets.reserve(seg_agg_rel.measures_size()); for (auto agg_measure : seg_agg_rel.measures()) { - ARROW_ASSIGN_OR_RAISE( - auto parsed_measure, - internal::ParseAggregateMeasure(agg_measure, ext_set, conv_opts, - /*is_hash=*/!keys.empty(), input_schema)); - aggregates.push_back(std::move(parsed_measure.aggregate)); - agg_src_fieldsets.push_back(std::move(parsed_measure.fieldset)); + ARROW_ASSIGN_OR_RAISE(auto aggregate, internal::ParseAggregateMeasure( + agg_measure, ext_set, conv_opts, + /*is_hash=*/!keys.empty(), input_schema)); + aggregates.push_back(std::move(aggregate)); } - ARROW_ASSIGN_OR_RAISE(auto decl_info, - internal::MakeAggregateDeclaration( - std::move(inputs[0].declaration), std::move(input_schema), - seg_agg_rel.measures_size(), std::move(aggregates), - std::move(agg_src_fieldsets), std::move(keys), - std::move(key_field_ids), std::move(segment_keys), - std::move(segment_key_field_ids), ext_set, conv_opts)); + ARROW_ASSIGN_OR_RAISE(auto args, acero::aggregate::MakeAggregateNodeArgs( + *input_schema, keys, segment_keys, aggregates)); + + ARROW_ASSIGN_OR_RAISE( + auto decl_info, + internal::MakeAggregateDeclaration( + std::move(inputs[0].declaration), std::move(args.output_schema), + std::move(aggregates), std::move(keys), std::move(segment_keys))); const auto& output_schema = decl_info.output_schema; size_t out_size = output_schema->num_fields(); diff --git a/cpp/src/arrow/engine/substrait/relation_internal.cc b/cpp/src/arrow/engine/substrait/relation_internal.cc index d1a81d3eaf6..4716e2b0614 100644 --- a/cpp/src/arrow/engine/substrait/relation_internal.cc +++ b/cpp/src/arrow/engine/substrait/relation_internal.cc @@ -28,6 +28,7 @@ #include #include +#include "arrow/acero/aggregate_node.h" #include "arrow/acero/exec_plan.h" #include "arrow/acero/options.h" #include "arrow/compute/api_aggregate.h" @@ -294,7 +295,7 @@ Status DiscoverFilesFromDir(const std::shared_ptr& local_fs namespace internal { -Result ParseAggregateMeasure( +Result ParseAggregateMeasure( const substrait::AggregateRel::Measure& agg_measure, const ExtensionSet& ext_set, const ConversionOptions& conversion_options, bool is_hash, const std::shared_ptr input_schema) { @@ -314,50 +315,16 @@ Result ParseAggregateMeasure( ARROW_ASSIGN_OR_RAISE(converter, ext_set.registry()->GetSubstraitAggregateToArrow( aggregate_call.id())); } - ARROW_ASSIGN_OR_RAISE(compute::Aggregate arrow_agg, converter(aggregate_call)); - - // find aggregate field ids from schema - const auto& target = arrow_agg.target; - std::vector fieldset; - fieldset.reserve(target.size()); - for (const auto& field_ref : target) { - ARROW_ASSIGN_OR_RAISE(auto match, field_ref.FindOne(*input_schema)); - fieldset.push_back(match[0]); - } - - return ParsedMeasure{std::move(arrow_agg), std::move(fieldset)}; + return converter(aggregate_call); } else { return Status::Invalid("substrait::AggregateFunction not provided"); } } ARROW_ENGINE_EXPORT Result MakeAggregateDeclaration( - acero::Declaration input_decl, std::shared_ptr input_schema, - const int measure_size, std::vector aggregates, - std::vector> agg_src_fieldsets, std::vector keys, - std::vector key_field_ids, std::vector segment_keys, - std::vector segment_key_field_ids, const ExtensionSet& ext_set, - const ConversionOptions& conversion_options) { - FieldVector output_fields; - output_fields.reserve(key_field_ids.size() + segment_key_field_ids.size() + - measure_size); - // extract aggregate fields to output schema - for (const auto& agg_src_fieldset : agg_src_fieldsets) { - for (int field : agg_src_fieldset) { - output_fields.emplace_back(input_schema->field(field)); - } - } - // extract key fields to output schema - for (int key_field_id : key_field_ids) { - output_fields.emplace_back(input_schema->field(key_field_id)); - } - // extract segment key fields to output schema - for (int segment_key_field_id : segment_key_field_ids) { - output_fields.emplace_back(input_schema->field(segment_key_field_id)); - } - - std::shared_ptr aggregate_schema = schema(std::move(output_fields)); - + acero::Declaration input_decl, std::shared_ptr aggregate_schema, + std::vector aggregates, std::vector keys, + std::vector segment_keys) { return DeclarationInfo{ acero::Declaration::Sequence( {std::move(input_decl), @@ -776,22 +743,17 @@ Result FromProto(const substrait::Rel& rel, const ExtensionSet& // prepare output schema from aggregates auto input_schema = input.output_schema; - // store key fields to be used when output schema is created - std::vector key_field_ids; std::vector keys; if (aggregate.groupings_size() > 0) { const substrait::AggregateRel::Grouping& group = aggregate.groupings(0); int grouping_expr_size = group.grouping_expressions_size(); keys.reserve(grouping_expr_size); - key_field_ids.reserve(grouping_expr_size); for (int exp_id = 0; exp_id < grouping_expr_size; exp_id++) { ARROW_ASSIGN_OR_RAISE( compute::Expression expr, FromProto(group.grouping_expressions(exp_id), ext_set, conversion_options)); const FieldRef* field_ref = expr.field_ref(); if (field_ref) { - ARROW_ASSIGN_OR_RAISE(auto match, field_ref->FindOne(*input_schema)); - key_field_ids.emplace_back(std::move(match[0])); keys.emplace_back(std::move(*field_ref)); } else { return Status::Invalid( @@ -803,25 +765,24 @@ Result FromProto(const substrait::Rel& rel, const ExtensionSet& const int measure_size = aggregate.measures_size(); std::vector aggregates; aggregates.reserve(measure_size); - // store aggregate fields to be used when output schema is created - std::vector> agg_src_fieldsets; - agg_src_fieldsets.reserve(measure_size); for (int measure_id = 0; measure_id < measure_size; measure_id++) { const auto& agg_measure = aggregate.measures(measure_id); ARROW_ASSIGN_OR_RAISE( - auto parsed_measure, + auto aggregate, internal::ParseAggregateMeasure(agg_measure, ext_set, conversion_options, /*is_hash=*/!keys.empty(), input_schema)); - aggregates.push_back(std::move(parsed_measure.aggregate)); - agg_src_fieldsets.push_back(std::move(parsed_measure.fieldset)); + aggregates.push_back(std::move(aggregate)); } + ARROW_ASSIGN_OR_RAISE(auto args, + acero::aggregate::MakeAggregateNodeArgs( + *input_schema, keys, /*segment_keys=*/{}, aggregates)); + ARROW_ASSIGN_OR_RAISE( auto aggregate_declaration, internal::MakeAggregateDeclaration( - std::move(input.declaration), std::move(input_schema), measure_size, - std::move(aggregates), std::move(agg_src_fieldsets), std::move(keys), - std::move(key_field_ids), {}, {}, ext_set, conversion_options)); + std::move(input.declaration), std::move(args.output_schema), + std::move(aggregates), std::move(keys), /*segment_keys=*/{})); auto aggregate_schema = aggregate_declaration.output_schema; return ProcessEmit(std::move(aggregate), std::move(aggregate_declaration), diff --git a/cpp/src/arrow/engine/substrait/relation_internal.h b/cpp/src/arrow/engine/substrait/relation_internal.h index 72a0c3f98af..a436f1770d7 100644 --- a/cpp/src/arrow/engine/substrait/relation_internal.h +++ b/cpp/src/arrow/engine/substrait/relation_internal.h @@ -50,11 +50,6 @@ ARROW_ENGINE_EXPORT Result> ToProto( namespace internal { -struct ParsedMeasure { - compute::Aggregate aggregate; - std::vector fieldset; -}; - /// \brief Parse an aggregate relation's measure /// /// \param[in] agg_measure the measure @@ -63,7 +58,7 @@ struct ParsedMeasure { /// \param[in] input_schema the schema to which field refs apply /// \param[in] is_hash whether the measure is a hash one (i.e., aggregation keys exist) ARROW_ENGINE_EXPORT -Result ParseAggregateMeasure( +Result ParseAggregateMeasure( const substrait::AggregateRel::Measure& agg_measure, const ExtensionSet& ext_set, const ConversionOptions& conversion_options, bool is_hash, const std::shared_ptr input_schema); @@ -71,23 +66,14 @@ Result ParseAggregateMeasure( /// \brief Make an aggregate declaration info /// /// \param[in] input_decl the input declaration to use -/// \param[in] input_schema the schema to which field refs apply -/// \param[in] measure_size the number of measures to use +/// \param[in] output_schema the schema to which field refs apply /// \param[in] aggregates the aggregates to use -/// \param[in] agg_src_fieldsets the field-sets per aggregate to use /// \param[in] keys the field-refs for grouping keys to use -/// \param[in] key_field_ids the field-ids for grouping keys to use /// \param[in] segment_keys the field-refs for segment keys to use -/// \param[in] segment_key_field_ids the field-ids for segment keys to use -/// \param[in] ext_set an extension mapping to use -/// \param[in] conversion_options options to control how the conversion is done ARROW_ENGINE_EXPORT Result MakeAggregateDeclaration( - acero::Declaration input_decl, std::shared_ptr input_schema, - const int measure_size, std::vector aggregates, - std::vector> agg_src_fieldsets, std::vector keys, - std::vector key_field_ids, std::vector segment_keys, - std::vector segment_key_field_ids, const ExtensionSet& ext_set, - const ConversionOptions& conversion_options); + acero::Declaration input_decl, std::shared_ptr output_schema, + std::vector aggregates, std::vector keys, + std::vector segment_keys); } // namespace internal From 85349f1e898bbb3e516c19837dac0f790840bca9 Mon Sep 17 00:00:00 2001 From: Yaron Gvili Date: Tue, 4 Apr 2023 15:54:19 -0400 Subject: [PATCH 2/2] clean up aggergate node API --- cpp/src/arrow/acero/aggregate_node.cc | 31 ++++++++++++++++- cpp/src/arrow/acero/aggregate_node.h | 34 +++---------------- cpp/src/arrow/engine/substrait/options.cc | 15 ++++---- .../engine/substrait/relation_internal.cc | 11 +++--- 4 files changed, 47 insertions(+), 44 deletions(-) diff --git a/cpp/src/arrow/acero/aggregate_node.cc b/cpp/src/arrow/acero/aggregate_node.cc index a3ee058ff5b..d1ba9eb9381 100644 --- a/cpp/src/arrow/acero/aggregate_node.cc +++ b/cpp/src/arrow/acero/aggregate_node.cc @@ -76,6 +76,35 @@ using compute::Segment; namespace acero { +namespace aggregate { + +struct ARROW_ACERO_EXPORT AggregateNodeArgs { + std::shared_ptr output_schema; + std::vector grouping_key_field_ids; + std::vector segment_key_field_ids; + std::unique_ptr segmenter; + std::vector> target_fieldsets; + std::vector aggregates; + std::vector kernels; + std::vector> kernel_intypes; + std::vector>> states; +}; + +/// \brief Make the arguments of an aggregate node +/// +/// \param[in] input_schema the schema of the input to the node +/// \param[in] keys the grouping keys for the aggregation +/// \param[in] segment_keys the segmenting keys for the aggregation +/// \param[in] aggregates the aggregates for the aggregation +/// \param[in] num_states_per_kernel number of states per kernel for the aggregation +/// \param[in] exec_ctx the execution context for the aggregation +ARROW_ACERO_EXPORT Result MakeAggregateNodeArgs( + const Schema& input_schema, const std::vector& keys, + const std::vector& segment_keys, const std::vector& aggregates, + size_t num_states_per_kernel = 1, ExecContext* exec_ctx = default_exec_context()); + +} // namespace aggregate + namespace { std::vector ExtendWithGroupIdType(const std::vector& in_types) { @@ -997,7 +1026,7 @@ Result> MakeOutputSchema( ExecContext* exec_ctx) { ARROW_ASSIGN_OR_RAISE( auto args, MakeAggregateNodeArgs(input_schema, keys, segment_keys, aggregates, - /*num_states_per_kernel=*/0, exec_ctx)); + /*num_states_per_kernel=*/1, exec_ctx)); return std::move(args.output_schema); } diff --git a/cpp/src/arrow/acero/aggregate_node.h b/cpp/src/arrow/acero/aggregate_node.h index 100a0a2446c..342a429ddda 100644 --- a/cpp/src/arrow/acero/aggregate_node.h +++ b/cpp/src/arrow/acero/aggregate_node.h @@ -15,6 +15,10 @@ // specific language governing permissions and limitations // under the License. +// This API is EXPERIMENTAL. + +#pragma once + #include #include @@ -31,45 +35,17 @@ namespace aggregate { using compute::Aggregate; using compute::default_exec_context; using compute::ExecContext; -using compute::Kernel; -using compute::KernelState; -using compute::RowSegmenter; - -struct ARROW_ACERO_EXPORT AggregateNodeArgs { - std::shared_ptr output_schema; - std::vector grouping_key_field_ids; - std::vector segment_key_field_ids; - std::unique_ptr segmenter; - std::vector> target_fieldsets; - std::vector aggregates; - std::vector kernels; - std::vector> kernel_intypes; - std::vector>> states; -}; - -/// \brief Make the arguments of an aggregate node -/// -/// \param[in] input_schema the schema of the input to the node -/// \param[in] keys the grouping keys for the aggregation -/// \param[in] segment_keys the segmenting keys for the aggregation -/// \param[in] num_states_per_kernel number of states per kernel for the aggregation -/// \param[in] exec_ctx the execution context for the aggregation -ARROW_ACERO_EXPORT Result MakeAggregateNodeArgs( - const Schema& input_schema, const std::vector& keys, - const std::vector& segment_keys, const std::vector& aggregates, - size_t num_states_per_kernel = 1, ExecContext* exec_ctx = default_exec_context()); /// \brief Make the output schema of an aggregate node /// /// \param[in] input_schema the schema of the input to the node /// \param[in] keys the grouping keys for the aggregation /// \param[in] segment_keys the segmenting keys for the aggregation -/// \param[in] num_states_per_kernel number of states per kernel for the aggregation /// \param[in] exec_ctx the execution context for the aggregation ARROW_ACERO_EXPORT Result> MakeOutputSchema( const Schema& input_schema, const std::vector& keys, const std::vector& segment_keys, const std::vector& aggregates, - size_t num_states_per_kernel, ExecContext* exec_ctx = default_exec_context()); + ExecContext* exec_ctx = default_exec_context()); } // namespace aggregate } // namespace acero diff --git a/cpp/src/arrow/engine/substrait/options.cc b/cpp/src/arrow/engine/substrait/options.cc index 905a65163ab..1d6a01774dc 100644 --- a/cpp/src/arrow/engine/substrait/options.cc +++ b/cpp/src/arrow/engine/substrait/options.cc @@ -211,16 +211,15 @@ class DefaultExtensionProvider : public BaseExtensionProvider { aggregates.push_back(std::move(aggregate)); } - ARROW_ASSIGN_OR_RAISE(auto args, acero::aggregate::MakeAggregateNodeArgs( - *input_schema, keys, segment_keys, aggregates)); + ARROW_ASSIGN_OR_RAISE(auto output_schema, + acero::aggregate::MakeOutputSchema(*input_schema, keys, + segment_keys, aggregates)); - ARROW_ASSIGN_OR_RAISE( - auto decl_info, - internal::MakeAggregateDeclaration( - std::move(inputs[0].declaration), std::move(args.output_schema), - std::move(aggregates), std::move(keys), std::move(segment_keys))); + ARROW_ASSIGN_OR_RAISE(auto decl_info, internal::MakeAggregateDeclaration( + std::move(inputs[0].declaration), + output_schema, std::move(aggregates), + std::move(keys), std::move(segment_keys))); - const auto& output_schema = decl_info.output_schema; size_t out_size = output_schema->num_fields(); std::vector field_output_indices(out_size); for (int i = 0; i < static_cast(out_size); i++) { diff --git a/cpp/src/arrow/engine/substrait/relation_internal.cc b/cpp/src/arrow/engine/substrait/relation_internal.cc index 4716e2b0614..d4fbffad7b2 100644 --- a/cpp/src/arrow/engine/substrait/relation_internal.cc +++ b/cpp/src/arrow/engine/substrait/relation_internal.cc @@ -774,17 +774,16 @@ Result FromProto(const substrait::Rel& rel, const ExtensionSet& aggregates.push_back(std::move(aggregate)); } - ARROW_ASSIGN_OR_RAISE(auto args, - acero::aggregate::MakeAggregateNodeArgs( + ARROW_ASSIGN_OR_RAISE(auto aggregate_schema, + acero::aggregate::MakeOutputSchema( *input_schema, keys, /*segment_keys=*/{}, aggregates)); ARROW_ASSIGN_OR_RAISE( auto aggregate_declaration, - internal::MakeAggregateDeclaration( - std::move(input.declaration), std::move(args.output_schema), - std::move(aggregates), std::move(keys), /*segment_keys=*/{})); + internal::MakeAggregateDeclaration(std::move(input.declaration), + aggregate_schema, std::move(aggregates), + std::move(keys), /*segment_keys=*/{})); - auto aggregate_schema = aggregate_declaration.output_schema; return ProcessEmit(std::move(aggregate), std::move(aggregate_declaration), std::move(aggregate_schema)); }