diff --git a/cpp/src/arrow/engine/substrait/relation_internal.cc b/cpp/src/arrow/engine/substrait/relation_internal.cc index c5d212c8c2f..4213895b616 100644 --- a/cpp/src/arrow/engine/substrait/relation_internal.cc +++ b/cpp/src/arrow/engine/substrait/relation_internal.cc @@ -42,11 +42,45 @@ using internal::make_unique; namespace engine { template -Status CheckRelCommon(const RelMessage& rel) { +Result> GetEmitInfo( + const RelMessage& rel, const std::shared_ptr& schema) { + const auto& emit = rel.common().emit(); + int emit_size = emit.output_mapping_size(); + std::vector proj_field_refs(emit_size); + for (int i = 0; i < emit_size; i++) { + int32_t map_id = emit.output_mapping(i); + proj_field_refs[i] = compute::field_ref(FieldRef(map_id)); + } + return std::move(proj_field_refs); +} + +template +Result ProcessEmit(const RelMessage& rel, + const DeclarationInfo& no_emit_declr, + const std::shared_ptr& schema) { if (rel.has_common()) { - if (rel.common().has_emit()) { - return Status::NotImplemented("substrait::RelCommon::Emit"); + switch (rel.common().emit_kind_case()) { + case substrait::RelCommon::EmitKindCase::kDirect: + return no_emit_declr; + case substrait::RelCommon::EmitKindCase::kEmit: { + ARROW_ASSIGN_OR_RAISE(auto emit_expressions, GetEmitInfo(rel, schema)); + return DeclarationInfo{ + compute::Declaration::Sequence( + {no_emit_declr.declaration, + {"project", compute::ProjectNodeOptions{std::move(emit_expressions)}}}), + std::move(schema)}; + } + default: + return Status::Invalid("Invalid emit case"); } + } else { + return no_emit_declr; + } +} + +template +Status CheckRelCommon(const RelMessage& rel) { + if (rel.has_common()) { if (rel.common().has_hint()) { return Status::NotImplemented("substrait::RelCommon::Hint"); } @@ -75,7 +109,6 @@ Result FromProto(const substrait::Rel& rel, const ExtensionSet& ARROW_ASSIGN_OR_RAISE(auto base_schema, FromProto(read.base_schema(), ext_set, conversion_options)); - auto num_columns = static_cast(base_schema->fields().size()); auto scan_options = std::make_shared(); scan_options->use_threads = true; @@ -104,7 +137,9 @@ Result FromProto(const substrait::Rel& rel, const ExtensionSet& named_table.names().end()); ARROW_ASSIGN_OR_RAISE(compute::Declaration source_decl, named_table_provider(table_names)); - return DeclarationInfo{std::move(source_decl), num_columns}; + return ProcessEmit(std::move(read), + DeclarationInfo{std::move(source_decl), base_schema}, + std::move(base_schema)); } if (!read.has_local_files()) { @@ -216,12 +251,14 @@ Result FromProto(const substrait::Rel& rel, const ExtensionSet& std::move(filesystem), std::move(files), std::move(format), {})); - ARROW_ASSIGN_OR_RAISE(auto ds, ds_factory->Finish(std::move(base_schema))); + ARROW_ASSIGN_OR_RAISE(auto ds, ds_factory->Finish(base_schema)); + + DeclarationInfo scan_declaration = { + compute::Declaration{"scan", dataset::ScanNodeOptions{ds, scan_options}}, + base_schema}; - return DeclarationInfo{ - compute::Declaration{ - "scan", dataset::ScanNodeOptions{std::move(ds), std::move(scan_options)}}, - num_columns}; + return ProcessEmit(std::move(read), std::move(scan_declaration), + std::move(base_schema)); } case substrait::Rel::RelTypeCase::kFilter: { @@ -239,19 +276,20 @@ Result FromProto(const substrait::Rel& rel, const ExtensionSet& } ARROW_ASSIGN_OR_RAISE(auto condition, FromProto(filter.condition(), ext_set, conversion_options)); - - return DeclarationInfo{ + DeclarationInfo filter_declaration{ compute::Declaration::Sequence({ std::move(input.declaration), {"filter", compute::FilterNodeOptions{std::move(condition)}}, }), - input.num_columns}; + input.output_schema}; + + return ProcessEmit(std::move(filter), std::move(filter_declaration), + input.output_schema); } case substrait::Rel::RelTypeCase::kProject: { const auto& project = rel.project(); RETURN_NOT_OK(CheckRelCommon(project)); - if (!project.has_input()) { return Status::Invalid("substrait::ProjectRel with no input relation"); } @@ -261,23 +299,48 @@ Result FromProto(const substrait::Rel& rel, const ExtensionSet& // NOTE: Substrait ProjectRels *append* columns, while Acero's project node replaces // them. Therefore, we need to prefix all the current columns for compatibility. std::vector expressions; - expressions.reserve(input.num_columns + project.expressions().size()); - for (int i = 0; i < input.num_columns; i++) { + int num_columns = input.output_schema->num_fields(); + expressions.reserve(num_columns + project.expressions().size()); + for (int i = 0; i < num_columns; i++) { expressions.emplace_back(compute::field_ref(FieldRef(i))); } + + int i = 0; + auto project_schema = input.output_schema; for (const auto& expr : project.expressions()) { - expressions.emplace_back(); - ARROW_ASSIGN_OR_RAISE(expressions.back(), + std::shared_ptr project_field; + ARROW_ASSIGN_OR_RAISE(compute::Expression des_expr, FromProto(expr, ext_set, conversion_options)); + auto bound_expr = des_expr.Bind(*input.output_schema); + if (auto* expr_call = bound_expr->call()) { + project_field = field(expr_call->function_name, + expr_call->kernel->signature->out_type().type()); + } else if (auto* field_ref = des_expr.field_ref()) { + ARROW_ASSIGN_OR_RAISE(FieldPath field_path, + field_ref->FindOne(*input.output_schema)); + ARROW_ASSIGN_OR_RAISE(project_field, field_path.Get(*input.output_schema)); + } else if (auto* literal = des_expr.literal()) { + project_field = + field("field_" + std::to_string(num_columns + i), literal->type()); + } + ARROW_ASSIGN_OR_RAISE( + project_schema, + project_schema->AddField( + num_columns + static_cast(project.expressions().size()) - 1, + std::move(project_field))); + i++; + expressions.emplace_back(des_expr); } - auto num_columns = static_cast(expressions.size()); - return DeclarationInfo{ + DeclarationInfo project_declaration{ compute::Declaration::Sequence({ std::move(input.declaration), {"project", compute::ProjectNodeOptions{std::move(expressions)}}, }), - num_columns}; + project_schema}; + + return ProcessEmit(std::move(project), std::move(project_declaration), + std::move(project_schema)); } case substrait::Rel::RelTypeCase::kJoin: { @@ -355,15 +418,26 @@ Result FromProto(const substrait::Rel& rel, const ExtensionSet& if (!left_keys || !right_keys) { return Status::Invalid("Left keys for join cannot be null"); } + + // Create output schema from left, right relations and join keys + FieldVector combined_fields = left.output_schema->fields(); + const FieldVector& right_fields = right.output_schema->fields(); + combined_fields.insert(combined_fields.end(), right_fields.begin(), + right_fields.end()); + std::shared_ptr join_schema = schema(std::move(combined_fields)); + compute::HashJoinNodeOptions join_options{{std::move(*left_keys)}, {std::move(*right_keys)}}; join_options.join_type = join_type; join_options.key_cmp = {join_key_cmp}; compute::Declaration join_dec{"hashjoin", std::move(join_options)}; - auto num_columns = left.num_columns + right.num_columns; join_dec.inputs.emplace_back(std::move(left.declaration)); join_dec.inputs.emplace_back(std::move(right.declaration)); - return DeclarationInfo{std::move(join_dec), num_columns}; + + DeclarationInfo join_declaration{std::move(join_dec), join_schema}; + + return ProcessEmit(std::move(join), std::move(join_declaration), + std::move(join_schema)); } case substrait::Rel::RelTypeCase::kAggregate: { const auto& aggregate = rel.aggregate(); @@ -381,16 +455,25 @@ Result FromProto(const substrait::Rel& rel, const ExtensionSet& "Grouping sets not supported. AggregateRel::groupings may not have more " "than one item"); } + + // prepare output schema from aggregates + auto input_schema = input.output_schema; + // store key fields to be used when output schema is created + std::vector key_field_ids; std::vector keys; if (aggregate.groupings_size() > 0) { const substrait::AggregateRel::Grouping& group = aggregate.groupings(0); - keys.reserve(group.grouping_expressions_size()); - for (int exp_id = 0; exp_id < group.grouping_expressions_size(); exp_id++) { + int grouping_expr_size = group.grouping_expressions_size(); + keys.reserve(grouping_expr_size); + key_field_ids.reserve(grouping_expr_size); + for (int exp_id = 0; exp_id < grouping_expr_size; exp_id++) { ARROW_ASSIGN_OR_RAISE( compute::Expression expr, FromProto(group.grouping_expressions(exp_id), ext_set, conversion_options)); const FieldRef* field_ref = expr.field_ref(); if (field_ref) { + ARROW_ASSIGN_OR_RAISE(auto match, field_ref->FindOne(*input_schema)); + key_field_ids.emplace_back(std::move(match[0])); keys.emplace_back(std::move(*field_ref)); } else { return Status::Invalid( @@ -402,6 +485,8 @@ Result FromProto(const substrait::Rel& rel, const ExtensionSet& int measure_size = aggregate.measures_size(); std::vector aggregates; aggregates.reserve(measure_size); + // store aggregate fields to be used when output schema is created + std::vector agg_src_field_ids(measure_size); for (int measure_id = 0; measure_id < measure_size; measure_id++) { const auto& agg_measure = aggregate.measures(measure_id); if (agg_measure.has_measure()) { @@ -416,17 +501,38 @@ Result FromProto(const substrait::Rel& rel, const ExtensionSet& ExtensionIdRegistry::SubstraitAggregateToArrow converter, ext_set.registry()->GetSubstraitAggregateToArrow(aggregate_call.id())); ARROW_ASSIGN_OR_RAISE(compute::Aggregate arrow_agg, converter(aggregate_call)); + + // find aggregate field ids from schema + const auto field_ref = arrow_agg.target; + ARROW_ASSIGN_OR_RAISE(auto match, field_ref.FindOne(*input_schema)); + agg_src_field_ids[measure_id] = match[0]; + aggregates.push_back(std::move(arrow_agg)); } else { return Status::Invalid("substrait::AggregateFunction not provided"); } } + FieldVector output_fields; + output_fields.reserve(key_field_ids.size() + agg_src_field_ids.size()); + // extract aggregate fields to output schema + for (int id = 0; id < static_cast(agg_src_field_ids.size()); id++) { + output_fields.emplace_back(input_schema->field(agg_src_field_ids[id])); + } + // extract key fields to output schema + for (int id = 0; id < static_cast(key_field_ids.size()); id++) { + output_fields.emplace_back(input_schema->field(key_field_ids[id])); + } + + std::shared_ptr aggregate_schema = schema(std::move(output_fields)); - return DeclarationInfo{ + DeclarationInfo aggregate_declaration{ compute::Declaration::Sequence( {std::move(input.declaration), {"aggregate", compute::AggregateNodeOptions{aggregates, keys}}}), - static_cast(aggregates.size())}; + aggregate_schema}; + + return ProcessEmit(std::move(aggregate), std::move(aggregate_declaration), + std::move(aggregate_schema)); } default: diff --git a/cpp/src/arrow/engine/substrait/relation_internal.h b/cpp/src/arrow/engine/substrait/relation_internal.h index 778d1e5bc01..514f3f97fc0 100644 --- a/cpp/src/arrow/engine/substrait/relation_internal.h +++ b/cpp/src/arrow/engine/substrait/relation_internal.h @@ -36,8 +36,7 @@ struct DeclarationInfo { /// The compute declaration produced thus far. compute::Declaration declaration; - /// The number of columns returned by the declaration. - int num_columns; + std::shared_ptr output_schema; }; /// \brief Convert a Substrait Rel object to an Acero declaration diff --git a/cpp/src/arrow/engine/substrait/serde_test.cc b/cpp/src/arrow/engine/substrait/serde_test.cc index 9b6c3f715f7..251c2bfe352 100644 --- a/cpp/src/arrow/engine/substrait/serde_test.cc +++ b/cpp/src/arrow/engine/substrait/serde_test.cc @@ -24,12 +24,10 @@ #include "arrow/dataset/file_base.h" #include "arrow/dataset/file_ipc.h" #include "arrow/dataset/file_parquet.h" - #include "arrow/dataset/plan.h" #include "arrow/dataset/scanner.h" #include "arrow/engine/substrait/extension_types.h" #include "arrow/engine/substrait/serde.h" - #include "arrow/engine/substrait/util.h" #include "arrow/filesystem/localfs.h" @@ -58,31 +56,35 @@ using internal::checked_cast; using internal::hash_combine; namespace engine { -Status WriteIpcData(const std::string& path, - const std::shared_ptr file_system, - const std::shared_ptr input) { +void WriteIpcData(const std::string& path, + const std::shared_ptr file_system, + const std::shared_ptr
input) { EXPECT_OK_AND_ASSIGN(auto mmap, file_system->OpenOutputStream(path)); - ARROW_ASSIGN_OR_RAISE( + ASSERT_OK_AND_ASSIGN( auto file_writer, MakeFileWriter(mmap, input->schema(), ipc::IpcWriteOptions::Defaults())); TableBatchReader reader(input); std::shared_ptr batch; while (true) { - RETURN_NOT_OK(reader.ReadNext(&batch)); + ASSERT_OK(reader.ReadNext(&batch)); if (batch == nullptr) { break; } - RETURN_NOT_OK(file_writer->WriteRecordBatch(*batch)); + ASSERT_OK(file_writer->WriteRecordBatch(*batch)); } - RETURN_NOT_OK(file_writer->Close()); - return Status::OK(); + ASSERT_OK(file_writer->Close()); } Result> GetTableFromPlan( - compute::Declaration& declarations, - arrow::AsyncGenerator>& sink_gen, - compute::ExecContext& exec_context, std::shared_ptr& output_schema) { + compute::Declaration& other_declrs, compute::ExecContext& exec_context, + const std::shared_ptr& output_schema) { ARROW_ASSIGN_OR_RAISE(auto plan, compute::ExecPlan::Make(&exec_context)); + + arrow::AsyncGenerator> sink_gen; + auto sink_node_options = compute::SinkNodeOptions{&sink_gen}; + auto sink_declaration = compute::Declaration({"sink", sink_node_options, "e"}); + auto declarations = compute::Declaration::Sequence({other_declrs, sink_declaration}); + ARROW_ASSIGN_OR_RAISE(auto decl, declarations.AddToPlan(plan.get())); RETURN_NOT_OK(decl->Validate()); @@ -171,6 +173,29 @@ inline compute::Expression UseBoringRefs(const compute::Expression& expr) { return compute::Expression{std::move(modified_call)}; } +void CheckRoundTripResult(const std::shared_ptr output_schema, + const std::shared_ptr
expected_table, + compute::ExecContext& exec_context, + std::shared_ptr& buf, + const std::vector& include_columns = {}, + const ConversionOptions& conversion_options = {}) { + std::shared_ptr sp_ext_id_reg = MakeExtensionIdRegistry(); + ExtensionIdRegistry* ext_id_reg = sp_ext_id_reg.get(); + ExtensionSet ext_set(ext_id_reg); + ASSERT_OK_AND_ASSIGN(auto sink_decls, DeserializePlans( + *buf, [] { return kNullConsumer; }, + ext_id_reg, &ext_set, conversion_options)); + auto other_declrs = sink_decls[0].inputs[0].get(); + + ASSERT_OK_AND_ASSIGN(auto output_table, + GetTableFromPlan(*other_declrs, exec_context, output_schema)); + if (!include_columns.empty()) { + ASSERT_OK_AND_ASSIGN(output_table, output_table->SelectColumns(include_columns)); + } + ASSERT_OK_AND_ASSIGN(output_table, output_table->CombineChunks()); + EXPECT_TRUE(expected_table->Equals(*output_table)); +} + TEST(Substrait, SupportedTypes) { auto ExpectEq = [](util::string_view json, std::shared_ptr expected_type) { ARROW_SCOPED_TRACE(json); @@ -1903,7 +1928,7 @@ TEST(Substrait, BasicPlanRoundTripping) { ASSERT_OK_AND_ASSIGN(auto file_path, tempdir->path().Join(file_name)); std::string file_path_str = file_path.ToString(); - ARROW_EXPECT_OK(WriteIpcData(file_path_str, filesystem, table)); + WriteIpcData(file_path_str, filesystem, table); std::vector files; const std::vector f_paths = {file_path_str}; @@ -2015,7 +2040,7 @@ TEST(Substrait, BasicPlanRoundTrippingEndToEnd) { ASSERT_OK_AND_ASSIGN(auto file_path, tempdir->path().Join(file_name)); std::string file_path_str = file_path.ToString(); - ARROW_EXPECT_OK(WriteIpcData(file_path_str, filesystem, table)); + WriteIpcData(file_path_str, filesystem, table); std::vector files; const std::vector f_paths = {file_path_str}; @@ -2037,16 +2062,13 @@ TEST(Substrait, BasicPlanRoundTrippingEndToEnd) { auto comp_right_value = compute::field_ref(filter_col_right); auto filter = compute::equal(comp_left_value, comp_right_value); - arrow::AsyncGenerator> sink_gen; - auto declarations = compute::Declaration::Sequence( {compute::Declaration( {"scan", dataset::ScanNodeOptions{dataset, scan_options}, "s"}), - compute::Declaration({"filter", compute::FilterNodeOptions{filter}, "f"}), - compute::Declaration({"sink", compute::SinkNodeOptions{&sink_gen}, "e"})}); + compute::Declaration({"filter", compute::FilterNodeOptions{filter}, "f"})}); - ASSERT_OK_AND_ASSIGN(auto expected_table, GetTableFromPlan(declarations, sink_gen, - exec_context, dummy_schema)); + ASSERT_OK_AND_ASSIGN(auto expected_table, + GetTableFromPlan(declarations, exec_context, dummy_schema)); std::shared_ptr sp_ext_id_reg = MakeExtensionIdRegistry(); ExtensionIdRegistry* ext_id_reg = sp_ext_id_reg.get(); @@ -2091,17 +2113,971 @@ TEST(Substrait, BasicPlanRoundTrippingEndToEnd) { checked_cast(roundtrip_frg_vec[idx++].get()); EXPECT_TRUE(l_frag->Equals(*r_frag)); } - arrow::AsyncGenerator> rnd_trp_sink_gen; - auto rnd_trp_sink_node_options = compute::SinkNodeOptions{&rnd_trp_sink_gen}; - auto rnd_trp_sink_declaration = - compute::Declaration({"sink", rnd_trp_sink_node_options, "e"}); - auto rnd_trp_declarations = - compute::Declaration::Sequence({*roundtripped_filter, rnd_trp_sink_declaration}); - ASSERT_OK_AND_ASSIGN(auto rnd_trp_table, - GetTableFromPlan(rnd_trp_declarations, rnd_trp_sink_gen, - exec_context, dummy_schema)); + ASSERT_OK_AND_ASSIGN(auto rnd_trp_table, GetTableFromPlan(*roundtripped_filter, + exec_context, dummy_schema)); EXPECT_TRUE(expected_table->Equals(*rnd_trp_table)); } +TEST(Substrait, ProjectRel) { +#ifdef _WIN32 + GTEST_SKIP() << "ARROW-16392: Substrait File URI not supported for Windows"; +#endif + compute::ExecContext exec_context; + auto dummy_schema = + schema({field("A", int32()), field("B", int32()), field("C", int32())}); + + // creating a dummy dataset using a dummy table + auto input_table = TableFromJSON(dummy_schema, {R"([ + [1, 1, 10], + [3, 5, 20], + [4, 1, 30], + [2, 1, 40], + [5, 5, 50], + [2, 2, 60] + ])"}); + + std::string substrait_json = R"({ + "relations": [{ + "rel": { + "project": { + "expressions": [{ + "scalarFunction": { + "functionReference": 0, + "arguments": [{ + "value": { + "selection": { + "directReference": { + "structField": { + "field": 0 + } + }, + "rootReference": { + } + } + } + }, { + "value": { + "selection": { + "directReference": { + "structField": { + "field": 1 + } + }, + "rootReference": { + } + } + } + }], + "output_type": { + "bool": {} + } + } + }, + ], + "input" : { + "read": { + "base_schema": { + "names": ["A", "B", "C"], + "struct": { + "types": [{ + "i32": {} + }, { + "i32": {} + }, { + "i32": {} + }] + } + }, + "namedTable": { + "names": [] + } + } + } + } + } + }], + "extension_uris": [ + { + "extension_uri_anchor": 0, + "uri": ")" + std::string(kSubstraitComparisonFunctionsUri) + + R"(" + } + ], + "extensions": [ + {"extension_function": { + "extension_uri_reference": 0, + "function_anchor": 0, + "name": "equal" + }} + ] + })"; + + ASSERT_OK_AND_ASSIGN(auto buf, internal::SubstraitFromJSON("Plan", substrait_json)); + auto output_schema = schema({field("A", int32()), field("B", int32()), + field("C", int32()), field("equal", boolean())}); + auto expected_table = TableFromJSON(output_schema, {R"([ + [1, 1, 10, true], + [3, 5, 20, false], + [4, 1, 30, false], + [2, 1, 40, false], + [5, 5, 50, true], + [2, 2, 60, true] + ])"}); + + NamedTableProvider table_provider = [input_table](const std::vector&) { + std::shared_ptr options = + std::make_shared(input_table); + return compute::Declaration("table_source", {}, options, "mock_source"); + }; + + ConversionOptions conversion_options; + conversion_options.named_table_provider = std::move(table_provider); + + CheckRoundTripResult(std::move(output_schema), std::move(expected_table), exec_context, + buf, {}, conversion_options); +} + +TEST(Substrait, ProjectRelOnFunctionWithEmit) { +#ifdef _WIN32 + GTEST_SKIP() << "ARROW-16392: Substrait File URI not supported for Windows"; +#endif + compute::ExecContext exec_context; + auto dummy_schema = + schema({field("A", int32()), field("B", int32()), field("C", int32())}); + + // creating a dummy dataset using a dummy table + auto input_table = TableFromJSON(dummy_schema, {R"([ + [1, 1, 10], + [3, 5, 20], + [4, 1, 30], + [2, 1, 40], + [5, 5, 50], + [2, 2, 60] + ])"}); + + std::string substrait_json = R"({ + "relations": [{ + "rel": { + "project": { + "common": { + "emit": { + "outputMapping": [0, 2, 3] + } + }, + "expressions": [{ + "scalarFunction": { + "functionReference": 0, + "arguments": [{ + "value": { + "selection": { + "directReference": { + "structField": { + "field": 0 + } + }, + "rootReference": { + } + } + } + }, { + "value": { + "selection": { + "directReference": { + "structField": { + "field": 1 + } + }, + "rootReference": { + } + } + } + }], + "output_type": { + "bool": {} + } + } + }, + ], + "input" : { + "read": { + "base_schema": { + "names": ["A", "B", "C"], + "struct": { + "types": [{ + "i32": {} + }, { + "i32": {} + }, { + "i32": {} + }] + } + }, + "namedTable": { + "names": [] + } + } + } + } + } + }], + "extension_uris": [ + { + "extension_uri_anchor": 0, + "uri": ")" + std::string(kSubstraitComparisonFunctionsUri) + + R"(" + } + ], + "extensions": [ + {"extension_function": { + "extension_uri_reference": 0, + "function_anchor": 0, + "name": "equal" + }} + ] + })"; + + ASSERT_OK_AND_ASSIGN(auto buf, internal::SubstraitFromJSON("Plan", substrait_json)); + auto output_schema = + schema({field("A", int32()), field("C", int32()), field("equal", boolean())}); + auto expected_table = TableFromJSON(output_schema, {R"([ + [1, 10, true], + [3, 20, false], + [4, 30, false], + [2, 40, false], + [5, 50, true], + [2, 60, true] + ])"}); + NamedTableProvider table_provider = [input_table](const std::vector&) { + std::shared_ptr options = + std::make_shared(input_table); + return compute::Declaration("table_source", {}, options, "mock_source"); + }; + + ConversionOptions conversion_options; + conversion_options.named_table_provider = std::move(table_provider); + + CheckRoundTripResult(std::move(output_schema), std::move(expected_table), exec_context, + buf, {}, conversion_options); +} + +TEST(Substrait, ReadRelWithEmit) { +#ifdef _WIN32 + GTEST_SKIP() << "ARROW-16392: Substrait File URI not supported for Windows"; +#endif + compute::ExecContext exec_context; + auto dummy_schema = + schema({field("A", int32()), field("B", int32()), field("C", int32())}); + + // creating a dummy dataset using a dummy table + auto input_table = TableFromJSON(dummy_schema, {R"([ + [1, 1, 10], + [3, 4, 20] + ])"}); + + std::string substrait_json = R"({ + "relations": [{ + "rel": { + "read": { + "common": { + "emit": { + "outputMapping": [1, 2] + } + }, + "base_schema": { + "names": ["A", "B", "C"], + "struct": { + "types": [{ + "i32": {} + }, { + "i32": {} + }, { + "i32": {} + }] + } + }, + "namedTable": { + "names" : [] + } + } + } + }], + })"; + + ASSERT_OK_AND_ASSIGN(auto buf, internal::SubstraitFromJSON("Plan", substrait_json)); + auto output_schema = schema({field("B", int32()), field("C", int32())}); + auto expected_table = TableFromJSON(output_schema, {R"([ + [1, 10], + [4, 20] + ])"}); + + NamedTableProvider table_provider = [input_table](const std::vector&) { + std::shared_ptr options = + std::make_shared(input_table); + return compute::Declaration("table_source", {}, options, "mock_source"); + }; + + ConversionOptions conversion_options; + conversion_options.named_table_provider = std::move(table_provider); + + CheckRoundTripResult(std::move(output_schema), std::move(expected_table), exec_context, + buf, {}, conversion_options); +} + +TEST(Substrait, FilterRelWithEmit) { +#ifdef _WIN32 + GTEST_SKIP() << "ARROW-16392: Substrait File URI not supported for Windows"; +#endif + compute::ExecContext exec_context; + auto dummy_schema = schema({field("A", int32()), field("B", int32()), + field("C", int32()), field("D", int32())}); + + // creating a dummy dataset using a dummy table + auto input_table = TableFromJSON(dummy_schema, {R"([ + [10, 1, 80, 7], + [20, 2, 70, 6], + [30, 3, 30, 5], + [40, 4, 20, 4], + [40, 5, 40, 3], + [20, 6, 20, 2], + [30, 7, 30, 1] + ])"}); + + std::string substrait_json = R"({ + "relations": [{ + "rel": { + "filter": { + "common": { + "emit": { + "outputMapping": [1, 3] + } + }, + "condition": { + "scalarFunction": { + "functionReference": 0, + "arguments": [{ + "value": { + "selection": { + "directReference": { + "structField": { + "field": 0 + } + }, + "rootReference": { + } + } + } + }, { + "value": { + "selection": { + "directReference": { + "structField": { + "field": 2 + } + }, + "rootReference": { + } + } + } + }], + "output_type": { + "bool": {} + } + } + }, + "input" : { + "read": { + "base_schema": { + "names": ["A", "B", "C", "D"], + "struct": { + "types": [{ + "i32": {} + }, { + "i32": {} + }, { + "i32": {} + },{ + "i32": {} + }] + } + }, + "namedTable": { + "names" : [] + } + } + } + } + } + }], + "extension_uris": [ + { + "extension_uri_anchor": 0, + "uri": ")" + std::string(kSubstraitComparisonFunctionsUri) + + R"(" + } + ], + "extensions": [ + {"extension_function": { + "extension_uri_reference": 0, + "function_anchor": 0, + "name": "equal" + }} + ] + })"; + + ASSERT_OK_AND_ASSIGN(auto buf, internal::SubstraitFromJSON("Plan", substrait_json)); + auto output_schema = schema({field("B", int32()), field("D", int32())}); + auto expected_table = TableFromJSON(output_schema, {R"([ + [3, 5], + [5, 3], + [6, 2], + [7, 1] + ])"}); + NamedTableProvider table_provider = [input_table](const std::vector&) { + std::shared_ptr options = + std::make_shared(input_table); + return compute::Declaration("table_source", {}, options, "mock_source"); + }; + + ConversionOptions conversion_options; + conversion_options.named_table_provider = std::move(table_provider); + + CheckRoundTripResult(std::move(output_schema), std::move(expected_table), exec_context, + buf, {}, conversion_options); +} + +TEST(Substrait, JoinRelEndToEnd) { +#ifdef _WIN32 + GTEST_SKIP() << "ARROW-16392: Substrait File URI not supported for Windows"; +#endif + compute::ExecContext exec_context; + auto left_schema = schema({field("A", int32()), field("B", int32())}); + + auto right_schema = schema({field("X", int32()), field("Y", int32())}); + + // creating a dummy dataset using a dummy table + auto left_table = TableFromJSON(left_schema, {R"([ + [10, 1], + [20, 2], + [30, 3] + ])"}); + + auto right_table = TableFromJSON(right_schema, {R"([ + [10, 11], + [80, 21], + [31, 31] + ])"}); + + std::string substrait_json = R"({ + "relations": [{ + "rel": { + "join": { + "left": { + "read": { + "base_schema": { + "names": ["A", "B"], + "struct": { + "types": [{ + "i32": {} + }, { + "i32": {} + }] + } + }, + "namedTable": { + "names" : ["left"] + } + } + }, + "right": { + "read": { + "base_schema": { + "names": ["X", "Y"], + "struct": { + "types": [{ + "i32": {} + }, { + "i32": {} + }] + } + }, + "namedTable": { + "names" : ["right"] + } + } + }, + "expression": { + "scalarFunction": { + "functionReference": 0, + "arguments": [{ + "value": { + "selection": { + "directReference": { + "structField": { + "field": 0 + } + }, + "rootReference": { + } + } + } + }, { + "value": { + "selection": { + "directReference": { + "structField": { + "field": 0 + } + }, + "rootReference": { + } + } + } + }], + "output_type": { + "bool": {} + } + } + }, + "type": "JOIN_TYPE_INNER" + } + } + }], + "extension_uris": [ + { + "extension_uri_anchor": 0, + "uri": ")" + std::string(kSubstraitComparisonFunctionsUri) + + R"(" + } + ], + "extensions": [ + {"extension_function": { + "extension_uri_reference": 0, + "function_anchor": 0, + "name": "equal" + }} + ] + })"; + + ASSERT_OK_AND_ASSIGN(auto buf, internal::SubstraitFromJSON("Plan", substrait_json)); + + // include these columns for comparison + auto output_schema = schema({ + field("A", int32()), + field("B", int32()), + field("X", int32()), + field("Y", int32()), + }); + + auto expected_table = TableFromJSON(std::move(output_schema), {R"([ + [10, 1, 10, 11] + ])"}); + + NamedTableProvider table_provider = + [left_table, right_table](const std::vector& names) { + std::shared_ptr
output_table; + for (const auto& name : names) { + if (name == "left") { + output_table = left_table; + } + if (name == "right") { + output_table = right_table; + } + } + std::shared_ptr options = + std::make_shared(std::move(output_table)); + return compute::Declaration("table_source", {}, options, "mock_source"); + }; + + ConversionOptions conversion_options; + conversion_options.named_table_provider = std::move(table_provider); + + CheckRoundTripResult(std::move(output_schema), std::move(expected_table), exec_context, + buf, {}, conversion_options); +} + +TEST(Substrait, JoinRelWithEmit) { +#ifdef _WIN32 + GTEST_SKIP() << "ARROW-16392: Substrait File URI not supported for Windows"; +#endif + compute::ExecContext exec_context; + auto left_schema = schema({field("A", int32()), field("B", int32())}); + + auto right_schema = schema({field("X", int32()), field("Y", int32())}); + + // creating a dummy dataset using a dummy table + auto left_table = TableFromJSON(left_schema, {R"([ + [10, 1], + [20, 2], + [30, 3] + ])"}); + + auto right_table = TableFromJSON(right_schema, {R"([ + [10, 11], + [80, 21], + [31, 31] + ])"}); + + std::string substrait_json = R"({ + "relations": [{ + "rel": { + "join": { + "common": { + "emit": { + "outputMapping": [0, 1, 3] + } + }, + "left": { + "read": { + "base_schema": { + "names": ["A", "B"], + "struct": { + "types": [{ + "i32": {} + }, { + "i32": {} + }] + } + }, + "namedTable" : { + "names" : ["left"] + } + } + }, + "right": { + "read": { + "base_schema": { + "names": ["X", "Y"], + "struct": { + "types": [{ + "i32": {} + }, { + "i32": {} + }] + } + }, + "namedTable" : { + "names" : ["right"] + } + } + }, + "expression": { + "scalarFunction": { + "functionReference": 0, + "arguments": [{ + "value": { + "selection": { + "directReference": { + "structField": { + "field": 0 + } + }, + "rootReference": { + } + } + } + }, { + "value": { + "selection": { + "directReference": { + "structField": { + "field": 0 + } + }, + "rootReference": { + } + } + } + }], + "output_type": { + "bool": {} + } + } + }, + "type": "JOIN_TYPE_INNER" + } + } + }], + "extension_uris": [ + { + "extension_uri_anchor": 0, + "uri": ")" + std::string(kSubstraitComparisonFunctionsUri) + + R"(" + } + ], + "extensions": [ + {"extension_function": { + "extension_uri_reference": 0, + "function_anchor": 0, + "name": "equal" + }} + ] + })"; + + ASSERT_OK_AND_ASSIGN(auto buf, internal::SubstraitFromJSON("Plan", substrait_json)); + auto output_schema = schema({ + field("A", int32()), + field("B", int32()), + field("Y", int32()), + }); + + auto expected_table = TableFromJSON(std::move(output_schema), {R"([ + [10, 1, 11] + ])"}); + + NamedTableProvider table_provider = + [left_table, right_table](const std::vector& names) { + std::shared_ptr
output_table; + for (const auto& name : names) { + if (name == "left") { + output_table = left_table; + } + if (name == "right") { + output_table = right_table; + } + } + std::shared_ptr options = + std::make_shared(std::move(output_table)); + return compute::Declaration("table_source", {}, options, "mock_source"); + }; + + ConversionOptions conversion_options; + conversion_options.named_table_provider = std::move(table_provider); + + CheckRoundTripResult(std::move(output_schema), std::move(expected_table), exec_context, + buf, {}, conversion_options); +} + +TEST(Substrait, AggregateRel) { +#ifdef _WIN32 + GTEST_SKIP() << "ARROW-16392: Substrait File URI not supported for Windows"; +#endif + compute::ExecContext exec_context; + auto dummy_schema = + schema({field("A", int32()), field("B", int32()), field("C", int32())}); + + // creating a dummy dataset using a dummy table + auto input_table = TableFromJSON(dummy_schema, {R"([ + [10, 1, 80], + [20, 2, 70], + [30, 3, 30], + [40, 4, 20], + [40, 5, 40], + [20, 6, 20], + [30, 7, 30] + ])"}); + + std::string substrait_json = R"({ + "relations": [{ + "rel": { + "aggregate": { + "input": { + "read": { + "base_schema": { + "names": ["A", "B", "C"], + "struct": { + "types": [{ + "i32": {} + }, { + "i32": {} + }, { + "i32": {} + }] + } + }, + "namedTable" : { + "names": [] + } + } + }, + "groupings": [{ + "groupingExpressions": [{ + "selection": { + "directReference": { + "structField": { + "field": 0 + } + } + } + }] + }], + "measures": [{ + "measure": { + "functionReference": 0, + "arguments": [{ + "value": { + "selection": { + "directReference": { + "structField": { + "field": 2 + } + } + } + } + }], + "sorts": [], + "phase": "AGGREGATION_PHASE_INITIAL_TO_RESULT", + "invocation": "AGGREGATION_INVOCATION_ALL", + "outputType": { + "i64": {} + } + } + }] + } + } + }], + "extensionUris": [{ + "extension_uri_anchor": 0, + "uri": "https://github.com/substrait-io/substrait/blob/main/extensions/functions_arithmetic.yaml" + }], + "extensions": [{ + "extension_function": { + "extension_uri_reference": 0, + "function_anchor": 0, + "name": "sum" + } + }], + })"; + + ASSERT_OK_AND_ASSIGN(auto buf, internal::SubstraitFromJSON("Plan", substrait_json)); + auto output_schema = schema({field("aggregates", int64()), field("keys", int32())}); + auto expected_table = TableFromJSON(output_schema, {R"([ + [80, 10], + [90, 20], + [60, 30], + [60, 40] + ])"}); + + NamedTableProvider table_provider = [input_table](const std::vector&) { + std::shared_ptr options = + std::make_shared(input_table); + return compute::Declaration("table_source", {}, options, "mock_source"); + }; + + ConversionOptions conversion_options; + conversion_options.named_table_provider = std::move(table_provider); + + CheckRoundTripResult(std::move(output_schema), std::move(expected_table), exec_context, + buf, {}, conversion_options); +} + +TEST(Substrait, AggregateRelEmit) { +#ifdef _WIN32 + GTEST_SKIP() << "ARROW-16392: Substrait File URI not supported for Windows"; +#endif + compute::ExecContext exec_context; + auto dummy_schema = + schema({field("A", int32()), field("B", int32()), field("C", int32())}); + + // creating a dummy dataset using a dummy table + auto input_table = TableFromJSON(dummy_schema, {R"([ + [10, 1, 80], + [20, 2, 70], + [30, 3, 30], + [40, 4, 20], + [40, 5, 40], + [20, 6, 20], + [30, 7, 30] + ])"}); + + // TODO: fixme https://issues.apache.org/jira/browse/ARROW-17484 + std::string substrait_json = R"({ + "relations": [{ + "rel": { + "aggregate": { + "common": { + "emit": { + "outputMapping": [0] + } + }, + "input": { + "read": { + "base_schema": { + "names": ["A", "B", "C"], + "struct": { + "types": [{ + "i32": {} + }, { + "i32": {} + }, { + "i32": {} + }] + } + }, + "namedTable" : { + "names" : [] + } + } + }, + "groupings": [{ + "groupingExpressions": [{ + "selection": { + "directReference": { + "structField": { + "field": 0 + } + } + } + }] + }], + "measures": [{ + "measure": { + "functionReference": 0, + "arguments": [{ + "value": { + "selection": { + "directReference": { + "structField": { + "field": 2 + } + } + } + } + }], + "sorts": [], + "phase": "AGGREGATION_PHASE_INITIAL_TO_RESULT", + "invocation": "AGGREGATION_INVOCATION_ALL", + "outputType": { + "i64": {} + } + } + }] + } + } + }], + "extensionUris": [{ + "extension_uri_anchor": 0, + "uri": "https://github.com/substrait-io/substrait/blob/main/extensions/functions_arithmetic.yaml" + }], + "extensions": [{ + "extension_function": { + "extension_uri_reference": 0, + "function_anchor": 0, + "name": "sum" + } + }], + })"; + + ASSERT_OK_AND_ASSIGN(auto buf, internal::SubstraitFromJSON("Plan", substrait_json)); + auto output_schema = schema({field("aggregates", int64())}); + auto expected_table = TableFromJSON(output_schema, {R"([ + [80], + [90], + [60], + [60] + ])"}); + + NamedTableProvider table_provider = [input_table](const std::vector&) { + std::shared_ptr options = + std::make_shared(input_table); + return compute::Declaration("table_source", {}, options, "mock_source"); + }; + + ConversionOptions conversion_options; + conversion_options.named_table_provider = std::move(table_provider); + + CheckRoundTripResult(std::move(output_schema), std::move(expected_table), exec_context, + buf, {}, conversion_options); +} + } // namespace engine } // namespace arrow