From 6f4a3f64b209bc78140611048f703d4e5827c4d0 Mon Sep 17 00:00:00 2001 From: Vibhatha Abeykoon Date: Tue, 16 Aug 2022 16:06:15 +0530 Subject: [PATCH 01/22] feat(project): adding project test case for substrait with minor changes for emit --- .../engine/substrait/relation_internal.cc | 75 ++++++-- .../engine/substrait/relation_internal.h | 2 + cpp/src/arrow/engine/substrait/serde_test.cc | 162 +++++++++++++++++- 3 files changed, 225 insertions(+), 14 deletions(-) diff --git a/cpp/src/arrow/engine/substrait/relation_internal.cc b/cpp/src/arrow/engine/substrait/relation_internal.cc index c5d212c8c2f..48fd6cea336 100644 --- a/cpp/src/arrow/engine/substrait/relation_internal.cc +++ b/cpp/src/arrow/engine/substrait/relation_internal.cc @@ -41,12 +41,38 @@ using internal::make_unique; namespace engine { +template +bool HasEmit(const RelMessage& rel) { + if (rel.has_common()) { + return rel.common().has_emit(); + } + return false; +} + +template +Result GetEmitExpression(const RelMessage& rel, + const std::shared_ptr& schema) { + const auto& emit = rel.common().emit(); + int emit_size = emit.output_mapping_size(); + std::vector proj_names(emit_size); + std::vector proj_field_refs(emit_size); + for (int i = 0; i < emit_size; i++) { + int32_t map_id = emit.output_mapping(i); + auto field = schema->field(map_id); + auto field_name = field->name(); + proj_names[i] = field_name; + proj_field_refs[i] = compute::field_ref(field_name); + } + auto expr = compute::project(proj_field_refs, proj_names); + return expr.Bind(*schema); +} + template Status CheckRelCommon(const RelMessage& rel) { if (rel.has_common()) { - if (rel.common().has_emit()) { - return Status::NotImplemented("substrait::RelCommon::Emit"); - } + // if (rel.common().has_emit()) { + // return Status::NotImplemented("substrait::RelCommon::Emit"); + // } if (rel.common().has_hint()) { return Status::NotImplemented("substrait::RelCommon::Hint"); } @@ -216,12 +242,24 @@ Result FromProto(const substrait::Rel& rel, const ExtensionSet& std::move(filesystem), std::move(files), std::move(format), {})); - ARROW_ASSIGN_OR_RAISE(auto ds, ds_factory->Finish(std::move(base_schema))); - + auto num_columns = static_cast(base_schema->fields().size()); + ARROW_ASSIGN_OR_RAISE(auto ds, ds_factory->Finish(base_schema)); + + // if(HasEmit(read)) { + // ARROW_ASSIGN_OR_RAISE(auto emit_expression, GetEmitExpression(read, + // base_schema)); return DeclarationInfo{ + // compute::Declaration::Sequence({ + // {"scan", dataset::ScanNodeOptions{std::move(ds), + // std::move(scan_options)}}, + // {"project", compute::ProjectNodeOptions{{emit_expression}}} + // }), + // num_columns, std::move(base_schema)}; + // } else { return DeclarationInfo{ compute::Declaration{ "scan", dataset::ScanNodeOptions{std::move(ds), std::move(scan_options)}}, - num_columns}; + num_columns, std::move(base_schema)}; + //} } case substrait::Rel::RelTypeCase::kFilter: { @@ -240,18 +278,29 @@ Result FromProto(const substrait::Rel& rel, const ExtensionSet& ARROW_ASSIGN_OR_RAISE(auto condition, FromProto(filter.condition(), ext_set, conversion_options)); + // if(HasEmit(filter)) { + // ARROW_ASSIGN_OR_RAISE(auto emit_expression, GetEmitExpression(filter, + // input.output_schema)); return DeclarationInfo{ + // compute::Declaration::Sequence({ + // std::move(input.declaration), + // {"filter", compute::FilterNodeOptions{std::move(condition)}}, + // {"project", compute::ProjectNodeOptions{{emit_expression}}} + // }), + // input.num_columns, + // input.output_schema}; + // } else { return DeclarationInfo{ compute::Declaration::Sequence({ std::move(input.declaration), {"filter", compute::FilterNodeOptions{std::move(condition)}}, }), - input.num_columns}; + input.num_columns, input.output_schema}; + //} } case substrait::Rel::RelTypeCase::kProject: { const auto& project = rel.project(); RETURN_NOT_OK(CheckRelCommon(project)); - if (!project.has_input()) { return Status::Invalid("substrait::ProjectRel with no input relation"); } @@ -272,12 +321,13 @@ Result FromProto(const substrait::Rel& rel, const ExtensionSet& } auto num_columns = static_cast(expressions.size()); + // TODO: get schema and add emit return DeclarationInfo{ compute::Declaration::Sequence({ std::move(input.declaration), {"project", compute::ProjectNodeOptions{std::move(expressions)}}, }), - num_columns}; + num_columns, arrow::schema({})}; } case substrait::Rel::RelTypeCase::kJoin: { @@ -363,7 +413,8 @@ Result FromProto(const substrait::Rel& rel, const ExtensionSet& auto num_columns = left.num_columns + right.num_columns; join_dec.inputs.emplace_back(std::move(left.declaration)); join_dec.inputs.emplace_back(std::move(right.declaration)); - return DeclarationInfo{std::move(join_dec), num_columns}; + // TODO: add schema and get emit rel + return DeclarationInfo{std::move(join_dec), num_columns, arrow::schema({})}; } case substrait::Rel::RelTypeCase::kAggregate: { const auto& aggregate = rel.aggregate(); @@ -421,12 +472,12 @@ Result FromProto(const substrait::Rel& rel, const ExtensionSet& return Status::Invalid("substrait::AggregateFunction not provided"); } } - + /// TODO: add emit and extract schema return DeclarationInfo{ compute::Declaration::Sequence( {std::move(input.declaration), {"aggregate", compute::AggregateNodeOptions{aggregates, keys}}}), - static_cast(aggregates.size())}; + static_cast(aggregates.size()), arrow::schema({})}; } default: diff --git a/cpp/src/arrow/engine/substrait/relation_internal.h b/cpp/src/arrow/engine/substrait/relation_internal.h index 778d1e5bc01..d8c14ea14a6 100644 --- a/cpp/src/arrow/engine/substrait/relation_internal.h +++ b/cpp/src/arrow/engine/substrait/relation_internal.h @@ -38,6 +38,8 @@ struct DeclarationInfo { /// The number of columns returned by the declaration. int num_columns; + + std::shared_ptr output_schema; }; /// \brief Convert a Substrait Rel object to an Acero declaration diff --git a/cpp/src/arrow/engine/substrait/serde_test.cc b/cpp/src/arrow/engine/substrait/serde_test.cc index 9b6c3f715f7..45ce3d5d7c1 100644 --- a/cpp/src/arrow/engine/substrait/serde_test.cc +++ b/cpp/src/arrow/engine/substrait/serde_test.cc @@ -24,12 +24,10 @@ #include "arrow/dataset/file_base.h" #include "arrow/dataset/file_ipc.h" #include "arrow/dataset/file_parquet.h" - #include "arrow/dataset/plan.h" #include "arrow/dataset/scanner.h" #include "arrow/engine/substrait/extension_types.h" #include "arrow/engine/substrait/serde.h" - #include "arrow/engine/substrait/util.h" #include "arrow/filesystem/localfs.h" @@ -95,6 +93,31 @@ Result> GetTableFromPlan( return arrow::Table::FromRecordBatchReader(sink_reader.get()); } +Status WriteParquetData(const std::string& path, + const std::shared_ptr file_system, + const std::shared_ptr input) { + EXPECT_OK_AND_ASSIGN(auto buffer_writer, file_system->OpenOutputStream(path)); + PARQUET_THROW_NOT_OK(parquet::arrow::WriteTable(*input, arrow::default_memory_pool(), + buffer_writer, /*chunk_size*/ 1)); + return buffer_writer->Close(); +} + +Result> GetTableFromPlan( + std::shared_ptr& plan, compute::Declaration& declarations, + arrow::AsyncGenerator>& sink_gen, + compute::ExecContext& exec_context, std::shared_ptr& output_schema) { + ARROW_ASSIGN_OR_RAISE(auto decl, declarations.AddToPlan(plan.get())); + + RETURN_NOT_OK(decl->Validate()); + + std::shared_ptr sink_reader = compute::MakeGeneratorReader( + output_schema, std::move(sink_gen), exec_context.memory_pool()); + + RETURN_NOT_OK(plan->Validate()); + RETURN_NOT_OK(plan->StartProducing()); + return arrow::Table::FromRecordBatchReader(sink_reader.get()); +} + class NullSinkNodeConsumer : public compute::SinkNodeConsumer { public: Status Init(const std::shared_ptr&, compute::BackpressureControl*) override { @@ -2103,5 +2126,140 @@ TEST(Substrait, BasicPlanRoundTrippingEndToEnd) { EXPECT_TRUE(expected_table->Equals(*rnd_trp_table)); } +TEST(Substrait, ProjectRel) { + compute::ExecContext exec_context; + auto dummy_schema = + schema({field("A", int32()), field("B", int32()), field("C", int32())}); + + // creating a dummy dataset using a dummy table + auto table = TableFromJSON(dummy_schema, {R"([ + [1, 1, 10], + [3, 4, 20] + ])"}); + + auto format = std::make_shared(); + auto filesystem = std::make_shared(); + + const std::string file_name = "serde_project_test.parquet"; + + ASSERT_OK_AND_ASSIGN(auto tempdir, + arrow::internal::TemporaryDir::Make("substrait_project_tempdir")); + ASSERT_OK_AND_ASSIGN(auto file_path, tempdir->path().Join(file_name)); + std::string file_path_str = file_path.ToString(); + + std::string toReplace("/T//"); + size_t pos = file_path_str.find(toReplace); + file_path_str.replace(pos, toReplace.length(), "/T/"); + + ARROW_EXPECT_OK(WriteParquetData(file_path_str, filesystem, table)); + + std::string substrait_file_uri = "file://" + file_path_str; + + std::string substrait_json = R"({ + "relations": [{ + "rel": { + "project": { + "expressions": [ + {"scalarFunction": { + "functionReference": 0, + "arguments": [{ + "value": { + "selection": { + "directReference": { + "structField": { + "field": 0 + } + }, + "rootReference": { + } + } + } + }, { + "value": { + "selection": { + "directReference": { + "structField": { + "field": 1 + } + }, + "rootReference": { + } + } + } + }] + } + }, + ], + "input" : { + "read": { + "base_schema": { + "names": ["A", "B", "C"], + "struct": { + "types": [{ + "i32": {} + }, { + "i32": {} + }, { + "i32": {} + }] + } + }, + "local_files": { + "items": [ + { + "uri_file": ")" + + substrait_file_uri + + R"(", + "parquet": {} + } + ] + } + } + } + } + } + }], + "extension_uris": [ + { + "extension_uri_anchor": 0, + "uri": ")" + substrait::default_extension_types_uri() + + R"(" + } + ], + "extensions": [ + {"extension_function": { + "extension_uri_reference": 0, + "function_anchor": 0, + "name": "add" + }} + ] + })"; + ASSERT_OK_AND_ASSIGN(auto buf, internal::SubstraitFromJSON("Plan", substrait_json)); + for (auto sp_ext_id_reg : + {std::shared_ptr(), substrait::MakeExtensionIdRegistry()}) { + ExtensionIdRegistry* ext_id_reg = sp_ext_id_reg.get(); + ExtensionSet ext_set(ext_id_reg); + ASSERT_OK_AND_ASSIGN(auto sink_decls, + DeserializePlans( + *buf, [] { return kNullConsumer; }, ext_id_reg, &ext_set)); + auto other_declrs = sink_decls[0].inputs[0].get(); + arrow::AsyncGenerator> sink_gen; + auto sink_node_options = compute::SinkNodeOptions{&sink_gen}; + auto sink_declaration = compute::Declaration({"sink", sink_node_options, "e"}); + auto declarations = compute::Declaration::Sequence({*other_declrs, sink_declaration}); + ASSERT_OK_AND_ASSIGN(auto acero_plan, compute::ExecPlan::Make(&exec_context)); + auto output_schema = schema({field("A", int32()), field("B", int32()), + field("C", int32()), field("ADD", int32())}); + auto expected_table = TableFromJSON(output_schema, {R"([ + [1, 1, 10, 2], + [3, 4, 20, 7] + ])"}); + ASSERT_OK_AND_ASSIGN(auto output_table, + GetTableFromPlan(acero_plan, declarations, sink_gen, + exec_context, output_schema)); + EXPECT_TRUE(expected_table->Equals(*output_table)); + } +} + } // namespace engine } // namespace arrow From 9367049ff794618189bfe96e0b78018eda56d460 Mon Sep 17 00:00:00 2001 From: Vibhatha Abeykoon Date: Tue, 16 Aug 2022 21:38:31 +0530 Subject: [PATCH 02/22] feat(emit): initial version of emit with project added --- .../engine/substrait/relation_internal.cc | 138 ++++++++++------- cpp/src/arrow/engine/substrait/serde_test.cc | 139 ++++++++++++++++++ 2 files changed, 225 insertions(+), 52 deletions(-) diff --git a/cpp/src/arrow/engine/substrait/relation_internal.cc b/cpp/src/arrow/engine/substrait/relation_internal.cc index 48fd6cea336..575cd8732dd 100644 --- a/cpp/src/arrow/engine/substrait/relation_internal.cc +++ b/cpp/src/arrow/engine/substrait/relation_internal.cc @@ -50,21 +50,20 @@ bool HasEmit(const RelMessage& rel) { } template -Result GetEmitExpression(const RelMessage& rel, - const std::shared_ptr& schema) { +Result> GetEmitExpression( + const RelMessage& rel, const std::shared_ptr& schema) { const auto& emit = rel.common().emit(); int emit_size = emit.output_mapping_size(); - std::vector proj_names(emit_size); + // std::vector proj_names(emit_size); std::vector proj_field_refs(emit_size); for (int i = 0; i < emit_size; i++) { int32_t map_id = emit.output_mapping(i); - auto field = schema->field(map_id); - auto field_name = field->name(); - proj_names[i] = field_name; - proj_field_refs[i] = compute::field_ref(field_name); + // auto field = schema->field(map_id); + // auto field_name = field->name(); + // proj_names[i] = field_name; + proj_field_refs[i] = compute::field_ref(FieldRef(map_id)); } - auto expr = compute::project(proj_field_refs, proj_names); - return expr.Bind(*schema); + return std::move(proj_field_refs); } template @@ -245,21 +244,21 @@ Result FromProto(const substrait::Rel& rel, const ExtensionSet& auto num_columns = static_cast(base_schema->fields().size()); ARROW_ASSIGN_OR_RAISE(auto ds, ds_factory->Finish(base_schema)); - // if(HasEmit(read)) { - // ARROW_ASSIGN_OR_RAISE(auto emit_expression, GetEmitExpression(read, - // base_schema)); return DeclarationInfo{ - // compute::Declaration::Sequence({ - // {"scan", dataset::ScanNodeOptions{std::move(ds), - // std::move(scan_options)}}, - // {"project", compute::ProjectNodeOptions{{emit_expression}}} - // }), - // num_columns, std::move(base_schema)}; - // } else { - return DeclarationInfo{ - compute::Declaration{ - "scan", dataset::ScanNodeOptions{std::move(ds), std::move(scan_options)}}, - num_columns, std::move(base_schema)}; - //} + if (HasEmit(read)) { + ARROW_ASSIGN_OR_RAISE(auto emit_expressions, + GetEmitExpression(read, base_schema)); + return DeclarationInfo{ + compute::Declaration::Sequence( + {{"scan", + dataset::ScanNodeOptions{std::move(ds), std::move(scan_options)}}, + {"project", compute::ProjectNodeOptions{std::move(emit_expressions)}}}), + num_columns, std::move(base_schema)}; + } else { + return DeclarationInfo{ + compute::Declaration{ + "scan", dataset::ScanNodeOptions{std::move(ds), std::move(scan_options)}}, + num_columns, std::move(base_schema)}; + } } case substrait::Rel::RelTypeCase::kFilter: { @@ -278,24 +277,23 @@ Result FromProto(const substrait::Rel& rel, const ExtensionSet& ARROW_ASSIGN_OR_RAISE(auto condition, FromProto(filter.condition(), ext_set, conversion_options)); - // if(HasEmit(filter)) { - // ARROW_ASSIGN_OR_RAISE(auto emit_expression, GetEmitExpression(filter, - // input.output_schema)); return DeclarationInfo{ - // compute::Declaration::Sequence({ - // std::move(input.declaration), - // {"filter", compute::FilterNodeOptions{std::move(condition)}}, - // {"project", compute::ProjectNodeOptions{{emit_expression}}} - // }), - // input.num_columns, - // input.output_schema}; - // } else { - return DeclarationInfo{ - compute::Declaration::Sequence({ - std::move(input.declaration), - {"filter", compute::FilterNodeOptions{std::move(condition)}}, - }), - input.num_columns, input.output_schema}; - //} + if (HasEmit(filter)) { + ARROW_ASSIGN_OR_RAISE(auto emit_expressions, + GetEmitExpression(filter, input.output_schema)); + return DeclarationInfo{ + compute::Declaration::Sequence( + {std::move(input.declaration), + {"filter", compute::FilterNodeOptions{std::move(condition)}}, + {"project", compute::ProjectNodeOptions{std::move(emit_expressions)}}}), + input.num_columns, input.output_schema}; + } else { + return DeclarationInfo{ + compute::Declaration::Sequence({ + std::move(input.declaration), + {"filter", compute::FilterNodeOptions{std::move(condition)}}, + }), + input.num_columns, input.output_schema}; + } } case substrait::Rel::RelTypeCase::kProject: { @@ -314,20 +312,56 @@ Result FromProto(const substrait::Rel& rel, const ExtensionSet& for (int i = 0; i < input.num_columns; i++) { expressions.emplace_back(compute::field_ref(FieldRef(i))); } + std::vector> new_fields(project.expressions().size()); + int i = 0; + auto project_schema = input.output_schema; for (const auto& expr : project.expressions()) { - expressions.emplace_back(); - ARROW_ASSIGN_OR_RAISE(expressions.back(), + ARROW_ASSIGN_OR_RAISE(compute::Expression des_expr, FromProto(expr, ext_set, conversion_options)); + auto bound_expr = des_expr.Bind(*input.output_schema); + if (auto* expr_call = bound_expr->call()) { + new_fields[i] = field(expr_call->function_name, + expr_call->kernel->signature->out_type().type()); + } else if (auto* field_ref = des_expr.field_ref()) { + ARROW_ASSIGN_OR_RAISE(FieldPath field_path, + field_ref->FindOne(*input.output_schema)); + ARROW_ASSIGN_OR_RAISE(new_fields[i], field_path.Get(*input.output_schema)); + } else if (auto* literal = des_expr.literal()) { + new_fields[i] = + field("field_" + std::to_string(input.num_columns + i), literal->type()); + } + i++; + expressions.emplace_back(des_expr); + } + while (!new_fields.empty()) { + auto field = new_fields.back(); + ARROW_ASSIGN_OR_RAISE( + project_schema, + project_schema->AddField( + input.num_columns + static_cast(new_fields.size()) - 1, + std::move(field))); + new_fields.pop_back(); } - auto num_columns = static_cast(expressions.size()); - // TODO: get schema and add emit - return DeclarationInfo{ - compute::Declaration::Sequence({ - std::move(input.declaration), - {"project", compute::ProjectNodeOptions{std::move(expressions)}}, - }), - num_columns, arrow::schema({})}; + if (HasEmit(project)) { + ARROW_ASSIGN_OR_RAISE(auto emit_expressions, + GetEmitExpression(project, project_schema)); + return DeclarationInfo{ + compute::Declaration::Sequence( + {std::move(input.declaration), + {"project", compute::ProjectNodeOptions{std::move(expressions)}, + "project"}, + {"project", compute::ProjectNodeOptions{std::move(emit_expressions)}, + "emit"}}), + num_columns, project_schema}; + } else { + return DeclarationInfo{ + compute::Declaration::Sequence({ + std::move(input.declaration), + {"project", compute::ProjectNodeOptions{std::move(expressions)}}, + }), + num_columns, project_schema}; + } } case substrait::Rel::RelTypeCase::kJoin: { diff --git a/cpp/src/arrow/engine/substrait/serde_test.cc b/cpp/src/arrow/engine/substrait/serde_test.cc index 45ce3d5d7c1..6e6bf8cd36f 100644 --- a/cpp/src/arrow/engine/substrait/serde_test.cc +++ b/cpp/src/arrow/engine/substrait/serde_test.cc @@ -2261,5 +2261,144 @@ TEST(Substrait, ProjectRel) { } } +TEST(Substrait, ProjectRelWithEmit) { + compute::ExecContext exec_context; + auto dummy_schema = + schema({field("A", int32()), field("B", int32()), field("C", int32())}); + + // creating a dummy dataset using a dummy table + auto table = TableFromJSON(dummy_schema, {R"([ + [1, 1, 10], + [3, 4, 20] + ])"}); + + auto format = std::make_shared(); + auto filesystem = std::make_shared(); + + const std::string file_name = "serde_project_emit_test.parquet"; + + ASSERT_OK_AND_ASSIGN(auto tempdir, + arrow::internal::TemporaryDir::Make("substrait_project_tempdir")); + ASSERT_OK_AND_ASSIGN(auto file_path, tempdir->path().Join(file_name)); + std::string file_path_str = file_path.ToString(); + + std::string toReplace("/T//"); + size_t pos = file_path_str.find(toReplace); + file_path_str.replace(pos, toReplace.length(), "/T/"); + + ARROW_EXPECT_OK(WriteParquetData(file_path_str, filesystem, table)); + + std::string substrait_file_uri = "file://" + file_path_str; + + std::string substrait_json = R"({ + "relations": [{ + "rel": { + "project": { + "common": { + "emit": { + "outputMapping": [0, 2] + } + }, + "expressions": [ + {"scalarFunction": { + "functionReference": 0, + "arguments": [{ + "value": { + "selection": { + "directReference": { + "structField": { + "field": 0 + } + }, + "rootReference": { + } + } + } + }, { + "value": { + "selection": { + "directReference": { + "structField": { + "field": 1 + } + }, + "rootReference": { + } + } + } + }] + } + }, + ], + "input" : { + "read": { + "base_schema": { + "names": ["A", "B", "C"], + "struct": { + "types": [{ + "i32": {} + }, { + "i32": {} + }, { + "i32": {} + }] + } + }, + "local_files": { + "items": [ + { + "uri_file": ")" + + substrait_file_uri + + R"(", + "parquet": {} + } + ] + } + } + } + } + } + }], + "extension_uris": [ + { + "extension_uri_anchor": 0, + "uri": ")" + substrait::default_extension_types_uri() + + R"(" + } + ], + "extensions": [ + {"extension_function": { + "extension_uri_reference": 0, + "function_anchor": 0, + "name": "add" + }} + ] + })"; + ASSERT_OK_AND_ASSIGN(auto buf, internal::SubstraitFromJSON("Plan", substrait_json)); + for (auto sp_ext_id_reg : + {std::shared_ptr(), substrait::MakeExtensionIdRegistry()}) { + ExtensionIdRegistry* ext_id_reg = sp_ext_id_reg.get(); + ExtensionSet ext_set(ext_id_reg); + ASSERT_OK_AND_ASSIGN(auto sink_decls, + DeserializePlans( + *buf, [] { return kNullConsumer; }, ext_id_reg, &ext_set)); + auto other_declrs = sink_decls[0].inputs[0].get(); + arrow::AsyncGenerator> sink_gen; + auto sink_node_options = compute::SinkNodeOptions{&sink_gen}; + auto sink_declaration = compute::Declaration({"sink", sink_node_options, "e"}); + auto declarations = compute::Declaration::Sequence({*other_declrs, sink_declaration}); + ASSERT_OK_AND_ASSIGN(auto acero_plan, compute::ExecPlan::Make(&exec_context)); + auto output_schema = schema({field("A", int32()), field("C", int32())}); + auto expected_table = TableFromJSON(output_schema, {R"([ + [1, 10], + [3, 20] + ])"}); + ASSERT_OK_AND_ASSIGN(auto output_table, + GetTableFromPlan(acero_plan, declarations, sink_gen, + exec_context, output_schema)); + EXPECT_TRUE(expected_table->Equals(*output_table)); + } +} + } // namespace engine } // namespace arrow From a83deec754c704195d4f71eb9e81a97982d35644 Mon Sep 17 00:00:00 2001 From: Vibhatha Abeykoon Date: Wed, 17 Aug 2022 07:22:05 +0530 Subject: [PATCH 03/22] fix(test): fixing the test feature --- cpp/src/arrow/engine/substrait/serde_test.cc | 25 ++++++++++++-------- 1 file changed, 15 insertions(+), 10 deletions(-) diff --git a/cpp/src/arrow/engine/substrait/serde_test.cc b/cpp/src/arrow/engine/substrait/serde_test.cc index 6e6bf8cd36f..ff98296e500 100644 --- a/cpp/src/arrow/engine/substrait/serde_test.cc +++ b/cpp/src/arrow/engine/substrait/serde_test.cc @@ -2127,6 +2127,9 @@ TEST(Substrait, BasicPlanRoundTrippingEndToEnd) { } TEST(Substrait, ProjectRel) { +#ifdef _WIN32 + GTEST_SKIP() << "ARROW-16392: Substrait File URI not supported for Windows"; +#endif compute::ExecContext exec_context; auto dummy_schema = schema({field("A", int32()), field("B", int32()), field("C", int32())}); @@ -2159,8 +2162,8 @@ TEST(Substrait, ProjectRel) { "relations": [{ "rel": { "project": { - "expressions": [ - {"scalarFunction": { + "expressions": [{ + "scalarFunction": { "functionReference": 0, "arguments": [{ "value": { @@ -2174,7 +2177,8 @@ TEST(Substrait, ProjectRel) { } } } - }, { + }, + { "value": { "selection": { "directReference": { @@ -2261,7 +2265,7 @@ TEST(Substrait, ProjectRel) { } } -TEST(Substrait, ProjectRelWithEmit) { +TEST(Substrait, ProjectRelOnFunctionWithEmit) { compute::ExecContext exec_context; auto dummy_schema = schema({field("A", int32()), field("B", int32()), field("C", int32())}); @@ -2296,11 +2300,11 @@ TEST(Substrait, ProjectRelWithEmit) { "project": { "common": { "emit": { - "outputMapping": [0, 2] + "outputMapping": [0, 2, 3] } }, - "expressions": [ - {"scalarFunction": { + "expressions": [{ + "scalarFunction": { "functionReference": 0, "arguments": [{ "value": { @@ -2388,10 +2392,11 @@ TEST(Substrait, ProjectRelWithEmit) { auto sink_declaration = compute::Declaration({"sink", sink_node_options, "e"}); auto declarations = compute::Declaration::Sequence({*other_declrs, sink_declaration}); ASSERT_OK_AND_ASSIGN(auto acero_plan, compute::ExecPlan::Make(&exec_context)); - auto output_schema = schema({field("A", int32()), field("C", int32())}); + auto output_schema = + schema({field("A", int32()), field("C", int32()), field("add", int32())}); auto expected_table = TableFromJSON(output_schema, {R"([ - [1, 10], - [3, 20] + [1, 10, 2], + [3, 20, 7] ])"}); ASSERT_OK_AND_ASSIGN(auto output_table, GetTableFromPlan(acero_plan, declarations, sink_gen, From e55b2c83185b376106251c5902707eccf726fe71 Mon Sep 17 00:00:00 2001 From: Vibhatha Abeykoon Date: Wed, 17 Aug 2022 10:51:28 +0530 Subject: [PATCH 04/22] feat(data-gen): adding data generator script wip --- cpp/src/arrow/engine/substrait/serde_test.cc | 144 ++++++++++++++++++- 1 file changed, 143 insertions(+), 1 deletion(-) diff --git a/cpp/src/arrow/engine/substrait/serde_test.cc b/cpp/src/arrow/engine/substrait/serde_test.cc index ff98296e500..814967b2deb 100644 --- a/cpp/src/arrow/engine/substrait/serde_test.cc +++ b/cpp/src/arrow/engine/substrait/serde_test.cc @@ -105,7 +105,7 @@ Status WriteParquetData(const std::string& path, Result> GetTableFromPlan( std::shared_ptr& plan, compute::Declaration& declarations, arrow::AsyncGenerator>& sink_gen, - compute::ExecContext& exec_context, std::shared_ptr& output_schema) { + compute::ExecContext& exec_context, const std::shared_ptr& output_schema) { ARROW_ASSIGN_OR_RAISE(auto decl, declarations.AddToPlan(plan.get())); RETURN_NOT_OK(decl->Validate()); @@ -194,6 +194,75 @@ inline compute::Expression UseBoringRefs(const compute::Expression& expr) { return compute::Expression{std::move(modified_call)}; } +// TODO: complete this interface +struct TempDataGenerator { + TempDataGenerator(const std::shared_ptr
input_data, + const std::string& file_prefix, const std::string& temp_dir_prefix) {} + + Status operator()() { + auto dummy_schema = + schema({field("A", int32()), field("B", int32()), field("C", int32())}); + + // creating a dummy dataset using a dummy table + auto table = TableFromJSON(dummy_schema, {R"([ + [1, 1, 10], + [3, 4, 20] + ])"}); + + auto format = std::make_shared(); + auto filesystem = std::make_shared(); + + const std::string file_name = "serde_read_emit_test.parquet"; + + ARROW_ASSIGN_OR_RAISE(auto tempdir, + arrow::internal::TemporaryDir::Make("substrait_read_tempdir")); + ARROW_ASSIGN_OR_RAISE(auto file_path, tempdir->path().Join(file_name)); + std::string file_path_str = file_path.ToString(); + + std::string toReplace("/T//"); + size_t pos = file_path_str.find(toReplace); + file_path_str.replace(pos, toReplace.length(), "/T/"); + + ARROW_EXPECT_OK(WriteParquetData(file_path_str, filesystem, table)); + return Status::OK(); + } +}; + +struct EmitValidate { + EmitValidate(const std::shared_ptr output_schema, + const std::shared_ptr
expected_table, + compute::ExecContext& exec_context, std::shared_ptr& buf) + : output_schema(output_schema), + expected_table(expected_table), + exec_context(exec_context), + buf(buf) {} + void operator()() { + for (auto sp_ext_id_reg : + {std::shared_ptr(), substrait::MakeExtensionIdRegistry()}) { + ExtensionIdRegistry* ext_id_reg = sp_ext_id_reg.get(); + ExtensionSet ext_set(ext_id_reg); + ASSERT_OK_AND_ASSIGN(auto sink_decls, + DeserializePlans( + *buf, [] { return kNullConsumer; }, ext_id_reg, &ext_set)); + auto other_declrs = sink_decls[0].inputs[0].get(); + arrow::AsyncGenerator> sink_gen; + auto sink_node_options = compute::SinkNodeOptions{&sink_gen}; + auto sink_declaration = compute::Declaration({"sink", sink_node_options, "e"}); + auto declarations = + compute::Declaration::Sequence({*other_declrs, sink_declaration}); + ASSERT_OK_AND_ASSIGN(auto acero_plan, compute::ExecPlan::Make(&exec_context)); + ASSERT_OK_AND_ASSIGN(auto output_table, + GetTableFromPlan(acero_plan, declarations, sink_gen, + exec_context, output_schema)); + EXPECT_TRUE(expected_table->Equals(*output_table)); + } + } + std::shared_ptr output_schema; + std::shared_ptr
expected_table; + compute::ExecContext exec_context; + std::shared_ptr buf; +}; + TEST(Substrait, SupportedTypes) { auto ExpectEq = [](util::string_view json, std::shared_ptr expected_type) { ARROW_SCOPED_TRACE(json); @@ -2405,5 +2474,78 @@ TEST(Substrait, ProjectRelOnFunctionWithEmit) { } } +TEST(Substrait, ReadRelWithEmit) { + compute::ExecContext exec_context; + auto dummy_schema = + schema({field("A", int32()), field("B", int32()), field("C", int32())}); + + // creating a dummy dataset using a dummy table + auto table = TableFromJSON(dummy_schema, {R"([ + [1, 1, 10], + [3, 4, 20] + ])"}); + + auto format = std::make_shared(); + auto filesystem = std::make_shared(); + + const std::string file_name = "serde_read_emit_test.parquet"; + + ASSERT_OK_AND_ASSIGN(auto tempdir, + arrow::internal::TemporaryDir::Make("substrait_read_tempdir")); + ASSERT_OK_AND_ASSIGN(auto file_path, tempdir->path().Join(file_name)); + std::string file_path_str = file_path.ToString(); + + std::string toReplace("/T//"); + size_t pos = file_path_str.find(toReplace); + file_path_str.replace(pos, toReplace.length(), "/T/"); + + ARROW_EXPECT_OK(WriteParquetData(file_path_str, filesystem, table)); + + std::string substrait_file_uri = "file://" + file_path_str; + + std::string substrait_json = R"({ + "relations": [{ + "rel": { + "read": { + "common": { + "emit": { + "outputMapping": [1, 2] + } + }, + "base_schema": { + "names": ["A", "B", "C"], + "struct": { + "types": [{ + "i32": {} + }, { + "i32": {} + }, { + "i32": {} + }] + } + }, + "local_files": { + "items": [ + { + "uri_file": ")" + substrait_file_uri + + R"(", + "parquet": {} + } + ] + } + } + } + }], + })"; + + ASSERT_OK_AND_ASSIGN(auto buf, internal::SubstraitFromJSON("Plan", substrait_json)); + auto output_schema = schema({field("B", int32()), field("C", int32())}); + auto expected_table = TableFromJSON(output_schema, {R"([ + [1, 10], + [4, 20] + ])"}); + EmitValidate(std::move(output_schema), std::move(expected_table), exec_context, buf); +} + } // namespace engine } // namespace arrow From 7fc83cbb018f2a3dbb6af41af4bfc0d97ad3e076 Mon Sep 17 00:00:00 2001 From: Vibhatha Abeykoon Date: Wed, 17 Aug 2022 21:30:59 +0530 Subject: [PATCH 05/22] fix(format): refactor to simplify tests --- cpp/src/arrow/engine/substrait/serde_test.cc | 153 +++++-------------- 1 file changed, 42 insertions(+), 111 deletions(-) diff --git a/cpp/src/arrow/engine/substrait/serde_test.cc b/cpp/src/arrow/engine/substrait/serde_test.cc index 814967b2deb..abb35861953 100644 --- a/cpp/src/arrow/engine/substrait/serde_test.cc +++ b/cpp/src/arrow/engine/substrait/serde_test.cc @@ -196,36 +196,32 @@ inline compute::Expression UseBoringRefs(const compute::Expression& expr) { // TODO: complete this interface struct TempDataGenerator { - TempDataGenerator(const std::shared_ptr
input_data, - const std::string& file_prefix, const std::string& temp_dir_prefix) {} + TempDataGenerator(const std::shared_ptr
input_table, + const std::string& file_prefix, + std::unique_ptr& tempdir) + : input_table(input_table), file_prefix(file_prefix), tempdir(tempdir) {} Status operator()() { - auto dummy_schema = - schema({field("A", int32()), field("B", int32()), field("C", int32())}); - - // creating a dummy dataset using a dummy table - auto table = TableFromJSON(dummy_schema, {R"([ - [1, 1, 10], - [3, 4, 20] - ])"}); - auto format = std::make_shared(); auto filesystem = std::make_shared(); - const std::string file_name = "serde_read_emit_test.parquet"; + const std::string file_name = file_prefix + ".parquet"; - ARROW_ASSIGN_OR_RAISE(auto tempdir, - arrow::internal::TemporaryDir::Make("substrait_read_tempdir")); ARROW_ASSIGN_OR_RAISE(auto file_path, tempdir->path().Join(file_name)); - std::string file_path_str = file_path.ToString(); + data_file_path = file_path.ToString(); std::string toReplace("/T//"); - size_t pos = file_path_str.find(toReplace); - file_path_str.replace(pos, toReplace.length(), "/T/"); + size_t pos = data_file_path.find(toReplace); + data_file_path.replace(pos, toReplace.length(), "/T/"); - ARROW_EXPECT_OK(WriteParquetData(file_path_str, filesystem, table)); + ARROW_EXPECT_OK(WriteParquetData(data_file_path, filesystem, input_table)); return Status::OK(); } + + std::shared_ptr
input_table; + std::string file_prefix; + std::unique_ptr& tempdir; + std::string data_file_path; }; struct EmitValidate { @@ -2204,28 +2200,17 @@ TEST(Substrait, ProjectRel) { schema({field("A", int32()), field("B", int32()), field("C", int32())}); // creating a dummy dataset using a dummy table - auto table = TableFromJSON(dummy_schema, {R"([ + auto input_table = TableFromJSON(dummy_schema, {R"([ [1, 1, 10], [3, 4, 20] ])"}); - auto format = std::make_shared(); - auto filesystem = std::make_shared(); - - const std::string file_name = "serde_project_test.parquet"; - + std::string file_prefix = "serde_project_test"; ASSERT_OK_AND_ASSIGN(auto tempdir, arrow::internal::TemporaryDir::Make("substrait_project_tempdir")); - ASSERT_OK_AND_ASSIGN(auto file_path, tempdir->path().Join(file_name)); - std::string file_path_str = file_path.ToString(); - - std::string toReplace("/T//"); - size_t pos = file_path_str.find(toReplace); - file_path_str.replace(pos, toReplace.length(), "/T/"); - ARROW_EXPECT_OK(WriteParquetData(file_path_str, filesystem, table)); - - std::string substrait_file_uri = "file://" + file_path_str; + TempDataGenerator datagen(input_table, file_prefix, tempdir); + std::string substrait_file_uri = "file://" + datagen.data_file_path; std::string substrait_json = R"({ "relations": [{ @@ -2307,31 +2292,15 @@ TEST(Substrait, ProjectRel) { }} ] })"; + ASSERT_OK_AND_ASSIGN(auto buf, internal::SubstraitFromJSON("Plan", substrait_json)); - for (auto sp_ext_id_reg : - {std::shared_ptr(), substrait::MakeExtensionIdRegistry()}) { - ExtensionIdRegistry* ext_id_reg = sp_ext_id_reg.get(); - ExtensionSet ext_set(ext_id_reg); - ASSERT_OK_AND_ASSIGN(auto sink_decls, - DeserializePlans( - *buf, [] { return kNullConsumer; }, ext_id_reg, &ext_set)); - auto other_declrs = sink_decls[0].inputs[0].get(); - arrow::AsyncGenerator> sink_gen; - auto sink_node_options = compute::SinkNodeOptions{&sink_gen}; - auto sink_declaration = compute::Declaration({"sink", sink_node_options, "e"}); - auto declarations = compute::Declaration::Sequence({*other_declrs, sink_declaration}); - ASSERT_OK_AND_ASSIGN(auto acero_plan, compute::ExecPlan::Make(&exec_context)); - auto output_schema = schema({field("A", int32()), field("B", int32()), - field("C", int32()), field("ADD", int32())}); - auto expected_table = TableFromJSON(output_schema, {R"([ - [1, 1, 10, 2], - [3, 4, 20, 7] - ])"}); - ASSERT_OK_AND_ASSIGN(auto output_table, - GetTableFromPlan(acero_plan, declarations, sink_gen, - exec_context, output_schema)); - EXPECT_TRUE(expected_table->Equals(*output_table)); - } + auto output_schema = schema({field("A", int32()), field("B", int32()), + field("C", int32()), field("ADD", int32())}); + auto expected_table = TableFromJSON(output_schema, {R"([ + [1, 1, 10, 2], + [3, 4, 20, 7] + ])"}); + EmitValidate(std::move(output_schema), std::move(expected_table), exec_context, buf); } TEST(Substrait, ProjectRelOnFunctionWithEmit) { @@ -2340,28 +2309,17 @@ TEST(Substrait, ProjectRelOnFunctionWithEmit) { schema({field("A", int32()), field("B", int32()), field("C", int32())}); // creating a dummy dataset using a dummy table - auto table = TableFromJSON(dummy_schema, {R"([ + auto input_table = TableFromJSON(dummy_schema, {R"([ [1, 1, 10], [3, 4, 20] ])"}); - auto format = std::make_shared(); - auto filesystem = std::make_shared(); - - const std::string file_name = "serde_project_emit_test.parquet"; + std::string file_prefix = "serde_project_emit_test"; + ASSERT_OK_AND_ASSIGN(auto tempdir, arrow::internal::TemporaryDir::Make( + "substrait_project_emit_tempdir")); - ASSERT_OK_AND_ASSIGN(auto tempdir, - arrow::internal::TemporaryDir::Make("substrait_project_tempdir")); - ASSERT_OK_AND_ASSIGN(auto file_path, tempdir->path().Join(file_name)); - std::string file_path_str = file_path.ToString(); - - std::string toReplace("/T//"); - size_t pos = file_path_str.find(toReplace); - file_path_str.replace(pos, toReplace.length(), "/T/"); - - ARROW_EXPECT_OK(WriteParquetData(file_path_str, filesystem, table)); - - std::string substrait_file_uri = "file://" + file_path_str; + TempDataGenerator datagen(input_table, file_prefix, tempdir); + std::string substrait_file_uri = "file://" + datagen.data_file_path; std::string substrait_json = R"({ "relations": [{ @@ -2447,31 +2405,15 @@ TEST(Substrait, ProjectRelOnFunctionWithEmit) { }} ] })"; + ASSERT_OK_AND_ASSIGN(auto buf, internal::SubstraitFromJSON("Plan", substrait_json)); - for (auto sp_ext_id_reg : - {std::shared_ptr(), substrait::MakeExtensionIdRegistry()}) { - ExtensionIdRegistry* ext_id_reg = sp_ext_id_reg.get(); - ExtensionSet ext_set(ext_id_reg); - ASSERT_OK_AND_ASSIGN(auto sink_decls, - DeserializePlans( - *buf, [] { return kNullConsumer; }, ext_id_reg, &ext_set)); - auto other_declrs = sink_decls[0].inputs[0].get(); - arrow::AsyncGenerator> sink_gen; - auto sink_node_options = compute::SinkNodeOptions{&sink_gen}; - auto sink_declaration = compute::Declaration({"sink", sink_node_options, "e"}); - auto declarations = compute::Declaration::Sequence({*other_declrs, sink_declaration}); - ASSERT_OK_AND_ASSIGN(auto acero_plan, compute::ExecPlan::Make(&exec_context)); - auto output_schema = - schema({field("A", int32()), field("C", int32()), field("add", int32())}); - auto expected_table = TableFromJSON(output_schema, {R"([ + auto output_schema = + schema({field("A", int32()), field("C", int32()), field("add", int32())}); + auto expected_table = TableFromJSON(output_schema, {R"([ [1, 10, 2], [3, 20, 7] - ])"}); - ASSERT_OK_AND_ASSIGN(auto output_table, - GetTableFromPlan(acero_plan, declarations, sink_gen, - exec_context, output_schema)); - EXPECT_TRUE(expected_table->Equals(*output_table)); - } + ])"}); + EmitValidate(std::move(output_schema), std::move(expected_table), exec_context, buf); } TEST(Substrait, ReadRelWithEmit) { @@ -2480,28 +2422,17 @@ TEST(Substrait, ReadRelWithEmit) { schema({field("A", int32()), field("B", int32()), field("C", int32())}); // creating a dummy dataset using a dummy table - auto table = TableFromJSON(dummy_schema, {R"([ + auto input_table = TableFromJSON(dummy_schema, {R"([ [1, 1, 10], [3, 4, 20] ])"}); - auto format = std::make_shared(); - auto filesystem = std::make_shared(); - - const std::string file_name = "serde_read_emit_test.parquet"; - ASSERT_OK_AND_ASSIGN(auto tempdir, arrow::internal::TemporaryDir::Make("substrait_read_tempdir")); - ASSERT_OK_AND_ASSIGN(auto file_path, tempdir->path().Join(file_name)); - std::string file_path_str = file_path.ToString(); - - std::string toReplace("/T//"); - size_t pos = file_path_str.find(toReplace); - file_path_str.replace(pos, toReplace.length(), "/T/"); - - ARROW_EXPECT_OK(WriteParquetData(file_path_str, filesystem, table)); + std::string file_prefix = "serde_read_emit_test"; - std::string substrait_file_uri = "file://" + file_path_str; + TempDataGenerator datagen(input_table, file_prefix, tempdir); + std::string substrait_file_uri = "file://" + datagen.data_file_path; std::string substrait_json = R"({ "relations": [{ From 43fea249647e897e405f107d9e4c556595703012 Mon Sep 17 00:00:00 2001 From: Vibhatha Abeykoon Date: Thu, 18 Aug 2022 11:46:52 +0530 Subject: [PATCH 06/22] feat(filter): adding filter emit --- cpp/src/arrow/engine/substrait/serde_test.cc | 141 ++++++++++++++++++- 1 file changed, 137 insertions(+), 4 deletions(-) diff --git a/cpp/src/arrow/engine/substrait/serde_test.cc b/cpp/src/arrow/engine/substrait/serde_test.cc index abb35861953..39252e03383 100644 --- a/cpp/src/arrow/engine/substrait/serde_test.cc +++ b/cpp/src/arrow/engine/substrait/serde_test.cc @@ -250,7 +250,7 @@ struct EmitValidate { ASSERT_OK_AND_ASSIGN(auto output_table, GetTableFromPlan(acero_plan, declarations, sink_gen, exec_context, output_schema)); - EXPECT_TRUE(expected_table->Equals(*output_table)); + EXPECT_TRUE(expected_table->Equals(*output_table, true)); } } std::shared_ptr output_schema; @@ -2210,6 +2210,7 @@ TEST(Substrait, ProjectRel) { arrow::internal::TemporaryDir::Make("substrait_project_tempdir")); TempDataGenerator datagen(input_table, file_prefix, tempdir); + ASSERT_OK(datagen()); std::string substrait_file_uri = "file://" + datagen.data_file_path; std::string substrait_json = R"({ @@ -2300,10 +2301,13 @@ TEST(Substrait, ProjectRel) { [1, 1, 10, 2], [3, 4, 20, 7] ])"}); - EmitValidate(std::move(output_schema), std::move(expected_table), exec_context, buf); + EmitValidate(std::move(output_schema), std::move(expected_table), exec_context, buf)(); } TEST(Substrait, ProjectRelOnFunctionWithEmit) { +#ifdef _WIN32 + GTEST_SKIP() << "ARROW-16392: Substrait File URI not supported for Windows"; +#endif compute::ExecContext exec_context; auto dummy_schema = schema({field("A", int32()), field("B", int32()), field("C", int32())}); @@ -2319,6 +2323,7 @@ TEST(Substrait, ProjectRelOnFunctionWithEmit) { "substrait_project_emit_tempdir")); TempDataGenerator datagen(input_table, file_prefix, tempdir); + ASSERT_OK(datagen()); std::string substrait_file_uri = "file://" + datagen.data_file_path; std::string substrait_json = R"({ @@ -2413,10 +2418,13 @@ TEST(Substrait, ProjectRelOnFunctionWithEmit) { [1, 10, 2], [3, 20, 7] ])"}); - EmitValidate(std::move(output_schema), std::move(expected_table), exec_context, buf); + EmitValidate(std::move(output_schema), std::move(expected_table), exec_context, buf)(); } TEST(Substrait, ReadRelWithEmit) { +#ifdef _WIN32 + GTEST_SKIP() << "ARROW-16392: Substrait File URI not supported for Windows"; +#endif compute::ExecContext exec_context; auto dummy_schema = schema({field("A", int32()), field("B", int32()), field("C", int32())}); @@ -2432,6 +2440,7 @@ TEST(Substrait, ReadRelWithEmit) { std::string file_prefix = "serde_read_emit_test"; TempDataGenerator datagen(input_table, file_prefix, tempdir); + ASSERT_OK(datagen()); std::string substrait_file_uri = "file://" + datagen.data_file_path; std::string substrait_json = R"({ @@ -2475,7 +2484,131 @@ TEST(Substrait, ReadRelWithEmit) { [1, 10], [4, 20] ])"}); - EmitValidate(std::move(output_schema), std::move(expected_table), exec_context, buf); + EmitValidate(std::move(output_schema), std::move(expected_table), exec_context, buf)(); +} + +TEST(Substrait, FilterRelWithEmit) { +#ifdef _WIN32 + GTEST_SKIP() << "ARROW-16392: Substrait File URI not supported for Windows"; +#endif + compute::ExecContext exec_context; + auto dummy_schema = schema({field("A", int32()), field("B", int32()), + field("C", int32()), field("D", int32())}); + + // creating a dummy dataset using a dummy table + auto input_table = TableFromJSON(dummy_schema, {R"([ + [10, 1, 80, 7], + [20, 2, 70, 6], + [30, 3, 30, 5], + [40, 4, 20, 4], + [40, 5, 40, 3], + [20, 6, 20, 2], + [30, 7, 30, 1] + ])"}); + + ASSERT_OK_AND_ASSIGN(auto tempdir, + arrow::internal::TemporaryDir::Make("substrait_read_tempdir")); + std::string file_prefix = "serde_read_emit_test"; + + TempDataGenerator datagen(input_table, file_prefix, tempdir); + ASSERT_OK(datagen()); + std::string substrait_file_uri = "file://" + datagen.data_file_path; + + std::string substrait_json = R"({ + "relations": [{ + "rel": { + "filter": { + "common": { + "emit": { + "outputMapping": [1, 3] + } + }, + "condition": { + "scalarFunction": { + "functionReference": 0, + "arguments": [{ + "value": { + "selection": { + "directReference": { + "structField": { + "field": 0 + } + }, + "rootReference": { + } + } + } + }, { + "value": { + "selection": { + "directReference": { + "structField": { + "field": 2 + } + }, + "rootReference": { + } + } + } + }] + } + }, + "input" : { + "read": { + "base_schema": { + "names": ["A", "B", "C", "D"], + "struct": { + "types": [{ + "i32": {} + }, { + "i32": {} + }, { + "i32": {} + },{ + "i32": {} + }] + } + }, + "local_files": { + "items": [ + { + "uri_file": ")" + + substrait_file_uri + + R"(", + "parquet": {} + } + ] + } + } + } + } + } + }], + "extension_uris": [ + { + "extension_uri_anchor": 0, + "uri": ")" + substrait::default_extension_types_uri() + + R"(" + } + ], + "extensions": [ + {"extension_function": { + "extension_uri_reference": 0, + "function_anchor": 0, + "name": "equal" + }} + ] + })"; + + ASSERT_OK_AND_ASSIGN(auto buf, internal::SubstraitFromJSON("Plan", substrait_json)); + auto output_schema = schema({field("B", int32()), field("D", int32())}); + auto expected_table = TableFromJSON(output_schema, {R"([ + [3, 5], + [5, 3], + [6, 2], + [7, 1] + ])"}); + EmitValidate(std::move(output_schema), std::move(expected_table), exec_context, buf)(); } } // namespace engine From 136edf5d64e86c7e6892581a56d330e01d4e6e32 Mon Sep 17 00:00:00 2001 From: Vibhatha Abeykoon Date: Thu, 18 Aug 2022 17:47:38 +0530 Subject: [PATCH 07/22] feat(join): adding join example --- .../engine/substrait/relation_internal.cc | 5 + cpp/src/arrow/engine/substrait/serde_test.cc | 215 +++++++++++++++++- 2 files changed, 217 insertions(+), 3 deletions(-) diff --git a/cpp/src/arrow/engine/substrait/relation_internal.cc b/cpp/src/arrow/engine/substrait/relation_internal.cc index 575cd8732dd..c3c787f2e56 100644 --- a/cpp/src/arrow/engine/substrait/relation_internal.cc +++ b/cpp/src/arrow/engine/substrait/relation_internal.cc @@ -439,6 +439,11 @@ Result FromProto(const substrait::Rel& rel, const ExtensionSet& if (!left_keys || !right_keys) { return Status::Invalid("Left keys for join cannot be null"); } + + // Create output schema from left, right relations and join keys + // std::shared_ptr left_schema = left.output_schema; + // std::shared_ptr right_schema = right.output_schema; + compute::HashJoinNodeOptions join_options{{std::move(*left_keys)}, {std::move(*right_keys)}}; join_options.join_type = join_type; diff --git a/cpp/src/arrow/engine/substrait/serde_test.cc b/cpp/src/arrow/engine/substrait/serde_test.cc index 39252e03383..5da926acc1b 100644 --- a/cpp/src/arrow/engine/substrait/serde_test.cc +++ b/cpp/src/arrow/engine/substrait/serde_test.cc @@ -227,11 +227,14 @@ struct TempDataGenerator { struct EmitValidate { EmitValidate(const std::shared_ptr output_schema, const std::shared_ptr
expected_table, - compute::ExecContext& exec_context, std::shared_ptr& buf) + compute::ExecContext& exec_context, std::shared_ptr& buf, + const std::vector& include_columns = {}, bool combine_chunks = false) : output_schema(output_schema), expected_table(expected_table), exec_context(exec_context), - buf(buf) {} + buf(buf), + include_columns(include_columns), + combine_chunks(combine_chunks) {} void operator()() { for (auto sp_ext_id_reg : {std::shared_ptr(), substrait::MakeExtensionIdRegistry()}) { @@ -250,13 +253,33 @@ struct EmitValidate { ASSERT_OK_AND_ASSIGN(auto output_table, GetTableFromPlan(acero_plan, declarations, sink_gen, exec_context, output_schema)); - EXPECT_TRUE(expected_table->Equals(*output_table, true)); + if (!include_columns.empty()) { + ASSERT_OK_AND_ASSIGN(output_table, output_table->SelectColumns(include_columns)); + } + if (combine_chunks) { + ASSERT_OK_AND_ASSIGN(output_table, output_table->CombineChunks()); + } + + EXPECT_TRUE(expected_table->Equals(*output_table)); + std::cout << "output" << std::endl; + std::cout << std::string(10, '#') << std::endl; + std::cout << output_table->ToString() << std::endl; + std::cout << std::string(10, '#') << std::endl; + std::cout << output_table->schema()->ToString(false) << std::endl; + + std::cout << "expected" << std::endl; + std::cout << std::string(10, '#') << std::endl; + std::cout << expected_table->ToString() << std::endl; + std::cout << std::string(10, '#') << std::endl; + std::cout << expected_table->schema()->ToString(false) << std::endl; } } std::shared_ptr output_schema; std::shared_ptr
expected_table; compute::ExecContext exec_context; std::shared_ptr buf; + const std::vector& include_columns; + bool combine_chunks; }; TEST(Substrait, SupportedTypes) { @@ -2611,5 +2634,191 @@ TEST(Substrait, FilterRelWithEmit) { EmitValidate(std::move(output_schema), std::move(expected_table), exec_context, buf)(); } +TEST(Substrait, JoinRelWithEmit) { +#ifdef _WIN32 + GTEST_SKIP() << "ARROW-16392: Substrait File URI not supported for Windows"; +#endif + compute::ExecContext exec_context; + auto left_schema = schema({field("A", int32()), field("B", int32()), + field("C", int32()), field("D", int32())}); + + auto right_schema = schema({field("X", int32()), field("Y", int32()), + field("Z", int32()), field("W", int32())}); + + // creating a dummy dataset using a dummy table + auto left_table = TableFromJSON(left_schema, {R"([ + [10, 1, 80, 70], + [20, 2, 70, 60], + [30, 3, 30, 50] + ])"}); + + auto right_table = TableFromJSON(right_schema, {R"([ + [10, 1, 81, 71], + [80, 2, 71, 61], + [31, 3, 31, 51] + ])"}); + + ASSERT_OK_AND_ASSIGN(auto tempdir, + arrow::internal::TemporaryDir::Make("substrait_join_tempdir")); + std::string left_file_prefix = "serde_join_left_emit_test"; + std::string right_file_prefix = "serde_join_right_emit_test"; + + TempDataGenerator datagen_left(left_table, left_file_prefix, tempdir); + ASSERT_OK(datagen_left()); + std::string substrait_left_file_uri = "file://" + datagen_left.data_file_path; + + TempDataGenerator datagen_right(right_table, right_file_prefix, tempdir); + ASSERT_OK(datagen_right()); + std::string substrait_right_file_uri = "file://" + datagen_right.data_file_path; + + std::string substrait_json = R"({ + "relations": [{ + "rel": { + "join": { + "left": { + "read": { + "base_schema": { + "names": ["A", "B", "C", "D"], + "struct": { + "types": [{ + "i32": {} + }, { + "i32": {} + }, { + "i32": {} + }, { + "i32": {} + }] + } + }, + "local_files": { + "items": [ + { + "uri_file": ")" + + substrait_left_file_uri + + R"(", + "parquet": {} + } + ] + } + } + }, + "right": { + "read": { + "base_schema": { + "names": ["X", "Y", "Z", "W"], + "struct": { + "types": [{ + "i32": {} + }, { + "i32": {} + }, { + "i32": {} + }, { + "i32": {} + }] + } + }, + "local_files": { + "items": [ + { + "uri_file": ")" + + substrait_right_file_uri + + R"(", + "parquet": {} + } + ] + } + } + }, + "expression": { + "scalarFunction": { + "functionReference": 0, + "arguments": [{ + "value": { + "selection": { + "directReference": { + "structField": { + "field": 0 + } + }, + "rootReference": { + } + } + } + }, { + "value": { + "selection": { + "directReference": { + "structField": { + "field": 0 + } + }, + "rootReference": { + } + } + } + }] + } + }, + "type": "JOIN_TYPE_INNER" + } + } + }], + "extension_uris": [ + { + "extension_uri_anchor": 0, + "uri": ")" + substrait::default_extension_types_uri() + + R"(" + } + ], + "extensions": [ + {"extension_function": { + "extension_uri_reference": 0, + "function_anchor": 0, + "name": "equal" + }} + ] + })"; + + ASSERT_OK_AND_ASSIGN(auto buf, internal::SubstraitFromJSON("Plan", substrait_json)); + auto output_schema = schema({ + field("A", int32()), + field("B", int32()), + field("C", int32()), + field("D", int32()), + field("__fragment_index_l", int32()), + field("__batch_index_l", int32()), + field("__last_in_fragment_l", boolean()), + field("__filename_l", utf8()), + field("X", int32()), + field("Y", int32()), + field("Z", int32()), + field("W", int32()), + field("__fragment_index_r", int32()), + field("__batch_index_r", int32()), + field("__last_in_fragment_r", boolean()), + field("__filename_r", utf8()), + }); + + // include these columns for comparison + std::vector include_columns{0, 1, 2, 3, 8, 9, 10, 11}; + auto compared_output_schema = schema({ + field("A", int32()), + field("B", int32()), + field("C", int32()), + field("D", int32()), + field("X", int32()), + field("Y", int32()), + field("Z", int32()), + field("W", int32()), + }); + auto expected_table = TableFromJSON(std::move(compared_output_schema), {R"([ + [10, 1, 80, 70, 10, 1, 81, 71] + ])"}); + EmitValidate(std::move(output_schema), std::move(expected_table), exec_context, buf, + std::move(include_columns), true)(); +} + } // namespace engine } // namespace arrow From a2c08a2a75b089c6a3a5f52f18e6c888dc84fb9b Mon Sep 17 00:00:00 2001 From: Vibhatha Abeykoon Date: Fri, 19 Aug 2022 11:21:55 +0530 Subject: [PATCH 08/22] fix(rebase): merge with substrait changes --- .../engine/substrait/relation_internal.cc | 27 ++++++++++++++----- cpp/src/arrow/engine/substrait/serde_test.cc | 18 ++++++++----- 2 files changed, 33 insertions(+), 12 deletions(-) diff --git a/cpp/src/arrow/engine/substrait/relation_internal.cc b/cpp/src/arrow/engine/substrait/relation_internal.cc index c3c787f2e56..cc207322898 100644 --- a/cpp/src/arrow/engine/substrait/relation_internal.cc +++ b/cpp/src/arrow/engine/substrait/relation_internal.cc @@ -129,7 +129,7 @@ Result FromProto(const substrait::Rel& rel, const ExtensionSet& named_table.names().end()); ARROW_ASSIGN_OR_RAISE(compute::Declaration source_decl, named_table_provider(table_names)); - return DeclarationInfo{std::move(source_decl), num_columns}; + return DeclarationInfo{std::move(source_decl), num_columns, base_schema}; } if (!read.has_local_files()) { @@ -241,7 +241,6 @@ Result FromProto(const substrait::Rel& rel, const ExtensionSet& std::move(filesystem), std::move(files), std::move(format), {})); - auto num_columns = static_cast(base_schema->fields().size()); ARROW_ASSIGN_OR_RAISE(auto ds, ds_factory->Finish(base_schema)); if (HasEmit(read)) { @@ -441,8 +440,14 @@ Result FromProto(const substrait::Rel& rel, const ExtensionSet& } // Create output schema from left, right relations and join keys - // std::shared_ptr left_schema = left.output_schema; - // std::shared_ptr right_schema = right.output_schema; + std::shared_ptr join_schema = left.output_schema; + std::shared_ptr right_schema = right.output_schema; + + for (const auto& field : right_schema->fields()) { + ARROW_ASSIGN_OR_RAISE( + join_schema, join_schema->AddField( + static_cast(join_schema->fields().size()) - 1, field)); + } compute::HashJoinNodeOptions join_options{{std::move(*left_keys)}, {std::move(*right_keys)}}; @@ -452,8 +457,18 @@ Result FromProto(const substrait::Rel& rel, const ExtensionSet& auto num_columns = left.num_columns + right.num_columns; join_dec.inputs.emplace_back(std::move(left.declaration)); join_dec.inputs.emplace_back(std::move(right.declaration)); - // TODO: add schema and get emit rel - return DeclarationInfo{std::move(join_dec), num_columns, arrow::schema({})}; + + if (HasEmit(join)) { + ARROW_ASSIGN_OR_RAISE(auto emit_expressions, + GetEmitExpression(join, join_schema)); + return DeclarationInfo{ + compute::Declaration::Sequence( + {std::move(join_dec), + {"project", compute::ProjectNodeOptions{std::move(emit_expressions)}}}), + num_columns, join_schema}; + } else { + return DeclarationInfo{std::move(join_dec), num_columns, join_schema}; + } } case substrait::Rel::RelTypeCase::kAggregate: { const auto& aggregate = rel.aggregate(); diff --git a/cpp/src/arrow/engine/substrait/serde_test.cc b/cpp/src/arrow/engine/substrait/serde_test.cc index 5da926acc1b..6a490b155bb 100644 --- a/cpp/src/arrow/engine/substrait/serde_test.cc +++ b/cpp/src/arrow/engine/substrait/serde_test.cc @@ -237,7 +237,7 @@ struct EmitValidate { combine_chunks(combine_chunks) {} void operator()() { for (auto sp_ext_id_reg : - {std::shared_ptr(), substrait::MakeExtensionIdRegistry()}) { + {std::shared_ptr(), MakeExtensionIdRegistry()}) { ExtensionIdRegistry* ext_id_reg = sp_ext_id_reg.get(); ExtensionSet ext_set(ext_id_reg); ASSERT_OK_AND_ASSIGN(auto sink_decls, @@ -2255,7 +2255,7 @@ TEST(Substrait, ProjectRel) { } } } - }, + }, { "value": { "selection": { @@ -2573,7 +2573,10 @@ TEST(Substrait, FilterRelWithEmit) { } } } - }] + }], + "output_type": { + "bool": {} + } } }, "input" : { @@ -2610,7 +2613,7 @@ TEST(Substrait, FilterRelWithEmit) { "extension_uris": [ { "extension_uri_anchor": 0, - "uri": ")" + substrait::default_extension_types_uri() + + "uri": ")" + std::string(kSubstraitComparisonFunctionsUri) + R"(" } ], @@ -2758,7 +2761,10 @@ TEST(Substrait, JoinRelWithEmit) { } } } - }] + }], + "output_type": { + "bool": {} + } } }, "type": "JOIN_TYPE_INNER" @@ -2768,7 +2774,7 @@ TEST(Substrait, JoinRelWithEmit) { "extension_uris": [ { "extension_uri_anchor": 0, - "uri": ")" + substrait::default_extension_types_uri() + + "uri": ")" + std::string(kSubstraitComparisonFunctionsUri) + R"(" } ], From 9578a91406ceff38bbd208df791234ff179c7b9f Mon Sep 17 00:00:00 2001 From: Vibhatha Abeykoon Date: Fri, 19 Aug 2022 11:30:58 +0530 Subject: [PATCH 09/22] fix(project): replaced the add op with equal for test case --- cpp/src/arrow/engine/substrait/serde_test.cc | 23 +++----------------- 1 file changed, 3 insertions(+), 20 deletions(-) diff --git a/cpp/src/arrow/engine/substrait/serde_test.cc b/cpp/src/arrow/engine/substrait/serde_test.cc index 6a490b155bb..f2259b5a218 100644 --- a/cpp/src/arrow/engine/substrait/serde_test.cc +++ b/cpp/src/arrow/engine/substrait/serde_test.cc @@ -228,13 +228,12 @@ struct EmitValidate { EmitValidate(const std::shared_ptr output_schema, const std::shared_ptr
expected_table, compute::ExecContext& exec_context, std::shared_ptr& buf, - const std::vector& include_columns = {}, bool combine_chunks = false) + const std::vector& include_columns = {}) : output_schema(output_schema), expected_table(expected_table), exec_context(exec_context), buf(buf), - include_columns(include_columns), - combine_chunks(combine_chunks) {} + include_columns(include_columns) {} void operator()() { for (auto sp_ext_id_reg : {std::shared_ptr(), MakeExtensionIdRegistry()}) { @@ -256,22 +255,7 @@ struct EmitValidate { if (!include_columns.empty()) { ASSERT_OK_AND_ASSIGN(output_table, output_table->SelectColumns(include_columns)); } - if (combine_chunks) { - ASSERT_OK_AND_ASSIGN(output_table, output_table->CombineChunks()); - } - EXPECT_TRUE(expected_table->Equals(*output_table)); - std::cout << "output" << std::endl; - std::cout << std::string(10, '#') << std::endl; - std::cout << output_table->ToString() << std::endl; - std::cout << std::string(10, '#') << std::endl; - std::cout << output_table->schema()->ToString(false) << std::endl; - - std::cout << "expected" << std::endl; - std::cout << std::string(10, '#') << std::endl; - std::cout << expected_table->ToString() << std::endl; - std::cout << std::string(10, '#') << std::endl; - std::cout << expected_table->schema()->ToString(false) << std::endl; } } std::shared_ptr output_schema; @@ -279,7 +263,6 @@ struct EmitValidate { compute::ExecContext exec_context; std::shared_ptr buf; const std::vector& include_columns; - bool combine_chunks; }; TEST(Substrait, SupportedTypes) { @@ -2823,7 +2806,7 @@ TEST(Substrait, JoinRelWithEmit) { [10, 1, 80, 70, 10, 1, 81, 71] ])"}); EmitValidate(std::move(output_schema), std::move(expected_table), exec_context, buf, - std::move(include_columns), true)(); + std::move(include_columns))(); } } // namespace engine From fb77dc19ce07ea67ad3fc3a6e89b2cb7665baf3b Mon Sep 17 00:00:00 2001 From: Vibhatha Abeykoon Date: Fri, 19 Aug 2022 12:54:36 +0530 Subject: [PATCH 10/22] feat(aggreagte): basic end-to-end test added --- .../engine/substrait/relation_internal.cc | 4 +- cpp/src/arrow/engine/substrait/serde_test.cc | 127 ++++++++++++++++++ 2 files changed, 129 insertions(+), 2 deletions(-) diff --git a/cpp/src/arrow/engine/substrait/relation_internal.cc b/cpp/src/arrow/engine/substrait/relation_internal.cc index cc207322898..9b2f3921318 100644 --- a/cpp/src/arrow/engine/substrait/relation_internal.cc +++ b/cpp/src/arrow/engine/substrait/relation_internal.cc @@ -465,9 +465,9 @@ Result FromProto(const substrait::Rel& rel, const ExtensionSet& compute::Declaration::Sequence( {std::move(join_dec), {"project", compute::ProjectNodeOptions{std::move(emit_expressions)}}}), - num_columns, join_schema}; + num_columns, std::move(join_schema)}; } else { - return DeclarationInfo{std::move(join_dec), num_columns, join_schema}; + return DeclarationInfo{std::move(join_dec), num_columns, std::move(join_schema)}; } } case substrait::Rel::RelTypeCase::kAggregate: { diff --git a/cpp/src/arrow/engine/substrait/serde_test.cc b/cpp/src/arrow/engine/substrait/serde_test.cc index f2259b5a218..c66bec77988 100644 --- a/cpp/src/arrow/engine/substrait/serde_test.cc +++ b/cpp/src/arrow/engine/substrait/serde_test.cc @@ -255,7 +255,18 @@ struct EmitValidate { if (!include_columns.empty()) { ASSERT_OK_AND_ASSIGN(output_table, output_table->SelectColumns(include_columns)); } + ASSERT_OK_AND_ASSIGN(output_table, output_table->CombineChunks()); EXPECT_TRUE(expected_table->Equals(*output_table)); + + std::cout << "output" << std::endl; + std::cout << std::string(20, '-') << std::endl; + std::cout << output_table->ToString() << std::endl; + std::cout << std::string(20, '-') << std::endl; + + std::cout << "expected" << std::endl; + std::cout << std::string(20, '-') << std::endl; + std::cout << expected_table->ToString() << std::endl; + std::cout << std::string(20, '-') << std::endl; } } std::shared_ptr output_schema; @@ -2809,5 +2820,121 @@ TEST(Substrait, JoinRelWithEmit) { std::move(include_columns))(); } +TEST(Substrait, AggregateRelEmit) { +#ifdef _WIN32 + GTEST_SKIP() << "ARROW-16392: Substrait File URI not supported for Windows"; +#endif + compute::ExecContext exec_context; + auto dummy_schema = + schema({field("A", int32()), field("B", int32()), field("C", int32())}); + + // creating a dummy dataset using a dummy table + auto input_table = TableFromJSON(dummy_schema, {R"([ + [10, 1, 80], + [20, 2, 70], + [30, 3, 30], + [40, 4, 20], + [40, 5, 40], + [20, 6, 20], + [30, 7, 30] + ])"}); + + ASSERT_OK_AND_ASSIGN(auto tempdir, + arrow::internal::TemporaryDir::Make("substrait_agg_tempdir")); + std::string file_prefix = "serde_agg_emit_test"; + + TempDataGenerator datagen(input_table, file_prefix, tempdir); + ASSERT_OK(datagen()); + std::string substrait_file_uri = "file://" + datagen.data_file_path; + std::string substrait_json = R"({ + "relations": [{ + "rel": { + "aggregate": { + "input": { + "read": { + "base_schema": { + "names": ["A", "B", "C"], + "struct": { + "types": [{ + "i32": {} + }, { + "i32": {} + }, { + "i32": {} + }] + } + }, + "local_files": { + "items": [ + { + "uri_file": ")" + + substrait_file_uri + + R"(", + "parquet": {} + } + ] + } + } + }, + "groupings": [{ + "groupingExpressions": [{ + "selection": { + "directReference": { + "structField": { + "field": 0 + } + } + } + }] + }], + "measures": [{ + "measure": { + "functionReference": 0, + "arguments": [{ + "value": { + "selection": { + "directReference": { + "structField": { + "field": 2 + } + } + } + } + }], + "sorts": [], + "phase": "AGGREGATION_PHASE_INITIAL_TO_RESULT", + "invocation": "AGGREGATION_INVOCATION_ALL", + "outputType": { + "i64": {} + } + } + }] + } + } + }], + "extensionUris": [{ + "extension_uri_anchor": 0, + "uri": "https://github.com/substrait-io/substrait/blob/main/extensions/functions_arithmetic.yaml" + }], + "extensions": [{ + "extension_function": { + "extension_uri_reference": 0, + "function_anchor": 0, + "name": "sum" + } + }], + })"; + + ASSERT_OK_AND_ASSIGN(auto buf, internal::SubstraitFromJSON("Plan", substrait_json)); + auto output_schema = schema({field("aggregates", int64()), field("keys", int32())}); + auto expected_table = TableFromJSON(output_schema, {R"([ + [80, 10], + [90, 20], + [60, 30], + [60, 40] + ])"}); + EmitValidate(std::move(output_schema), std::move(expected_table), exec_context, buf)(); +} + } // namespace engine } // namespace arrow From ea2a05c4049c7d4f4270742f3e8ff2efe83f2b45 Mon Sep 17 00:00:00 2001 From: Vibhatha Abeykoon Date: Sun, 21 Aug 2022 07:45:40 +0530 Subject: [PATCH 11/22] feat(agg): adding aggregate feature for emits --- .../engine/substrait/relation_internal.cc | 60 +++- cpp/src/arrow/engine/substrait/serde_test.cc | 317 +++++++++++++++++- 2 files changed, 354 insertions(+), 23 deletions(-) diff --git a/cpp/src/arrow/engine/substrait/relation_internal.cc b/cpp/src/arrow/engine/substrait/relation_internal.cc index 9b2f3921318..f16bb4a910c 100644 --- a/cpp/src/arrow/engine/substrait/relation_internal.cc +++ b/cpp/src/arrow/engine/substrait/relation_internal.cc @@ -54,15 +54,12 @@ Result> GetEmitExpression( const RelMessage& rel, const std::shared_ptr& schema) { const auto& emit = rel.common().emit(); int emit_size = emit.output_mapping_size(); - // std::vector proj_names(emit_size); std::vector proj_field_refs(emit_size); for (int i = 0; i < emit_size; i++) { int32_t map_id = emit.output_mapping(i); - // auto field = schema->field(map_id); - // auto field_name = field->name(); - // proj_names[i] = field_name; proj_field_refs[i] = compute::field_ref(FieldRef(map_id)); } + // TODO: return emit size and expression as a tuple return std::move(proj_field_refs); } @@ -486,16 +483,25 @@ Result FromProto(const substrait::Rel& rel, const ExtensionSet& "Grouping sets not supported. AggregateRel::groupings may not have more " "than one item"); } + + // prepare output schema from aggregates + auto input_schema = input.output_schema; + // store key fields to be used when output schema is created + std::vector key_field_ids; std::vector keys; if (aggregate.groupings_size() > 0) { const substrait::AggregateRel::Grouping& group = aggregate.groupings(0); - keys.reserve(group.grouping_expressions_size()); - for (int exp_id = 0; exp_id < group.grouping_expressions_size(); exp_id++) { + int grouping_expr_size = group.grouping_expressions_size(); + keys.reserve(grouping_expr_size); + key_field_ids.reserve(grouping_expr_size); + for (int exp_id = 0; exp_id < grouping_expr_size; exp_id++) { ARROW_ASSIGN_OR_RAISE( compute::Expression expr, FromProto(group.grouping_expressions(exp_id), ext_set, conversion_options)); const FieldRef* field_ref = expr.field_ref(); if (field_ref) { + ARROW_ASSIGN_OR_RAISE(auto match, field_ref->FindOne(*input_schema)); + key_field_ids.emplace_back(std::move(match[0])); keys.emplace_back(std::move(*field_ref)); } else { return Status::Invalid( @@ -507,6 +513,8 @@ Result FromProto(const substrait::Rel& rel, const ExtensionSet& int measure_size = aggregate.measures_size(); std::vector aggregates; aggregates.reserve(measure_size); + // store aggregate fields to be used when output schema is created + std::vector agg_src_field_ids(measure_size); for (int measure_id = 0; measure_id < measure_size; measure_id++) { const auto& agg_measure = aggregate.measures(measure_id); if (agg_measure.has_measure()) { @@ -521,17 +529,45 @@ Result FromProto(const substrait::Rel& rel, const ExtensionSet& ExtensionIdRegistry::SubstraitAggregateToArrow converter, ext_set.registry()->GetSubstraitAggregateToArrow(aggregate_call.id())); ARROW_ASSIGN_OR_RAISE(compute::Aggregate arrow_agg, converter(aggregate_call)); + + // find aggregate field ids from schema + const auto field_ref = arrow_agg.target; + ARROW_ASSIGN_OR_RAISE(auto match, field_ref.FindOne(*input_schema)); + agg_src_field_ids[measure_id] = match[0]; + aggregates.push_back(std::move(arrow_agg)); } else { return Status::Invalid("substrait::AggregateFunction not provided"); } } - /// TODO: add emit and extract schema - return DeclarationInfo{ - compute::Declaration::Sequence( - {std::move(input.declaration), - {"aggregate", compute::AggregateNodeOptions{aggregates, keys}}}), - static_cast(aggregates.size()), arrow::schema({})}; + FieldVector output_fields; + output_fields.reserve(key_field_ids.size() + agg_src_field_ids.size()); + // extract aggregate fields to output schema + for (int id = 0; id < static_cast(agg_src_field_ids.size()); id++) { + output_fields.emplace_back(input_schema->field(agg_src_field_ids[id])); + } + // extract key fields to output schema + for (int id = 0; id < static_cast(key_field_ids.size()); id++) { + output_fields.emplace_back(input_schema->field(key_field_ids[id])); + } + + std::shared_ptr aggregate_schema = schema(std::move(output_fields)); + if (HasEmit(aggregate)) { + ARROW_ASSIGN_OR_RAISE(auto emit_expressions, + GetEmitExpression(aggregate, aggregate_schema)); + return DeclarationInfo{ + compute::Declaration::Sequence( + {std::move(input.declaration), + {"aggregate", compute::AggregateNodeOptions{aggregates, keys}}, + {"project", compute::ProjectNodeOptions{std::move(emit_expressions)}}}), + static_cast(aggregates.size()), std::move(aggregate_schema)}; + } else { + return DeclarationInfo{ + compute::Declaration::Sequence( + {std::move(input.declaration), + {"aggregate", compute::AggregateNodeOptions{aggregates, keys}}}), + static_cast(aggregates.size()), std::move(aggregate_schema)}; + } } default: diff --git a/cpp/src/arrow/engine/substrait/serde_test.cc b/cpp/src/arrow/engine/substrait/serde_test.cc index c66bec77988..2960f92edc5 100644 --- a/cpp/src/arrow/engine/substrait/serde_test.cc +++ b/cpp/src/arrow/engine/substrait/serde_test.cc @@ -258,15 +258,15 @@ struct EmitValidate { ASSERT_OK_AND_ASSIGN(output_table, output_table->CombineChunks()); EXPECT_TRUE(expected_table->Equals(*output_table)); - std::cout << "output" << std::endl; - std::cout << std::string(20, '-') << std::endl; - std::cout << output_table->ToString() << std::endl; - std::cout << std::string(20, '-') << std::endl; - - std::cout << "expected" << std::endl; - std::cout << std::string(20, '-') << std::endl; - std::cout << expected_table->ToString() << std::endl; - std::cout << std::string(20, '-') << std::endl; + // std::cout << "output" << std::endl; + // std::cout << std::string(20, '-') << std::endl; + // std::cout << output_table->ToString() << std::endl; + // std::cout << std::string(20, '-') << std::endl; + + // std::cout << "expected" << std::endl; + // std::cout << std::string(20, '-') << std::endl; + // std::cout << expected_table->ToString() << std::endl; + // std::cout << std::string(20, '-') << std::endl; } } std::shared_ptr output_schema; @@ -2631,7 +2631,7 @@ TEST(Substrait, FilterRelWithEmit) { EmitValidate(std::move(output_schema), std::move(expected_table), exec_context, buf)(); } -TEST(Substrait, JoinRelWithEmit) { +TEST(Substrait, JoinRelEndToEnd) { #ifdef _WIN32 GTEST_SKIP() << "ARROW-16392: Substrait File URI not supported for Windows"; #endif @@ -2820,7 +2820,180 @@ TEST(Substrait, JoinRelWithEmit) { std::move(include_columns))(); } -TEST(Substrait, AggregateRelEmit) { +TEST(Substrait, JoinRelWithEmit) { +#ifdef _WIN32 + GTEST_SKIP() << "ARROW-16392: Substrait File URI not supported for Windows"; +#endif + compute::ExecContext exec_context; + auto left_schema = schema({field("A", int32()), field("B", int32()), + field("C", int32()), field("D", int32())}); + + auto right_schema = schema({field("X", int32()), field("Y", int32()), + field("Z", int32()), field("W", int32())}); + + // creating a dummy dataset using a dummy table + auto left_table = TableFromJSON(left_schema, {R"([ + [10, 1, 80, 70], + [20, 2, 70, 60], + [30, 3, 30, 50] + ])"}); + + auto right_table = TableFromJSON(right_schema, {R"([ + [10, 1, 81, 71], + [80, 2, 71, 61], + [31, 3, 31, 51] + ])"}); + + ASSERT_OK_AND_ASSIGN(auto tempdir, + arrow::internal::TemporaryDir::Make("substrait_join_tempdir")); + std::string left_file_prefix = "serde_join_left_emit_test"; + std::string right_file_prefix = "serde_join_right_emit_test"; + + TempDataGenerator datagen_left(left_table, left_file_prefix, tempdir); + ASSERT_OK(datagen_left()); + std::string substrait_left_file_uri = "file://" + datagen_left.data_file_path; + + TempDataGenerator datagen_right(right_table, right_file_prefix, tempdir); + ASSERT_OK(datagen_right()); + std::string substrait_right_file_uri = "file://" + datagen_right.data_file_path; + + std::string substrait_json = R"({ + "relations": [{ + "rel": { + "join": { + "common": { + "emit": { + "outputMapping": [0, 1, 2, 3, 8, 9, 10, 11] + } + }, + "left": { + "read": { + "base_schema": { + "names": ["A", "B", "C", "D"], + "struct": { + "types": [{ + "i32": {} + }, { + "i32": {} + }, { + "i32": {} + }, { + "i32": {} + }] + } + }, + "local_files": { + "items": [ + { + "uri_file": ")" + + substrait_left_file_uri + + R"(", + "parquet": {} + } + ] + } + } + }, + "right": { + "read": { + "base_schema": { + "names": ["X", "Y", "Z", "W"], + "struct": { + "types": [{ + "i32": {} + }, { + "i32": {} + }, { + "i32": {} + }, { + "i32": {} + }] + } + }, + "local_files": { + "items": [ + { + "uri_file": ")" + + substrait_right_file_uri + + R"(", + "parquet": {} + } + ] + } + } + }, + "expression": { + "scalarFunction": { + "functionReference": 0, + "arguments": [{ + "value": { + "selection": { + "directReference": { + "structField": { + "field": 0 + } + }, + "rootReference": { + } + } + } + }, { + "value": { + "selection": { + "directReference": { + "structField": { + "field": 0 + } + }, + "rootReference": { + } + } + } + }], + "output_type": { + "bool": {} + } + } + }, + "type": "JOIN_TYPE_INNER" + } + } + }], + "extension_uris": [ + { + "extension_uri_anchor": 0, + "uri": ")" + std::string(kSubstraitComparisonFunctionsUri) + + R"(" + } + ], + "extensions": [ + {"extension_function": { + "extension_uri_reference": 0, + "function_anchor": 0, + "name": "equal" + }} + ] + })"; + + ASSERT_OK_AND_ASSIGN(auto buf, internal::SubstraitFromJSON("Plan", substrait_json)); + auto output_schema = schema({ + field("A", int32()), + field("B", int32()), + field("C", int32()), + field("D", int32()), + field("X", int32()), + field("Y", int32()), + field("Z", int32()), + field("W", int32()), + }); + + auto expected_table = TableFromJSON(std::move(output_schema), {R"([ + [10, 1, 80, 70, 10, 1, 81, 71] + ])"}); + EmitValidate(std::move(output_schema), std::move(expected_table), exec_context, buf)(); +} + +TEST(Substrait, AggregateRel) { #ifdef _WIN32 GTEST_SKIP() << "ARROW-16392: Substrait File URI not supported for Windows"; #endif @@ -2936,5 +3109,127 @@ TEST(Substrait, AggregateRelEmit) { EmitValidate(std::move(output_schema), std::move(expected_table), exec_context, buf)(); } +TEST(Substrait, AggregateRelEmit) { +#ifdef _WIN32 + GTEST_SKIP() << "ARROW-16392: Substrait File URI not supported for Windows"; +#endif + compute::ExecContext exec_context; + auto dummy_schema = + schema({field("A", int32()), field("B", int32()), field("C", int32())}); + + // creating a dummy dataset using a dummy table + auto input_table = TableFromJSON(dummy_schema, {R"([ + [10, 1, 80], + [20, 2, 70], + [30, 3, 30], + [40, 4, 20], + [40, 5, 40], + [20, 6, 20], + [30, 7, 30] + ])"}); + + ASSERT_OK_AND_ASSIGN(auto tempdir, + arrow::internal::TemporaryDir::Make("substrait_agg_tempdir")); + std::string file_prefix = "serde_agg_emit_test"; + + TempDataGenerator datagen(input_table, file_prefix, tempdir); + ASSERT_OK(datagen()); + std::string substrait_file_uri = "file://" + datagen.data_file_path; + // TODO: fixme https://issues.apache.org/jira/browse/ARROW-17484 + std::string substrait_json = R"({ + "relations": [{ + "rel": { + "aggregate": { + "common": { + "emit": { + "outputMapping": [0] + } + }, + "input": { + "read": { + "base_schema": { + "names": ["A", "B", "C"], + "struct": { + "types": [{ + "i32": {} + }, { + "i32": {} + }, { + "i32": {} + }] + } + }, + "local_files": { + "items": [ + { + "uri_file": ")" + + substrait_file_uri + + R"(", + "parquet": {} + } + ] + } + } + }, + "groupings": [{ + "groupingExpressions": [{ + "selection": { + "directReference": { + "structField": { + "field": 0 + } + } + } + }] + }], + "measures": [{ + "measure": { + "functionReference": 0, + "arguments": [{ + "value": { + "selection": { + "directReference": { + "structField": { + "field": 2 + } + } + } + } + }], + "sorts": [], + "phase": "AGGREGATION_PHASE_INITIAL_TO_RESULT", + "invocation": "AGGREGATION_INVOCATION_ALL", + "outputType": { + "i64": {} + } + } + }] + } + } + }], + "extensionUris": [{ + "extension_uri_anchor": 0, + "uri": "https://github.com/substrait-io/substrait/blob/main/extensions/functions_arithmetic.yaml" + }], + "extensions": [{ + "extension_function": { + "extension_uri_reference": 0, + "function_anchor": 0, + "name": "sum" + } + }], + })"; + + ASSERT_OK_AND_ASSIGN(auto buf, internal::SubstraitFromJSON("Plan", substrait_json)); + auto output_schema = schema({field("aggregates", int64())}); + auto expected_table = TableFromJSON(output_schema, {R"([ + [80], + [90], + [60], + [60] + ])"}); + EmitValidate(std::move(output_schema), std::move(expected_table), exec_context, buf)(); +} + } // namespace engine } // namespace arrow From 8da8b541289d203b3534bf9bb56f1d6743b9465d Mon Sep 17 00:00:00 2001 From: Vibhatha Abeykoon Date: Sun, 21 Aug 2022 08:00:50 +0530 Subject: [PATCH 12/22] fix(num_columns): fix the number of columns for emit feature --- .../engine/substrait/relation_internal.cc | 24 +++++++++---------- 1 file changed, 11 insertions(+), 13 deletions(-) diff --git a/cpp/src/arrow/engine/substrait/relation_internal.cc b/cpp/src/arrow/engine/substrait/relation_internal.cc index f16bb4a910c..46a5560bd74 100644 --- a/cpp/src/arrow/engine/substrait/relation_internal.cc +++ b/cpp/src/arrow/engine/substrait/relation_internal.cc @@ -50,7 +50,7 @@ bool HasEmit(const RelMessage& rel) { } template -Result> GetEmitExpression( +Result> GetEmitInfo( const RelMessage& rel, const std::shared_ptr& schema) { const auto& emit = rel.common().emit(); int emit_size = emit.output_mapping_size(); @@ -241,14 +241,13 @@ Result FromProto(const substrait::Rel& rel, const ExtensionSet& ARROW_ASSIGN_OR_RAISE(auto ds, ds_factory->Finish(base_schema)); if (HasEmit(read)) { - ARROW_ASSIGN_OR_RAISE(auto emit_expressions, - GetEmitExpression(read, base_schema)); + ARROW_ASSIGN_OR_RAISE(auto emit_expressions, GetEmitInfo(read, base_schema)); return DeclarationInfo{ compute::Declaration::Sequence( {{"scan", dataset::ScanNodeOptions{std::move(ds), std::move(scan_options)}}, {"project", compute::ProjectNodeOptions{std::move(emit_expressions)}}}), - num_columns, std::move(base_schema)}; + static_cast(emit_expressions.size()), std::move(base_schema)}; } else { return DeclarationInfo{ compute::Declaration{ @@ -275,13 +274,13 @@ Result FromProto(const substrait::Rel& rel, const ExtensionSet& if (HasEmit(filter)) { ARROW_ASSIGN_OR_RAISE(auto emit_expressions, - GetEmitExpression(filter, input.output_schema)); + GetEmitInfo(filter, input.output_schema)); return DeclarationInfo{ compute::Declaration::Sequence( {std::move(input.declaration), {"filter", compute::FilterNodeOptions{std::move(condition)}}, {"project", compute::ProjectNodeOptions{std::move(emit_expressions)}}}), - input.num_columns, input.output_schema}; + static_cast(emit_expressions.size()), input.output_schema}; } else { return DeclarationInfo{ compute::Declaration::Sequence({ @@ -341,7 +340,7 @@ Result FromProto(const substrait::Rel& rel, const ExtensionSet& auto num_columns = static_cast(expressions.size()); if (HasEmit(project)) { ARROW_ASSIGN_OR_RAISE(auto emit_expressions, - GetEmitExpression(project, project_schema)); + GetEmitInfo(project, project_schema)); return DeclarationInfo{ compute::Declaration::Sequence( {std::move(input.declaration), @@ -349,7 +348,7 @@ Result FromProto(const substrait::Rel& rel, const ExtensionSet& "project"}, {"project", compute::ProjectNodeOptions{std::move(emit_expressions)}, "emit"}}), - num_columns, project_schema}; + static_cast(emit_expressions.size()), project_schema}; } else { return DeclarationInfo{ compute::Declaration::Sequence({ @@ -456,13 +455,12 @@ Result FromProto(const substrait::Rel& rel, const ExtensionSet& join_dec.inputs.emplace_back(std::move(right.declaration)); if (HasEmit(join)) { - ARROW_ASSIGN_OR_RAISE(auto emit_expressions, - GetEmitExpression(join, join_schema)); + ARROW_ASSIGN_OR_RAISE(auto emit_expressions, GetEmitInfo(join, join_schema)); return DeclarationInfo{ compute::Declaration::Sequence( {std::move(join_dec), {"project", compute::ProjectNodeOptions{std::move(emit_expressions)}}}), - num_columns, std::move(join_schema)}; + static_cast(emit_expressions.size()), std::move(join_schema)}; } else { return DeclarationInfo{std::move(join_dec), num_columns, std::move(join_schema)}; } @@ -554,13 +552,13 @@ Result FromProto(const substrait::Rel& rel, const ExtensionSet& std::shared_ptr aggregate_schema = schema(std::move(output_fields)); if (HasEmit(aggregate)) { ARROW_ASSIGN_OR_RAISE(auto emit_expressions, - GetEmitExpression(aggregate, aggregate_schema)); + GetEmitInfo(aggregate, aggregate_schema)); return DeclarationInfo{ compute::Declaration::Sequence( {std::move(input.declaration), {"aggregate", compute::AggregateNodeOptions{aggregates, keys}}, {"project", compute::ProjectNodeOptions{std::move(emit_expressions)}}}), - static_cast(aggregates.size()), std::move(aggregate_schema)}; + static_cast(emit_expressions.size()), std::move(aggregate_schema)}; } else { return DeclarationInfo{ compute::Declaration::Sequence( From 1f4da760917d846a5f622f4502db3ed186d802e2 Mon Sep 17 00:00:00 2001 From: Vibhatha Abeykoon Date: Sun, 21 Aug 2022 11:04:20 +0530 Subject: [PATCH 13/22] fix(cleanup): cleaning up code --- cpp/src/arrow/engine/substrait/relation_internal.cc | 4 ---- cpp/src/arrow/engine/substrait/serde_test.cc | 10 ---------- 2 files changed, 14 deletions(-) diff --git a/cpp/src/arrow/engine/substrait/relation_internal.cc b/cpp/src/arrow/engine/substrait/relation_internal.cc index 46a5560bd74..921cebc5871 100644 --- a/cpp/src/arrow/engine/substrait/relation_internal.cc +++ b/cpp/src/arrow/engine/substrait/relation_internal.cc @@ -59,16 +59,12 @@ Result> GetEmitInfo( int32_t map_id = emit.output_mapping(i); proj_field_refs[i] = compute::field_ref(FieldRef(map_id)); } - // TODO: return emit size and expression as a tuple return std::move(proj_field_refs); } template Status CheckRelCommon(const RelMessage& rel) { if (rel.has_common()) { - // if (rel.common().has_emit()) { - // return Status::NotImplemented("substrait::RelCommon::Emit"); - // } if (rel.common().has_hint()) { return Status::NotImplemented("substrait::RelCommon::Hint"); } diff --git a/cpp/src/arrow/engine/substrait/serde_test.cc b/cpp/src/arrow/engine/substrait/serde_test.cc index 2960f92edc5..f4a98a88e92 100644 --- a/cpp/src/arrow/engine/substrait/serde_test.cc +++ b/cpp/src/arrow/engine/substrait/serde_test.cc @@ -257,16 +257,6 @@ struct EmitValidate { } ASSERT_OK_AND_ASSIGN(output_table, output_table->CombineChunks()); EXPECT_TRUE(expected_table->Equals(*output_table)); - - // std::cout << "output" << std::endl; - // std::cout << std::string(20, '-') << std::endl; - // std::cout << output_table->ToString() << std::endl; - // std::cout << std::string(20, '-') << std::endl; - - // std::cout << "expected" << std::endl; - // std::cout << std::string(20, '-') << std::endl; - // std::cout << expected_table->ToString() << std::endl; - // std::cout << std::string(20, '-') << std::endl; } } std::shared_ptr output_schema; From d99ddf432ad232424ae815d0ce7a6a2e578449e0 Mon Sep 17 00:00:00 2001 From: Vibhatha Abeykoon Date: Wed, 31 Aug 2022 21:03:17 +0530 Subject: [PATCH 14/22] fix(reviews): remove column count from DeclarationInfo --- .../engine/substrait/relation_internal.cc | 36 +++++++++---------- .../engine/substrait/relation_internal.h | 3 -- 2 files changed, 17 insertions(+), 22 deletions(-) diff --git a/cpp/src/arrow/engine/substrait/relation_internal.cc b/cpp/src/arrow/engine/substrait/relation_internal.cc index 921cebc5871..3849bd08968 100644 --- a/cpp/src/arrow/engine/substrait/relation_internal.cc +++ b/cpp/src/arrow/engine/substrait/relation_internal.cc @@ -93,7 +93,6 @@ Result FromProto(const substrait::Rel& rel, const ExtensionSet& ARROW_ASSIGN_OR_RAISE(auto base_schema, FromProto(read.base_schema(), ext_set, conversion_options)); - auto num_columns = static_cast(base_schema->fields().size()); auto scan_options = std::make_shared(); scan_options->use_threads = true; @@ -122,7 +121,7 @@ Result FromProto(const substrait::Rel& rel, const ExtensionSet& named_table.names().end()); ARROW_ASSIGN_OR_RAISE(compute::Declaration source_decl, named_table_provider(table_names)); - return DeclarationInfo{std::move(source_decl), num_columns, base_schema}; + return DeclarationInfo{std::move(source_decl), base_schema}; } if (!read.has_local_files()) { @@ -243,12 +242,12 @@ Result FromProto(const substrait::Rel& rel, const ExtensionSet& {{"scan", dataset::ScanNodeOptions{std::move(ds), std::move(scan_options)}}, {"project", compute::ProjectNodeOptions{std::move(emit_expressions)}}}), - static_cast(emit_expressions.size()), std::move(base_schema)}; + std::move(base_schema)}; } else { return DeclarationInfo{ compute::Declaration{ "scan", dataset::ScanNodeOptions{std::move(ds), std::move(scan_options)}}, - num_columns, std::move(base_schema)}; + std::move(base_schema)}; } } @@ -276,14 +275,14 @@ Result FromProto(const substrait::Rel& rel, const ExtensionSet& {std::move(input.declaration), {"filter", compute::FilterNodeOptions{std::move(condition)}}, {"project", compute::ProjectNodeOptions{std::move(emit_expressions)}}}), - static_cast(emit_expressions.size()), input.output_schema}; + input.output_schema}; } else { return DeclarationInfo{ compute::Declaration::Sequence({ std::move(input.declaration), {"filter", compute::FilterNodeOptions{std::move(condition)}}, }), - input.num_columns, input.output_schema}; + input.output_schema}; } } @@ -299,8 +298,9 @@ Result FromProto(const substrait::Rel& rel, const ExtensionSet& // NOTE: Substrait ProjectRels *append* columns, while Acero's project node replaces // them. Therefore, we need to prefix all the current columns for compatibility. std::vector expressions; - expressions.reserve(input.num_columns + project.expressions().size()); - for (int i = 0; i < input.num_columns; i++) { + int num_columns = input.output_schema->num_fields(); + expressions.reserve(num_columns + project.expressions().size()); + for (int i = 0; i < num_columns; i++) { expressions.emplace_back(compute::field_ref(FieldRef(i))); } std::vector> new_fields(project.expressions().size()); @@ -319,7 +319,7 @@ Result FromProto(const substrait::Rel& rel, const ExtensionSet& ARROW_ASSIGN_OR_RAISE(new_fields[i], field_path.Get(*input.output_schema)); } else if (auto* literal = des_expr.literal()) { new_fields[i] = - field("field_" + std::to_string(input.num_columns + i), literal->type()); + field("field_" + std::to_string(num_columns + i), literal->type()); } i++; expressions.emplace_back(des_expr); @@ -329,11 +329,10 @@ Result FromProto(const substrait::Rel& rel, const ExtensionSet& ARROW_ASSIGN_OR_RAISE( project_schema, project_schema->AddField( - input.num_columns + static_cast(new_fields.size()) - 1, - std::move(field))); + num_columns + static_cast(new_fields.size()) - 1, std::move(field))); new_fields.pop_back(); } - auto num_columns = static_cast(expressions.size()); + if (HasEmit(project)) { ARROW_ASSIGN_OR_RAISE(auto emit_expressions, GetEmitInfo(project, project_schema)); @@ -344,14 +343,14 @@ Result FromProto(const substrait::Rel& rel, const ExtensionSet& "project"}, {"project", compute::ProjectNodeOptions{std::move(emit_expressions)}, "emit"}}), - static_cast(emit_expressions.size()), project_schema}; + project_schema}; } else { return DeclarationInfo{ compute::Declaration::Sequence({ std::move(input.declaration), {"project", compute::ProjectNodeOptions{std::move(expressions)}}, }), - num_columns, project_schema}; + project_schema}; } } @@ -446,7 +445,6 @@ Result FromProto(const substrait::Rel& rel, const ExtensionSet& join_options.join_type = join_type; join_options.key_cmp = {join_key_cmp}; compute::Declaration join_dec{"hashjoin", std::move(join_options)}; - auto num_columns = left.num_columns + right.num_columns; join_dec.inputs.emplace_back(std::move(left.declaration)); join_dec.inputs.emplace_back(std::move(right.declaration)); @@ -456,9 +454,9 @@ Result FromProto(const substrait::Rel& rel, const ExtensionSet& compute::Declaration::Sequence( {std::move(join_dec), {"project", compute::ProjectNodeOptions{std::move(emit_expressions)}}}), - static_cast(emit_expressions.size()), std::move(join_schema)}; + std::move(join_schema)}; } else { - return DeclarationInfo{std::move(join_dec), num_columns, std::move(join_schema)}; + return DeclarationInfo{std::move(join_dec), std::move(join_schema)}; } } case substrait::Rel::RelTypeCase::kAggregate: { @@ -554,13 +552,13 @@ Result FromProto(const substrait::Rel& rel, const ExtensionSet& {std::move(input.declaration), {"aggregate", compute::AggregateNodeOptions{aggregates, keys}}, {"project", compute::ProjectNodeOptions{std::move(emit_expressions)}}}), - static_cast(emit_expressions.size()), std::move(aggregate_schema)}; + std::move(aggregate_schema)}; } else { return DeclarationInfo{ compute::Declaration::Sequence( {std::move(input.declaration), {"aggregate", compute::AggregateNodeOptions{aggregates, keys}}}), - static_cast(aggregates.size()), std::move(aggregate_schema)}; + std::move(aggregate_schema)}; } } diff --git a/cpp/src/arrow/engine/substrait/relation_internal.h b/cpp/src/arrow/engine/substrait/relation_internal.h index d8c14ea14a6..514f3f97fc0 100644 --- a/cpp/src/arrow/engine/substrait/relation_internal.h +++ b/cpp/src/arrow/engine/substrait/relation_internal.h @@ -36,9 +36,6 @@ struct DeclarationInfo { /// The compute declaration produced thus far. compute::Declaration declaration; - /// The number of columns returned by the declaration. - int num_columns; - std::shared_ptr output_schema; }; From 81ad00b8c5168ce3c7e3616ac4492cdd55fffe09 Mon Sep 17 00:00:00 2001 From: Vibhatha Abeykoon Date: Wed, 31 Aug 2022 21:29:38 +0530 Subject: [PATCH 15/22] fix(reviews): removed a redundant loop --- .../engine/substrait/relation_internal.cc | 20 +++++++++---------- 1 file changed, 9 insertions(+), 11 deletions(-) diff --git a/cpp/src/arrow/engine/substrait/relation_internal.cc b/cpp/src/arrow/engine/substrait/relation_internal.cc index 3849bd08968..c161f089095 100644 --- a/cpp/src/arrow/engine/substrait/relation_internal.cc +++ b/cpp/src/arrow/engine/substrait/relation_internal.cc @@ -303,34 +303,32 @@ Result FromProto(const substrait::Rel& rel, const ExtensionSet& for (int i = 0; i < num_columns; i++) { expressions.emplace_back(compute::field_ref(FieldRef(i))); } - std::vector> new_fields(project.expressions().size()); + int i = 0; auto project_schema = input.output_schema; for (const auto& expr : project.expressions()) { + std::shared_ptr project_field; ARROW_ASSIGN_OR_RAISE(compute::Expression des_expr, FromProto(expr, ext_set, conversion_options)); auto bound_expr = des_expr.Bind(*input.output_schema); if (auto* expr_call = bound_expr->call()) { - new_fields[i] = field(expr_call->function_name, + project_field = field(expr_call->function_name, expr_call->kernel->signature->out_type().type()); } else if (auto* field_ref = des_expr.field_ref()) { ARROW_ASSIGN_OR_RAISE(FieldPath field_path, field_ref->FindOne(*input.output_schema)); - ARROW_ASSIGN_OR_RAISE(new_fields[i], field_path.Get(*input.output_schema)); + ARROW_ASSIGN_OR_RAISE(project_field, field_path.Get(*input.output_schema)); } else if (auto* literal = des_expr.literal()) { - new_fields[i] = + project_field = field("field_" + std::to_string(num_columns + i), literal->type()); } - i++; - expressions.emplace_back(des_expr); - } - while (!new_fields.empty()) { - auto field = new_fields.back(); ARROW_ASSIGN_OR_RAISE( project_schema, project_schema->AddField( - num_columns + static_cast(new_fields.size()) - 1, std::move(field))); - new_fields.pop_back(); + num_columns + static_cast(project.expressions().size()) - 1, + std::move(project_field))); + i++; + expressions.emplace_back(des_expr); } if (HasEmit(project)) { From bba665f17c06bbe09e677adc6ee91c16135ce7cf Mon Sep 17 00:00:00 2001 From: Vibhatha Abeykoon Date: Wed, 31 Aug 2022 23:38:17 +0530 Subject: [PATCH 16/22] fix(reviews): updated the emit processing logic and added switch cases --- .../engine/substrait/relation_internal.cc | 138 ++++++++---------- 1 file changed, 61 insertions(+), 77 deletions(-) diff --git a/cpp/src/arrow/engine/substrait/relation_internal.cc b/cpp/src/arrow/engine/substrait/relation_internal.cc index c161f089095..e01ee5050bd 100644 --- a/cpp/src/arrow/engine/substrait/relation_internal.cc +++ b/cpp/src/arrow/engine/substrait/relation_internal.cc @@ -62,6 +62,30 @@ Result> GetEmitInfo( return std::move(proj_field_refs); } +template +Result ProcessEmit(const RelMessage& rel, + const DeclarationInfo& no_emit_declr, + const std::shared_ptr& schema) { + if (rel.has_common()) { + switch (rel.common().emit_kind_case()) { + case substrait::RelCommon::EmitKindCase::kDirect: + return no_emit_declr; + case substrait::RelCommon::EmitKindCase::kEmit: { + ARROW_ASSIGN_OR_RAISE(auto emit_expressions, GetEmitInfo(rel, schema)); + return DeclarationInfo{ + compute::Declaration::Sequence( + {no_emit_declr.declaration, + {"project", compute::ProjectNodeOptions{std::move(emit_expressions)}}}), + std::move(schema)}; + } + default: + return Status::Invalid("Invalid emit case"); + } + } else { + return no_emit_declr; + } +} + template Status CheckRelCommon(const RelMessage& rel) { if (rel.has_common()) { @@ -235,20 +259,12 @@ Result FromProto(const substrait::Rel& rel, const ExtensionSet& ARROW_ASSIGN_OR_RAISE(auto ds, ds_factory->Finish(base_schema)); - if (HasEmit(read)) { - ARROW_ASSIGN_OR_RAISE(auto emit_expressions, GetEmitInfo(read, base_schema)); - return DeclarationInfo{ - compute::Declaration::Sequence( - {{"scan", - dataset::ScanNodeOptions{std::move(ds), std::move(scan_options)}}, - {"project", compute::ProjectNodeOptions{std::move(emit_expressions)}}}), - std::move(base_schema)}; - } else { - return DeclarationInfo{ - compute::Declaration{ - "scan", dataset::ScanNodeOptions{std::move(ds), std::move(scan_options)}}, - std::move(base_schema)}; - } + DeclarationInfo no_emit_declaration = { + compute::Declaration{"scan", dataset::ScanNodeOptions{ds, scan_options}}, + base_schema}; + + return ProcessEmit(std::move(read), std::move(no_emit_declaration), + std::move(base_schema)); } case substrait::Rel::RelTypeCase::kFilter: { @@ -266,24 +282,15 @@ Result FromProto(const substrait::Rel& rel, const ExtensionSet& } ARROW_ASSIGN_OR_RAISE(auto condition, FromProto(filter.condition(), ext_set, conversion_options)); - - if (HasEmit(filter)) { - ARROW_ASSIGN_OR_RAISE(auto emit_expressions, - GetEmitInfo(filter, input.output_schema)); - return DeclarationInfo{ - compute::Declaration::Sequence( - {std::move(input.declaration), - {"filter", compute::FilterNodeOptions{std::move(condition)}}, - {"project", compute::ProjectNodeOptions{std::move(emit_expressions)}}}), - input.output_schema}; - } else { - return DeclarationInfo{ - compute::Declaration::Sequence({ - std::move(input.declaration), - {"filter", compute::FilterNodeOptions{std::move(condition)}}, - }), - input.output_schema}; - } + DeclarationInfo no_emit_declaration{ + compute::Declaration::Sequence({ + std::move(input.declaration), + {"filter", compute::FilterNodeOptions{std::move(condition)}}, + }), + input.output_schema}; + + return ProcessEmit(std::move(filter), std::move(no_emit_declaration), + input.output_schema); } case substrait::Rel::RelTypeCase::kProject: { @@ -331,25 +338,15 @@ Result FromProto(const substrait::Rel& rel, const ExtensionSet& expressions.emplace_back(des_expr); } - if (HasEmit(project)) { - ARROW_ASSIGN_OR_RAISE(auto emit_expressions, - GetEmitInfo(project, project_schema)); - return DeclarationInfo{ - compute::Declaration::Sequence( - {std::move(input.declaration), - {"project", compute::ProjectNodeOptions{std::move(expressions)}, - "project"}, - {"project", compute::ProjectNodeOptions{std::move(emit_expressions)}, - "emit"}}), - project_schema}; - } else { - return DeclarationInfo{ - compute::Declaration::Sequence({ - std::move(input.declaration), - {"project", compute::ProjectNodeOptions{std::move(expressions)}}, - }), - project_schema}; - } + DeclarationInfo no_emit_declaration{ + compute::Declaration::Sequence({ + std::move(input.declaration), + {"project", compute::ProjectNodeOptions{std::move(expressions)}}, + }), + project_schema}; + + return ProcessEmit(std::move(project), std::move(no_emit_declaration), + std::move(project_schema)); } case substrait::Rel::RelTypeCase::kJoin: { @@ -446,16 +443,10 @@ Result FromProto(const substrait::Rel& rel, const ExtensionSet& join_dec.inputs.emplace_back(std::move(left.declaration)); join_dec.inputs.emplace_back(std::move(right.declaration)); - if (HasEmit(join)) { - ARROW_ASSIGN_OR_RAISE(auto emit_expressions, GetEmitInfo(join, join_schema)); - return DeclarationInfo{ - compute::Declaration::Sequence( - {std::move(join_dec), - {"project", compute::ProjectNodeOptions{std::move(emit_expressions)}}}), - std::move(join_schema)}; - } else { - return DeclarationInfo{std::move(join_dec), std::move(join_schema)}; - } + DeclarationInfo no_emit_declaration{std::move(join_dec), join_schema}; + + return ProcessEmit(std::move(join), std::move(no_emit_declaration), + std::move(join_schema)); } case substrait::Rel::RelTypeCase::kAggregate: { const auto& aggregate = rel.aggregate(); @@ -542,22 +533,15 @@ Result FromProto(const substrait::Rel& rel, const ExtensionSet& } std::shared_ptr aggregate_schema = schema(std::move(output_fields)); - if (HasEmit(aggregate)) { - ARROW_ASSIGN_OR_RAISE(auto emit_expressions, - GetEmitInfo(aggregate, aggregate_schema)); - return DeclarationInfo{ - compute::Declaration::Sequence( - {std::move(input.declaration), - {"aggregate", compute::AggregateNodeOptions{aggregates, keys}}, - {"project", compute::ProjectNodeOptions{std::move(emit_expressions)}}}), - std::move(aggregate_schema)}; - } else { - return DeclarationInfo{ - compute::Declaration::Sequence( - {std::move(input.declaration), - {"aggregate", compute::AggregateNodeOptions{aggregates, keys}}}), - std::move(aggregate_schema)}; - } + + DeclarationInfo no_emit_declaration{ + compute::Declaration::Sequence( + {std::move(input.declaration), + {"aggregate", compute::AggregateNodeOptions{aggregates, keys}}}), + aggregate_schema}; + + return ProcessEmit(std::move(aggregate), std::move(no_emit_declaration), + std::move(aggregate_schema)); } default: From 7eb462399f02ad69903802b68ac6b6567dc60c0c Mon Sep 17 00:00:00 2001 From: Vibhatha Abeykoon Date: Thu, 1 Sep 2022 06:26:55 +0530 Subject: [PATCH 17/22] fix(path_issue): added a check for replacing clause --- cpp/src/arrow/engine/substrait/serde_test.cc | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/cpp/src/arrow/engine/substrait/serde_test.cc b/cpp/src/arrow/engine/substrait/serde_test.cc index f4a98a88e92..3ed51d17c85 100644 --- a/cpp/src/arrow/engine/substrait/serde_test.cc +++ b/cpp/src/arrow/engine/substrait/serde_test.cc @@ -212,7 +212,9 @@ struct TempDataGenerator { std::string toReplace("/T//"); size_t pos = data_file_path.find(toReplace); - data_file_path.replace(pos, toReplace.length(), "/T/"); + if (pos >= 0) { + data_file_path.replace(pos, toReplace.length(), "/T/"); + } ARROW_EXPECT_OK(WriteParquetData(data_file_path, filesystem, input_table)); return Status::OK(); From 54b18df9b2835845b11660a2f057ee1e74351415 Mon Sep 17 00:00:00 2001 From: Vibhatha Abeykoon Date: Thu, 1 Sep 2022 13:49:49 +0530 Subject: [PATCH 18/22] fix(path): remove temp path fix --- cpp/src/arrow/engine/substrait/serde_test.cc | 9 --------- 1 file changed, 9 deletions(-) diff --git a/cpp/src/arrow/engine/substrait/serde_test.cc b/cpp/src/arrow/engine/substrait/serde_test.cc index 3ed51d17c85..66806483f6d 100644 --- a/cpp/src/arrow/engine/substrait/serde_test.cc +++ b/cpp/src/arrow/engine/substrait/serde_test.cc @@ -204,18 +204,9 @@ struct TempDataGenerator { Status operator()() { auto format = std::make_shared(); auto filesystem = std::make_shared(); - const std::string file_name = file_prefix + ".parquet"; - ARROW_ASSIGN_OR_RAISE(auto file_path, tempdir->path().Join(file_name)); data_file_path = file_path.ToString(); - - std::string toReplace("/T//"); - size_t pos = data_file_path.find(toReplace); - if (pos >= 0) { - data_file_path.replace(pos, toReplace.length(), "/T/"); - } - ARROW_EXPECT_OK(WriteParquetData(data_file_path, filesystem, input_table)); return Status::OK(); } From 5051070dc7be84a0916d6432c7465d15f19a413a Mon Sep 17 00:00:00 2001 From: Vibhatha Abeykoon Date: Wed, 7 Sep 2022 12:23:12 +0530 Subject: [PATCH 19/22] fix(reviews): imd commit --- cpp/src/arrow/engine/substrait/serde_test.cc | 563 ++++++++----------- 1 file changed, 241 insertions(+), 322 deletions(-) diff --git a/cpp/src/arrow/engine/substrait/serde_test.cc b/cpp/src/arrow/engine/substrait/serde_test.cc index 66806483f6d..bafdd7b0f21 100644 --- a/cpp/src/arrow/engine/substrait/serde_test.cc +++ b/cpp/src/arrow/engine/substrait/serde_test.cc @@ -217,47 +217,33 @@ struct TempDataGenerator { std::string data_file_path; }; -struct EmitValidate { - EmitValidate(const std::shared_ptr output_schema, - const std::shared_ptr
expected_table, - compute::ExecContext& exec_context, std::shared_ptr& buf, - const std::vector& include_columns = {}) - : output_schema(output_schema), - expected_table(expected_table), - exec_context(exec_context), - buf(buf), - include_columns(include_columns) {} - void operator()() { - for (auto sp_ext_id_reg : - {std::shared_ptr(), MakeExtensionIdRegistry()}) { - ExtensionIdRegistry* ext_id_reg = sp_ext_id_reg.get(); - ExtensionSet ext_set(ext_id_reg); - ASSERT_OK_AND_ASSIGN(auto sink_decls, - DeserializePlans( - *buf, [] { return kNullConsumer; }, ext_id_reg, &ext_set)); - auto other_declrs = sink_decls[0].inputs[0].get(); - arrow::AsyncGenerator> sink_gen; - auto sink_node_options = compute::SinkNodeOptions{&sink_gen}; - auto sink_declaration = compute::Declaration({"sink", sink_node_options, "e"}); - auto declarations = - compute::Declaration::Sequence({*other_declrs, sink_declaration}); - ASSERT_OK_AND_ASSIGN(auto acero_plan, compute::ExecPlan::Make(&exec_context)); - ASSERT_OK_AND_ASSIGN(auto output_table, - GetTableFromPlan(acero_plan, declarations, sink_gen, - exec_context, output_schema)); - if (!include_columns.empty()) { - ASSERT_OK_AND_ASSIGN(output_table, output_table->SelectColumns(include_columns)); - } - ASSERT_OK_AND_ASSIGN(output_table, output_table->CombineChunks()); - EXPECT_TRUE(expected_table->Equals(*output_table)); - } +void CheckRoundTripResult(const std::shared_ptr output_schema, + const std::shared_ptr
expected_table, + compute::ExecContext& exec_context, + std::shared_ptr& buf, + const std::vector& include_columns = {}, + const ConversionOptions& conversion_options = {}) { + std::shared_ptr sp_ext_id_reg = MakeExtensionIdRegistry(); + ExtensionIdRegistry* ext_id_reg = sp_ext_id_reg.get(); + ExtensionSet ext_set(ext_id_reg); + ASSERT_OK_AND_ASSIGN(auto sink_decls, DeserializePlans( + *buf, [] { return kNullConsumer; }, + ext_id_reg, &ext_set, conversion_options)); + auto other_declrs = sink_decls[0].inputs[0].get(); + arrow::AsyncGenerator> sink_gen; + auto sink_node_options = compute::SinkNodeOptions{&sink_gen}; + auto sink_declaration = compute::Declaration({"sink", sink_node_options, "e"}); + auto declarations = compute::Declaration::Sequence({*other_declrs, sink_declaration}); + ASSERT_OK_AND_ASSIGN(auto acero_plan, compute::ExecPlan::Make(&exec_context)); + ASSERT_OK_AND_ASSIGN( + auto output_table, + GetTableFromPlan(acero_plan, declarations, sink_gen, exec_context, output_schema)); + if (!include_columns.empty()) { + ASSERT_OK_AND_ASSIGN(output_table, output_table->SelectColumns(include_columns)); } - std::shared_ptr output_schema; - std::shared_ptr
expected_table; - compute::ExecContext exec_context; - std::shared_ptr buf; - const std::vector& include_columns; -}; + ASSERT_OK_AND_ASSIGN(output_table, output_table->CombineChunks()); + EXPECT_TRUE(expected_table->Equals(*output_table)); +} TEST(Substrait, SupportedTypes) { auto ExpectEq = [](util::string_view json, std::shared_ptr expected_type) { @@ -2205,14 +2191,6 @@ TEST(Substrait, ProjectRel) { [3, 4, 20] ])"}); - std::string file_prefix = "serde_project_test"; - ASSERT_OK_AND_ASSIGN(auto tempdir, - arrow::internal::TemporaryDir::Make("substrait_project_tempdir")); - - TempDataGenerator datagen(input_table, file_prefix, tempdir); - ASSERT_OK(datagen()); - std::string substrait_file_uri = "file://" + datagen.data_file_path; - std::string substrait_json = R"({ "relations": [{ "rel": { @@ -2263,15 +2241,8 @@ TEST(Substrait, ProjectRel) { }] } }, - "local_files": { - "items": [ - { - "uri_file": ")" + - substrait_file_uri + - R"(", - "parquet": {} - } - ] + "namedTable": { + "names": [] } } } @@ -2301,7 +2272,18 @@ TEST(Substrait, ProjectRel) { [1, 1, 10, 2], [3, 4, 20, 7] ])"}); - EmitValidate(std::move(output_schema), std::move(expected_table), exec_context, buf)(); + + NamedTableProvider table_provider = [input_table](const std::vector&) { + std::shared_ptr options = + std::make_shared(input_table); + return compute::Declaration("table_source", {}, options, "mock_source"); + }; + + ConversionOptions conversion_options; + conversion_options.named_table_provider = std::move(table_provider); + + CheckRoundTripResult(std::move(output_schema), std::move(expected_table), exec_context, + buf, {}, conversion_options); } TEST(Substrait, ProjectRelOnFunctionWithEmit) { @@ -2318,14 +2300,6 @@ TEST(Substrait, ProjectRelOnFunctionWithEmit) { [3, 4, 20] ])"}); - std::string file_prefix = "serde_project_emit_test"; - ASSERT_OK_AND_ASSIGN(auto tempdir, arrow::internal::TemporaryDir::Make( - "substrait_project_emit_tempdir")); - - TempDataGenerator datagen(input_table, file_prefix, tempdir); - ASSERT_OK(datagen()); - std::string substrait_file_uri = "file://" + datagen.data_file_path; - std::string substrait_json = R"({ "relations": [{ "rel": { @@ -2380,15 +2354,8 @@ TEST(Substrait, ProjectRelOnFunctionWithEmit) { }] } }, - "local_files": { - "items": [ - { - "uri_file": ")" + - substrait_file_uri + - R"(", - "parquet": {} - } - ] + "namedTable": { + "names": [] } } } @@ -2418,75 +2385,90 @@ TEST(Substrait, ProjectRelOnFunctionWithEmit) { [1, 10, 2], [3, 20, 7] ])"}); - EmitValidate(std::move(output_schema), std::move(expected_table), exec_context, buf)(); -} - -TEST(Substrait, ReadRelWithEmit) { -#ifdef _WIN32 - GTEST_SKIP() << "ARROW-16392: Substrait File URI not supported for Windows"; -#endif - compute::ExecContext exec_context; - auto dummy_schema = - schema({field("A", int32()), field("B", int32()), field("C", int32())}); - - // creating a dummy dataset using a dummy table - auto input_table = TableFromJSON(dummy_schema, {R"([ - [1, 1, 10], - [3, 4, 20] - ])"}); - - ASSERT_OK_AND_ASSIGN(auto tempdir, - arrow::internal::TemporaryDir::Make("substrait_read_tempdir")); - std::string file_prefix = "serde_read_emit_test"; - - TempDataGenerator datagen(input_table, file_prefix, tempdir); - ASSERT_OK(datagen()); - std::string substrait_file_uri = "file://" + datagen.data_file_path; + NamedTableProvider table_provider = [input_table](const std::vector&) { + std::shared_ptr options = + std::make_shared(input_table); + return compute::Declaration("table_source", {}, options, "mock_source"); + }; - std::string substrait_json = R"({ - "relations": [{ - "rel": { - "read": { - "common": { - "emit": { - "outputMapping": [1, 2] - } - }, - "base_schema": { - "names": ["A", "B", "C"], - "struct": { - "types": [{ - "i32": {} - }, { - "i32": {} - }, { - "i32": {} - }] - } - }, - "local_files": { - "items": [ - { - "uri_file": ")" + substrait_file_uri + - R"(", - "parquet": {} - } - ] - } - } - } - }], - })"; + ConversionOptions conversion_options; + conversion_options.named_table_provider = std::move(table_provider); - ASSERT_OK_AND_ASSIGN(auto buf, internal::SubstraitFromJSON("Plan", substrait_json)); - auto output_schema = schema({field("B", int32()), field("C", int32())}); - auto expected_table = TableFromJSON(output_schema, {R"([ - [1, 10], - [4, 20] - ])"}); - EmitValidate(std::move(output_schema), std::move(expected_table), exec_context, buf)(); + CheckRoundTripResult(std::move(output_schema), std::move(expected_table), exec_context, + buf, {}, conversion_options); } +// TEST(Substrait, ReadRelWithEmit) { +// #ifdef _WIN32 +// GTEST_SKIP() << "ARROW-16392: Substrait File URI not supported for Windows"; +// #endif +// compute::ExecContext exec_context; +// auto dummy_schema = +// schema({field("A", int32()), field("B", int32()), field("C", int32())}); + +// // creating a dummy dataset using a dummy table +// auto input_table = TableFromJSON(dummy_schema, {R"([ +// [1, 1, 10], +// [3, 4, 20] +// ])"}); + +// ASSERT_OK_AND_ASSIGN(auto tempdir, +// arrow::internal::TemporaryDir::Make("substrait_read_tempdir")); +// std::string file_prefix = "serde_read_emit_test"; + +// TempDataGenerator datagen(input_table, file_prefix, tempdir); +// ASSERT_OK(datagen()); +// std::string substrait_file_uri = "file://" + datagen.data_file_path; + +// std::string substrait_json = R"({ +// "relations": [{ +// "rel": { +// "read": { +// "common": { +// "emit": { +// "outputMapping": [1, 2] +// } +// }, +// "base_schema": { +// "names": ["A", "B", "C"], +// "struct": { +// "types": [{ +// "i32": {} +// }, { +// "i32": {} +// }, { +// "i32": {} +// }] +// } +// }, +// namedTable: { +// "names" : [] +// } +// } +// } +// }], +// })"; + +// ASSERT_OK_AND_ASSIGN(auto buf, internal::SubstraitFromJSON("Plan", substrait_json)); +// auto output_schema = schema({field("B", int32()), field("C", int32())}); +// auto expected_table = TableFromJSON(output_schema, {R"([ +// [1, 10], +// [4, 20] +// ])"}); +// NamedTableProvider table_provider = [input_table](const std::vector&) { +// std::shared_ptr options = +// std::make_shared(input_table); +// return compute::Declaration("table_source", {}, options, "mock_source"); +// }; + +// ConversionOptions conversion_options; +// conversion_options.named_table_provider = std::move(table_provider); + +// CheckRoundTripResult(std::move(output_schema), std::move(expected_table), +// exec_context, +// buf, {}, conversion_options); +// } + TEST(Substrait, FilterRelWithEmit) { #ifdef _WIN32 GTEST_SKIP() << "ARROW-16392: Substrait File URI not supported for Windows"; @@ -2572,15 +2554,8 @@ TEST(Substrait, FilterRelWithEmit) { }] } }, - "local_files": { - "items": [ - { - "uri_file": ")" + - substrait_file_uri + - R"(", - "parquet": {} - } - ] + "namedTable": { + "names" : [] } } } @@ -2611,7 +2586,17 @@ TEST(Substrait, FilterRelWithEmit) { [6, 2], [7, 1] ])"}); - EmitValidate(std::move(output_schema), std::move(expected_table), exec_context, buf)(); + NamedTableProvider table_provider = [input_table](const std::vector&) { + std::shared_ptr options = + std::make_shared(input_table); + return compute::Declaration("table_source", {}, options, "mock_source"); + }; + + ConversionOptions conversion_options; + conversion_options.named_table_provider = std::move(table_provider); + + CheckRoundTripResult(std::move(output_schema), std::move(expected_table), exec_context, + buf, {}, conversion_options); } TEST(Substrait, JoinRelEndToEnd) { @@ -2619,38 +2604,23 @@ TEST(Substrait, JoinRelEndToEnd) { GTEST_SKIP() << "ARROW-16392: Substrait File URI not supported for Windows"; #endif compute::ExecContext exec_context; - auto left_schema = schema({field("A", int32()), field("B", int32()), - field("C", int32()), field("D", int32())}); + auto left_schema = schema({field("A", int32()), field("B", int32())}); - auto right_schema = schema({field("X", int32()), field("Y", int32()), - field("Z", int32()), field("W", int32())}); + auto right_schema = schema({field("X", int32()), field("Y", int32())}); // creating a dummy dataset using a dummy table auto left_table = TableFromJSON(left_schema, {R"([ - [10, 1, 80, 70], - [20, 2, 70, 60], - [30, 3, 30, 50] + [10, 1], + [20, 2], + [30, 3] ])"}); auto right_table = TableFromJSON(right_schema, {R"([ - [10, 1, 81, 71], - [80, 2, 71, 61], - [31, 3, 31, 51] + [10, 11], + [80, 21], + [31, 31] ])"}); - ASSERT_OK_AND_ASSIGN(auto tempdir, - arrow::internal::TemporaryDir::Make("substrait_join_tempdir")); - std::string left_file_prefix = "serde_join_left_emit_test"; - std::string right_file_prefix = "serde_join_right_emit_test"; - - TempDataGenerator datagen_left(left_table, left_file_prefix, tempdir); - ASSERT_OK(datagen_left()); - std::string substrait_left_file_uri = "file://" + datagen_left.data_file_path; - - TempDataGenerator datagen_right(right_table, right_file_prefix, tempdir); - ASSERT_OK(datagen_right()); - std::string substrait_right_file_uri = "file://" + datagen_right.data_file_path; - std::string substrait_json = R"({ "relations": [{ "rel": { @@ -2658,56 +2628,34 @@ TEST(Substrait, JoinRelEndToEnd) { "left": { "read": { "base_schema": { - "names": ["A", "B", "C", "D"], + "names": ["A", "B"], "struct": { "types": [{ "i32": {} }, { "i32": {} - }, { - "i32": {} - }, { - "i32": {} }] } }, - "local_files": { - "items": [ - { - "uri_file": ")" + - substrait_left_file_uri + - R"(", - "parquet": {} - } - ] + "namedTable": { + "names" : ["left"] } } }, "right": { "read": { "base_schema": { - "names": ["X", "Y", "Z", "W"], + "names": ["X", "Y"], "struct": { "types": [{ "i32": {} }, { "i32": {} - }, { - "i32": {} - }, { - "i32": {} }] } }, - "local_files": { - "items": [ - { - "uri_file": ")" + - substrait_right_file_uri + - R"(", - "parquet": {} - } - ] + "namedTable": { + "names" : ["right"] } } }, @@ -2765,42 +2713,40 @@ TEST(Substrait, JoinRelEndToEnd) { })"; ASSERT_OK_AND_ASSIGN(auto buf, internal::SubstraitFromJSON("Plan", substrait_json)); - auto output_schema = schema({ - field("A", int32()), - field("B", int32()), - field("C", int32()), - field("D", int32()), - field("__fragment_index_l", int32()), - field("__batch_index_l", int32()), - field("__last_in_fragment_l", boolean()), - field("__filename_l", utf8()), - field("X", int32()), - field("Y", int32()), - field("Z", int32()), - field("W", int32()), - field("__fragment_index_r", int32()), - field("__batch_index_r", int32()), - field("__last_in_fragment_r", boolean()), - field("__filename_r", utf8()), - }); // include these columns for comparison - std::vector include_columns{0, 1, 2, 3, 8, 9, 10, 11}; - auto compared_output_schema = schema({ + auto output_schema = schema({ field("A", int32()), field("B", int32()), - field("C", int32()), - field("D", int32()), field("X", int32()), field("Y", int32()), - field("Z", int32()), - field("W", int32()), }); - auto expected_table = TableFromJSON(std::move(compared_output_schema), {R"([ - [10, 1, 80, 70, 10, 1, 81, 71] + + auto expected_table = TableFromJSON(std::move(output_schema), {R"([ + [10, 1, 10, 11] ])"}); - EmitValidate(std::move(output_schema), std::move(expected_table), exec_context, buf, - std::move(include_columns))(); + + NamedTableProvider table_provider = + [left_table, right_table](const std::vector& names) { + std::shared_ptr
output_table; + for (const auto& name : names) { + if (name == "left") { + output_table = left_table; + } + if (name == "right") { + output_table = right_table; + } + } + std::shared_ptr options = + std::make_shared(std::move(output_table)); + return compute::Declaration("table_source", {}, options, "mock_source"); + }; + + ConversionOptions conversion_options; + conversion_options.named_table_provider = std::move(table_provider); + + CheckRoundTripResult(std::move(output_schema), std::move(expected_table), exec_context, + buf, {}, conversion_options); } TEST(Substrait, JoinRelWithEmit) { @@ -2808,100 +2754,63 @@ TEST(Substrait, JoinRelWithEmit) { GTEST_SKIP() << "ARROW-16392: Substrait File URI not supported for Windows"; #endif compute::ExecContext exec_context; - auto left_schema = schema({field("A", int32()), field("B", int32()), - field("C", int32()), field("D", int32())}); + auto left_schema = schema({field("A", int32()), field("B", int32())}); - auto right_schema = schema({field("X", int32()), field("Y", int32()), - field("Z", int32()), field("W", int32())}); + auto right_schema = schema({field("X", int32()), field("Y", int32())}); // creating a dummy dataset using a dummy table auto left_table = TableFromJSON(left_schema, {R"([ - [10, 1, 80, 70], - [20, 2, 70, 60], - [30, 3, 30, 50] + [10, 1], + [20, 2], + [30, 3] ])"}); auto right_table = TableFromJSON(right_schema, {R"([ - [10, 1, 81, 71], - [80, 2, 71, 61], - [31, 3, 31, 51] + [10, 11], + [80, 21], + [31, 31] ])"}); - ASSERT_OK_AND_ASSIGN(auto tempdir, - arrow::internal::TemporaryDir::Make("substrait_join_tempdir")); - std::string left_file_prefix = "serde_join_left_emit_test"; - std::string right_file_prefix = "serde_join_right_emit_test"; - - TempDataGenerator datagen_left(left_table, left_file_prefix, tempdir); - ASSERT_OK(datagen_left()); - std::string substrait_left_file_uri = "file://" + datagen_left.data_file_path; - - TempDataGenerator datagen_right(right_table, right_file_prefix, tempdir); - ASSERT_OK(datagen_right()); - std::string substrait_right_file_uri = "file://" + datagen_right.data_file_path; - std::string substrait_json = R"({ "relations": [{ "rel": { "join": { "common": { "emit": { - "outputMapping": [0, 1, 2, 3, 8, 9, 10, 11] + "outputMapping": [0, 1, 3] } }, "left": { "read": { "base_schema": { - "names": ["A", "B", "C", "D"], + "names": ["A", "B"], "struct": { "types": [{ "i32": {} }, { "i32": {} - }, { - "i32": {} - }, { - "i32": {} }] } }, - "local_files": { - "items": [ - { - "uri_file": ")" + - substrait_left_file_uri + - R"(", - "parquet": {} - } - ] + "namedTable" : { + "names" : ["left"] } } }, "right": { "read": { "base_schema": { - "names": ["X", "Y", "Z", "W"], + "names": ["X", "Y"], "struct": { "types": [{ "i32": {} }, { "i32": {} - }, { - "i32": {} - }, { - "i32": {} }] } }, - "local_files": { - "items": [ - { - "uri_file": ")" + - substrait_right_file_uri + - R"(", - "parquet": {} - } - ] + "namedTable" : { + "names" : ["right"] } } }, @@ -2962,18 +2871,34 @@ TEST(Substrait, JoinRelWithEmit) { auto output_schema = schema({ field("A", int32()), field("B", int32()), - field("C", int32()), - field("D", int32()), - field("X", int32()), field("Y", int32()), - field("Z", int32()), - field("W", int32()), }); auto expected_table = TableFromJSON(std::move(output_schema), {R"([ - [10, 1, 80, 70, 10, 1, 81, 71] + [10, 1, 11] ])"}); - EmitValidate(std::move(output_schema), std::move(expected_table), exec_context, buf)(); + + NamedTableProvider table_provider = + [left_table, right_table](const std::vector& names) { + std::shared_ptr
output_table; + for (const auto& name : names) { + if (name == "left") { + output_table = left_table; + } + if (name == "right") { + output_table = right_table; + } + } + std::shared_ptr options = + std::make_shared(std::move(output_table)); + return compute::Declaration("table_source", {}, options, "mock_source"); + }; + + ConversionOptions conversion_options; + conversion_options.named_table_provider = std::move(table_provider); + + CheckRoundTripResult(std::move(output_schema), std::move(expected_table), exec_context, + buf, {}, conversion_options); } TEST(Substrait, AggregateRel) { @@ -2995,13 +2920,6 @@ TEST(Substrait, AggregateRel) { [30, 7, 30] ])"}); - ASSERT_OK_AND_ASSIGN(auto tempdir, - arrow::internal::TemporaryDir::Make("substrait_agg_tempdir")); - std::string file_prefix = "serde_agg_emit_test"; - - TempDataGenerator datagen(input_table, file_prefix, tempdir); - ASSERT_OK(datagen()); - std::string substrait_file_uri = "file://" + datagen.data_file_path; std::string substrait_json = R"({ "relations": [{ "rel": { @@ -3020,15 +2938,8 @@ TEST(Substrait, AggregateRel) { }] } }, - "local_files": { - "items": [ - { - "uri_file": ")" + - substrait_file_uri + - R"(", - "parquet": {} - } - ] + "namedTable" : { + "names": [] } } }, @@ -3089,7 +3000,18 @@ TEST(Substrait, AggregateRel) { [60, 30], [60, 40] ])"}); - EmitValidate(std::move(output_schema), std::move(expected_table), exec_context, buf)(); + + NamedTableProvider table_provider = [input_table](const std::vector&) { + std::shared_ptr options = + std::make_shared(input_table); + return compute::Declaration("table_source", {}, options, "mock_source"); + }; + + ConversionOptions conversion_options; + conversion_options.named_table_provider = std::move(table_provider); + + CheckRoundTripResult(std::move(output_schema), std::move(expected_table), exec_context, + buf, {}, conversion_options); } TEST(Substrait, AggregateRelEmit) { @@ -3111,13 +3033,6 @@ TEST(Substrait, AggregateRelEmit) { [30, 7, 30] ])"}); - ASSERT_OK_AND_ASSIGN(auto tempdir, - arrow::internal::TemporaryDir::Make("substrait_agg_tempdir")); - std::string file_prefix = "serde_agg_emit_test"; - - TempDataGenerator datagen(input_table, file_prefix, tempdir); - ASSERT_OK(datagen()); - std::string substrait_file_uri = "file://" + datagen.data_file_path; // TODO: fixme https://issues.apache.org/jira/browse/ARROW-17484 std::string substrait_json = R"({ "relations": [{ @@ -3142,15 +3057,8 @@ TEST(Substrait, AggregateRelEmit) { }] } }, - "local_files": { - "items": [ - { - "uri_file": ")" + - substrait_file_uri + - R"(", - "parquet": {} - } - ] + "namedTable" : { + "names" : [] } } }, @@ -3211,7 +3119,18 @@ TEST(Substrait, AggregateRelEmit) { [60], [60] ])"}); - EmitValidate(std::move(output_schema), std::move(expected_table), exec_context, buf)(); + + NamedTableProvider table_provider = [input_table](const std::vector&) { + std::shared_ptr options = + std::make_shared(input_table); + return compute::Declaration("table_source", {}, options, "mock_source"); + }; + + ConversionOptions conversion_options; + conversion_options.named_table_provider = std::move(table_provider); + + CheckRoundTripResult(std::move(output_schema), std::move(expected_table), exec_context, + buf, {}, conversion_options); } } // namespace engine From 5bb10512ca8a37aa2f5b0c832065f808929cff27 Mon Sep 17 00:00:00 2001 From: Vibhatha Abeykoon Date: Wed, 7 Sep 2022 12:56:32 +0530 Subject: [PATCH 20/22] fix(read): namedTable emit config added --- .../engine/substrait/relation_internal.cc | 6 +- cpp/src/arrow/engine/substrait/serde_test.cc | 132 ++++++++---------- 2 files changed, 66 insertions(+), 72 deletions(-) diff --git a/cpp/src/arrow/engine/substrait/relation_internal.cc b/cpp/src/arrow/engine/substrait/relation_internal.cc index e01ee5050bd..c58339d5cff 100644 --- a/cpp/src/arrow/engine/substrait/relation_internal.cc +++ b/cpp/src/arrow/engine/substrait/relation_internal.cc @@ -143,9 +143,11 @@ Result FromProto(const substrait::Rel& rel, const ExtensionSet& const substrait::ReadRel::NamedTable& named_table = read.named_table(); std::vector table_names(named_table.names().begin(), named_table.names().end()); - ARROW_ASSIGN_OR_RAISE(compute::Declaration source_decl, + ARROW_ASSIGN_OR_RAISE(compute::Declaration no_emit_declaration, named_table_provider(table_names)); - return DeclarationInfo{std::move(source_decl), base_schema}; + return ProcessEmit(std::move(read), + DeclarationInfo{std::move(no_emit_declaration), base_schema}, + std::move(base_schema)); } if (!read.has_local_files()) { diff --git a/cpp/src/arrow/engine/substrait/serde_test.cc b/cpp/src/arrow/engine/substrait/serde_test.cc index bafdd7b0f21..ed450f0b7b2 100644 --- a/cpp/src/arrow/engine/substrait/serde_test.cc +++ b/cpp/src/arrow/engine/substrait/serde_test.cc @@ -2398,76 +2398,68 @@ TEST(Substrait, ProjectRelOnFunctionWithEmit) { buf, {}, conversion_options); } -// TEST(Substrait, ReadRelWithEmit) { -// #ifdef _WIN32 -// GTEST_SKIP() << "ARROW-16392: Substrait File URI not supported for Windows"; -// #endif -// compute::ExecContext exec_context; -// auto dummy_schema = -// schema({field("A", int32()), field("B", int32()), field("C", int32())}); - -// // creating a dummy dataset using a dummy table -// auto input_table = TableFromJSON(dummy_schema, {R"([ -// [1, 1, 10], -// [3, 4, 20] -// ])"}); - -// ASSERT_OK_AND_ASSIGN(auto tempdir, -// arrow::internal::TemporaryDir::Make("substrait_read_tempdir")); -// std::string file_prefix = "serde_read_emit_test"; - -// TempDataGenerator datagen(input_table, file_prefix, tempdir); -// ASSERT_OK(datagen()); -// std::string substrait_file_uri = "file://" + datagen.data_file_path; - -// std::string substrait_json = R"({ -// "relations": [{ -// "rel": { -// "read": { -// "common": { -// "emit": { -// "outputMapping": [1, 2] -// } -// }, -// "base_schema": { -// "names": ["A", "B", "C"], -// "struct": { -// "types": [{ -// "i32": {} -// }, { -// "i32": {} -// }, { -// "i32": {} -// }] -// } -// }, -// namedTable: { -// "names" : [] -// } -// } -// } -// }], -// })"; - -// ASSERT_OK_AND_ASSIGN(auto buf, internal::SubstraitFromJSON("Plan", substrait_json)); -// auto output_schema = schema({field("B", int32()), field("C", int32())}); -// auto expected_table = TableFromJSON(output_schema, {R"([ -// [1, 10], -// [4, 20] -// ])"}); -// NamedTableProvider table_provider = [input_table](const std::vector&) { -// std::shared_ptr options = -// std::make_shared(input_table); -// return compute::Declaration("table_source", {}, options, "mock_source"); -// }; - -// ConversionOptions conversion_options; -// conversion_options.named_table_provider = std::move(table_provider); - -// CheckRoundTripResult(std::move(output_schema), std::move(expected_table), -// exec_context, -// buf, {}, conversion_options); -// } +TEST(Substrait, ReadRelWithEmit) { +#ifdef _WIN32 + GTEST_SKIP() << "ARROW-16392: Substrait File URI not supported for Windows"; +#endif + compute::ExecContext exec_context; + auto dummy_schema = + schema({field("A", int32()), field("B", int32()), field("C", int32())}); + + // creating a dummy dataset using a dummy table + auto input_table = TableFromJSON(dummy_schema, {R"([ + [1, 1, 10], + [3, 4, 20] + ])"}); + + std::string substrait_json = R"({ + "relations": [{ + "rel": { + "read": { + "common": { + "emit": { + "outputMapping": [1, 2] + } + }, + "base_schema": { + "names": ["A", "B", "C"], + "struct": { + "types": [{ + "i32": {} + }, { + "i32": {} + }, { + "i32": {} + }] + } + }, + "namedTable": { + "names" : [] + } + } + } + }], + })"; + + ASSERT_OK_AND_ASSIGN(auto buf, internal::SubstraitFromJSON("Plan", substrait_json)); + auto output_schema = schema({field("B", int32()), field("C", int32())}); + auto expected_table = TableFromJSON(output_schema, {R"([ + [1, 10], + [4, 20] + ])"}); + + NamedTableProvider table_provider = [input_table](const std::vector&) { + std::shared_ptr options = + std::make_shared(input_table); + return compute::Declaration("table_source", {}, options, "mock_source"); + }; + + ConversionOptions conversion_options; + conversion_options.named_table_provider = std::move(table_provider); + + CheckRoundTripResult(std::move(output_schema), std::move(expected_table), exec_context, + buf, {}, conversion_options); +} TEST(Substrait, FilterRelWithEmit) { #ifdef _WIN32 From 19e49ed3ddcb3888c502fff1900adb41deff073f Mon Sep 17 00:00:00 2001 From: Vibhatha Abeykoon Date: Fri, 9 Sep 2022 19:26:59 +0530 Subject: [PATCH 21/22] fix(rebase) --- cpp/src/arrow/engine/substrait/serde_test.cc | 53 ++++++++++++++------ 1 file changed, 37 insertions(+), 16 deletions(-) diff --git a/cpp/src/arrow/engine/substrait/serde_test.cc b/cpp/src/arrow/engine/substrait/serde_test.cc index ed450f0b7b2..69300152a61 100644 --- a/cpp/src/arrow/engine/substrait/serde_test.cc +++ b/cpp/src/arrow/engine/substrait/serde_test.cc @@ -2188,7 +2188,11 @@ TEST(Substrait, ProjectRel) { // creating a dummy dataset using a dummy table auto input_table = TableFromJSON(dummy_schema, {R"([ [1, 1, 10], - [3, 4, 20] + [3, 5, 20], + [4, 1, 30], + [2, 1, 40], + [5, 5, 50], + [2, 2, 60] ])"}); std::string substrait_json = R"({ @@ -2210,8 +2214,7 @@ TEST(Substrait, ProjectRel) { } } } - }, - { + }, { "value": { "selection": { "directReference": { @@ -2223,7 +2226,10 @@ TEST(Substrait, ProjectRel) { } } } - }] + }], + "output_type": { + "bool": {} + } } }, ], @@ -2252,7 +2258,7 @@ TEST(Substrait, ProjectRel) { "extension_uris": [ { "extension_uri_anchor": 0, - "uri": ")" + substrait::default_extension_types_uri() + + "uri": ")" + std::string(kSubstraitComparisonFunctionsUri) + R"(" } ], @@ -2260,17 +2266,21 @@ TEST(Substrait, ProjectRel) { {"extension_function": { "extension_uri_reference": 0, "function_anchor": 0, - "name": "add" + "name": "equal" }} ] })"; ASSERT_OK_AND_ASSIGN(auto buf, internal::SubstraitFromJSON("Plan", substrait_json)); auto output_schema = schema({field("A", int32()), field("B", int32()), - field("C", int32()), field("ADD", int32())}); + field("C", int32()), field("equal", boolean())}); auto expected_table = TableFromJSON(output_schema, {R"([ - [1, 1, 10, 2], - [3, 4, 20, 7] + [1, 1, 10, true], + [3, 5, 20, false], + [4, 1, 30, false], + [2, 1, 40, false], + [5, 5, 50, true], + [2, 2, 60, true] ])"}); NamedTableProvider table_provider = [input_table](const std::vector&) { @@ -2297,7 +2307,11 @@ TEST(Substrait, ProjectRelOnFunctionWithEmit) { // creating a dummy dataset using a dummy table auto input_table = TableFromJSON(dummy_schema, {R"([ [1, 1, 10], - [3, 4, 20] + [3, 5, 20], + [4, 1, 30], + [2, 1, 40], + [5, 5, 50], + [2, 2, 60] ])"}); std::string substrait_json = R"({ @@ -2336,7 +2350,10 @@ TEST(Substrait, ProjectRelOnFunctionWithEmit) { } } } - }] + }], + "output_type": { + "bool": {} + } } }, ], @@ -2365,7 +2382,7 @@ TEST(Substrait, ProjectRelOnFunctionWithEmit) { "extension_uris": [ { "extension_uri_anchor": 0, - "uri": ")" + substrait::default_extension_types_uri() + + "uri": ")" + std::string(kSubstraitComparisonFunctionsUri) + R"(" } ], @@ -2373,17 +2390,21 @@ TEST(Substrait, ProjectRelOnFunctionWithEmit) { {"extension_function": { "extension_uri_reference": 0, "function_anchor": 0, - "name": "add" + "name": "equal" }} ] })"; ASSERT_OK_AND_ASSIGN(auto buf, internal::SubstraitFromJSON("Plan", substrait_json)); auto output_schema = - schema({field("A", int32()), field("C", int32()), field("add", int32())}); + schema({field("A", int32()), field("C", int32()), field("equal", boolean())}); auto expected_table = TableFromJSON(output_schema, {R"([ - [1, 10, 2], - [3, 20, 7] + [1, 10, true], + [3, 20, false], + [4, 30, false], + [2, 40, false], + [5, 50, true], + [2, 60, true] ])"}); NamedTableProvider table_provider = [input_table](const std::vector&) { std::shared_ptr options = From 2416e957a51a48b13a9836a23300647c8979561f Mon Sep 17 00:00:00 2001 From: Vibhatha Abeykoon Date: Fri, 9 Sep 2022 20:51:13 +0530 Subject: [PATCH 22/22] fix(reviews): address reviews --- .../engine/substrait/relation_internal.cc | 45 +++---- cpp/src/arrow/engine/substrait/serde_test.cc | 113 ++++-------------- 2 files changed, 40 insertions(+), 118 deletions(-) diff --git a/cpp/src/arrow/engine/substrait/relation_internal.cc b/cpp/src/arrow/engine/substrait/relation_internal.cc index c58339d5cff..4213895b616 100644 --- a/cpp/src/arrow/engine/substrait/relation_internal.cc +++ b/cpp/src/arrow/engine/substrait/relation_internal.cc @@ -41,14 +41,6 @@ using internal::make_unique; namespace engine { -template -bool HasEmit(const RelMessage& rel) { - if (rel.has_common()) { - return rel.common().has_emit(); - } - return false; -} - template Result> GetEmitInfo( const RelMessage& rel, const std::shared_ptr& schema) { @@ -143,10 +135,10 @@ Result FromProto(const substrait::Rel& rel, const ExtensionSet& const substrait::ReadRel::NamedTable& named_table = read.named_table(); std::vector table_names(named_table.names().begin(), named_table.names().end()); - ARROW_ASSIGN_OR_RAISE(compute::Declaration no_emit_declaration, + ARROW_ASSIGN_OR_RAISE(compute::Declaration source_decl, named_table_provider(table_names)); return ProcessEmit(std::move(read), - DeclarationInfo{std::move(no_emit_declaration), base_schema}, + DeclarationInfo{std::move(source_decl), base_schema}, std::move(base_schema)); } @@ -261,11 +253,11 @@ Result FromProto(const substrait::Rel& rel, const ExtensionSet& ARROW_ASSIGN_OR_RAISE(auto ds, ds_factory->Finish(base_schema)); - DeclarationInfo no_emit_declaration = { + DeclarationInfo scan_declaration = { compute::Declaration{"scan", dataset::ScanNodeOptions{ds, scan_options}}, base_schema}; - return ProcessEmit(std::move(read), std::move(no_emit_declaration), + return ProcessEmit(std::move(read), std::move(scan_declaration), std::move(base_schema)); } @@ -284,14 +276,14 @@ Result FromProto(const substrait::Rel& rel, const ExtensionSet& } ARROW_ASSIGN_OR_RAISE(auto condition, FromProto(filter.condition(), ext_set, conversion_options)); - DeclarationInfo no_emit_declaration{ + DeclarationInfo filter_declaration{ compute::Declaration::Sequence({ std::move(input.declaration), {"filter", compute::FilterNodeOptions{std::move(condition)}}, }), input.output_schema}; - return ProcessEmit(std::move(filter), std::move(no_emit_declaration), + return ProcessEmit(std::move(filter), std::move(filter_declaration), input.output_schema); } @@ -340,14 +332,14 @@ Result FromProto(const substrait::Rel& rel, const ExtensionSet& expressions.emplace_back(des_expr); } - DeclarationInfo no_emit_declaration{ + DeclarationInfo project_declaration{ compute::Declaration::Sequence({ std::move(input.declaration), {"project", compute::ProjectNodeOptions{std::move(expressions)}}, }), project_schema}; - return ProcessEmit(std::move(project), std::move(no_emit_declaration), + return ProcessEmit(std::move(project), std::move(project_declaration), std::move(project_schema)); } @@ -428,14 +420,11 @@ Result FromProto(const substrait::Rel& rel, const ExtensionSet& } // Create output schema from left, right relations and join keys - std::shared_ptr join_schema = left.output_schema; - std::shared_ptr right_schema = right.output_schema; - - for (const auto& field : right_schema->fields()) { - ARROW_ASSIGN_OR_RAISE( - join_schema, join_schema->AddField( - static_cast(join_schema->fields().size()) - 1, field)); - } + FieldVector combined_fields = left.output_schema->fields(); + const FieldVector& right_fields = right.output_schema->fields(); + combined_fields.insert(combined_fields.end(), right_fields.begin(), + right_fields.end()); + std::shared_ptr join_schema = schema(std::move(combined_fields)); compute::HashJoinNodeOptions join_options{{std::move(*left_keys)}, {std::move(*right_keys)}}; @@ -445,9 +434,9 @@ Result FromProto(const substrait::Rel& rel, const ExtensionSet& join_dec.inputs.emplace_back(std::move(left.declaration)); join_dec.inputs.emplace_back(std::move(right.declaration)); - DeclarationInfo no_emit_declaration{std::move(join_dec), join_schema}; + DeclarationInfo join_declaration{std::move(join_dec), join_schema}; - return ProcessEmit(std::move(join), std::move(no_emit_declaration), + return ProcessEmit(std::move(join), std::move(join_declaration), std::move(join_schema)); } case substrait::Rel::RelTypeCase::kAggregate: { @@ -536,13 +525,13 @@ Result FromProto(const substrait::Rel& rel, const ExtensionSet& std::shared_ptr aggregate_schema = schema(std::move(output_fields)); - DeclarationInfo no_emit_declaration{ + DeclarationInfo aggregate_declaration{ compute::Declaration::Sequence( {std::move(input.declaration), {"aggregate", compute::AggregateNodeOptions{aggregates, keys}}}), aggregate_schema}; - return ProcessEmit(std::move(aggregate), std::move(no_emit_declaration), + return ProcessEmit(std::move(aggregate), std::move(aggregate_declaration), std::move(aggregate_schema)); } diff --git a/cpp/src/arrow/engine/substrait/serde_test.cc b/cpp/src/arrow/engine/substrait/serde_test.cc index 69300152a61..251c2bfe352 100644 --- a/cpp/src/arrow/engine/substrait/serde_test.cc +++ b/cpp/src/arrow/engine/substrait/serde_test.cc @@ -56,56 +56,35 @@ using internal::checked_cast; using internal::hash_combine; namespace engine { -Status WriteIpcData(const std::string& path, - const std::shared_ptr file_system, - const std::shared_ptr
input) { +void WriteIpcData(const std::string& path, + const std::shared_ptr file_system, + const std::shared_ptr
input) { EXPECT_OK_AND_ASSIGN(auto mmap, file_system->OpenOutputStream(path)); - ARROW_ASSIGN_OR_RAISE( + ASSERT_OK_AND_ASSIGN( auto file_writer, MakeFileWriter(mmap, input->schema(), ipc::IpcWriteOptions::Defaults())); TableBatchReader reader(input); std::shared_ptr batch; while (true) { - RETURN_NOT_OK(reader.ReadNext(&batch)); + ASSERT_OK(reader.ReadNext(&batch)); if (batch == nullptr) { break; } - RETURN_NOT_OK(file_writer->WriteRecordBatch(*batch)); + ASSERT_OK(file_writer->WriteRecordBatch(*batch)); } - RETURN_NOT_OK(file_writer->Close()); - return Status::OK(); + ASSERT_OK(file_writer->Close()); } Result> GetTableFromPlan( - compute::Declaration& declarations, - arrow::AsyncGenerator>& sink_gen, - compute::ExecContext& exec_context, std::shared_ptr& output_schema) { + compute::Declaration& other_declrs, compute::ExecContext& exec_context, + const std::shared_ptr& output_schema) { ARROW_ASSIGN_OR_RAISE(auto plan, compute::ExecPlan::Make(&exec_context)); - ARROW_ASSIGN_OR_RAISE(auto decl, declarations.AddToPlan(plan.get())); - RETURN_NOT_OK(decl->Validate()); - - std::shared_ptr sink_reader = compute::MakeGeneratorReader( - output_schema, std::move(sink_gen), exec_context.memory_pool()); - - RETURN_NOT_OK(plan->Validate()); - RETURN_NOT_OK(plan->StartProducing()); - return arrow::Table::FromRecordBatchReader(sink_reader.get()); -} - -Status WriteParquetData(const std::string& path, - const std::shared_ptr file_system, - const std::shared_ptr
input) { - EXPECT_OK_AND_ASSIGN(auto buffer_writer, file_system->OpenOutputStream(path)); - PARQUET_THROW_NOT_OK(parquet::arrow::WriteTable(*input, arrow::default_memory_pool(), - buffer_writer, /*chunk_size*/ 1)); - return buffer_writer->Close(); -} + arrow::AsyncGenerator> sink_gen; + auto sink_node_options = compute::SinkNodeOptions{&sink_gen}; + auto sink_declaration = compute::Declaration({"sink", sink_node_options, "e"}); + auto declarations = compute::Declaration::Sequence({other_declrs, sink_declaration}); -Result> GetTableFromPlan( - std::shared_ptr& plan, compute::Declaration& declarations, - arrow::AsyncGenerator>& sink_gen, - compute::ExecContext& exec_context, const std::shared_ptr& output_schema) { ARROW_ASSIGN_OR_RAISE(auto decl, declarations.AddToPlan(plan.get())); RETURN_NOT_OK(decl->Validate()); @@ -194,29 +173,6 @@ inline compute::Expression UseBoringRefs(const compute::Expression& expr) { return compute::Expression{std::move(modified_call)}; } -// TODO: complete this interface -struct TempDataGenerator { - TempDataGenerator(const std::shared_ptr
input_table, - const std::string& file_prefix, - std::unique_ptr& tempdir) - : input_table(input_table), file_prefix(file_prefix), tempdir(tempdir) {} - - Status operator()() { - auto format = std::make_shared(); - auto filesystem = std::make_shared(); - const std::string file_name = file_prefix + ".parquet"; - ARROW_ASSIGN_OR_RAISE(auto file_path, tempdir->path().Join(file_name)); - data_file_path = file_path.ToString(); - ARROW_EXPECT_OK(WriteParquetData(data_file_path, filesystem, input_table)); - return Status::OK(); - } - - std::shared_ptr
input_table; - std::string file_prefix; - std::unique_ptr& tempdir; - std::string data_file_path; -}; - void CheckRoundTripResult(const std::shared_ptr output_schema, const std::shared_ptr
expected_table, compute::ExecContext& exec_context, @@ -230,14 +186,9 @@ void CheckRoundTripResult(const std::shared_ptr output_schema, *buf, [] { return kNullConsumer; }, ext_id_reg, &ext_set, conversion_options)); auto other_declrs = sink_decls[0].inputs[0].get(); - arrow::AsyncGenerator> sink_gen; - auto sink_node_options = compute::SinkNodeOptions{&sink_gen}; - auto sink_declaration = compute::Declaration({"sink", sink_node_options, "e"}); - auto declarations = compute::Declaration::Sequence({*other_declrs, sink_declaration}); - ASSERT_OK_AND_ASSIGN(auto acero_plan, compute::ExecPlan::Make(&exec_context)); - ASSERT_OK_AND_ASSIGN( - auto output_table, - GetTableFromPlan(acero_plan, declarations, sink_gen, exec_context, output_schema)); + + ASSERT_OK_AND_ASSIGN(auto output_table, + GetTableFromPlan(*other_declrs, exec_context, output_schema)); if (!include_columns.empty()) { ASSERT_OK_AND_ASSIGN(output_table, output_table->SelectColumns(include_columns)); } @@ -1977,7 +1928,7 @@ TEST(Substrait, BasicPlanRoundTripping) { ASSERT_OK_AND_ASSIGN(auto file_path, tempdir->path().Join(file_name)); std::string file_path_str = file_path.ToString(); - ARROW_EXPECT_OK(WriteIpcData(file_path_str, filesystem, table)); + WriteIpcData(file_path_str, filesystem, table); std::vector files; const std::vector f_paths = {file_path_str}; @@ -2089,7 +2040,7 @@ TEST(Substrait, BasicPlanRoundTrippingEndToEnd) { ASSERT_OK_AND_ASSIGN(auto file_path, tempdir->path().Join(file_name)); std::string file_path_str = file_path.ToString(); - ARROW_EXPECT_OK(WriteIpcData(file_path_str, filesystem, table)); + WriteIpcData(file_path_str, filesystem, table); std::vector files; const std::vector f_paths = {file_path_str}; @@ -2111,16 +2062,13 @@ TEST(Substrait, BasicPlanRoundTrippingEndToEnd) { auto comp_right_value = compute::field_ref(filter_col_right); auto filter = compute::equal(comp_left_value, comp_right_value); - arrow::AsyncGenerator> sink_gen; - auto declarations = compute::Declaration::Sequence( {compute::Declaration( {"scan", dataset::ScanNodeOptions{dataset, scan_options}, "s"}), - compute::Declaration({"filter", compute::FilterNodeOptions{filter}, "f"}), - compute::Declaration({"sink", compute::SinkNodeOptions{&sink_gen}, "e"})}); + compute::Declaration({"filter", compute::FilterNodeOptions{filter}, "f"})}); - ASSERT_OK_AND_ASSIGN(auto expected_table, GetTableFromPlan(declarations, sink_gen, - exec_context, dummy_schema)); + ASSERT_OK_AND_ASSIGN(auto expected_table, + GetTableFromPlan(declarations, exec_context, dummy_schema)); std::shared_ptr sp_ext_id_reg = MakeExtensionIdRegistry(); ExtensionIdRegistry* ext_id_reg = sp_ext_id_reg.get(); @@ -2165,15 +2113,8 @@ TEST(Substrait, BasicPlanRoundTrippingEndToEnd) { checked_cast(roundtrip_frg_vec[idx++].get()); EXPECT_TRUE(l_frag->Equals(*r_frag)); } - arrow::AsyncGenerator> rnd_trp_sink_gen; - auto rnd_trp_sink_node_options = compute::SinkNodeOptions{&rnd_trp_sink_gen}; - auto rnd_trp_sink_declaration = - compute::Declaration({"sink", rnd_trp_sink_node_options, "e"}); - auto rnd_trp_declarations = - compute::Declaration::Sequence({*roundtripped_filter, rnd_trp_sink_declaration}); - ASSERT_OK_AND_ASSIGN(auto rnd_trp_table, - GetTableFromPlan(rnd_trp_declarations, rnd_trp_sink_gen, - exec_context, dummy_schema)); + ASSERT_OK_AND_ASSIGN(auto rnd_trp_table, GetTableFromPlan(*roundtripped_filter, + exec_context, dummy_schema)); EXPECT_TRUE(expected_table->Equals(*rnd_trp_table)); } @@ -2501,14 +2442,6 @@ TEST(Substrait, FilterRelWithEmit) { [30, 7, 30, 1] ])"}); - ASSERT_OK_AND_ASSIGN(auto tempdir, - arrow::internal::TemporaryDir::Make("substrait_read_tempdir")); - std::string file_prefix = "serde_read_emit_test"; - - TempDataGenerator datagen(input_table, file_prefix, tempdir); - ASSERT_OK(datagen()); - std::string substrait_file_uri = "file://" + datagen.data_file_path; - std::string substrait_json = R"({ "relations": [{ "rel": {