diff --git a/cpp/proto/substrait/extension_rels.proto b/cpp/proto/substrait/extension_rels.proto index 78c11b7d7e2..ec86824c65a 100644 --- a/cpp/proto/substrait/extension_rels.proto +++ b/cpp/proto/substrait/extension_rels.proto @@ -58,3 +58,14 @@ message NamedTapRel { // If empty, field names will be automatically generated. repeated string columns = 3; } + +message SegmentedAggregateRel { + // Grouping keys of the aggregation + repeated substrait.Expression.FieldReference grouping_keys = 1; + + // Segment keys of the aggregation + repeated substrait.Expression.FieldReference segment_keys = 2; + + // A list of one or more aggregate expressions along with an optional filter. + repeated substrait.AggregateRel.Measure measures = 3; +} diff --git a/cpp/src/arrow/compute/exec/source_node.cc b/cpp/src/arrow/compute/exec/source_node.cc index 310c0c64740..73eea80196f 100644 --- a/cpp/src/arrow/compute/exec/source_node.cc +++ b/cpp/src/arrow/compute/exec/source_node.cc @@ -28,8 +28,10 @@ #include "arrow/compute/expression.h" #include "arrow/datum.h" #include "arrow/io/util_internal.h" +#include "arrow/ipc/util.h" #include "arrow/result.h" #include "arrow/table.h" +#include "arrow/util/align_util.h" #include "arrow/util/async_generator.h" #include "arrow/util/async_util.h" #include "arrow/util/checked_cast.h" @@ -102,6 +104,13 @@ struct SourceNode : ExecNode, public TracedNode { batch_size = morsel_length; } ExecBatch batch = morsel.Slice(offset, batch_size); + for (auto& value : batch.values) { + if (value.is_array()) { + ARROW_ASSIGN_OR_RAISE( + value, util::EnsureAlignment(value.make_array(), ipc::kArrowAlignment, + default_memory_pool())); + } + } if (has_ordering) { batch.index = batch_index; } diff --git a/cpp/src/arrow/engine/substrait/expression_internal.cc b/cpp/src/arrow/engine/substrait/expression_internal.cc index e1223a51329..722dec2a300 100644 --- a/cpp/src/arrow/engine/substrait/expression_internal.cc +++ b/cpp/src/arrow/engine/substrait/expression_internal.cc @@ -138,6 +138,105 @@ std::string EnumToString(int value, const google::protobuf::EnumDescriptor* desc return value_desc->name(); } +Result FromProto(const substrait::Expression::ReferenceSegment* ref, + const ExtensionSet& ext_set, + const ConversionOptions& conversion_options, + std::optional in_expr) { + auto in_ref = ref; + auto& current = in_expr; + while (ref != nullptr) { + switch (ref->reference_type_case()) { + case substrait::Expression::ReferenceSegment::kStructField: { + auto index = ref->struct_field().field(); + if (!current) { + // Root StructField (column selection) + current = compute::field_ref(FieldRef(index)); + } else if (auto current_ref = current->field_ref()) { + // Nested StructFields on the root (selection of struct-typed column + // combined with selecting struct fields) + current = compute::field_ref(FieldRef(*current_ref, index)); + } else if (current->call() && current->call()->function_name == "struct_field") { + // Nested StructFields on top of an arbitrary expression + auto* field_options = + checked_cast(current->call()->options.get()); + field_options->field_ref = FieldRef(std::move(field_options->field_ref), index); + } else { + // First StructField on top of an arbitrary expression + current = compute::call("struct_field", {std::move(*current)}, + arrow::compute::StructFieldOptions({index})); + } + + // Segment handled, continue with child segment (if any) + if (ref->struct_field().has_child()) { + ref = &ref->struct_field().child(); + } else { + ref = nullptr; + } + break; + } + case substrait::Expression::ReferenceSegment::kListElement: { + if (!current) { + // Root ListField (illegal) + return Status::Invalid( + "substrait::ListElement cannot take a Relation as an argument"); + } + + // ListField on top of an arbitrary expression + current = compute::call( + "list_element", + {std::move(*current), compute::literal(ref->list_element().offset())}); + + // Segment handled, continue with child segment (if any) + if (ref->list_element().has_child()) { + ref = &ref->list_element().child(); + } else { + ref = nullptr; + } + break; + } + default: + // Unimplemented construct, break current of loop + current.reset(); + ref = nullptr; + } + } + if (current) { + return *std::move(current); + } + + return Status::NotImplemented( + "conversion to arrow::compute::Expression from Substrait reference segment: ", + in_ref ? in_ref->DebugString() : "null"); +} + +Result FromProto(const substrait::Expression::FieldReference* fref, + const ExtensionSet& ext_set, + const ConversionOptions& conversion_options, + std::optional in_expr) { + if (fref->reference_type_case() != + substrait::Expression::FieldReference::kDirectReference || + fref->root_type_case() != substrait::Expression::FieldReference::kRootReference) { + return Status::NotImplemented("substrait::FieldReference not direct root reference"); + } + auto& dref = fref->direct_reference(); + return FromProto(&dref, ext_set, conversion_options, std::move(in_expr)); +} + +Result DirectReferenceFromProto( + const substrait::Expression::FieldReference* fref, const ExtensionSet& ext_set, + const ConversionOptions& conversion_options) { + ARROW_ASSIGN_OR_RAISE(compute::Expression expr, + FromProto(fref, ext_set, conversion_options, {})); + const FieldRef* field_ref = expr.field_ref(); + if (field_ref) { + return *field_ref; + } else { + return Status::Invalid( + "A direct reference was expected but a more complex expression was given " + "instead"); + } +} + Result FromProto(const substrait::AggregateFunction& func, bool is_hash, const ExtensionSet& ext_set, const ConversionOptions& conversion_options) { @@ -193,67 +292,7 @@ Result FromProto(const substrait::Expression& expr, } const auto* ref = &expr.selection().direct_reference(); - while (ref != nullptr) { - switch (ref->reference_type_case()) { - case substrait::Expression::ReferenceSegment::kStructField: { - auto index = ref->struct_field().field(); - if (!out) { - // Root StructField (column selection) - out = compute::field_ref(FieldRef(index)); - } else if (auto out_ref = out->field_ref()) { - // Nested StructFields on the root (selection of struct-typed column - // combined with selecting struct fields) - out = compute::field_ref(FieldRef(*out_ref, index)); - } else if (out->call() && out->call()->function_name == "struct_field") { - // Nested StructFields on top of an arbitrary expression - auto* field_options = - checked_cast(out->call()->options.get()); - field_options->field_ref = - FieldRef(std::move(field_options->field_ref), index); - } else { - // First StructField on top of an arbitrary expression - out = compute::call("struct_field", {std::move(*out)}, - arrow::compute::StructFieldOptions({index})); - } - - // Segment handled, continue with child segment (if any) - if (ref->struct_field().has_child()) { - ref = &ref->struct_field().child(); - } else { - ref = nullptr; - } - break; - } - case substrait::Expression::ReferenceSegment::kListElement: { - if (!out) { - // Root ListField (illegal) - return Status::Invalid( - "substrait::ListElement cannot take a Relation as an argument"); - } - - // ListField on top of an arbitrary expression - out = compute::call( - "list_element", - {std::move(*out), compute::literal(ref->list_element().offset())}); - - // Segment handled, continue with child segment (if any) - if (ref->list_element().has_child()) { - ref = &ref->list_element().child(); - } else { - ref = nullptr; - } - break; - } - default: - // Unimplemented construct, break out of loop - out.reset(); - ref = nullptr; - } - } - if (out) { - return *std::move(out); - } - break; + return FromProto(ref, ext_set, conversion_options, std::move(out)); } case substrait::Expression::kIfThen: { diff --git a/cpp/src/arrow/engine/substrait/expression_internal.h b/cpp/src/arrow/engine/substrait/expression_internal.h index e947537dd1e..cddc46066f1 100644 --- a/cpp/src/arrow/engine/substrait/expression_internal.h +++ b/cpp/src/arrow/engine/substrait/expression_internal.h @@ -20,6 +20,7 @@ #pragma once #include +#include #include "arrow/compute/type_fwd.h" #include "arrow/datum.h" @@ -34,6 +35,15 @@ namespace engine { class SubstraitCall; +ARROW_ENGINE_EXPORT +Result DirectReferenceFromProto(const substrait::Expression::FieldReference*, + const ExtensionSet&, const ConversionOptions&); + +ARROW_ENGINE_EXPORT +Result FromProto(const substrait::Expression::FieldReference*, + const ExtensionSet&, const ConversionOptions&, + std::optional); + ARROW_ENGINE_EXPORT Result FromProto(const substrait::Expression&, const ExtensionSet&, const ConversionOptions&); diff --git a/cpp/src/arrow/engine/substrait/options.cc b/cpp/src/arrow/engine/substrait/options.cc index 4a4f3a87150..da07f151633 100644 --- a/cpp/src/arrow/engine/substrait/options.cc +++ b/cpp/src/arrow/engine/substrait/options.cc @@ -75,6 +75,11 @@ class DefaultExtensionProvider : public BaseExtensionProvider { rel.UnpackTo(&named_tap_rel); return MakeNamedTapRel(conv_opts, inputs, named_tap_rel, ext_set); } + if (rel.Is()) { + substrait_ext::SegmentedAggregateRel seg_agg_rel; + rel.UnpackTo(&seg_agg_rel); + return MakeSegmentedAggregateRel(conv_opts, inputs, seg_agg_rel, ext_set); + } return Status::NotImplemented("Unrecognized extension in Susbstrait plan: ", rel.DebugString()); } @@ -165,6 +170,74 @@ class DefaultExtensionProvider : public BaseExtensionProvider { named_tap_rel.name(), renamed_schema)); return RelationInfo{{std::move(decl), std::move(renamed_schema)}, std::nullopt}; } + + Result MakeSegmentedAggregateRel( + const ConversionOptions& conv_opts, const std::vector& inputs, + const substrait_ext::SegmentedAggregateRel& seg_agg_rel, + const ExtensionSet& ext_set) { + if (inputs.size() != 1) { + return Status::Invalid( + "substrait_ext::SegmentedAggregateRel requires a single input but got: ", + inputs.size()); + } + if (seg_agg_rel.segment_keys_size() == 0) { + return Status::Invalid( + "substrait_ext::SegmentedAggregateRel requires at least one segment key"); + } + + auto input_schema = inputs[0].output_schema; + + // store key fields to be used when output schema is created + std::vector key_field_ids; + std::vector keys; + for (auto& ref : seg_agg_rel.grouping_keys()) { + ARROW_ASSIGN_OR_RAISE(auto field_ref, + DirectReferenceFromProto(&ref, ext_set, conv_opts)); + ARROW_ASSIGN_OR_RAISE(auto match, field_ref.FindOne(*input_schema)); + key_field_ids.emplace_back(std::move(match[0])); + keys.emplace_back(std::move(field_ref)); + } + + // store segment key fields to be used when output schema is created + std::vector segment_key_field_ids; + std::vector segment_keys; + for (auto& ref : seg_agg_rel.segment_keys()) { + ARROW_ASSIGN_OR_RAISE(auto field_ref, + DirectReferenceFromProto(&ref, ext_set, conv_opts)); + ARROW_ASSIGN_OR_RAISE(auto match, field_ref.FindOne(*input_schema)); + segment_key_field_ids.emplace_back(std::move(match[0])); + segment_keys.emplace_back(std::move(field_ref)); + } + + std::vector aggregates; + aggregates.reserve(seg_agg_rel.measures_size()); + std::vector> agg_src_fieldsets; + agg_src_fieldsets.reserve(seg_agg_rel.measures_size()); + for (auto agg_measure : seg_agg_rel.measures()) { + ARROW_ASSIGN_OR_RAISE( + auto parsed_measure, + internal::ParseAggregateMeasure(agg_measure, ext_set, conv_opts, + /*is_hash=*/!keys.empty(), input_schema)); + aggregates.push_back(std::move(parsed_measure.aggregate)); + agg_src_fieldsets.push_back(std::move(parsed_measure.fieldset)); + } + + ARROW_ASSIGN_OR_RAISE(auto decl_info, + internal::MakeAggregateDeclaration( + std::move(inputs[0].declaration), std::move(input_schema), + seg_agg_rel.measures_size(), std::move(aggregates), + std::move(agg_src_fieldsets), std::move(keys), + std::move(key_field_ids), std::move(segment_keys), + std::move(segment_key_field_ids), ext_set, conv_opts)); + + const auto& output_schema = decl_info.output_schema; + size_t out_size = output_schema->num_fields(); + std::vector field_output_indices(out_size); + for (int i = 0; i < static_cast(out_size); i++) { + field_output_indices[i] = i; + } + return RelationInfo{decl_info, std::move(field_output_indices)}; + } }; namespace { diff --git a/cpp/src/arrow/engine/substrait/relation_internal.cc b/cpp/src/arrow/engine/substrait/relation_internal.cc index eefc37607ad..7f9c4c289fb 100644 --- a/cpp/src/arrow/engine/substrait/relation_internal.cc +++ b/cpp/src/arrow/engine/substrait/relation_internal.cc @@ -293,6 +293,81 @@ Status DiscoverFilesFromDir(const std::shared_ptr& local_fs return Status::OK(); } +namespace internal { + +Result ParseAggregateMeasure( + const substrait::AggregateRel::Measure& agg_measure, const ExtensionSet& ext_set, + const ConversionOptions& conversion_options, bool is_hash, + const std::shared_ptr input_schema) { + if (agg_measure.has_measure()) { + if (agg_measure.has_filter()) { + return Status::NotImplemented("Aggregate filters are not supported."); + } + const auto& agg_func = agg_measure.measure(); + ARROW_ASSIGN_OR_RAISE(SubstraitCall aggregate_call, + FromProto(agg_func, is_hash, ext_set, conversion_options)); + ExtensionIdRegistry::SubstraitAggregateToArrow converter; + if (aggregate_call.id().uri.empty() || aggregate_call.id().uri[0] == '/') { + ARROW_ASSIGN_OR_RAISE(converter, + ext_set.registry()->GetSubstraitAggregateToArrowFallback( + aggregate_call.id().name)); + } else { + ARROW_ASSIGN_OR_RAISE(converter, ext_set.registry()->GetSubstraitAggregateToArrow( + aggregate_call.id())); + } + ARROW_ASSIGN_OR_RAISE(compute::Aggregate arrow_agg, converter(aggregate_call)); + + // find aggregate field ids from schema + const auto& target = arrow_agg.target; + std::vector fieldset; + fieldset.reserve(target.size()); + for (const auto& field_ref : target) { + ARROW_ASSIGN_OR_RAISE(auto match, field_ref.FindOne(*input_schema)); + fieldset.push_back(match[0]); + } + + return ParsedMeasure{std::move(arrow_agg), std::move(fieldset)}; + } else { + return Status::Invalid("substrait::AggregateFunction not provided"); + } +} + +ARROW_ENGINE_EXPORT Result MakeAggregateDeclaration( + compute::Declaration input_decl, std::shared_ptr input_schema, + const int measure_size, std::vector aggregates, + std::vector> agg_src_fieldsets, std::vector keys, + std::vector key_field_ids, std::vector segment_keys, + std::vector segment_key_field_ids, const ExtensionSet& ext_set, + const ConversionOptions& conversion_options) { + FieldVector output_fields; + output_fields.reserve(key_field_ids.size() + segment_key_field_ids.size() + + measure_size); + // extract aggregate fields to output schema + for (const auto& agg_src_fieldset : agg_src_fieldsets) { + for (int field : agg_src_fieldset) { + output_fields.emplace_back(input_schema->field(field)); + } + } + // extract key fields to output schema + for (int key_field_id : key_field_ids) { + output_fields.emplace_back(input_schema->field(key_field_id)); + } + // extract segment key fields to output schema + for (int segment_key_field_id : segment_key_field_ids) { + output_fields.emplace_back(input_schema->field(segment_key_field_id)); + } + + std::shared_ptr aggregate_schema = schema(std::move(output_fields)); + + return DeclarationInfo{ + compute::Declaration::Sequence( + {std::move(input_decl), + {"aggregate", compute::AggregateNodeOptions{aggregates, keys, segment_keys}}}), + aggregate_schema}; +} + +} // namespace internal + Result FromProto(const substrait::Rel& rel, const ExtensionSet& ext_set, const ConversionOptions& conversion_options) { static bool dataset_init = false; @@ -730,62 +805,26 @@ Result FromProto(const substrait::Rel& rel, const ExtensionSet& std::vector aggregates; aggregates.reserve(measure_size); // store aggregate fields to be used when output schema is created - std::vector> agg_src_fieldsets(measure_size); + std::vector> agg_src_fieldsets; + agg_src_fieldsets.reserve(measure_size); for (int measure_id = 0; measure_id < measure_size; measure_id++) { const auto& agg_measure = aggregate.measures(measure_id); - if (agg_measure.has_measure()) { - if (agg_measure.has_filter()) { - return Status::NotImplemented("Aggregate filters are not supported."); - } - const auto& agg_func = agg_measure.measure(); - ARROW_ASSIGN_OR_RAISE(SubstraitCall aggregate_call, - FromProto(agg_func, /*is_hash=*/!keys.empty(), ext_set, - conversion_options)); - ExtensionIdRegistry::SubstraitAggregateToArrow converter; - if (aggregate_call.id().uri.empty() || aggregate_call.id().uri[0] == '/') { - ARROW_ASSIGN_OR_RAISE( - converter, ext_set.registry()->GetSubstraitAggregateToArrowFallback( - aggregate_call.id().name)); - } else { - ARROW_ASSIGN_OR_RAISE( - converter, - ext_set.registry()->GetSubstraitAggregateToArrow(aggregate_call.id())); - } - ARROW_ASSIGN_OR_RAISE(compute::Aggregate arrow_agg, converter(aggregate_call)); - - // find aggregate field ids from schema - const auto& target = arrow_agg.target; - for (const auto& field_ref : target) { - ARROW_ASSIGN_OR_RAISE(auto match, field_ref.FindOne(*input_schema)); - agg_src_fieldsets[measure_id].push_back(match[0]); - } - - aggregates.push_back(std::move(arrow_agg)); - } else { - return Status::Invalid("substrait::AggregateFunction not provided"); - } - } - FieldVector output_fields; - output_fields.reserve(key_field_ids.size() + measure_size); - // extract aggregate fields to output schema - for (const auto& agg_src_fieldset : agg_src_fieldsets) { - for (int field : agg_src_fieldset) { - output_fields.emplace_back(input_schema->field(field)); - } - } - // extract key fields to output schema - for (int key_field_id : key_field_ids) { - output_fields.emplace_back(input_schema->field(key_field_id)); + ARROW_ASSIGN_OR_RAISE( + auto parsed_measure, + internal::ParseAggregateMeasure(agg_measure, ext_set, conversion_options, + /*is_hash=*/!keys.empty(), input_schema)); + aggregates.push_back(std::move(parsed_measure.aggregate)); + agg_src_fieldsets.push_back(std::move(parsed_measure.fieldset)); } - std::shared_ptr aggregate_schema = schema(std::move(output_fields)); - - DeclarationInfo aggregate_declaration{ - compute::Declaration::Sequence( - {std::move(input.declaration), - {"aggregate", compute::AggregateNodeOptions{aggregates, keys}}}), - aggregate_schema}; + ARROW_ASSIGN_OR_RAISE( + auto aggregate_declaration, + internal::MakeAggregateDeclaration( + std::move(input.declaration), std::move(input_schema), measure_size, + std::move(aggregates), std::move(agg_src_fieldsets), std::move(keys), + std::move(key_field_ids), {}, {}, ext_set, conversion_options)); + auto aggregate_schema = aggregate_declaration.output_schema; return ProcessEmit(std::move(aggregate), std::move(aggregate_declaration), std::move(aggregate_schema)); } diff --git a/cpp/src/arrow/engine/substrait/relation_internal.h b/cpp/src/arrow/engine/substrait/relation_internal.h index 17153f5365f..8d560d9f41f 100644 --- a/cpp/src/arrow/engine/substrait/relation_internal.h +++ b/cpp/src/arrow/engine/substrait/relation_internal.h @@ -21,6 +21,7 @@ #include +#include "arrow/compute/api_aggregate.h" #include "arrow/compute/type_fwd.h" #include "arrow/engine/substrait/relation.h" #include "arrow/engine/substrait/type_fwd.h" @@ -46,5 +47,48 @@ Result FromProto(const substrait::Rel&, const ExtensionSet&, ARROW_ENGINE_EXPORT Result> ToProto( const compute::Declaration&, ExtensionSet*, const ConversionOptions&); +namespace internal { + +struct ParsedMeasure { + compute::Aggregate aggregate; + std::vector fieldset; +}; + +/// \brief Parse an aggregate relation's measure +/// +/// \param[in] agg_measure the measure +/// \param[in] ext_set an extension mapping to use in parsing +/// \param[in] conversion_options options to control how the conversion is done +/// \param[in] input_schema the schema to which field refs apply +/// \param[in] is_hash whether the measure is a hash one (i.e., aggregation keys exist) +ARROW_ENGINE_EXPORT +Result ParseAggregateMeasure( + const substrait::AggregateRel::Measure& agg_measure, const ExtensionSet& ext_set, + const ConversionOptions& conversion_options, bool is_hash, + const std::shared_ptr input_schema); + +/// \brief Make an aggregate declaration info +/// +/// \param[in] input_decl the input declaration to use +/// \param[in] input_schema the schema to which field refs apply +/// \param[in] measure_size the number of measures to use +/// \param[in] aggregates the aggregates to use +/// \param[in] agg_src_fieldsets the field-sets per aggregate to use +/// \param[in] keys the field-refs for grouping keys to use +/// \param[in] key_field_ids the field-ids for grouping keys to use +/// \param[in] segment_keys the field-refs for segment keys to use +/// \param[in] segment_key_field_ids the field-ids for segment keys to use +/// \param[in] ext_set an extension mapping to use +/// \param[in] conversion_options options to control how the conversion is done +ARROW_ENGINE_EXPORT Result MakeAggregateDeclaration( + compute::Declaration input_decl, std::shared_ptr input_schema, + const int measure_size, std::vector aggregates, + std::vector> agg_src_fieldsets, std::vector keys, + std::vector key_field_ids, std::vector segment_keys, + std::vector segment_key_field_ids, const ExtensionSet& ext_set, + const ConversionOptions& conversion_options); + +} // namespace internal + } // namespace engine } // namespace arrow diff --git a/cpp/src/arrow/engine/substrait/serde_test.cc b/cpp/src/arrow/engine/substrait/serde_test.cc index 32b3b5c5c17..0f605cd7ce4 100644 --- a/cpp/src/arrow/engine/substrait/serde_test.cc +++ b/cpp/src/arrow/engine/substrait/serde_test.cc @@ -5607,5 +5607,132 @@ TEST(Substrait, PlanWithNamedTapExtension) { CheckRoundTripResult(std::move(expected_table), buf, {}, conversion_options); } +TEST(Substrait, PlanWithSegmentedAggregateExtension) { + // This demos an extension relation + std::string substrait_json = R"({ + "extensionUris": [{ + "extension_uri_anchor": 0, + "uri": "https://github.com/substrait-io/substrait/blob/main/extensions/functions_arithmetic.yaml" + }], + "extensions": [{ + "extension_function": { + "extension_uri_reference": 0, + "function_anchor": 0, + "name": "sum" + } + }], + "relations": [{ + "root": { + "input": { + "extension_single": { + "input": { + "read": { + "common": { + "direct": { + } + }, + "baseSchema": { + "names": ["time", "key", "value"], + "struct": { + "types": [ + { + "i32": { + "typeVariationReference": 0, + "nullability": "NULLABILITY_NULLABLE" + } + }, + { + "i32": { + "typeVariationReference": 0, + "nullability": "NULLABILITY_NULLABLE" + } + }, + { + "fp64": { + "typeVariationReference": 0, + "nullability": "NULLABILITY_NULLABLE" + } + } + ], + "typeVariationReference": 0, + "nullability": "NULLABILITY_REQUIRED" + } + }, + "namedTable": { + "names": ["T"] + } + } + }, + "detail": { + "@type": "/arrow.substrait_ext.SegmentedAggregateRel", + "grouping_keys": [{ + "directReference": { + "structField": { + "field": 1 + } + }, + "rootReference": {} + }], + "segment_keys": [{ + "directReference": { + "structField": { + "field": 0 + } + }, + "rootReference": {} + }], + "measures": [{ + "measure": { + "functionReference": 0, + "arguments": [{ + "value": { + "selection": { + "directReference": { + "structField": { + "field": 2 + } + } + } + } + }], + "sorts": [], + "phase": "AGGREGATION_PHASE_INITIAL_TO_RESULT", + "outputType": { + "i64": {} + } + } + }] + } + } + }, + "names": ["v", "k", "t"] + } + }], + "expectedTypeUrls": [] + })"; + + std::shared_ptr input_schema = + schema({field("time", int32()), field("key", int32()), field("value", float64())}); + NamedTableProvider table_provider = AlwaysProvideSameTable( + TableFromJSON(input_schema, {"[[1, 1, 1], [1, 2, 2], [1, 1, 3]," + " [2, 2, 4], [2, 1, 5], [2, 2, 6]]"})); + ConversionOptions conversion_options; + conversion_options.named_table_provider = std::move(table_provider); + conversion_options.named_tap_provider = + [](const std::string& tap_kind, std::vector inputs, + const std::string& tap_name, + std::shared_ptr tap_schema) -> Result { + return compute::Declaration{tap_kind, std::move(inputs), compute::ExecNodeOptions{}}; + }; + + ASSERT_OK_AND_ASSIGN(auto buf, internal::SubstraitFromJSON("Plan", substrait_json)); + + std::shared_ptr output_schema = + schema({field("v", float64()), field("k", int32()), field("t", int32())}); + auto expected_table = + TableFromJSON(output_schema, {"[[4, 1, 1], [2, 2, 1], [10, 2, 2], [5, 1, 2]]"}); + CheckRoundTripResult(std::move(expected_table), buf, {}, conversion_options); +} + } // namespace engine } // namespace arrow