From 6cd2df60bd414669b960b59c4c8b229f06b4d700 Mon Sep 17 00:00:00 2001 From: Yaron Gvili Date: Mon, 27 Mar 2023 09:33:01 -0400 Subject: [PATCH 1/5] GH-34626: [C++] Add ordered/segmented aggregation Substrait extension --- cpp/proto/substrait/extension_rels.proto | 13 ++ cpp/src/arrow/compute/exec/source_node.cc | 15 ++ .../engine/substrait/expression_internal.cc | 133 +++++++------ .../engine/substrait/expression_internal.h | 6 + cpp/src/arrow/engine/substrait/options.cc | 67 +++++++ .../engine/substrait/relation_internal.cc | 174 +++++++++++------- .../engine/substrait/relation_internal.h | 49 +++++ cpp/src/arrow/engine/substrait/serde_test.cc | 125 +++++++++++++ 8 files changed, 459 insertions(+), 123 deletions(-) diff --git a/cpp/proto/substrait/extension_rels.proto b/cpp/proto/substrait/extension_rels.proto index 78c11b7d7e2..25073903c60 100644 --- a/cpp/proto/substrait/extension_rels.proto +++ b/cpp/proto/substrait/extension_rels.proto @@ -58,3 +58,16 @@ message NamedTapRel { // If empty, field names will be automatically generated. repeated string columns = 3; } + +message SegmentedAggregateRel { + substrait.RelCommon common = 1; + + // Grouping keys of the aggregation + repeated substrait.Expression.ReferenceSegment grouping_keys = 2; + + // Segment keys of the aggregation + repeated substrait.Expression.ReferenceSegment segment_keys = 3; + + // A list of one or more aggregate expressions along with an optional filter. + repeated substrait.AggregateRel.Measure measures = 4; +} diff --git a/cpp/src/arrow/compute/exec/source_node.cc b/cpp/src/arrow/compute/exec/source_node.cc index 310c0c64740..b6bdc7c5db5 100644 --- a/cpp/src/arrow/compute/exec/source_node.cc +++ b/cpp/src/arrow/compute/exec/source_node.cc @@ -28,8 +28,10 @@ #include "arrow/compute/expression.h" #include "arrow/datum.h" #include "arrow/io/util_internal.h" +#include "arrow/ipc/util.h" #include "arrow/result.h" #include "arrow/table.h" +#include "arrow/util/align_util.h" #include "arrow/util/async_generator.h" #include "arrow/util/async_util.h" #include "arrow/util/checked_cast.h" @@ -102,6 +104,19 @@ struct SourceNode : ExecNode, public TracedNode { batch_size = morsel_length; } ExecBatch batch = morsel.Slice(offset, batch_size); + for (auto& value : batch.values) { + if (value.is_array()) { + ARROW_ASSIGN_OR_RAISE( + value, util::EnsureAlignment(value.make_array(), ipc::kArrowAlignment, + default_memory_pool())); + } + if (value.is_chunked_array()) { + ARROW_ASSIGN_OR_RAISE( + value, + util::EnsureAlignment(value.chunked_array(), ipc::kArrowAlignment, + default_memory_pool())); + } + } if (has_ordering) { batch.index = batch_index; } diff --git a/cpp/src/arrow/engine/substrait/expression_internal.cc b/cpp/src/arrow/engine/substrait/expression_internal.cc index e1223a51329..43a44440bef 100644 --- a/cpp/src/arrow/engine/substrait/expression_internal.cc +++ b/cpp/src/arrow/engine/substrait/expression_internal.cc @@ -138,6 +138,77 @@ std::string EnumToString(int value, const google::protobuf::EnumDescriptor* desc return value_desc->name(); } +Result FromProto(const substrait::Expression::ReferenceSegment* ref, + const ExtensionSet& ext_set, + const ConversionOptions& conversion_options, + std::optional in_expr) { + auto in_ref = ref; + auto& out = in_expr; + while (ref != nullptr) { + switch (ref->reference_type_case()) { + case substrait::Expression::ReferenceSegment::kStructField: { + auto index = ref->struct_field().field(); + if (!out) { + // Root StructField (column selection) + out = compute::field_ref(FieldRef(index)); + } else if (auto out_ref = out->field_ref()) { + // Nested StructFields on the root (selection of struct-typed column + // combined with selecting struct fields) + out = compute::field_ref(FieldRef(*out_ref, index)); + } else if (out->call() && out->call()->function_name == "struct_field") { + // Nested StructFields on top of an arbitrary expression + auto* field_options = + checked_cast(out->call()->options.get()); + field_options->field_ref = FieldRef(std::move(field_options->field_ref), index); + } else { + // First StructField on top of an arbitrary expression + out = compute::call("struct_field", {std::move(*out)}, + arrow::compute::StructFieldOptions({index})); + } + + // Segment handled, continue with child segment (if any) + if (ref->struct_field().has_child()) { + ref = &ref->struct_field().child(); + } else { + ref = nullptr; + } + break; + } + case substrait::Expression::ReferenceSegment::kListElement: { + if (!out) { + // Root ListField (illegal) + return Status::Invalid( + "substrait::ListElement cannot take a Relation as an argument"); + } + + // ListField on top of an arbitrary expression + out = compute::call( + "list_element", + {std::move(*out), compute::literal(ref->list_element().offset())}); + + // Segment handled, continue with child segment (if any) + if (ref->list_element().has_child()) { + ref = &ref->list_element().child(); + } else { + ref = nullptr; + } + break; + } + default: + // Unimplemented construct, break out of loop + out.reset(); + ref = nullptr; + } + } + if (out) { + return *std::move(out); + } + + return Status::NotImplemented( + "conversion to arrow::compute::Expression from Substrait reference segment: ", + in_ref ? in_ref->DebugString() : "null"); +} + Result FromProto(const substrait::AggregateFunction& func, bool is_hash, const ExtensionSet& ext_set, const ConversionOptions& conversion_options) { @@ -193,67 +264,7 @@ Result FromProto(const substrait::Expression& expr, } const auto* ref = &expr.selection().direct_reference(); - while (ref != nullptr) { - switch (ref->reference_type_case()) { - case substrait::Expression::ReferenceSegment::kStructField: { - auto index = ref->struct_field().field(); - if (!out) { - // Root StructField (column selection) - out = compute::field_ref(FieldRef(index)); - } else if (auto out_ref = out->field_ref()) { - // Nested StructFields on the root (selection of struct-typed column - // combined with selecting struct fields) - out = compute::field_ref(FieldRef(*out_ref, index)); - } else if (out->call() && out->call()->function_name == "struct_field") { - // Nested StructFields on top of an arbitrary expression - auto* field_options = - checked_cast(out->call()->options.get()); - field_options->field_ref = - FieldRef(std::move(field_options->field_ref), index); - } else { - // First StructField on top of an arbitrary expression - out = compute::call("struct_field", {std::move(*out)}, - arrow::compute::StructFieldOptions({index})); - } - - // Segment handled, continue with child segment (if any) - if (ref->struct_field().has_child()) { - ref = &ref->struct_field().child(); - } else { - ref = nullptr; - } - break; - } - case substrait::Expression::ReferenceSegment::kListElement: { - if (!out) { - // Root ListField (illegal) - return Status::Invalid( - "substrait::ListElement cannot take a Relation as an argument"); - } - - // ListField on top of an arbitrary expression - out = compute::call( - "list_element", - {std::move(*out), compute::literal(ref->list_element().offset())}); - - // Segment handled, continue with child segment (if any) - if (ref->list_element().has_child()) { - ref = &ref->list_element().child(); - } else { - ref = nullptr; - } - break; - } - default: - // Unimplemented construct, break out of loop - out.reset(); - ref = nullptr; - } - } - if (out) { - return *std::move(out); - } - break; + return FromProto(ref, ext_set, conversion_options, std::move(out)); } case substrait::Expression::kIfThen: { diff --git a/cpp/src/arrow/engine/substrait/expression_internal.h b/cpp/src/arrow/engine/substrait/expression_internal.h index e947537dd1e..d522fdf7703 100644 --- a/cpp/src/arrow/engine/substrait/expression_internal.h +++ b/cpp/src/arrow/engine/substrait/expression_internal.h @@ -20,6 +20,7 @@ #pragma once #include +#include #include "arrow/compute/type_fwd.h" #include "arrow/datum.h" @@ -34,6 +35,11 @@ namespace engine { class SubstraitCall; +ARROW_ENGINE_EXPORT +Result FromProto(const substrait::Expression::ReferenceSegment*, + const ExtensionSet&, const ConversionOptions&, + std::optional); + ARROW_ENGINE_EXPORT Result FromProto(const substrait::Expression&, const ExtensionSet&, const ConversionOptions&); diff --git a/cpp/src/arrow/engine/substrait/options.cc b/cpp/src/arrow/engine/substrait/options.cc index 4a4f3a87150..90b745ace8e 100644 --- a/cpp/src/arrow/engine/substrait/options.cc +++ b/cpp/src/arrow/engine/substrait/options.cc @@ -75,6 +75,11 @@ class DefaultExtensionProvider : public BaseExtensionProvider { rel.UnpackTo(&named_tap_rel); return MakeNamedTapRel(conv_opts, inputs, named_tap_rel, ext_set); } + if (rel.Is()) { + substrait_ext::SegmentedAggregateRel seg_agg_rel; + rel.UnpackTo(&seg_agg_rel); + return MakeSegmentedAggregateRel(conv_opts, inputs, seg_agg_rel, ext_set); + } return Status::NotImplemented("Unrecognized extension in Susbstrait plan: ", rel.DebugString()); } @@ -165,6 +170,68 @@ class DefaultExtensionProvider : public BaseExtensionProvider { named_tap_rel.name(), renamed_schema)); return RelationInfo{{std::move(decl), std::move(renamed_schema)}, std::nullopt}; } + + Result MakeSegmentedAggregateRel( + const ConversionOptions& conv_opts, const std::vector& inputs, + const substrait_ext::SegmentedAggregateRel& seg_agg_rel, + const ExtensionSet& ext_set) { + if (inputs.size() != 1) { + return Status::Invalid( + "substrait_ext::SegmentedAggregateRel requires a single input but got: ", + inputs.size()); + } + + auto input_schema = inputs[0].output_schema; + + // store key fields to be used when output schema is created + std::vector key_field_ids; + std::vector keys; + for (auto& key_refseg : seg_agg_rel.grouping_keys()) { + ARROW_ASSIGN_OR_RAISE(auto expr, FromProto(&key_refseg, ext_set, conv_opts, {})); + if (auto field_ref = expr.field_ref()) { + ARROW_ASSIGN_OR_RAISE(auto match, field_ref->FindOne(*input_schema)); + key_field_ids.emplace_back(std::move(match[0])); + keys.emplace_back(std::move(*field_ref)); + } + } + + // store segment key fields to be used when output schema is created + std::vector segment_key_field_ids; + std::vector segment_keys; + for (auto& key_refseg : seg_agg_rel.segment_keys()) { + ARROW_ASSIGN_OR_RAISE(auto expr, FromProto(&key_refseg, ext_set, conv_opts, {})); + if (auto field_ref = expr.field_ref()) { + ARROW_ASSIGN_OR_RAISE(auto match, field_ref->FindOne(*input_schema)); + segment_key_field_ids.emplace_back(std::move(match[0])); + segment_keys.emplace_back(std::move(*field_ref)); + } + } + + std::vector aggregates; + std::vector> agg_src_fieldsets; + for (auto agg_measure : seg_agg_rel.measures()) { + ARROW_RETURN_NOT_OK(internal::ParseAggregateMeasure( + agg_measure, ext_set, conv_opts, /*is_hash=*/!keys.empty(), input_schema, + &aggregates, &agg_src_fieldsets)); + } + + ARROW_ASSIGN_OR_RAISE( + auto decl_info, + internal::MakeAggregateDeclaration( + seg_agg_rel.common(), std::move(inputs[0].declaration), + std::move(input_schema), seg_agg_rel.measures_size(), std::move(aggregates), + std::move(agg_src_fieldsets), std::move(keys), std::move(key_field_ids), + std::move(segment_keys), std::move(segment_key_field_ids), ext_set, + conv_opts)); + + const auto& output_schema = decl_info.output_schema; + size_t out_size = output_schema->num_fields(); + std::vector field_output_indices(out_size); + for (int i = 0; i < static_cast(out_size); i++) { + field_output_indices[i] = i; + } + return RelationInfo{decl_info, std::move(field_output_indices)}; + } }; namespace { diff --git a/cpp/src/arrow/engine/substrait/relation_internal.cc b/cpp/src/arrow/engine/substrait/relation_internal.cc index eefc37607ad..24e09ac1339 100644 --- a/cpp/src/arrow/engine/substrait/relation_internal.cc +++ b/cpp/src/arrow/engine/substrait/relation_internal.cc @@ -74,10 +74,9 @@ struct EmitInfo { std::shared_ptr schema; }; -template -Result GetEmitInfo(const RelMessage& rel, +Result GetEmitInfo(const substrait::RelCommon& rel_common, const std::shared_ptr& input_schema) { - const auto& emit = rel.common().emit(); + const auto& emit = rel_common.emit(); int emit_size = emit.output_mapping_size(); std::vector proj_field_refs(emit_size); EmitInfo emit_info; @@ -91,6 +90,11 @@ Result GetEmitInfo(const RelMessage& rel, emit_info.schema = schema(std::move(emit_fields)); return std::move(emit_info); } +template +Result GetEmitInfo(const RelMessage& rel, + const std::shared_ptr& input_schema) { + return GetEmitInfo(rel.common(), input_schema); +} Result ProcessEmitProject( std::optional rel_common_opt, @@ -128,16 +132,15 @@ Result ProcessEmitProject( } } -template -Result ProcessEmit(const RelMessage& rel, +Result ProcessEmit(std::optional rel_common_opt, const DeclarationInfo& no_emit_declr, const std::shared_ptr& schema) { - if (rel.has_common()) { - switch (rel.common().emit_kind_case()) { + if (rel_common_opt) { + switch (rel_common_opt->emit_kind_case()) { case substrait::RelCommon::EmitKindCase::kDirect: return no_emit_declr; case substrait::RelCommon::EmitKindCase::kEmit: { - ARROW_ASSIGN_OR_RAISE(auto emit_info, GetEmitInfo(rel, schema)); + ARROW_ASSIGN_OR_RAISE(auto emit_info, GetEmitInfo(*rel_common_opt, schema)); return DeclarationInfo{ compute::Declaration::Sequence( {no_emit_declr.declaration, @@ -152,6 +155,13 @@ Result ProcessEmit(const RelMessage& rel, return no_emit_declr; } } +template +Result ProcessEmit(const RelMessage& rel, + const DeclarationInfo& no_emit_declr, + const std::shared_ptr& schema) { + return ProcessEmit(rel.has_common() ? std::make_optional(rel.common()) : std::nullopt, + no_emit_declr, schema); +} /// In the specialization, a single ProjectNode is being used to /// get the Acero relation with or without emit. template <> @@ -293,6 +303,90 @@ Status DiscoverFilesFromDir(const std::shared_ptr& local_fs return Status::OK(); } +namespace internal { + +ARROW_ENGINE_EXPORT Status ParseAggregateMeasure( + const substrait::AggregateRel::Measure& agg_measure, const ExtensionSet& ext_set, + const ConversionOptions& conversion_options, bool is_hash, + const std::shared_ptr input_schema, + std::vector* aggregates_ptr, + std::vector>* agg_src_fieldsets_ptr) { + std::vector& aggregates = *aggregates_ptr; + std::vector>& agg_src_fieldsets = *agg_src_fieldsets_ptr; + if (agg_measure.has_measure()) { + if (agg_measure.has_filter()) { + return Status::NotImplemented("Aggregate filters are not supported."); + } + const auto& agg_func = agg_measure.measure(); + ARROW_ASSIGN_OR_RAISE(SubstraitCall aggregate_call, + FromProto(agg_func, is_hash, ext_set, conversion_options)); + ExtensionIdRegistry::SubstraitAggregateToArrow converter; + if (aggregate_call.id().uri.empty() || aggregate_call.id().uri[0] == '/') { + ARROW_ASSIGN_OR_RAISE(converter, + ext_set.registry()->GetSubstraitAggregateToArrowFallback( + aggregate_call.id().name)); + } else { + ARROW_ASSIGN_OR_RAISE(converter, ext_set.registry()->GetSubstraitAggregateToArrow( + aggregate_call.id())); + } + ARROW_ASSIGN_OR_RAISE(compute::Aggregate arrow_agg, converter(aggregate_call)); + + // find aggregate field ids from schema + const auto& target = arrow_agg.target; + size_t measure_id = agg_src_fieldsets.size(); + agg_src_fieldsets.push_back({}); + for (const auto& field_ref : target) { + ARROW_ASSIGN_OR_RAISE(auto match, field_ref.FindOne(*input_schema)); + agg_src_fieldsets[measure_id].push_back(match[0]); + } + + aggregates.push_back(std::move(arrow_agg)); + return Status::OK(); + } else { + return Status::Invalid("substrait::AggregateFunction not provided"); + } +} + +ARROW_ENGINE_EXPORT Result MakeAggregateDeclaration( + std::optional agg_common_opt, compute::Declaration input_decl, + std::shared_ptr input_schema, const int measure_size, + std::vector aggregates, + std::vector> agg_src_fieldsets, std::vector keys, + std::vector key_field_ids, std::vector segment_keys, + std::vector segment_key_field_ids, const ExtensionSet& ext_set, + const ConversionOptions& conversion_options) { + FieldVector output_fields; + output_fields.reserve(key_field_ids.size() + segment_key_field_ids.size() + + measure_size); + // extract aggregate fields to output schema + for (const auto& agg_src_fieldset : agg_src_fieldsets) { + for (int field : agg_src_fieldset) { + output_fields.emplace_back(input_schema->field(field)); + } + } + // extract key fields to output schema + for (int key_field_id : key_field_ids) { + output_fields.emplace_back(input_schema->field(key_field_id)); + } + // extract segment key fields to output schema + for (int segment_key_field_id : segment_key_field_ids) { + output_fields.emplace_back(input_schema->field(segment_key_field_id)); + } + + std::shared_ptr aggregate_schema = schema(std::move(output_fields)); + + DeclarationInfo aggregate_declaration{ + compute::Declaration::Sequence( + {std::move(input_decl), + {"aggregate", compute::AggregateNodeOptions{aggregates, keys, segment_keys}}}), + aggregate_schema}; + + return ProcessEmit(std::move(agg_common_opt), std::move(aggregate_declaration), + std::move(aggregate_schema)); +} + +} // namespace internal + Result FromProto(const substrait::Rel& rel, const ExtensionSet& ext_set, const ConversionOptions& conversion_options) { static bool dataset_init = false; @@ -730,64 +824,20 @@ Result FromProto(const substrait::Rel& rel, const ExtensionSet& std::vector aggregates; aggregates.reserve(measure_size); // store aggregate fields to be used when output schema is created - std::vector> agg_src_fieldsets(measure_size); + std::vector> agg_src_fieldsets; + agg_src_fieldsets.reserve(measure_size); for (int measure_id = 0; measure_id < measure_size; measure_id++) { const auto& agg_measure = aggregate.measures(measure_id); - if (agg_measure.has_measure()) { - if (agg_measure.has_filter()) { - return Status::NotImplemented("Aggregate filters are not supported."); - } - const auto& agg_func = agg_measure.measure(); - ARROW_ASSIGN_OR_RAISE(SubstraitCall aggregate_call, - FromProto(agg_func, /*is_hash=*/!keys.empty(), ext_set, - conversion_options)); - ExtensionIdRegistry::SubstraitAggregateToArrow converter; - if (aggregate_call.id().uri.empty() || aggregate_call.id().uri[0] == '/') { - ARROW_ASSIGN_OR_RAISE( - converter, ext_set.registry()->GetSubstraitAggregateToArrowFallback( - aggregate_call.id().name)); - } else { - ARROW_ASSIGN_OR_RAISE( - converter, - ext_set.registry()->GetSubstraitAggregateToArrow(aggregate_call.id())); - } - ARROW_ASSIGN_OR_RAISE(compute::Aggregate arrow_agg, converter(aggregate_call)); - - // find aggregate field ids from schema - const auto& target = arrow_agg.target; - for (const auto& field_ref : target) { - ARROW_ASSIGN_OR_RAISE(auto match, field_ref.FindOne(*input_schema)); - agg_src_fieldsets[measure_id].push_back(match[0]); - } - - aggregates.push_back(std::move(arrow_agg)); - } else { - return Status::Invalid("substrait::AggregateFunction not provided"); - } - } - FieldVector output_fields; - output_fields.reserve(key_field_ids.size() + measure_size); - // extract aggregate fields to output schema - for (const auto& agg_src_fieldset : agg_src_fieldsets) { - for (int field : agg_src_fieldset) { - output_fields.emplace_back(input_schema->field(field)); - } + ARROW_RETURN_NOT_OK(internal::ParseAggregateMeasure( + agg_measure, ext_set, conversion_options, /*is_hash=*/!keys.empty(), + input_schema, &aggregates, &agg_src_fieldsets)); } - // extract key fields to output schema - for (int key_field_id : key_field_ids) { - output_fields.emplace_back(input_schema->field(key_field_id)); - } - - std::shared_ptr aggregate_schema = schema(std::move(output_fields)); - - DeclarationInfo aggregate_declaration{ - compute::Declaration::Sequence( - {std::move(input.declaration), - {"aggregate", compute::AggregateNodeOptions{aggregates, keys}}}), - aggregate_schema}; - return ProcessEmit(std::move(aggregate), std::move(aggregate_declaration), - std::move(aggregate_schema)); + return internal::MakeAggregateDeclaration( + aggregate.has_common() ? std::make_optional(aggregate.common()) : std::nullopt, + std::move(input.declaration), std::move(input_schema), measure_size, + std::move(aggregates), std::move(agg_src_fieldsets), std::move(keys), + std::move(key_field_ids), {}, {}, ext_set, conversion_options); } case substrait::Rel::RelTypeCase::kExtensionLeaf: diff --git a/cpp/src/arrow/engine/substrait/relation_internal.h b/cpp/src/arrow/engine/substrait/relation_internal.h index 17153f5365f..9718df377fc 100644 --- a/cpp/src/arrow/engine/substrait/relation_internal.h +++ b/cpp/src/arrow/engine/substrait/relation_internal.h @@ -30,6 +30,12 @@ #include "substrait/algebra.pb.h" // IWYU pragma: export namespace arrow { +namespace compute { + +struct Aggregate; +class AggregateNodeOptions; + +} // namespace compute namespace engine { /// \brief Convert a Substrait Rel object to an Acero declaration @@ -46,5 +52,48 @@ Result FromProto(const substrait::Rel&, const ExtensionSet&, ARROW_ENGINE_EXPORT Result> ToProto( const compute::Declaration&, ExtensionSet*, const ConversionOptions&); +namespace internal { + +/// \brief Parse an aggregate relation's measure +/// +/// \param[in] agg_measure the measure +/// \param[in] ext_set an extension mapping to use in parsing +/// \param[in] conversion_options options to control how the conversion is done +/// \param[in] input_schema the schema to which field refs apply +/// \param[in] is_hash whether the measure is a hash one (i.e., aggregation keys exist) +/// \param[out] aggregates points to vector to push the parsed measure into +/// \param[out] agg_src_fieldsets points to vector to push the parsed field set into +ARROW_ENGINE_EXPORT Status ParseAggregateMeasure( + const substrait::AggregateRel::Measure& agg_measure, const ExtensionSet& ext_set, + const ConversionOptions& conversion_options, bool is_hash, + const std::shared_ptr input_schema, + std::vector* aggregates, + std::vector>* agg_src_fieldsets); + +/// \brief Make an aggregate declaration info +/// +/// \param[in] agg_common_opt the aggregate relation's common info, if exists +/// \param[in] input_decl the input declaration to use +/// \param[in] input_schema the schema to which field refs apply +/// \param[in] measure_size the number of measures to use +/// \param[in] aggregates the aggregates to use +/// \param[in] agg_src_fieldsets the field-sets per aggregate to use +/// \param[in] keys the field-refs for grouping keys to use +/// \param[in] key_field_ids the field-ids for grouping keys to use +/// \param[in] segment_keys the field-refs for segment keys to use +/// \param[in] segment_key_field_ids the field-ids for segment keys to use +/// \param[in] ext_set an extension mapping to use +/// \param[in] conversion_options options to control how the conversion is done +ARROW_ENGINE_EXPORT Result MakeAggregateDeclaration( + std::optional agg_common_opt, compute::Declaration input_decl, + std::shared_ptr input_schema, const int measure_size, + std::vector aggregates, + std::vector> agg_src_fieldsets, std::vector keys, + std::vector key_field_ids, std::vector segment_keys, + std::vector segment_key_field_ids, const ExtensionSet& ext_set, + const ConversionOptions& conversion_options); + +} // namespace internal + } // namespace engine } // namespace arrow diff --git a/cpp/src/arrow/engine/substrait/serde_test.cc b/cpp/src/arrow/engine/substrait/serde_test.cc index 32b3b5c5c17..db7cd230917 100644 --- a/cpp/src/arrow/engine/substrait/serde_test.cc +++ b/cpp/src/arrow/engine/substrait/serde_test.cc @@ -5607,5 +5607,130 @@ TEST(Substrait, PlanWithNamedTapExtension) { CheckRoundTripResult(std::move(expected_table), buf, {}, conversion_options); } +TEST(Substrait, PlanWithSegmentedAggregateExtension) { + // This demos an extension relation + std::string substrait_json = R"({ + "extensionUris": [{ + "extension_uri_anchor": 0, + "uri": "https://github.com/substrait-io/substrait/blob/main/extensions/functions_arithmetic.yaml" + }], + "extensions": [{ + "extension_function": { + "extension_uri_reference": 0, + "function_anchor": 0, + "name": "sum" + } + }], + "relations": [{ + "root": { + "input": { + "extension_single": { + "input": { + "read": { + "common": { + "direct": { + } + }, + "baseSchema": { + "names": ["time", "key", "value"], + "struct": { + "types": [ + { + "i32": { + "typeVariationReference": 0, + "nullability": "NULLABILITY_NULLABLE" + } + }, + { + "i32": { + "typeVariationReference": 0, + "nullability": "NULLABILITY_NULLABLE" + } + }, + { + "fp64": { + "typeVariationReference": 0, + "nullability": "NULLABILITY_NULLABLE" + } + } + ], + "typeVariationReference": 0, + "nullability": "NULLABILITY_REQUIRED" + } + }, + "namedTable": { + "names": ["T"] + } + } + }, + "detail": { + "@type": "/arrow.substrait_ext.SegmentedAggregateRel", + "common": { + "direct": { + } + }, + "grouping_keys": [{ + "structField": { + "field": 1 + } + }], + "segment_keys": [{ + "structField": { + "field": 0 + } + }], + "measures": [{ + "measure": { + "functionReference": 0, + "arguments": [{ + "value": { + "selection": { + "directReference": { + "structField": { + "field": 2 + } + } + } + } + }], + "sorts": [], + "phase": "AGGREGATION_PHASE_INITIAL_TO_RESULT", + "outputType": { + "i64": {} + } + } + }] + } + } + }, + "names": ["v", "k", "t"] + } + }], + "expectedTypeUrls": [] + })"; + + std::shared_ptr input_schema = + schema({field("time", int32()), field("key", int32()), field("value", float64())}); + NamedTableProvider table_provider = AlwaysProvideSameTable( + TableFromJSON(input_schema, {"[[1, 1, 1], [1, 2, 2], [1, 1, 3]," + " [2, 2, 4], [2, 1, 5], [2, 2, 6]]"})); + ConversionOptions conversion_options; + conversion_options.named_table_provider = std::move(table_provider); + conversion_options.named_tap_provider = + [](const std::string& tap_kind, std::vector inputs, + const std::string& tap_name, + std::shared_ptr tap_schema) -> Result { + return compute::Declaration{tap_kind, std::move(inputs), compute::ExecNodeOptions{}}; + }; + + ASSERT_OK_AND_ASSIGN(auto buf, internal::SubstraitFromJSON("Plan", substrait_json)); + + std::shared_ptr output_schema = + schema({field("v", float64()), field("k", int32()), field("t", int32())}); + auto expected_table = + TableFromJSON(output_schema, {"[[4, 1, 1], [2, 2, 1], [10, 2, 2], [5, 1, 2]]"}); + CheckRoundTripResult(std::move(expected_table), buf, {}, conversion_options); +} + } // namespace engine } // namespace arrow From fe8c4ddf0fb32ee0290161ba0f29e4cf7d659100 Mon Sep 17 00:00:00 2001 From: Yaron Gvili Date: Tue, 28 Mar 2023 09:00:36 -0400 Subject: [PATCH 2/5] requested fixes --- cpp/proto/substrait/extension_rels.proto | 2 - cpp/src/arrow/compute/exec/source_node.cc | 6 --- .../engine/substrait/expression_internal.cc | 47 +++++++++++------ .../engine/substrait/expression_internal.h | 4 ++ cpp/src/arrow/engine/substrait/options.cc | 37 +++++++------- .../engine/substrait/relation_internal.cc | 50 ++++++++----------- .../engine/substrait/relation_internal.h | 6 +-- cpp/src/arrow/engine/substrait/serde_test.cc | 4 -- 8 files changed, 74 insertions(+), 82 deletions(-) diff --git a/cpp/proto/substrait/extension_rels.proto b/cpp/proto/substrait/extension_rels.proto index 25073903c60..06dd86330b3 100644 --- a/cpp/proto/substrait/extension_rels.proto +++ b/cpp/proto/substrait/extension_rels.proto @@ -60,8 +60,6 @@ message NamedTapRel { } message SegmentedAggregateRel { - substrait.RelCommon common = 1; - // Grouping keys of the aggregation repeated substrait.Expression.ReferenceSegment grouping_keys = 2; diff --git a/cpp/src/arrow/compute/exec/source_node.cc b/cpp/src/arrow/compute/exec/source_node.cc index b6bdc7c5db5..73eea80196f 100644 --- a/cpp/src/arrow/compute/exec/source_node.cc +++ b/cpp/src/arrow/compute/exec/source_node.cc @@ -110,12 +110,6 @@ struct SourceNode : ExecNode, public TracedNode { value, util::EnsureAlignment(value.make_array(), ipc::kArrowAlignment, default_memory_pool())); } - if (value.is_chunked_array()) { - ARROW_ASSIGN_OR_RAISE( - value, - util::EnsureAlignment(value.chunked_array(), ipc::kArrowAlignment, - default_memory_pool())); - } } if (has_ordering) { batch.index = batch_index; diff --git a/cpp/src/arrow/engine/substrait/expression_internal.cc b/cpp/src/arrow/engine/substrait/expression_internal.cc index 43a44440bef..9a5bca41128 100644 --- a/cpp/src/arrow/engine/substrait/expression_internal.cc +++ b/cpp/src/arrow/engine/substrait/expression_internal.cc @@ -143,27 +143,27 @@ Result FromProto(const substrait::Expression::ReferenceSegm const ConversionOptions& conversion_options, std::optional in_expr) { auto in_ref = ref; - auto& out = in_expr; + auto& current = in_expr; while (ref != nullptr) { switch (ref->reference_type_case()) { case substrait::Expression::ReferenceSegment::kStructField: { auto index = ref->struct_field().field(); - if (!out) { + if (!current) { // Root StructField (column selection) - out = compute::field_ref(FieldRef(index)); - } else if (auto out_ref = out->field_ref()) { + current = compute::field_ref(FieldRef(index)); + } else if (auto current_ref = current->field_ref()) { // Nested StructFields on the root (selection of struct-typed column // combined with selecting struct fields) - out = compute::field_ref(FieldRef(*out_ref, index)); - } else if (out->call() && out->call()->function_name == "struct_field") { + current = compute::field_ref(FieldRef(*current_ref, index)); + } else if (current->call() && current->call()->function_name == "struct_field") { // Nested StructFields on top of an arbitrary expression auto* field_options = - checked_cast(out->call()->options.get()); + checked_cast(current->call()->options.get()); field_options->field_ref = FieldRef(std::move(field_options->field_ref), index); } else { // First StructField on top of an arbitrary expression - out = compute::call("struct_field", {std::move(*out)}, - arrow::compute::StructFieldOptions({index})); + current = compute::call("struct_field", {std::move(*current)}, + arrow::compute::StructFieldOptions({index})); } // Segment handled, continue with child segment (if any) @@ -175,16 +175,16 @@ Result FromProto(const substrait::Expression::ReferenceSegm break; } case substrait::Expression::ReferenceSegment::kListElement: { - if (!out) { + if (!current) { // Root ListField (illegal) return Status::Invalid( "substrait::ListElement cannot take a Relation as an argument"); } // ListField on top of an arbitrary expression - out = compute::call( + current = compute::call( "list_element", - {std::move(*out), compute::literal(ref->list_element().offset())}); + {std::move(*current), compute::literal(ref->list_element().offset())}); // Segment handled, continue with child segment (if any) if (ref->list_element().has_child()) { @@ -195,13 +195,13 @@ Result FromProto(const substrait::Expression::ReferenceSegm break; } default: - // Unimplemented construct, break out of loop - out.reset(); + // Unimplemented construct, break current of loop + current.reset(); ref = nullptr; } } - if (out) { - return *std::move(out); + if (current) { + return *std::move(current); } return Status::NotImplemented( @@ -209,6 +209,21 @@ Result FromProto(const substrait::Expression::ReferenceSegm in_ref ? in_ref->DebugString() : "null"); } +Result DirectReferenceFromProto( + const substrait::Expression::ReferenceSegment* refseg, const ExtensionSet& ext_set, + const ConversionOptions& conversion_options) { + ARROW_ASSIGN_OR_RAISE(compute::Expression expr, + FromProto(refseg, ext_set, conversion_options, {})); + const FieldRef* field_ref = expr.field_ref(); + if (field_ref) { + return *field_ref; + } else { + return Status::Invalid( + "A direct reference was expected but a more complex expression was given " + "instead"); + } +} + Result FromProto(const substrait::AggregateFunction& func, bool is_hash, const ExtensionSet& ext_set, const ConversionOptions& conversion_options) { diff --git a/cpp/src/arrow/engine/substrait/expression_internal.h b/cpp/src/arrow/engine/substrait/expression_internal.h index d522fdf7703..9239ee5d235 100644 --- a/cpp/src/arrow/engine/substrait/expression_internal.h +++ b/cpp/src/arrow/engine/substrait/expression_internal.h @@ -35,6 +35,10 @@ namespace engine { class SubstraitCall; +ARROW_ENGINE_EXPORT +Result DirectReferenceFromProto(const substrait::Expression::ReferenceSegment*, + const ExtensionSet&, const ConversionOptions&); + ARROW_ENGINE_EXPORT Result FromProto(const substrait::Expression::ReferenceSegment*, const ExtensionSet&, const ConversionOptions&, diff --git a/cpp/src/arrow/engine/substrait/options.cc b/cpp/src/arrow/engine/substrait/options.cc index 90b745ace8e..95570e4051f 100644 --- a/cpp/src/arrow/engine/substrait/options.cc +++ b/cpp/src/arrow/engine/substrait/options.cc @@ -187,24 +187,22 @@ class DefaultExtensionProvider : public BaseExtensionProvider { std::vector key_field_ids; std::vector keys; for (auto& key_refseg : seg_agg_rel.grouping_keys()) { - ARROW_ASSIGN_OR_RAISE(auto expr, FromProto(&key_refseg, ext_set, conv_opts, {})); - if (auto field_ref = expr.field_ref()) { - ARROW_ASSIGN_OR_RAISE(auto match, field_ref->FindOne(*input_schema)); - key_field_ids.emplace_back(std::move(match[0])); - keys.emplace_back(std::move(*field_ref)); - } + ARROW_ASSIGN_OR_RAISE(auto field_ref, + DirectReferenceFromProto(&key_refseg, ext_set, conv_opts)); + ARROW_ASSIGN_OR_RAISE(auto match, field_ref.FindOne(*input_schema)); + key_field_ids.emplace_back(std::move(match[0])); + keys.emplace_back(std::move(field_ref)); } // store segment key fields to be used when output schema is created std::vector segment_key_field_ids; std::vector segment_keys; for (auto& key_refseg : seg_agg_rel.segment_keys()) { - ARROW_ASSIGN_OR_RAISE(auto expr, FromProto(&key_refseg, ext_set, conv_opts, {})); - if (auto field_ref = expr.field_ref()) { - ARROW_ASSIGN_OR_RAISE(auto match, field_ref->FindOne(*input_schema)); - segment_key_field_ids.emplace_back(std::move(match[0])); - segment_keys.emplace_back(std::move(*field_ref)); - } + ARROW_ASSIGN_OR_RAISE(auto field_ref, + DirectReferenceFromProto(&key_refseg, ext_set, conv_opts)); + ARROW_ASSIGN_OR_RAISE(auto match, field_ref.FindOne(*input_schema)); + segment_key_field_ids.emplace_back(std::move(match[0])); + segment_keys.emplace_back(std::move(field_ref)); } std::vector aggregates; @@ -215,14 +213,13 @@ class DefaultExtensionProvider : public BaseExtensionProvider { &aggregates, &agg_src_fieldsets)); } - ARROW_ASSIGN_OR_RAISE( - auto decl_info, - internal::MakeAggregateDeclaration( - seg_agg_rel.common(), std::move(inputs[0].declaration), - std::move(input_schema), seg_agg_rel.measures_size(), std::move(aggregates), - std::move(agg_src_fieldsets), std::move(keys), std::move(key_field_ids), - std::move(segment_keys), std::move(segment_key_field_ids), ext_set, - conv_opts)); + ARROW_ASSIGN_OR_RAISE(auto decl_info, + internal::MakeAggregateDeclaration( + std::move(inputs[0].declaration), std::move(input_schema), + seg_agg_rel.measures_size(), std::move(aggregates), + std::move(agg_src_fieldsets), std::move(keys), + std::move(key_field_ids), std::move(segment_keys), + std::move(segment_key_field_ids), ext_set, conv_opts)); const auto& output_schema = decl_info.output_schema; size_t out_size = output_schema->num_fields(); diff --git a/cpp/src/arrow/engine/substrait/relation_internal.cc b/cpp/src/arrow/engine/substrait/relation_internal.cc index 24e09ac1339..56915da1580 100644 --- a/cpp/src/arrow/engine/substrait/relation_internal.cc +++ b/cpp/src/arrow/engine/substrait/relation_internal.cc @@ -74,9 +74,10 @@ struct EmitInfo { std::shared_ptr schema; }; -Result GetEmitInfo(const substrait::RelCommon& rel_common, +template +Result GetEmitInfo(const RelMessage& rel, const std::shared_ptr& input_schema) { - const auto& emit = rel_common.emit(); + const auto& emit = rel.common().emit(); int emit_size = emit.output_mapping_size(); std::vector proj_field_refs(emit_size); EmitInfo emit_info; @@ -90,11 +91,6 @@ Result GetEmitInfo(const substrait::RelCommon& rel_common, emit_info.schema = schema(std::move(emit_fields)); return std::move(emit_info); } -template -Result GetEmitInfo(const RelMessage& rel, - const std::shared_ptr& input_schema) { - return GetEmitInfo(rel.common(), input_schema); -} Result ProcessEmitProject( std::optional rel_common_opt, @@ -132,15 +128,16 @@ Result ProcessEmitProject( } } -Result ProcessEmit(std::optional rel_common_opt, +template +Result ProcessEmit(const RelMessage& rel, const DeclarationInfo& no_emit_declr, const std::shared_ptr& schema) { - if (rel_common_opt) { - switch (rel_common_opt->emit_kind_case()) { + if (rel.has_common()) { + switch (rel.common().emit_kind_case()) { case substrait::RelCommon::EmitKindCase::kDirect: return no_emit_declr; case substrait::RelCommon::EmitKindCase::kEmit: { - ARROW_ASSIGN_OR_RAISE(auto emit_info, GetEmitInfo(*rel_common_opt, schema)); + ARROW_ASSIGN_OR_RAISE(auto emit_info, GetEmitInfo(rel, schema)); return DeclarationInfo{ compute::Declaration::Sequence( {no_emit_declr.declaration, @@ -155,13 +152,6 @@ Result ProcessEmit(std::optional rel_comm return no_emit_declr; } } -template -Result ProcessEmit(const RelMessage& rel, - const DeclarationInfo& no_emit_declr, - const std::shared_ptr& schema) { - return ProcessEmit(rel.has_common() ? std::make_optional(rel.common()) : std::nullopt, - no_emit_declr, schema); -} /// In the specialization, a single ProjectNode is being used to /// get the Acero relation with or without emit. template <> @@ -348,9 +338,8 @@ ARROW_ENGINE_EXPORT Status ParseAggregateMeasure( } ARROW_ENGINE_EXPORT Result MakeAggregateDeclaration( - std::optional agg_common_opt, compute::Declaration input_decl, - std::shared_ptr input_schema, const int measure_size, - std::vector aggregates, + compute::Declaration input_decl, std::shared_ptr input_schema, + const int measure_size, std::vector aggregates, std::vector> agg_src_fieldsets, std::vector keys, std::vector key_field_ids, std::vector segment_keys, std::vector segment_key_field_ids, const ExtensionSet& ext_set, @@ -375,14 +364,11 @@ ARROW_ENGINE_EXPORT Result MakeAggregateDeclaration( std::shared_ptr aggregate_schema = schema(std::move(output_fields)); - DeclarationInfo aggregate_declaration{ + return DeclarationInfo{ compute::Declaration::Sequence( {std::move(input_decl), {"aggregate", compute::AggregateNodeOptions{aggregates, keys, segment_keys}}}), aggregate_schema}; - - return ProcessEmit(std::move(agg_common_opt), std::move(aggregate_declaration), - std::move(aggregate_schema)); } } // namespace internal @@ -833,11 +819,15 @@ Result FromProto(const substrait::Rel& rel, const ExtensionSet& input_schema, &aggregates, &agg_src_fieldsets)); } - return internal::MakeAggregateDeclaration( - aggregate.has_common() ? std::make_optional(aggregate.common()) : std::nullopt, - std::move(input.declaration), std::move(input_schema), measure_size, - std::move(aggregates), std::move(agg_src_fieldsets), std::move(keys), - std::move(key_field_ids), {}, {}, ext_set, conversion_options); + ARROW_ASSIGN_OR_RAISE( + auto aggregate_declaration, + internal::MakeAggregateDeclaration( + std::move(input.declaration), std::move(input_schema), measure_size, + std::move(aggregates), std::move(agg_src_fieldsets), std::move(keys), + std::move(key_field_ids), {}, {}, ext_set, conversion_options)); + + return ProcessEmit(std::move(aggregate), std::move(aggregate_declaration), + aggregate_declaration.output_schema); } case substrait::Rel::RelTypeCase::kExtensionLeaf: diff --git a/cpp/src/arrow/engine/substrait/relation_internal.h b/cpp/src/arrow/engine/substrait/relation_internal.h index 9718df377fc..1ef909ded89 100644 --- a/cpp/src/arrow/engine/substrait/relation_internal.h +++ b/cpp/src/arrow/engine/substrait/relation_internal.h @@ -72,7 +72,6 @@ ARROW_ENGINE_EXPORT Status ParseAggregateMeasure( /// \brief Make an aggregate declaration info /// -/// \param[in] agg_common_opt the aggregate relation's common info, if exists /// \param[in] input_decl the input declaration to use /// \param[in] input_schema the schema to which field refs apply /// \param[in] measure_size the number of measures to use @@ -85,9 +84,8 @@ ARROW_ENGINE_EXPORT Status ParseAggregateMeasure( /// \param[in] ext_set an extension mapping to use /// \param[in] conversion_options options to control how the conversion is done ARROW_ENGINE_EXPORT Result MakeAggregateDeclaration( - std::optional agg_common_opt, compute::Declaration input_decl, - std::shared_ptr input_schema, const int measure_size, - std::vector aggregates, + compute::Declaration input_decl, std::shared_ptr input_schema, + const int measure_size, std::vector aggregates, std::vector> agg_src_fieldsets, std::vector keys, std::vector key_field_ids, std::vector segment_keys, std::vector segment_key_field_ids, const ExtensionSet& ext_set, diff --git a/cpp/src/arrow/engine/substrait/serde_test.cc b/cpp/src/arrow/engine/substrait/serde_test.cc index db7cd230917..fe3c5d29018 100644 --- a/cpp/src/arrow/engine/substrait/serde_test.cc +++ b/cpp/src/arrow/engine/substrait/serde_test.cc @@ -5665,10 +5665,6 @@ TEST(Substrait, PlanWithSegmentedAggregateExtension) { }, "detail": { "@type": "/arrow.substrait_ext.SegmentedAggregateRel", - "common": { - "direct": { - } - }, "grouping_keys": [{ "structField": { "field": 1 From 6415c2a77e568e13703167c729807dc382a389b2 Mon Sep 17 00:00:00 2001 From: Yaron Gvili Date: Tue, 28 Mar 2023 15:10:14 -0400 Subject: [PATCH 3/5] renumber message fields --- cpp/proto/substrait/extension_rels.proto | 6 +++--- cpp/src/arrow/engine/substrait/options.cc | 4 ++++ 2 files changed, 7 insertions(+), 3 deletions(-) diff --git a/cpp/proto/substrait/extension_rels.proto b/cpp/proto/substrait/extension_rels.proto index 06dd86330b3..f6d6ed4717c 100644 --- a/cpp/proto/substrait/extension_rels.proto +++ b/cpp/proto/substrait/extension_rels.proto @@ -61,11 +61,11 @@ message NamedTapRel { message SegmentedAggregateRel { // Grouping keys of the aggregation - repeated substrait.Expression.ReferenceSegment grouping_keys = 2; + repeated substrait.Expression.ReferenceSegment grouping_keys = 1; // Segment keys of the aggregation - repeated substrait.Expression.ReferenceSegment segment_keys = 3; + repeated substrait.Expression.ReferenceSegment segment_keys = 2; // A list of one or more aggregate expressions along with an optional filter. - repeated substrait.AggregateRel.Measure measures = 4; + repeated substrait.AggregateRel.Measure measures = 3; } diff --git a/cpp/src/arrow/engine/substrait/options.cc b/cpp/src/arrow/engine/substrait/options.cc index 95570e4051f..f3d013d9a35 100644 --- a/cpp/src/arrow/engine/substrait/options.cc +++ b/cpp/src/arrow/engine/substrait/options.cc @@ -180,6 +180,10 @@ class DefaultExtensionProvider : public BaseExtensionProvider { "substrait_ext::SegmentedAggregateRel requires a single input but got: ", inputs.size()); } + if (seg_agg_rel.segment_keys_size() == 0) { + return Status::Invalid( + "substrait_ext::SegmentedAggregateRel requires at least one segment key"); + } auto input_schema = inputs[0].output_schema; From 72e0b2e2a1bb3f7a8f1f76ed076ac93d9f228538 Mon Sep 17 00:00:00 2001 From: Yaron Gvili Date: Wed, 29 Mar 2023 06:47:33 -0400 Subject: [PATCH 4/5] ParsedMeasure --- cpp/src/arrow/engine/substrait/options.cc | 11 +++++-- .../engine/substrait/relation_internal.cc | 29 +++++++++---------- .../engine/substrait/relation_internal.h | 21 ++++++-------- 3 files changed, 31 insertions(+), 30 deletions(-) diff --git a/cpp/src/arrow/engine/substrait/options.cc b/cpp/src/arrow/engine/substrait/options.cc index f3d013d9a35..f0440bf4561 100644 --- a/cpp/src/arrow/engine/substrait/options.cc +++ b/cpp/src/arrow/engine/substrait/options.cc @@ -210,11 +210,16 @@ class DefaultExtensionProvider : public BaseExtensionProvider { } std::vector aggregates; + aggregates.reserve(seg_agg_rel.measures_size()); std::vector> agg_src_fieldsets; + agg_src_fieldsets.reserve(seg_agg_rel.measures_size()); for (auto agg_measure : seg_agg_rel.measures()) { - ARROW_RETURN_NOT_OK(internal::ParseAggregateMeasure( - agg_measure, ext_set, conv_opts, /*is_hash=*/!keys.empty(), input_schema, - &aggregates, &agg_src_fieldsets)); + ARROW_ASSIGN_OR_RAISE( + auto parsed_measure, + internal::ParseAggregateMeasure(agg_measure, ext_set, conv_opts, + /*is_hash=*/!keys.empty(), input_schema)); + aggregates.push_back(std::move(parsed_measure.aggregate)); + agg_src_fieldsets.push_back(std::move(parsed_measure.fieldset)); } ARROW_ASSIGN_OR_RAISE(auto decl_info, diff --git a/cpp/src/arrow/engine/substrait/relation_internal.cc b/cpp/src/arrow/engine/substrait/relation_internal.cc index 56915da1580..7f9c4c289fb 100644 --- a/cpp/src/arrow/engine/substrait/relation_internal.cc +++ b/cpp/src/arrow/engine/substrait/relation_internal.cc @@ -295,14 +295,10 @@ Status DiscoverFilesFromDir(const std::shared_ptr& local_fs namespace internal { -ARROW_ENGINE_EXPORT Status ParseAggregateMeasure( +Result ParseAggregateMeasure( const substrait::AggregateRel::Measure& agg_measure, const ExtensionSet& ext_set, const ConversionOptions& conversion_options, bool is_hash, - const std::shared_ptr input_schema, - std::vector* aggregates_ptr, - std::vector>* agg_src_fieldsets_ptr) { - std::vector& aggregates = *aggregates_ptr; - std::vector>& agg_src_fieldsets = *agg_src_fieldsets_ptr; + const std::shared_ptr input_schema) { if (agg_measure.has_measure()) { if (agg_measure.has_filter()) { return Status::NotImplemented("Aggregate filters are not supported."); @@ -323,15 +319,14 @@ ARROW_ENGINE_EXPORT Status ParseAggregateMeasure( // find aggregate field ids from schema const auto& target = arrow_agg.target; - size_t measure_id = agg_src_fieldsets.size(); - agg_src_fieldsets.push_back({}); + std::vector fieldset; + fieldset.reserve(target.size()); for (const auto& field_ref : target) { ARROW_ASSIGN_OR_RAISE(auto match, field_ref.FindOne(*input_schema)); - agg_src_fieldsets[measure_id].push_back(match[0]); + fieldset.push_back(match[0]); } - aggregates.push_back(std::move(arrow_agg)); - return Status::OK(); + return ParsedMeasure{std::move(arrow_agg), std::move(fieldset)}; } else { return Status::Invalid("substrait::AggregateFunction not provided"); } @@ -814,9 +809,12 @@ Result FromProto(const substrait::Rel& rel, const ExtensionSet& agg_src_fieldsets.reserve(measure_size); for (int measure_id = 0; measure_id < measure_size; measure_id++) { const auto& agg_measure = aggregate.measures(measure_id); - ARROW_RETURN_NOT_OK(internal::ParseAggregateMeasure( - agg_measure, ext_set, conversion_options, /*is_hash=*/!keys.empty(), - input_schema, &aggregates, &agg_src_fieldsets)); + ARROW_ASSIGN_OR_RAISE( + auto parsed_measure, + internal::ParseAggregateMeasure(agg_measure, ext_set, conversion_options, + /*is_hash=*/!keys.empty(), input_schema)); + aggregates.push_back(std::move(parsed_measure.aggregate)); + agg_src_fieldsets.push_back(std::move(parsed_measure.fieldset)); } ARROW_ASSIGN_OR_RAISE( @@ -826,8 +824,9 @@ Result FromProto(const substrait::Rel& rel, const ExtensionSet& std::move(aggregates), std::move(agg_src_fieldsets), std::move(keys), std::move(key_field_ids), {}, {}, ext_set, conversion_options)); + auto aggregate_schema = aggregate_declaration.output_schema; return ProcessEmit(std::move(aggregate), std::move(aggregate_declaration), - aggregate_declaration.output_schema); + std::move(aggregate_schema)); } case substrait::Rel::RelTypeCase::kExtensionLeaf: diff --git a/cpp/src/arrow/engine/substrait/relation_internal.h b/cpp/src/arrow/engine/substrait/relation_internal.h index 1ef909ded89..8d560d9f41f 100644 --- a/cpp/src/arrow/engine/substrait/relation_internal.h +++ b/cpp/src/arrow/engine/substrait/relation_internal.h @@ -21,6 +21,7 @@ #include +#include "arrow/compute/api_aggregate.h" #include "arrow/compute/type_fwd.h" #include "arrow/engine/substrait/relation.h" #include "arrow/engine/substrait/type_fwd.h" @@ -30,12 +31,6 @@ #include "substrait/algebra.pb.h" // IWYU pragma: export namespace arrow { -namespace compute { - -struct Aggregate; -class AggregateNodeOptions; - -} // namespace compute namespace engine { /// \brief Convert a Substrait Rel object to an Acero declaration @@ -54,6 +49,11 @@ ARROW_ENGINE_EXPORT Result> ToProto( namespace internal { +struct ParsedMeasure { + compute::Aggregate aggregate; + std::vector fieldset; +}; + /// \brief Parse an aggregate relation's measure /// /// \param[in] agg_measure the measure @@ -61,14 +61,11 @@ namespace internal { /// \param[in] conversion_options options to control how the conversion is done /// \param[in] input_schema the schema to which field refs apply /// \param[in] is_hash whether the measure is a hash one (i.e., aggregation keys exist) -/// \param[out] aggregates points to vector to push the parsed measure into -/// \param[out] agg_src_fieldsets points to vector to push the parsed field set into -ARROW_ENGINE_EXPORT Status ParseAggregateMeasure( +ARROW_ENGINE_EXPORT +Result ParseAggregateMeasure( const substrait::AggregateRel::Measure& agg_measure, const ExtensionSet& ext_set, const ConversionOptions& conversion_options, bool is_hash, - const std::shared_ptr input_schema, - std::vector* aggregates, - std::vector>* agg_src_fieldsets); + const std::shared_ptr input_schema); /// \brief Make an aggregate declaration info /// From 450cd962549f76a961713d990dc5f4e8a327f143 Mon Sep 17 00:00:00 2001 From: Yaron Gvili Date: Wed, 29 Mar 2023 16:38:55 -0400 Subject: [PATCH 5/5] ReferenceSegment -> FieldReference --- cpp/proto/substrait/extension_rels.proto | 4 ++-- .../engine/substrait/expression_internal.cc | 17 +++++++++++++++-- .../engine/substrait/expression_internal.h | 4 ++-- cpp/src/arrow/engine/substrait/options.cc | 8 ++++---- cpp/src/arrow/engine/substrait/serde_test.cc | 18 ++++++++++++------ 5 files changed, 35 insertions(+), 16 deletions(-) diff --git a/cpp/proto/substrait/extension_rels.proto b/cpp/proto/substrait/extension_rels.proto index f6d6ed4717c..ec86824c65a 100644 --- a/cpp/proto/substrait/extension_rels.proto +++ b/cpp/proto/substrait/extension_rels.proto @@ -61,10 +61,10 @@ message NamedTapRel { message SegmentedAggregateRel { // Grouping keys of the aggregation - repeated substrait.Expression.ReferenceSegment grouping_keys = 1; + repeated substrait.Expression.FieldReference grouping_keys = 1; // Segment keys of the aggregation - repeated substrait.Expression.ReferenceSegment segment_keys = 2; + repeated substrait.Expression.FieldReference segment_keys = 2; // A list of one or more aggregate expressions along with an optional filter. repeated substrait.AggregateRel.Measure measures = 3; diff --git a/cpp/src/arrow/engine/substrait/expression_internal.cc b/cpp/src/arrow/engine/substrait/expression_internal.cc index 9a5bca41128..722dec2a300 100644 --- a/cpp/src/arrow/engine/substrait/expression_internal.cc +++ b/cpp/src/arrow/engine/substrait/expression_internal.cc @@ -209,11 +209,24 @@ Result FromProto(const substrait::Expression::ReferenceSegm in_ref ? in_ref->DebugString() : "null"); } +Result FromProto(const substrait::Expression::FieldReference* fref, + const ExtensionSet& ext_set, + const ConversionOptions& conversion_options, + std::optional in_expr) { + if (fref->reference_type_case() != + substrait::Expression::FieldReference::kDirectReference || + fref->root_type_case() != substrait::Expression::FieldReference::kRootReference) { + return Status::NotImplemented("substrait::FieldReference not direct root reference"); + } + auto& dref = fref->direct_reference(); + return FromProto(&dref, ext_set, conversion_options, std::move(in_expr)); +} + Result DirectReferenceFromProto( - const substrait::Expression::ReferenceSegment* refseg, const ExtensionSet& ext_set, + const substrait::Expression::FieldReference* fref, const ExtensionSet& ext_set, const ConversionOptions& conversion_options) { ARROW_ASSIGN_OR_RAISE(compute::Expression expr, - FromProto(refseg, ext_set, conversion_options, {})); + FromProto(fref, ext_set, conversion_options, {})); const FieldRef* field_ref = expr.field_ref(); if (field_ref) { return *field_ref; diff --git a/cpp/src/arrow/engine/substrait/expression_internal.h b/cpp/src/arrow/engine/substrait/expression_internal.h index 9239ee5d235..cddc46066f1 100644 --- a/cpp/src/arrow/engine/substrait/expression_internal.h +++ b/cpp/src/arrow/engine/substrait/expression_internal.h @@ -36,11 +36,11 @@ namespace engine { class SubstraitCall; ARROW_ENGINE_EXPORT -Result DirectReferenceFromProto(const substrait::Expression::ReferenceSegment*, +Result DirectReferenceFromProto(const substrait::Expression::FieldReference*, const ExtensionSet&, const ConversionOptions&); ARROW_ENGINE_EXPORT -Result FromProto(const substrait::Expression::ReferenceSegment*, +Result FromProto(const substrait::Expression::FieldReference*, const ExtensionSet&, const ConversionOptions&, std::optional); diff --git a/cpp/src/arrow/engine/substrait/options.cc b/cpp/src/arrow/engine/substrait/options.cc index f0440bf4561..da07f151633 100644 --- a/cpp/src/arrow/engine/substrait/options.cc +++ b/cpp/src/arrow/engine/substrait/options.cc @@ -190,9 +190,9 @@ class DefaultExtensionProvider : public BaseExtensionProvider { // store key fields to be used when output schema is created std::vector key_field_ids; std::vector keys; - for (auto& key_refseg : seg_agg_rel.grouping_keys()) { + for (auto& ref : seg_agg_rel.grouping_keys()) { ARROW_ASSIGN_OR_RAISE(auto field_ref, - DirectReferenceFromProto(&key_refseg, ext_set, conv_opts)); + DirectReferenceFromProto(&ref, ext_set, conv_opts)); ARROW_ASSIGN_OR_RAISE(auto match, field_ref.FindOne(*input_schema)); key_field_ids.emplace_back(std::move(match[0])); keys.emplace_back(std::move(field_ref)); @@ -201,9 +201,9 @@ class DefaultExtensionProvider : public BaseExtensionProvider { // store segment key fields to be used when output schema is created std::vector segment_key_field_ids; std::vector segment_keys; - for (auto& key_refseg : seg_agg_rel.segment_keys()) { + for (auto& ref : seg_agg_rel.segment_keys()) { ARROW_ASSIGN_OR_RAISE(auto field_ref, - DirectReferenceFromProto(&key_refseg, ext_set, conv_opts)); + DirectReferenceFromProto(&ref, ext_set, conv_opts)); ARROW_ASSIGN_OR_RAISE(auto match, field_ref.FindOne(*input_schema)); segment_key_field_ids.emplace_back(std::move(match[0])); segment_keys.emplace_back(std::move(field_ref)); diff --git a/cpp/src/arrow/engine/substrait/serde_test.cc b/cpp/src/arrow/engine/substrait/serde_test.cc index fe3c5d29018..0f605cd7ce4 100644 --- a/cpp/src/arrow/engine/substrait/serde_test.cc +++ b/cpp/src/arrow/engine/substrait/serde_test.cc @@ -5666,14 +5666,20 @@ TEST(Substrait, PlanWithSegmentedAggregateExtension) { "detail": { "@type": "/arrow.substrait_ext.SegmentedAggregateRel", "grouping_keys": [{ - "structField": { - "field": 1 - } + "directReference": { + "structField": { + "field": 1 + } + }, + "rootReference": {} }], "segment_keys": [{ - "structField": { - "field": 0 - } + "directReference": { + "structField": { + "field": 0 + } + }, + "rootReference": {} }], "measures": [{ "measure": {