From 5b691f8522f4a35e744e359c03491391a3aa2357 Mon Sep 17 00:00:00 2001 From: Vibhatha Abeykoon Date: Wed, 27 Jul 2022 10:15:45 +0530 Subject: [PATCH 01/30] rebase: short history 16855 --- cpp/src/arrow/compute/exec/exec_plan.h | 10 + cpp/src/arrow/compute/exec/plan_test.cc | 39 +++ cpp/src/arrow/compute/exec/source_node.cc | 2 +- cpp/src/arrow/dataset/scanner_test.cc | 18 + .../arrow/engine/substrait/plan_internal.cc | 48 +++ .../arrow/engine/substrait/plan_internal.h | 5 + cpp/src/arrow/engine/substrait/registry.cc | 86 +++++ cpp/src/arrow/engine/substrait/registry.h | 69 ++++ .../engine/substrait/relation_internal.cc | 82 +++++ .../engine/substrait/relation_internal.h | 7 + cpp/src/arrow/engine/substrait/serde.cc | 19 +- cpp/src/arrow/engine/substrait/serde.h | 18 + cpp/src/arrow/engine/substrait/serde_test.cc | 316 +++++++++++++++++- 13 files changed, 713 insertions(+), 6 deletions(-) create mode 100644 cpp/src/arrow/engine/substrait/registry.cc create mode 100644 cpp/src/arrow/engine/substrait/registry.h diff --git a/cpp/src/arrow/compute/exec/exec_plan.h b/cpp/src/arrow/compute/exec/exec_plan.h index 263f3634a5a..4e1a0398867 100644 --- a/cpp/src/arrow/compute/exec/exec_plan.h +++ b/cpp/src/arrow/compute/exec/exec_plan.h @@ -199,6 +199,14 @@ class ARROW_EXPORT ExecNode { /// This node's exec plan ExecPlan* plan() { return plan_; } + /// Set this node's options + /// This is an optional method included to support Acero to Substrait + /// serialization. + void SetOptions(ExecNodeOptions* options) { options_ = options; } + + /// This node's options + ExecNodeOptions* options() { return options_; } + /// \brief An optional label, for display and debugging /// /// There is no guarantee that this value is non-empty or unique. @@ -367,6 +375,8 @@ class ARROW_EXPORT ExecNode { Future<> finished_ = Future<>::Make(); util::tracing::Span span_; + + ExecNodeOptions* options_; }; /// \brief MapNode is an ExecNode type class which process a task like filter/project diff --git a/cpp/src/arrow/compute/exec/plan_test.cc b/cpp/src/arrow/compute/exec/plan_test.cc index e06c41c7489..e0262142914 100644 --- a/cpp/src/arrow/compute/exec/plan_test.cc +++ b/cpp/src/arrow/compute/exec/plan_test.cc @@ -279,6 +279,27 @@ TEST(ExecPlanExecution, TableSourceSink) { } } +TEST(ExecPlanExecution, TableSourceNodeOptions) { + + ASSERT_OK_AND_ASSIGN(auto plan, ExecPlan::Make()); + AsyncGenerator> sink_gen; + + auto exp_batches = MakeBasicBatches(); + ASSERT_OK_AND_ASSIGN(auto table, + TableFromExecBatches(exp_batches.schema, exp_batches.batches)); + auto table_source_options = TableSourceNodeOptions{table, 3}; + + ASSERT_OK_AND_ASSIGN( + ExecNode * table_source, + MakeExecNode("table_source", plan.get(), {}, table_source_options)); + + table_source->SetOptions(&table_source_options); + const auto& res_table_options = static_cast(*table_source->options()); + + EXPECT_EQ(table_source_options.table, res_table_options.table); + EXPECT_EQ(table_source_options.max_batch_size, res_table_options.max_batch_size); +} + TEST(ExecPlanExecution, TableSourceSinkError) { ASSERT_OK_AND_ASSIGN(auto plan, ExecPlan::Make()); AsyncGenerator> sink_gen; @@ -1481,5 +1502,23 @@ TEST(ExecPlan, SourceEnforcesBatchLimit) { } } +TEST(ExecPlan, ExecNodeOption) { + auto input = MakeGroupableBatches(); + + auto exec_ctx = arrow::internal::make_unique( + default_memory_pool(), arrow::internal::GetCpuThreadPool()); + + ASSERT_OK_AND_ASSIGN(auto plan, ExecPlan::Make(exec_ctx.get())); + + AsyncGenerator> sink_gen; + + SourceNodeOptions options{input.schema, input.gen(/*parallel*/ true, /*slow=*/false)}; + + ASSERT_OK_AND_ASSIGN(auto* source, MakeExecNode("source", plan.get(), {}, options)); + source->SetOptions(&options); + const auto& opts = static_cast(*source->options()); + ASSERT_EQ(opts.output_schema, options.output_schema); +} + } // namespace compute } // namespace arrow diff --git a/cpp/src/arrow/compute/exec/source_node.cc b/cpp/src/arrow/compute/exec/source_node.cc index a640cf737ef..783b7d33e8b 100644 --- a/cpp/src/arrow/compute/exec/source_node.cc +++ b/cpp/src/arrow/compute/exec/source_node.cc @@ -48,7 +48,7 @@ namespace { struct SourceNode : ExecNode { SourceNode(ExecPlan* plan, std::shared_ptr output_schema, AsyncGenerator> generator) - : ExecNode(plan, {}, {}, std::move(output_schema), + : ExecNode(plan, {}, {}, output_schema, /*num_outputs=*/1), generator_(std::move(generator)) {} diff --git a/cpp/src/arrow/dataset/scanner_test.cc b/cpp/src/arrow/dataset/scanner_test.cc index 804e82b57db..982306725c7 100644 --- a/cpp/src/arrow/dataset/scanner_test.cc +++ b/cpp/src/arrow/dataset/scanner_test.cc @@ -1991,5 +1991,23 @@ TEST(ScanNode, MinimalGroupedAggEndToEnd) { AssertTablesEqual(*expected, *sorted.table(), /*same_chunk_layout=*/false); } +TEST(ScanNode, NodeOptions) { + TestPlan plan; + + auto basic = MakeBasicDataset(); + + auto options = std::make_shared(); + options->projection = Materialize({}); // set an empty projection + ScanNodeOptions scan_node_options{basic.dataset, options}; + ASSERT_OK_AND_ASSIGN(auto* scan, + compute::MakeExecNode("scan", plan.get(), {}, scan_node_options)); + scan->SetOptions(&scan_node_options); + const auto& res_scan_options = static_cast(*scan->options()); + + ASSERT_EQ(scan_node_options.dataset->schema(), res_scan_options.dataset->schema()); + ASSERT_EQ(scan_node_options.scan_options->projection, + res_scan_options.scan_options->projection); +} + } // namespace dataset } // namespace arrow diff --git a/cpp/src/arrow/engine/substrait/plan_internal.cc b/cpp/src/arrow/engine/substrait/plan_internal.cc index b0fdb9bdc2f..952a3f47f87 100644 --- a/cpp/src/arrow/engine/substrait/plan_internal.cc +++ b/cpp/src/arrow/engine/substrait/plan_internal.cc @@ -133,5 +133,53 @@ Result GetExtensionSetFromPlan(const substrait::Plan& plan, registry); } +Status TraversePlan(compute::ExecNode* node, + std::vector* node_vector) { + if (node->num_inputs() > 0) { + const auto& sources = node->inputs(); + for (const auto& source : sources) { + ARROW_RETURN_NOT_OK(TraversePlan(source, node_vector)); + } + } + node_vector->push_back(node); + return Status::OK(); +} + +// arrow::Result MakeDeclaration(compute::ExecNode* node) { +// compute::Declaration decl; +// const auto* kind_name = node->kind_name(); +// if(kind_name == std::string("scan")) { +// std::vector inputs = node->inputs(); + +// } else { +// return Status::NotImplemented(kind_name, " relation not implemented."); +// } +// return decl; +// } + +Result> ToProto(const compute::ExecPlan& plan, + ExtensionSet* ext_set) { + auto plan_rel = internal::make_unique(); + + std::cout << "Plan Show" << std::endl; + + std::cout << plan.ToString() << std::endl; + std::cout << "----------" << std::endl; + const auto& sinks = plan.sinks(); + std::vector node_vec; + ARROW_RETURN_NOT_OK(TraversePlan(sinks[0], &node_vec)); + + std::cout << "Printing Node Vector" << std::endl; + for (const auto& node : node_vec) { + const auto* kind_name = node->kind_name(); + std::cout << kind_name << std::endl; + //const auto& output_schema = node->output_schema(); + // if (output_schema) { + // } + } + + return plan_rel; +} + } // namespace engine } // namespace arrow diff --git a/cpp/src/arrow/engine/substrait/plan_internal.h b/cpp/src/arrow/engine/substrait/plan_internal.h index dce23cdceba..2ae3cdd3ec2 100644 --- a/cpp/src/arrow/engine/substrait/plan_internal.h +++ b/cpp/src/arrow/engine/substrait/plan_internal.h @@ -19,6 +19,8 @@ #pragma once +#include "arrow/compute/exec/exec_plan.h" + #include "arrow/engine/substrait/extension_set.h" #include "arrow/engine/substrait/visibility.h" #include "arrow/type_fwd.h" @@ -51,5 +53,8 @@ Result GetExtensionSetFromPlan( const substrait::Plan& plan, const ExtensionIdRegistry* registry = default_extension_id_registry()); +ARROW_ENGINE_EXPORT Result> ToProto( + const compute::ExecPlan& plan, ExtensionSet* ext_set); + } // namespace engine } // namespace arrow diff --git a/cpp/src/arrow/engine/substrait/registry.cc b/cpp/src/arrow/engine/substrait/registry.cc new file mode 100644 index 00000000000..2456598a45b --- /dev/null +++ b/cpp/src/arrow/engine/substrait/registry.cc @@ -0,0 +1,86 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +// NOTE: API is EXPERIMENTAL and will change without going through a +// deprecation cycle + +#include "arrow/engine/substrait/registry.h" + +namespace arrow { + +namespace engine { +class SubstraitConversionRegistry::SubstraitConversionRegistryImpl { + public: + explicit SubstraitConversionRegistryImpl( + SubstraitConversionRegistryImpl* parent = NULLPTR) + : parent_(parent) {} + ~SubstraitConversionRegistryImpl() {} + + std::unique_ptr Make() { + return std::unique_ptr( + new SubstraitConversionRegistry()); + } + + std::unique_ptr Make(SubstraitConversionRegistry* parent) { + return std::unique_ptr(new SubstraitConversionRegistry( + new SubstraitConversionRegistry::SubstraitConversionRegistryImpl( + parent->impl_.get()))); + } + + Status RegisterConverter(const std::string& kind_name, SubstraitConverter converter) { + if (kind_name == "scan") { + return Status::NotImplemented("Scan serialization not implemented"); + } else if (kind_name == "filter") { + return Status::NotImplemented("Filter serialization not implemented"); + } else if (kind_name == "project") { + return Status::NotImplemented("Project serialization not implemented"); + } else if (kind_name == "augmented_project") { + return Status::NotImplemented("Augmented Project serialization not implemented"); + } else if (kind_name == "hashjoin") { + return Status::NotImplemented("Filter serialization not implemented"); + } else if (kind_name == "asofjoin") { + return Status::NotImplemented("Asof Join serialization not implemented"); + } else if (kind_name == "select_k_sink") { + return Status::NotImplemented("SelectK serialization not implemented"); + } else if (kind_name == "union") { + return Status::NotImplemented("Union serialization not implemented"); + } else if (kind_name == "write") { + return Status::NotImplemented("Write serialization not implemented"); + } else if (kind_name == "tee") { + return Status::NotImplemented("Tee serialization not implemented"); + } else { + return Status::Invalid("Unsupported ExecNode: ", kind_name); + } + } + + SubstraitConversionRegistryImpl* parent_; + std::mutex lock_; + std::unordered_map name_to_converter_; +}; + +struct DefaultSubstraitConversionRegistry : SubstraitConversionRegistryImpl { + DefaultSubstraitConversionRegistry() {} +}; + +SubstraitConversionRegistry* GetSubstraitConversionRegistry() { + static DefaultSubstraitConversionRegistry impl_; + return &impl_; +} + +} // namespace engine + +} // namespace arrow diff --git a/cpp/src/arrow/engine/substrait/registry.h b/cpp/src/arrow/engine/substrait/registry.h new file mode 100644 index 00000000000..8fe857770a0 --- /dev/null +++ b/cpp/src/arrow/engine/substrait/registry.h @@ -0,0 +1,69 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +// NOTE: API is EXPERIMENTAL and will change without going through a +// deprecation cycle + +#pragma once + +#include +#include +#include + +#include "arrow/result.h" +#include "arrow/status.h" +#include "arrow/util/visibility.h" + +#include "arrow/compute/exec/exec_plan.h" +#include "arrow/engine/substrait/extension_types.h" +#include "arrow/engine/substrait/serde.h" +#include "arrow/engine/substrait/visibility.h" +#include "arrow/type_fwd.h" + +#include "substrait/algebra.pb.h" // IWYU pragma: export + +namespace arrrow { + +namespace engine { + +using SubstraitConverter = + std::function(Schema, Declaration)>; + +class ARROW_EXPORT SubstraitConversionRegistry { + public: + ~SubstraitConversionRegistry(); + + static std::unique_ptr Make(); + + static std::unique_ptr Make( + SubstraitConversionRegistry* parent); + + Status RegisterConverter(const std::string& kind_name, SubstraitConverter converter); + + private: + SubstraitConversionRegistry(); + + class SubstraitConversionRegistryImpl; + std::unique_ptr impl_; + + explicit SubstraitConversionRegistry(SubstraitConversionRegistryImpl* impl); +}; + +ARROW_EXPORT SubstraitConversionRegistry* GetSubstraitConversionRegistry(); + +} // namespace engine +} // namespace arrrow diff --git a/cpp/src/arrow/engine/substrait/relation_internal.cc b/cpp/src/arrow/engine/substrait/relation_internal.cc index c5c02f51558..efbe2ab83e2 100644 --- a/cpp/src/arrow/engine/substrait/relation_internal.cc +++ b/cpp/src/arrow/engine/substrait/relation_internal.cc @@ -29,8 +29,14 @@ #include "arrow/filesystem/localfs.h" #include "arrow/filesystem/path_util.h" #include "arrow/filesystem/util_internal.h" +#include "arrow/util/checked_cast.h" +#include "arrow/util/make_unique.h" namespace arrow { + +using internal::checked_cast; +using internal::make_unique; + namespace engine { template @@ -421,5 +427,81 @@ Result FromProto(const substrait::Rel& rel, const ExtensionSet& rel.DebugString()); } +namespace { + +Result> MakeReadRelation( + const compute::Declaration& declaration, ExtensionSet* ext_set) { + auto read_rel = make_unique(); + const auto& scan_node_options = + checked_cast(*declaration.options); + + auto dataset = + dynamic_cast(scan_node_options.dataset.get()); + if (dataset == nullptr) { + return Status::Invalid("Can only convert file system datasets to a Substrait plan."); + } + // set schema + ARROW_ASSIGN_OR_RAISE(auto named_struct, ToProto(*dataset->schema(), ext_set)); + read_rel->set_allocated_base_schema(named_struct.release()); + + // set local files + auto read_rel_lfs = make_unique(); + for (const auto& file : dataset->files()) { + auto read_rel_lfs_ffs = make_unique(); + read_rel_lfs_ffs->set_uri_path("file://" + file); + + // set file format + // arrow and feather are temporarily handled via the Parquet format until + // upgraded to the latest Substrait version. + auto format_type_name = dataset->format()->type_name(); + if (format_type_name == "parquet") { + auto parquet_fmt = + make_unique(); + read_rel_lfs_ffs->set_allocated_parquet(parquet_fmt.release()); + } else if (format_type_name == "arrow") { + auto arrow_fmt = + make_unique(); + read_rel_lfs_ffs->set_allocated_arrow(arrow_fmt.release()); + } else if (format_type_name == "orc") { + auto orc_fmt = + make_unique(); + read_rel_lfs_ffs->set_allocated_orc(orc_fmt.release()); + } else { + return Status::Invalid("Unsupported file type : ", format_type_name); + } + read_rel_lfs->mutable_items()->AddAllocated(read_rel_lfs_ffs.release()); + } + *read_rel->mutable_local_files() = *read_rel_lfs.get(); + + return read_rel; +} + +Result> MakeRelation( + const compute::Declaration& declaration, ExtensionSet* ext_set) { + const std::string& rel_name = declaration.factory_name; + auto rel = make_unique(); + if (rel_name == "scan") { + rel->set_allocated_read(MakeReadRelation(declaration, ext_set)->release()); + } else if (rel_name == "filter") { + return Status::NotImplemented("Filter operator not supported."); + } else if (rel_name == "project") { + return Status::NotImplemented("Project operator not supported."); + } else if (rel_name == "hashjoin") { + return Status::NotImplemented("Join operator not supported."); + } else if (rel_name == "aggregate") { + return Status::NotImplemented("Aggregate operator not supported."); + } else { + return Status::Invalid("Unsupported exec node factory name :", rel_name); + } + return rel; +} + +} // namespace + +Result> ToProto(const compute::Declaration& declaration, + ExtensionSet* ext_set) { + return MakeRelation(declaration, ext_set); +} + } // namespace engine } // namespace arrow diff --git a/cpp/src/arrow/engine/substrait/relation_internal.h b/cpp/src/arrow/engine/substrait/relation_internal.h index 3699d1f6577..e6b99d58773 100644 --- a/cpp/src/arrow/engine/substrait/relation_internal.h +++ b/cpp/src/arrow/engine/substrait/relation_internal.h @@ -44,5 +44,12 @@ ARROW_ENGINE_EXPORT Result FromProto(const substrait::Rel&, const ExtensionSet&, const ConversionOptions&); +ARROW_ENGINE_EXPORT +Result> ToProto(const compute::Declaration&, + ExtensionSet*); + +ARROW_ENGINE_EXPORT std::tuple ScanRelationConverter( + const arrow::Schema&, const compute::Declaration&); + } // namespace engine } // namespace arrow diff --git a/cpp/src/arrow/engine/substrait/serde.cc b/cpp/src/arrow/engine/substrait/serde.cc index 9f7d979e2f0..0c16611aff6 100644 --- a/cpp/src/arrow/engine/substrait/serde.cc +++ b/cpp/src/arrow/engine/substrait/serde.cc @@ -52,9 +52,22 @@ Result ParseFromBuffer(const Buffer& buf) { return message; } -Result DeserializeRelation( - const Buffer& buf, const ExtensionSet& ext_set, - const ConversionOptions& conversion_options) { +Result> SerializePlan(const compute::ExecPlan& plan, + ExtensionSet* ext_set) { + ARROW_ASSIGN_OR_RAISE(auto subs_plan, ToProto(plan, ext_set)); + std::string serialized = subs_plan->SerializeAsString(); + return Buffer::FromString(std::move(serialized)); +} + +Result> SerializeRelation(const compute::Declaration& declaration, + ExtensionSet* ext_set) { + ARROW_ASSIGN_OR_RAISE(auto relation, ToProto(declaration, ext_set)); + std::string serialized = relation->SerializeAsString(); + return Buffer::FromString(std::move(serialized)); +} + +Result DeserializeRelation(const Buffer& buf, + const ExtensionSet& ext_set) { ARROW_ASSIGN_OR_RAISE(auto rel, ParseFromBuffer(buf)); ARROW_ASSIGN_OR_RAISE(auto decl_info, FromProto(rel, ext_set, conversion_options)); return std::move(decl_info.declaration); diff --git a/cpp/src/arrow/engine/substrait/serde.h b/cpp/src/arrow/engine/substrait/serde.h index 6c2083fb56a..6f4c2bbb2bc 100644 --- a/cpp/src/arrow/engine/substrait/serde.h +++ b/cpp/src/arrow/engine/substrait/serde.h @@ -36,6 +36,10 @@ namespace arrow { namespace engine { +ARROW_ENGINE_EXPORT +Result> SerializePlan(const compute::ExecPlan& plan, + ExtensionSet* ext_set); + /// Factory function type for generating the node that consumes the batches produced by /// each toplevel Substrait relation when deserializing a Substrait Plan. using ConsumerFactory = std::function()>; @@ -124,6 +128,10 @@ ARROW_ENGINE_EXPORT Result> DeserializePlan( const ExtensionIdRegistry* registry = NULLPTR, ExtensionSet* ext_set_out = NULLPTR, const ConversionOptions& conversion_options = {}); +ARROW_ENGINE_EXPORT Result DeserializePlan( + const Buffer& buf, const ConsumerFactory& consumer_factory, + ExtensionSet* ext_set_out = NULLPTR); + /// \brief Deserializes a Substrait Type message to the corresponding Arrow type /// /// \param[in] buf a buffer containing the protobuf serialization of a Substrait Type @@ -202,6 +210,16 @@ Result> SerializeExpression( const compute::Expression& expr, ExtensionSet* ext_set, const ConversionOptions& conversion_options = {}); +/// \brief Serializes an Arrow compute Declaration to a Substrait Relation message +/// +/// \param[in] declaration the Arrow compute declaration to serialize +/// \param[in,out] ext_set the extension mapping to use; may be updated to add +/// mappings for the components in the used declaration +/// \return a buffer containing the protobuf serialization of the corresponding Substrait +/// Relation message +ARROW_ENGINE_EXPORT Result> SerializeRelation( + const compute::Declaration& declaration, ExtensionSet* ext_set); + /// \brief Deserializes a Substrait Rel (relation) message to an ExecNode declaration /// /// \param[in] buf a buffer containing the protobuf serialization of a Substrait diff --git a/cpp/src/arrow/engine/substrait/serde_test.cc b/cpp/src/arrow/engine/substrait/serde_test.cc index 04405b31680..cb2ec8d5339 100644 --- a/cpp/src/arrow/engine/substrait/serde_test.cc +++ b/cpp/src/arrow/engine/substrait/serde_test.cc @@ -15,6 +15,8 @@ // specific language governing permissions and limitations // under the License. +#include "arrow/engine/substrait/serde.h" + #include #include #include @@ -22,18 +24,23 @@ #include "arrow/compute/exec/expression_internal.h" #include "arrow/dataset/file_base.h" -#include "arrow/dataset/file_ipc.h" +#include "arrow/dataset/file_parquet.h" #include "arrow/dataset/plan.h" #include "arrow/dataset/scanner.h" #include "arrow/engine/substrait/extension_types.h" -#include "arrow/engine/substrait/serde.h" #include "arrow/engine/substrait/util.h" +#include "arrow/filesystem/localfs.h" #include "arrow/filesystem/mockfs.h" #include "arrow/filesystem/test_util.h" +#include "arrow/io/compressed.h" +#include "arrow/io/memory.h" +#include "arrow/ipc/writer.h" #include "arrow/testing/gtest_util.h" #include "arrow/testing/matchers.h" #include "arrow/util/key_value_metadata.h" +#include "parquet/arrow/writer.h" + using testing::ElementsAre; using testing::Eq; using testing::HasSubstr; @@ -45,6 +52,42 @@ using internal::checked_cast; namespace engine { +bool WriteParquetData(const std::string& path, + const std::shared_ptr file_system, + const std::shared_ptr input, const int64_t chunk_size = 3) { + EXPECT_OK_AND_ASSIGN(auto buffer_writer, file_system->OpenOutputStream(path)); + PARQUET_THROW_NOT_OK(parquet::arrow::WriteTable(*input, arrow::default_memory_pool(), + buffer_writer, chunk_size)); + return buffer_writer->Close().ok(); +} + +bool CompareDataset(std::shared_ptr ds_lhs, + std::shared_ptr ds_rhs) { + const auto& fsd_lhs = checked_cast(*ds_lhs); + const auto& fsd_rhs = checked_cast(*ds_rhs); + const auto& files_lhs = fsd_lhs.files(); + const auto& files_rhs = fsd_rhs.files(); + + if (files_lhs.size() != files_rhs.size()) { + return false; + } + uint64_t fidx = 0; + for (const auto& l_file : files_lhs) { + if (l_file != files_rhs[fidx++]) { + return false; + } + } + bool cmp_file_format = fsd_lhs.format()->Equals(*fsd_rhs.format()); + bool cmp_file_system = fsd_lhs.filesystem()->Equals(fsd_rhs.filesystem()); + return cmp_file_format && cmp_file_system; +} + +bool CompareScanOptions(const dataset::ScanNodeOptions& lhs, + const dataset::ScanNodeOptions& rhs) { + return lhs.require_sequenced_output == rhs.require_sequenced_output && + CompareDataset(lhs.dataset, rhs.dataset); +} + class NullSinkNodeConsumer : public compute::SinkNodeConsumer { public: Status Init(const std::shared_ptr&, compute::BackpressureControl*) override { @@ -866,6 +909,7 @@ Result GetSubstraitJSON() { auto file_name = arrow::internal::PlatformFilename::FromString(dir_string)->Join("binary.parquet"); auto file_path = file_name->ToString(); + std::string substrait_json = R"({ "relations": [ {"rel": { @@ -1814,5 +1858,273 @@ TEST(Substrait, AggregateBadPhase) { ASSERT_RAISES(NotImplemented, DeserializePlans(*buf, [] { return kNullConsumer; })); } +TEST(Substrait, SerializePlan) { +#ifdef _WIN32 + GTEST_SKIP() << "ARROW-16392: Substrait File URI not supported for Windows"; +#else + compute::ExecContext exec_context; + ExtensionSet ext_set; + auto dummy_schema = schema({field("lkey", int32()), field("rkey", int32()), + field("shared", int32()), field("ldistinct", int32())}); + // creating a dummy dataset using a dummy table + auto format = std::make_shared(); + auto filesystem = std::make_shared(); + + std::vector files1, files2; + const std::vector f_paths = {"/tmp/data1.parquet", "/tmp/data2.parquet"}; + + for (const auto& f_path : f_paths) { + ASSERT_OK_AND_ASSIGN(auto f_file, filesystem->GetFileInfo(f_path)); + files1.push_back(std::move(f_file)); + } + files2 = files1; + + ASSERT_OK_AND_ASSIGN(auto ds_factory1, dataset::FileSystemDatasetFactory::Make( + filesystem, std::move(files1), format, {})); + ASSERT_OK_AND_ASSIGN(auto dataset1, ds_factory1->Finish(dummy_schema)); + + ASSERT_OK_AND_ASSIGN(auto ds_factory2, dataset::FileSystemDatasetFactory::Make( + filesystem, std::move(files2), format, {})); + ASSERT_OK_AND_ASSIGN(auto dataset2, ds_factory2->Finish(dummy_schema)); + + auto scan_options1 = std::make_shared(); + scan_options1->projection = compute::project({}, {}); + + auto scan_options2 = std::make_shared(); + scan_options2->projection = compute::project({}, {}); + + auto filter1 = compute::greater(compute::field_ref("lkey"), compute::literal(3)); + auto filter2 = compute::greater(compute::field_ref("lkey"), compute::literal(2)); + auto filter3 = compute::greater(compute::field_ref("lkey_l1"), compute::literal(1)); + + arrow::AsyncGenerator > sink_gen; + + auto scan_node_options1 = dataset::ScanNodeOptions{dataset1, scan_options1}; + auto scan_node_options2 = dataset::ScanNodeOptions{dataset2, scan_options2}; + auto filter_node_options1 = compute::FilterNodeOptions{filter1}; + auto filter_node_options2 = compute::FilterNodeOptions{filter2}; + auto filter_node_options3 = compute::FilterNodeOptions{filter3}; + auto sink_node_options = compute::SinkNodeOptions{&sink_gen}; + + auto scan_declaration1 = compute::Declaration({"scan", scan_node_options1}); + auto scan_declaration2 = compute::Declaration({"scan", scan_node_options2}); + auto filter_declaration1 = compute::Declaration({"filter", filter_node_options1}); + auto filter_declaration2 = compute::Declaration({"filter", filter_node_options2}); + auto filter_declaration3 = compute::Declaration({"filter", filter_node_options3}); + auto sink_declaration = compute::Declaration({"sink", sink_node_options}); + + auto scan_declarations1 = + compute::Declaration::Sequence({scan_declaration1, filter_declaration1}); + auto scan_declarations2 = + compute::Declaration::Sequence({scan_declaration2, filter_declaration2}); + compute::HashJoinNodeOptions join_node_options{arrow::compute::JoinType::INNER, + /*in_left_keys=*/{"lkey"}, + /*in_right_keys=*/{"rkey"}, + /*filter*/ arrow::compute::literal(true), + /*output_suffix_for_left*/ "_l1", + /*output_suffix_for_right*/ "_r1"}; + + auto join_declaration = compute::Declaration({"hashjoin", join_node_options}); + join_declaration.inputs.emplace_back(scan_declarations1); + join_declaration.inputs.emplace_back(scan_declarations2); + + auto declarations = compute::Declaration::Sequence( + {join_declaration, filter_declaration3, sink_declaration}); + + ASSERT_OK_AND_ASSIGN(auto plan, compute::ExecPlan::Make(&exec_context)); + ASSERT_OK_AND_ASSIGN(auto decl, declarations.AddToPlan(plan.get())); + + ASSERT_OK(decl->Validate()); + + ASSERT_OK_AND_ASSIGN(auto serialized_plan, SerializePlan(*plan, &ext_set)); + // ASSERT_OK_AND_ASSIGN(auto deserialized_plan, + // DeserializeRelation(*serialized_plan, ext_set)); + +#endif +} + +// TEST(Substrait, SerializeRelation) { +// #ifdef _WIN32 +// GTEST_SKIP() << "ARROW-16392: Substrait File URI not supported for Windows"; +// #else +// ExtensionSet ext_set; +// auto dummy_schema = schema({field("a", int32()), field("b", int32())}); +// auto table = TableFromJSON(dummy_schema, {R"([ +// [1, 1], +// [3, 4] +// ])", +// R"([ +// [0, 2], +// [1, 3], +// [4, 1], +// [3, 1], +// [1, 2] +// ])", +// R"([ +// [2, 2], +// [5, 3], +// [1, 3] +// ])"}); +// const std::string path = "/testing.parquet"; + +// EXPECT_OK_AND_ASSIGN(auto filesystem, +// fs::internal::MockFileSystem::Make(fs::kNoTime, {})); + +// EXPECT_EQ(WriteParquetData(path, filesystem, table), true); +// // creating a dummy dataset using a dummy table +// auto format = std::make_shared(); + +// std::vector files; +// const std::vector f_paths = {path}; + +// for (const auto& f_path : f_paths) { +// ASSERT_OK_AND_ASSIGN(auto f_file, filesystem->GetFileInfo(f_path)); +// files.push_back(std::move(f_file)); +// } + +// ASSERT_OK_AND_ASSIGN(auto ds_factory, +// dataset::FileSystemDatasetFactory::Make( +// filesystem, std::move(files), std::move(format), {})); +// ASSERT_OK_AND_ASSIGN(auto dataset, ds_factory->Finish(dummy_schema)); + +// auto options = std::make_shared(); +// options->projection = compute::project({}, {}); +// auto scan_node_options = dataset::ScanNodeOptions{dataset, options}; + +// auto scan_declaration = compute::Declaration({"scan", scan_node_options}); + +// ASSERT_OK_AND_ASSIGN(auto serialized_rel, +// SerializeRelation(scan_declaration, &ext_set)); +// ASSERT_OK_AND_ASSIGN(auto deserialized_decl, +// DeserializeRelation(*serialized_rel, ext_set)); + +// auto& mfs = checked_cast(*filesystem); +// mfs.AllFiles(); + +// EXPECT_EQ(deserialized_decl.factory_name, scan_declaration.factory_name); +// const auto& lhs = +// checked_cast(*deserialized_decl.options); +// const auto& rhs = +// checked_cast(*scan_declaration.options); +// ASSERT_TRUE(CompareScanOptions(lhs, rhs)); +// #endif +// } + +// TEST(Substrait, SerializeRelationEndToEnd) { +// #ifdef _WIN32 +// GTEST_SKIP() << "ARROW-16392: Substrait File URI not supported for Windows"; +// #else +// ExtensionSet ext_set; +// compute::ExecContext exec_context; + +// auto dummy_schema = schema({field("a", int32()), field("b", int32())}); +// auto table = TableFromJSON(dummy_schema, {R"([ +// [1, 1], +// [3, 4] +// ])", +// R"([ +// [0, 2], +// [1, 3], +// [4, 1], +// [3, 1], +// [1, 2] +// ])", +// R"([ +// [2, 2], +// [5, 3], +// [1, 3] +// ])"}); +// const std::string path = "/testing.parquet"; + +// EXPECT_OK_AND_ASSIGN(auto filesystem, +// fs::internal::MockFileSystem::Make(fs::kNoTime, {})); + +// EXPECT_EQ(WriteParquetData(path, filesystem, table), true); + +// auto format = std::make_shared(); + +// std::vector files; +// ASSERT_OK_AND_ASSIGN(auto f_file, filesystem->GetFileInfo(path)); +// files.push_back(std::move(f_file)); + +// ASSERT_OK_AND_ASSIGN(auto ds_factory, dataset::FileSystemDatasetFactory::Make( +// filesystem, files, format, {})); +// ASSERT_OK_AND_ASSIGN(auto other_ds_factory, dataset::FileSystemDatasetFactory::Make( +// std::move(filesystem), +// std::move(files), std::move(format), +// {})); + +// ASSERT_OK_AND_ASSIGN(auto dataset, ds_factory->Finish(dummy_schema)); +// ASSERT_OK_AND_ASSIGN(auto other_dataset, other_ds_factory->Finish(dummy_schema)); + +// auto options = std::make_shared(); +// options->projection = compute::project({}, {}); + +// auto scan_node_options = dataset::ScanNodeOptions{dataset, options}; + +// arrow::AsyncGenerator > sink_gen; + +// auto sink_node_options = compute::SinkNodeOptions{&sink_gen}; + +// auto scan_declaration = compute::Declaration({"scan", scan_node_options}); +// auto sink_declaration = compute::Declaration({"sink", sink_node_options}); + +// auto declarations = +// compute::Declaration::Sequence({scan_declaration, sink_declaration}); + +// ASSERT_OK_AND_ASSIGN(auto plan, compute::ExecPlan::Make(&exec_context)); +// ASSERT_OK_AND_ASSIGN(auto decl, declarations.AddToPlan(plan.get())); + +// ASSERT_OK(decl->Validate()); + +// std::shared_ptr sink_reader = compute::MakeGeneratorReader( +// dummy_schema, std::move(sink_gen), exec_context.memory_pool()); + +// ASSERT_OK(plan->Validate()); +// ASSERT_OK(plan->StartProducing()); + +// std::shared_ptr response_table; + +// ASSERT_OK_AND_ASSIGN(response_table, +// arrow::Table::FromRecordBatchReader(sink_reader.get())); + +// auto other_scan_node_options = dataset::ScanNodeOptions{other_dataset, options}; +// auto other_scan_declaration = compute::Declaration({"scan", +// other_scan_node_options}); + +// ASSERT_OK_AND_ASSIGN(auto serialized_rel, +// SerializeRelation(other_scan_declaration, &ext_set)); +// ASSERT_OK_AND_ASSIGN(auto deserialized_decl, +// DeserializeRelation(*serialized_rel, ext_set)); + +// // arrow::AsyncGenerator > des_sink_gen; +// // auto des_sink_node_options = compute::SinkNodeOptions{&des_sink_gen}; + +// // auto des_sink_declaration = compute::Declaration({"sink", des_sink_node_options}); + +// // auto t_decls = +// // compute::Declaration::Sequence({deserialized_decl, des_sink_declaration}); + +// // ASSERT_OK_AND_ASSIGN(auto t_plan, compute::ExecPlan::Make()); +// // ASSERT_OK_AND_ASSIGN(auto t_decl, t_decls.AddToPlan(t_plan.get())); + +// // ASSERT_OK(t_decl->Validate()); + +// // std::shared_ptr des_sink_reader = +// // compute::MakeGeneratorReader(dummy_schema, std::move(des_sink_gen), +// // exec_context.memory_pool()); + +// // ASSERT_OK(t_plan->Validate()); +// // ASSERT_OK(t_plan->StartProducing()); + +// // std::shared_ptr des_response_table; + +// // ASSERT_OK_AND_ASSIGN(des_response_table, +// // arrow::Table::FromRecordBatchReader(des_sink_reader.get())); + +// // ASSERT_TRUE(response_table->Equals(*des_response_table, true)); +// #endif +// } + } // namespace engine } // namespace arrow From 623ef61c184c69f0e38a22cd8bf252b19e65702f Mon Sep 17 00:00:00 2001 From: Vibhatha Abeykoon Date: Thu, 28 Jul 2022 09:37:50 +0530 Subject: [PATCH 02/30] fix(registry): initial version of registry def --- cpp/src/arrow/compute/exec/exec_plan.h | 8 +-- cpp/src/arrow/compute/exec/plan_test.cc | 32 +++++------ cpp/src/arrow/dataset/scanner_test.cc | 10 ++-- cpp/src/arrow/engine/CMakeLists.txt | 3 +- .../arrow/engine/substrait/plan_internal.cc | 37 ++++++++----- cpp/src/arrow/engine/substrait/registry.cc | 55 +++++-------------- cpp/src/arrow/engine/substrait/registry.h | 29 +++------- .../engine/substrait/relation_internal.cc | 49 +++++++++++++++++ .../engine/substrait/relation_internal.h | 4 +- 9 files changed, 122 insertions(+), 105 deletions(-) diff --git a/cpp/src/arrow/compute/exec/exec_plan.h b/cpp/src/arrow/compute/exec/exec_plan.h index 4e1a0398867..7be0feaa26d 100644 --- a/cpp/src/arrow/compute/exec/exec_plan.h +++ b/cpp/src/arrow/compute/exec/exec_plan.h @@ -199,13 +199,13 @@ class ARROW_EXPORT ExecNode { /// This node's exec plan ExecPlan* plan() { return plan_; } - /// Set this node's options + /// Set this node's options /// This is an optional method included to support Acero to Substrait /// serialization. - void SetOptions(ExecNodeOptions* options) { options_ = options; } + void SetOptions(std::shared_ptr options) { options_ = options; } /// This node's options - ExecNodeOptions* options() { return options_; } + std::shared_ptr options() { return options_; } /// \brief An optional label, for display and debugging /// @@ -376,7 +376,7 @@ class ARROW_EXPORT ExecNode { util::tracing::Span span_; - ExecNodeOptions* options_; + std::shared_ptr options_; }; /// \brief MapNode is an ExecNode type class which process a task like filter/project diff --git a/cpp/src/arrow/compute/exec/plan_test.cc b/cpp/src/arrow/compute/exec/plan_test.cc index e0262142914..906f3ffaa4d 100644 --- a/cpp/src/arrow/compute/exec/plan_test.cc +++ b/cpp/src/arrow/compute/exec/plan_test.cc @@ -280,24 +280,23 @@ TEST(ExecPlanExecution, TableSourceSink) { } TEST(ExecPlanExecution, TableSourceNodeOptions) { - ASSERT_OK_AND_ASSIGN(auto plan, ExecPlan::Make()); AsyncGenerator> sink_gen; auto exp_batches = MakeBasicBatches(); ASSERT_OK_AND_ASSIGN(auto table, - TableFromExecBatches(exp_batches.schema, exp_batches.batches)); - auto table_source_options = TableSourceNodeOptions{table, 3}; + TableFromExecBatches(exp_batches.schema, exp_batches.batches)); + auto table_source_options = std::make_shared(table, 3); - ASSERT_OK_AND_ASSIGN( - ExecNode * table_source, - MakeExecNode("table_source", plan.get(), {}, table_source_options)); - - table_source->SetOptions(&table_source_options); - const auto& res_table_options = static_cast(*table_source->options()); - - EXPECT_EQ(table_source_options.table, res_table_options.table); - EXPECT_EQ(table_source_options.max_batch_size, res_table_options.max_batch_size); + ASSERT_OK_AND_ASSIGN(ExecNode * table_source, MakeExecNode("table_source", plan.get(), + {}, *table_source_options)); + + table_source->SetOptions(table_source_options); + const auto& res_table_options = + static_cast(*table_source->options()); + + EXPECT_EQ(table_source_options->table, res_table_options.table); + EXPECT_EQ(table_source_options->max_batch_size, res_table_options.max_batch_size); } TEST(ExecPlanExecution, TableSourceSinkError) { @@ -1512,12 +1511,13 @@ TEST(ExecPlan, ExecNodeOption) { AsyncGenerator> sink_gen; - SourceNodeOptions options{input.schema, input.gen(/*parallel*/ true, /*slow=*/false)}; + auto options = std::make_shared( + input.schema, input.gen(/*parallel*/ true, /*slow=*/false)); - ASSERT_OK_AND_ASSIGN(auto* source, MakeExecNode("source", plan.get(), {}, options)); - source->SetOptions(&options); + ASSERT_OK_AND_ASSIGN(auto* source, MakeExecNode("source", plan.get(), {}, *options)); + source->SetOptions(options); const auto& opts = static_cast(*source->options()); - ASSERT_EQ(opts.output_schema, options.output_schema); + ASSERT_EQ(options->output_schema, opts.output_schema); } } // namespace compute diff --git a/cpp/src/arrow/dataset/scanner_test.cc b/cpp/src/arrow/dataset/scanner_test.cc index 982306725c7..d9343578b44 100644 --- a/cpp/src/arrow/dataset/scanner_test.cc +++ b/cpp/src/arrow/dataset/scanner_test.cc @@ -1998,14 +1998,14 @@ TEST(ScanNode, NodeOptions) { auto options = std::make_shared(); options->projection = Materialize({}); // set an empty projection - ScanNodeOptions scan_node_options{basic.dataset, options}; + auto scan_node_options = std::make_shared(basic.dataset, options); ASSERT_OK_AND_ASSIGN(auto* scan, - compute::MakeExecNode("scan", plan.get(), {}, scan_node_options)); - scan->SetOptions(&scan_node_options); + compute::MakeExecNode("scan", plan.get(), {}, *scan_node_options)); + scan->SetOptions(scan_node_options); const auto& res_scan_options = static_cast(*scan->options()); - ASSERT_EQ(scan_node_options.dataset->schema(), res_scan_options.dataset->schema()); - ASSERT_EQ(scan_node_options.scan_options->projection, + ASSERT_EQ(scan_node_options->dataset->schema(), res_scan_options.dataset->schema()); + ASSERT_EQ(scan_node_options->scan_options->projection, res_scan_options.scan_options->projection); } diff --git a/cpp/src/arrow/engine/CMakeLists.txt b/cpp/src/arrow/engine/CMakeLists.txt index a8d5be90af8..f3e39a37c53 100644 --- a/cpp/src/arrow/engine/CMakeLists.txt +++ b/cpp/src/arrow/engine/CMakeLists.txt @@ -25,8 +25,7 @@ set(ARROW_SUBSTRAIT_SRCS substrait/extension_types.cc substrait/plan_internal.cc substrait/relation_internal.cc - substrait/serde.cc - substrait/test_plan_builder.cc + substrait/registry.cc substrait/type_internal.cc substrait/util.cc) diff --git a/cpp/src/arrow/engine/substrait/plan_internal.cc b/cpp/src/arrow/engine/substrait/plan_internal.cc index 952a3f47f87..ac99b710e0b 100644 --- a/cpp/src/arrow/engine/substrait/plan_internal.cc +++ b/cpp/src/arrow/engine/substrait/plan_internal.cc @@ -16,7 +16,8 @@ // under the License. #include "arrow/engine/substrait/plan_internal.h" - +#include "arrow/dataset/plan.h" +#include "arrow/dataset/scanner.h" #include "arrow/result.h" #include "arrow/util/hashing.h" #include "arrow/util/logging.h" @@ -145,17 +146,20 @@ Status TraversePlan(compute::ExecNode* node, return Status::OK(); } -// arrow::Result MakeDeclaration(compute::ExecNode* node) { -// compute::Declaration decl; -// const auto* kind_name = node->kind_name(); -// if(kind_name == std::string("scan")) { -// std::vector inputs = node->inputs(); - -// } else { -// return Status::NotImplemented(kind_name, " relation not implemented."); -// } -// return decl; -// } +Result MakeDeclaration(compute::ExecNode* node) { + const auto* kind_name = node->kind_name(); + if (kind_name == std::string("scan")) { + std::vector> inputs( + node->num_inputs()); + int64_t idx = 0; + for (auto* input : node->inputs()) { + inputs[idx++] = input; + } + return compute::Declaration(kind_name, std::move(inputs), node->options()); + } else { + return Status::NotImplemented(kind_name, " relation not implemented."); + } +} Result> ToProto(const compute::ExecPlan& plan, ExtensionSet* ext_set) { @@ -173,9 +177,12 @@ Result> ToProto(const compute::ExecPlan& pla for (const auto& node : node_vec) { const auto* kind_name = node->kind_name(); std::cout << kind_name << std::endl; - //const auto& output_schema = node->output_schema(); - // if (output_schema) { - // } + const auto& output_schema = node->output_schema(); + auto declaration = MakeDeclaration(node); + if (output_schema) { + std::cout << "Schema >>> " << std::endl; + std::cout << output_schema->ToString(false) << std::endl; + } } return plan_rel; diff --git a/cpp/src/arrow/engine/substrait/registry.cc b/cpp/src/arrow/engine/substrait/registry.cc index 2456598a45b..f28867318f3 100644 --- a/cpp/src/arrow/engine/substrait/registry.cc +++ b/cpp/src/arrow/engine/substrait/registry.cc @@ -19,56 +19,30 @@ // deprecation cycle #include "arrow/engine/substrait/registry.h" +#include "arrow/engine/substrait/relation_internal.h" namespace arrow { namespace engine { -class SubstraitConversionRegistry::SubstraitConversionRegistryImpl { + +class SubstraitConversionRegistryImpl : public SubstraitConversionRegistry { public: - explicit SubstraitConversionRegistryImpl( - SubstraitConversionRegistryImpl* parent = NULLPTR) - : parent_(parent) {} - ~SubstraitConversionRegistryImpl() {} + SubstraitConversionRegistryImpl(); - std::unique_ptr Make() { - return std::unique_ptr( - new SubstraitConversionRegistry()); - } + Status RegisterConverter(std::string factory_name, + SubstraitConverter converter) override { + auto it_success = + name_to_converter_.emplace(std::move(factory_name), std::move(converter)); - std::unique_ptr Make(SubstraitConversionRegistry* parent) { - return std::unique_ptr(new SubstraitConversionRegistry( - new SubstraitConversionRegistry::SubstraitConversionRegistryImpl( - parent->impl_.get()))); - } - - Status RegisterConverter(const std::string& kind_name, SubstraitConverter converter) { - if (kind_name == "scan") { - return Status::NotImplemented("Scan serialization not implemented"); - } else if (kind_name == "filter") { - return Status::NotImplemented("Filter serialization not implemented"); - } else if (kind_name == "project") { - return Status::NotImplemented("Project serialization not implemented"); - } else if (kind_name == "augmented_project") { - return Status::NotImplemented("Augmented Project serialization not implemented"); - } else if (kind_name == "hashjoin") { - return Status::NotImplemented("Filter serialization not implemented"); - } else if (kind_name == "asofjoin") { - return Status::NotImplemented("Asof Join serialization not implemented"); - } else if (kind_name == "select_k_sink") { - return Status::NotImplemented("SelectK serialization not implemented"); - } else if (kind_name == "union") { - return Status::NotImplemented("Union serialization not implemented"); - } else if (kind_name == "write") { - return Status::NotImplemented("Write serialization not implemented"); - } else if (kind_name == "tee") { - return Status::NotImplemented("Tee serialization not implemented"); - } else { - return Status::Invalid("Unsupported ExecNode: ", kind_name); + if (!it_success.second) { + const auto& factory_name = it_success.first->first; + return Status::KeyError("SubstraitConverter named ", factory_name, + " already registered."); } + return Status::OK(); } - SubstraitConversionRegistryImpl* parent_; - std::mutex lock_; + private: std::unordered_map name_to_converter_; }; @@ -82,5 +56,4 @@ SubstraitConversionRegistry* GetSubstraitConversionRegistry() { } } // namespace engine - } // namespace arrow diff --git a/cpp/src/arrow/engine/substrait/registry.h b/cpp/src/arrow/engine/substrait/registry.h index 8fe857770a0..8761d8349b5 100644 --- a/cpp/src/arrow/engine/substrait/registry.h +++ b/cpp/src/arrow/engine/substrait/registry.h @@ -29,41 +29,30 @@ #include "arrow/util/visibility.h" #include "arrow/compute/exec/exec_plan.h" +#include "arrow/engine/substrait/extension_set.h" #include "arrow/engine/substrait/extension_types.h" +#include "arrow/engine/substrait/relation_internal.h" #include "arrow/engine/substrait/serde.h" #include "arrow/engine/substrait/visibility.h" #include "arrow/type_fwd.h" #include "substrait/algebra.pb.h" // IWYU pragma: export -namespace arrrow { +namespace arrow { namespace engine { -using SubstraitConverter = - std::function(Schema, Declaration)>; - class ARROW_EXPORT SubstraitConversionRegistry { public: - ~SubstraitConversionRegistry(); - - static std::unique_ptr Make(); - - static std::unique_ptr Make( - SubstraitConversionRegistry* parent); - - Status RegisterConverter(const std::string& kind_name, SubstraitConverter converter); - - private: - SubstraitConversionRegistry(); - - class SubstraitConversionRegistryImpl; - std::unique_ptr impl_; + virtual ~SubstraitConversionRegistry() = default; + using SubstraitConverter = std::function>( + const std::shared_ptr&, const compute::Declaration&, ExtensionSet*)>; - explicit SubstraitConversionRegistry(SubstraitConversionRegistryImpl* impl); + virtual Status RegisterConverter(std::string factory_name, + SubstraitConverter converter) = 0; }; ARROW_EXPORT SubstraitConversionRegistry* GetSubstraitConversionRegistry(); } // namespace engine -} // namespace arrrow +} // namespace arrow diff --git a/cpp/src/arrow/engine/substrait/relation_internal.cc b/cpp/src/arrow/engine/substrait/relation_internal.cc index efbe2ab83e2..42a43a533f4 100644 --- a/cpp/src/arrow/engine/substrait/relation_internal.cc +++ b/cpp/src/arrow/engine/substrait/relation_internal.cc @@ -503,5 +503,54 @@ Result> ToProto(const compute::Declaration& decl return MakeRelation(declaration, ext_set); } +Result> ScanRelationConverter( + const std::shared_ptr& schema, const compute::Declaration& declaration, + ExtensionSet* ext_set) { + auto rel = make_unique(); + auto read_rel = make_unique(); + const auto& scan_node_options = + checked_cast(*declaration.options); + auto dataset = + dynamic_cast(scan_node_options.dataset.get()); + if (dataset == nullptr) { + return Status::Invalid("Can only convert file system datasets to a Substrait plan."); + } + // set schema + ARROW_ASSIGN_OR_RAISE(auto named_struct, ToProto(*dataset->schema(), ext_set)); + read_rel->set_allocated_base_schema(named_struct.release()); + + // set local files + auto read_rel_lfs = make_unique(); + for (const auto& file : dataset->files()) { + auto read_rel_lfs_ffs = make_unique(); + read_rel_lfs_ffs->set_uri_path("file://" + file); + + // set file format + // arrow and feather are temporarily handled via the Parquet format until + // upgraded to the latest Substrait version. + auto format_type_name = dataset->format()->type_name(); + if (format_type_name == "parquet") { + auto parquet_fmt = + make_unique(); + read_rel_lfs_ffs->set_allocated_parquet(parquet_fmt.release()); + } else if (format_type_name == "arrow") { + auto arrow_fmt = + make_unique(); + read_rel_lfs_ffs->set_allocated_arrow(arrow_fmt.release()); + } else if (format_type_name == "orc") { + auto orc_fmt = + make_unique(); + read_rel_lfs_ffs->set_allocated_orc(orc_fmt.release()); + } else { + return Status::Invalid("Unsupported file type : ", format_type_name); + } + read_rel_lfs->mutable_items()->AddAllocated(read_rel_lfs_ffs.release()); + } + // TODO(Before PR Merge) : evaluate better hand-off of pointers + *read_rel->mutable_local_files() = *read_rel_lfs.get(); + rel->set_allocated_read(read_rel.release()); + return std::move(rel); +} + } // namespace engine } // namespace arrow diff --git a/cpp/src/arrow/engine/substrait/relation_internal.h b/cpp/src/arrow/engine/substrait/relation_internal.h index e6b99d58773..e1b95dc210d 100644 --- a/cpp/src/arrow/engine/substrait/relation_internal.h +++ b/cpp/src/arrow/engine/substrait/relation_internal.h @@ -48,8 +48,8 @@ ARROW_ENGINE_EXPORT Result> ToProto(const compute::Declaration&, ExtensionSet*); -ARROW_ENGINE_EXPORT std::tuple ScanRelationConverter( - const arrow::Schema&, const compute::Declaration&); +ARROW_ENGINE_EXPORT Result> ScanRelationConverter( + const std::shared_ptr&, const compute::Declaration&, ExtensionSet* ext_set); } // namespace engine } // namespace arrow From 49647c700500321c0b59bec1a412491982dae98f Mon Sep 17 00:00:00 2001 From: Vibhatha Abeykoon Date: Thu, 28 Jul 2022 17:12:51 +0530 Subject: [PATCH 03/30] adding registry end-to-end WIP --- .../arrow/engine/substrait/plan_internal.cc | 24 +++++++++++++++++-- cpp/src/arrow/engine/substrait/registry.cc | 17 ++++++++++--- cpp/src/arrow/engine/substrait/registry.h | 4 +++- 3 files changed, 39 insertions(+), 6 deletions(-) diff --git a/cpp/src/arrow/engine/substrait/plan_internal.cc b/cpp/src/arrow/engine/substrait/plan_internal.cc index ac99b710e0b..e20704114b4 100644 --- a/cpp/src/arrow/engine/substrait/plan_internal.cc +++ b/cpp/src/arrow/engine/substrait/plan_internal.cc @@ -16,6 +16,7 @@ // under the License. #include "arrow/engine/substrait/plan_internal.h" + #include "arrow/dataset/plan.h" #include "arrow/dataset/scanner.h" #include "arrow/result.h" @@ -24,6 +25,8 @@ #include "arrow/util/make_unique.h" #include "arrow/util/unreachable.h" +#include "arrow/engine/substrait/registry.h" + #include namespace arrow { @@ -161,10 +164,21 @@ Result MakeDeclaration(compute::ExecNode* node) { } } +Status SetRelation(const std::unique_ptr& plan, + const std::unique_ptr& partial_plan, + const std::string& factory_name) { + if (factory_name == "scan" && partial_plan->has_read()) { + plan->set_allocated_read(partial_plan->release_read()); + } else { + return Status::NotImplemented("Substrait converter ", factory_name, "not supported."); + } + return Status::OK(); +} + Result> ToProto(const compute::ExecPlan& plan, ExtensionSet* ext_set) { auto plan_rel = internal::make_unique(); - + auto rel = internal::make_unique(); std::cout << "Plan Show" << std::endl; std::cout << plan.ToString() << std::endl; @@ -173,15 +187,21 @@ Result> ToProto(const compute::ExecPlan& pla std::vector node_vec; ARROW_RETURN_NOT_OK(TraversePlan(sinks[0], &node_vec)); + SubstraitConversionRegistry* registry = default_substrait_conversion_registry(); + std::cout << "Printing Node Vector" << std::endl; for (const auto& node : node_vec) { const auto* kind_name = node->kind_name(); std::cout << kind_name << std::endl; const auto& output_schema = node->output_schema(); - auto declaration = MakeDeclaration(node); + ARROW_ASSIGN_OR_RAISE(auto declaration, MakeDeclaration(node)); + ARROW_ASSIGN_OR_RAISE(auto factory, registry->GetConverter(kind_name)); if (output_schema) { std::cout << "Schema >>> " << std::endl; std::cout << output_schema->ToString(false) << std::endl; + ARROW_ASSIGN_OR_RAISE(auto factory_rel, + factory(output_schema, declaration, ext_set)); + RETURN_NOT_OK(SetRelation(rel, factory_rel, kind_name)); } } diff --git a/cpp/src/arrow/engine/substrait/registry.cc b/cpp/src/arrow/engine/substrait/registry.cc index f28867318f3..83261a8bdce 100644 --- a/cpp/src/arrow/engine/substrait/registry.cc +++ b/cpp/src/arrow/engine/substrait/registry.cc @@ -27,7 +27,16 @@ namespace engine { class SubstraitConversionRegistryImpl : public SubstraitConversionRegistry { public: - SubstraitConversionRegistryImpl(); + virtual ~SubstraitConversionRegistryImpl() {} + + Result GetConverter(const std::string& factory_name) override { + auto it = name_to_converter_.find(factory_name); + if (it == name_to_converter_.end()) { + return Status::KeyError("SubstraitConverter named ", factory_name, + " not present in registry."); + } + return it->second; + } Status RegisterConverter(std::string factory_name, SubstraitConverter converter) override { @@ -47,10 +56,12 @@ class SubstraitConversionRegistryImpl : public SubstraitConversionRegistry { }; struct DefaultSubstraitConversionRegistry : SubstraitConversionRegistryImpl { - DefaultSubstraitConversionRegistry() {} + DefaultSubstraitConversionRegistry() { + DCHECK_OK(RegisterConverter("scan", ScanRelationConverter)); + } }; -SubstraitConversionRegistry* GetSubstraitConversionRegistry() { +SubstraitConversionRegistry* default_substrait_conversion_registry() { static DefaultSubstraitConversionRegistry impl_; return &impl_; } diff --git a/cpp/src/arrow/engine/substrait/registry.h b/cpp/src/arrow/engine/substrait/registry.h index 8761d8349b5..94f46e1dc7c 100644 --- a/cpp/src/arrow/engine/substrait/registry.h +++ b/cpp/src/arrow/engine/substrait/registry.h @@ -48,11 +48,13 @@ class ARROW_EXPORT SubstraitConversionRegistry { using SubstraitConverter = std::function>( const std::shared_ptr&, const compute::Declaration&, ExtensionSet*)>; + virtual Result GetConverter(const std::string& factory_name) = 0; + virtual Status RegisterConverter(std::string factory_name, SubstraitConverter converter) = 0; }; -ARROW_EXPORT SubstraitConversionRegistry* GetSubstraitConversionRegistry(); +ARROW_EXPORT SubstraitConversionRegistry* default_substrait_conversion_registry(); } // namespace engine } // namespace arrow From 57347db84fefd525d5daf087245d2216e8790577 Mon Sep 17 00:00:00 2001 From: Vibhatha Abeykoon Date: Mon, 1 Aug 2022 15:55:34 +0530 Subject: [PATCH 04/30] fix(scan): include initial version of registry test WIP --- .../arrow/engine/substrait/plan_internal.cc | 99 +++++++++---------- .../arrow/engine/substrait/plan_internal.h | 5 +- cpp/src/arrow/engine/substrait/registry.cc | 1 + .../engine/substrait/relation_internal.cc | 38 +++++++ .../engine/substrait/relation_internal.h | 3 + cpp/src/arrow/engine/substrait/serde.cc | 5 +- cpp/src/arrow/engine/substrait/serde.h | 3 +- cpp/src/arrow/engine/substrait/serde_test.cc | 77 +++++---------- 8 files changed, 117 insertions(+), 114 deletions(-) diff --git a/cpp/src/arrow/engine/substrait/plan_internal.cc b/cpp/src/arrow/engine/substrait/plan_internal.cc index e20704114b4..c6c3878c714 100644 --- a/cpp/src/arrow/engine/substrait/plan_internal.cc +++ b/cpp/src/arrow/engine/substrait/plan_internal.cc @@ -19,14 +19,14 @@ #include "arrow/dataset/plan.h" #include "arrow/dataset/scanner.h" +#include "arrow/engine/substrait/registry.h" #include "arrow/result.h" +#include "arrow/util/hash_util.h" #include "arrow/util/hashing.h" #include "arrow/util/logging.h" #include "arrow/util/make_unique.h" #include "arrow/util/unreachable.h" -#include "arrow/engine/substrait/registry.h" - #include namespace arrow { @@ -36,6 +36,7 @@ using internal::checked_cast; namespace engine { namespace internal { +using ::arrow::internal::hash_combine; using ::arrow::internal::make_unique; } // namespace internal @@ -137,74 +138,62 @@ Result GetExtensionSetFromPlan(const substrait::Plan& plan, registry); } -Status TraversePlan(compute::ExecNode* node, - std::vector* node_vector) { - if (node->num_inputs() > 0) { - const auto& sources = node->inputs(); - for (const auto& source : sources) { - ARROW_RETURN_NOT_OK(TraversePlan(source, node_vector)); - } +Status SetRelation(const std::unique_ptr& plan, + const std::unique_ptr& partial_plan, + const std::string& factory_name) { + if (factory_name == "scan" && partial_plan->has_read()) { + plan->set_allocated_read(partial_plan->release_read()); + } else if (factory_name == "filter" && partial_plan->has_filter()) { + plan->set_allocated_filter(partial_plan->release_filter()); + } else { + return Status::NotImplemented("Substrait converter ", factory_name, + " not supported."); } - node_vector->push_back(node); return Status::OK(); } -Result MakeDeclaration(compute::ExecNode* node) { - const auto* kind_name = node->kind_name(); - if (kind_name == std::string("scan")) { - std::vector> inputs( - node->num_inputs()); - int64_t idx = 0; - for (auto* input : node->inputs()) { - inputs[idx++] = input; - } - return compute::Declaration(kind_name, std::move(inputs), node->options()); +Result> ExtractSchemaToBind(const compute::Declaration& declr) { + std::shared_ptr bind_schema; + if (declr.factory_name == "scan") { + const auto& opts = checked_cast(*(declr.options)); + bind_schema = opts.dataset->schema(); + } else if (declr.factory_name == "filter") { + auto input_declr = util::get(declr.inputs[0]); + ARROW_ASSIGN_OR_RAISE(bind_schema, ExtractSchemaToBind(input_declr)); + } else if (declr.factory_name == "hashjoin") { + } else if (declr.factory_name == "sink") { + return bind_schema; } else { - return Status::NotImplemented(kind_name, " relation not implemented."); + return Status::Invalid("Schema extraction failed, unsupported factory ", + declr.factory_name); } + return bind_schema; } -Status SetRelation(const std::unique_ptr& plan, - const std::unique_ptr& partial_plan, - const std::string& factory_name) { - if (factory_name == "scan" && partial_plan->has_read()) { - plan->set_allocated_read(partial_plan->release_read()); - } else { - return Status::NotImplemented("Substrait converter ", factory_name, "not supported."); +Status TraverseDeclarations(const compute::Declaration& declaration, + ExtensionSet* ext_set, std::unique_ptr& rel) { + std::vector inputs = declaration.inputs; + for (auto& input : inputs) { + auto input_decl = util::get(input); + RETURN_NOT_OK(TraverseDeclarations(input_decl, ext_set, rel)); + } + const auto& factory_name = declaration.factory_name; + std::cout << factory_name << std::endl; + ARROW_ASSIGN_OR_RAISE(auto schema, ExtractSchemaToBind(declaration)); + SubstraitConversionRegistry* registry = default_substrait_conversion_registry(); + if (factory_name != "sink") { + ARROW_ASSIGN_OR_RAISE(auto factory, registry->GetConverter(factory_name)); + ARROW_ASSIGN_OR_RAISE(auto factory_rel, factory(schema, declaration, ext_set)); } return Status::OK(); } -Result> ToProto(const compute::ExecPlan& plan, +Result> ToProto(compute::ExecPlan* plan, + const compute::Declaration& declr, ExtensionSet* ext_set) { auto plan_rel = internal::make_unique(); auto rel = internal::make_unique(); - std::cout << "Plan Show" << std::endl; - - std::cout << plan.ToString() << std::endl; - std::cout << "----------" << std::endl; - const auto& sinks = plan.sinks(); - std::vector node_vec; - ARROW_RETURN_NOT_OK(TraversePlan(sinks[0], &node_vec)); - - SubstraitConversionRegistry* registry = default_substrait_conversion_registry(); - - std::cout << "Printing Node Vector" << std::endl; - for (const auto& node : node_vec) { - const auto* kind_name = node->kind_name(); - std::cout << kind_name << std::endl; - const auto& output_schema = node->output_schema(); - ARROW_ASSIGN_OR_RAISE(auto declaration, MakeDeclaration(node)); - ARROW_ASSIGN_OR_RAISE(auto factory, registry->GetConverter(kind_name)); - if (output_schema) { - std::cout << "Schema >>> " << std::endl; - std::cout << output_schema->ToString(false) << std::endl; - ARROW_ASSIGN_OR_RAISE(auto factory_rel, - factory(output_schema, declaration, ext_set)); - RETURN_NOT_OK(SetRelation(rel, factory_rel, kind_name)); - } - } - + RETURN_NOT_OK(TraverseDeclarations(declr, ext_set, rel)); return plan_rel; } diff --git a/cpp/src/arrow/engine/substrait/plan_internal.h b/cpp/src/arrow/engine/substrait/plan_internal.h index 2ae3cdd3ec2..686d10b2dbf 100644 --- a/cpp/src/arrow/engine/substrait/plan_internal.h +++ b/cpp/src/arrow/engine/substrait/plan_internal.h @@ -54,7 +54,10 @@ Result GetExtensionSetFromPlan( const ExtensionIdRegistry* registry = default_extension_id_registry()); ARROW_ENGINE_EXPORT Result> ToProto( - const compute::ExecPlan& plan, ExtensionSet* ext_set); + compute::ExecPlan* plan, const compute::Declaration& declr, ExtensionSet* ext_set); + +// ARROW_ENGINE_EXPORT Result> ToProto( +// const compute::Declaration& declaration, ExtensionSet* ext_set); } // namespace engine } // namespace arrow diff --git a/cpp/src/arrow/engine/substrait/registry.cc b/cpp/src/arrow/engine/substrait/registry.cc index 83261a8bdce..237a689b5ac 100644 --- a/cpp/src/arrow/engine/substrait/registry.cc +++ b/cpp/src/arrow/engine/substrait/registry.cc @@ -58,6 +58,7 @@ class SubstraitConversionRegistryImpl : public SubstraitConversionRegistry { struct DefaultSubstraitConversionRegistry : SubstraitConversionRegistryImpl { DefaultSubstraitConversionRegistry() { DCHECK_OK(RegisterConverter("scan", ScanRelationConverter)); + DCHECK_OK(RegisterConverter("filter", FilterRelationConverter)); } }; diff --git a/cpp/src/arrow/engine/substrait/relation_internal.cc b/cpp/src/arrow/engine/substrait/relation_internal.cc index 42a43a533f4..055ee9f746a 100644 --- a/cpp/src/arrow/engine/substrait/relation_internal.cc +++ b/cpp/src/arrow/engine/substrait/relation_internal.cc @@ -503,6 +503,16 @@ Result> ToProto(const compute::Declaration& decl return MakeRelation(declaration, ext_set); } +Result> GetRelationFromDeclaration( + const compute::Declaration& declaration, ExtensionSet* ext_set) { + auto declr_input = declaration.inputs[0]; + // TODO: figure out a better way + if (util::get_if(&declr_input)) { + return Status::NotImplemented("Only support Plans written in Declaration format."); + } + return ToProto(util::get(declr_input), ext_set); +} + Result> ScanRelationConverter( const std::shared_ptr& schema, const compute::Declaration& declaration, ExtensionSet* ext_set) { @@ -552,5 +562,33 @@ Result> ScanRelationConverter( return std::move(rel); } +Result> FilterRelationConverter( + const std::shared_ptr& schema, const compute::Declaration& declaration, + ExtensionSet* ext_set) { + auto rel = make_unique(); + auto filter_rel = make_unique(); + const auto& filter_node_options = + checked_cast(*(declaration.options)); + + auto filter_expr = filter_node_options.filter_expression; + compute::Expression bound_expression; + if (!filter_expr.IsBound()) { + ARROW_ASSIGN_OR_RAISE(bound_expression, filter_expr.Bind(*schema)); + } + + if (declaration.inputs.size() == 0) { + return Status::Invalid("Filter node doesn't have an input."); + } + + auto input_rel = GetRelationFromDeclaration(declaration, ext_set); + + filter_rel->set_allocated_input(input_rel->release()); + + ARROW_ASSIGN_OR_RAISE(auto subs_expr, ToProto(bound_expression, ext_set)); + *filter_rel->mutable_condition() = *subs_expr.get(); + rel->set_allocated_filter(filter_rel.release()); + return rel; +} + } // namespace engine } // namespace arrow diff --git a/cpp/src/arrow/engine/substrait/relation_internal.h b/cpp/src/arrow/engine/substrait/relation_internal.h index e1b95dc210d..d1fec63abcc 100644 --- a/cpp/src/arrow/engine/substrait/relation_internal.h +++ b/cpp/src/arrow/engine/substrait/relation_internal.h @@ -51,5 +51,8 @@ Result> ToProto(const compute::Declaration&, ARROW_ENGINE_EXPORT Result> ScanRelationConverter( const std::shared_ptr&, const compute::Declaration&, ExtensionSet* ext_set); +ARROW_ENGINE_EXPORT Result> FilterRelationConverter( + const std::shared_ptr&, const compute::Declaration&, ExtensionSet* ext_set); + } // namespace engine } // namespace arrow diff --git a/cpp/src/arrow/engine/substrait/serde.cc b/cpp/src/arrow/engine/substrait/serde.cc index 0c16611aff6..58d307e1201 100644 --- a/cpp/src/arrow/engine/substrait/serde.cc +++ b/cpp/src/arrow/engine/substrait/serde.cc @@ -52,9 +52,10 @@ Result ParseFromBuffer(const Buffer& buf) { return message; } -Result> SerializePlan(const compute::ExecPlan& plan, +Result> SerializePlan(compute::ExecPlan* plan, + const compute::Declaration& declr, ExtensionSet* ext_set) { - ARROW_ASSIGN_OR_RAISE(auto subs_plan, ToProto(plan, ext_set)); + ARROW_ASSIGN_OR_RAISE(auto subs_plan, ToProto(plan, declr, ext_set)); std::string serialized = subs_plan->SerializeAsString(); return Buffer::FromString(std::move(serialized)); } diff --git a/cpp/src/arrow/engine/substrait/serde.h b/cpp/src/arrow/engine/substrait/serde.h index 6f4c2bbb2bc..d573d55a7b1 100644 --- a/cpp/src/arrow/engine/substrait/serde.h +++ b/cpp/src/arrow/engine/substrait/serde.h @@ -37,7 +37,8 @@ namespace arrow { namespace engine { ARROW_ENGINE_EXPORT -Result> SerializePlan(const compute::ExecPlan& plan, +Result> SerializePlan(compute::ExecPlan* plan, + const compute::Declaration& declr, ExtensionSet* ext_set); /// Factory function type for generating the node that consumes the batches produced by diff --git a/cpp/src/arrow/engine/substrait/serde_test.cc b/cpp/src/arrow/engine/substrait/serde_test.cc index cb2ec8d5339..f66b0fc20ce 100644 --- a/cpp/src/arrow/engine/substrait/serde_test.cc +++ b/cpp/src/arrow/engine/substrait/serde_test.cc @@ -41,6 +41,9 @@ #include "parquet/arrow/writer.h" +#include "arrow/util/hash_util.h" +#include "arrow/util/hashing.h" + using testing::ElementsAre; using testing::Eq; using testing::HasSubstr; @@ -49,7 +52,7 @@ using testing::UnorderedElementsAre; namespace arrow { using internal::checked_cast; - +using internal::hash_combine; namespace engine { bool WriteParquetData(const std::string& path, @@ -1864,82 +1867,46 @@ TEST(Substrait, SerializePlan) { #else compute::ExecContext exec_context; ExtensionSet ext_set; - auto dummy_schema = schema({field("lkey", int32()), field("rkey", int32()), - field("shared", int32()), field("ldistinct", int32())}); + auto dummy_schema = schema( + {field("key", int32()), field("shared", int32()), field("distinct", int32())}); // creating a dummy dataset using a dummy table auto format = std::make_shared(); auto filesystem = std::make_shared(); - std::vector files1, files2; + std::vector files; const std::vector f_paths = {"/tmp/data1.parquet", "/tmp/data2.parquet"}; for (const auto& f_path : f_paths) { ASSERT_OK_AND_ASSIGN(auto f_file, filesystem->GetFileInfo(f_path)); - files1.push_back(std::move(f_file)); + files.push_back(std::move(f_file)); } - files2 = files1; - - ASSERT_OK_AND_ASSIGN(auto ds_factory1, dataset::FileSystemDatasetFactory::Make( - filesystem, std::move(files1), format, {})); - ASSERT_OK_AND_ASSIGN(auto dataset1, ds_factory1->Finish(dummy_schema)); - ASSERT_OK_AND_ASSIGN(auto ds_factory2, dataset::FileSystemDatasetFactory::Make( - filesystem, std::move(files2), format, {})); - ASSERT_OK_AND_ASSIGN(auto dataset2, ds_factory2->Finish(dummy_schema)); + ASSERT_OK_AND_ASSIGN(auto ds_factory, dataset::FileSystemDatasetFactory::Make( + filesystem, std::move(files), format, {})); + ASSERT_OK_AND_ASSIGN(auto dataset, ds_factory->Finish(dummy_schema)); - auto scan_options1 = std::make_shared(); - scan_options1->projection = compute::project({}, {}); + auto scan_options = std::make_shared(); + scan_options->projection = compute::project({}, {}); - auto scan_options2 = std::make_shared(); - scan_options2->projection = compute::project({}, {}); - - auto filter1 = compute::greater(compute::field_ref("lkey"), compute::literal(3)); - auto filter2 = compute::greater(compute::field_ref("lkey"), compute::literal(2)); - auto filter3 = compute::greater(compute::field_ref("lkey_l1"), compute::literal(1)); + auto filter = compute::equal(compute::field_ref("key"), compute::literal(3)); arrow::AsyncGenerator > sink_gen; - auto scan_node_options1 = dataset::ScanNodeOptions{dataset1, scan_options1}; - auto scan_node_options2 = dataset::ScanNodeOptions{dataset2, scan_options2}; - auto filter_node_options1 = compute::FilterNodeOptions{filter1}; - auto filter_node_options2 = compute::FilterNodeOptions{filter2}; - auto filter_node_options3 = compute::FilterNodeOptions{filter3}; + auto scan_node_options = dataset::ScanNodeOptions{dataset, scan_options}; + auto filter_node_options = compute::FilterNodeOptions{filter}; auto sink_node_options = compute::SinkNodeOptions{&sink_gen}; - auto scan_declaration1 = compute::Declaration({"scan", scan_node_options1}); - auto scan_declaration2 = compute::Declaration({"scan", scan_node_options2}); - auto filter_declaration1 = compute::Declaration({"filter", filter_node_options1}); - auto filter_declaration2 = compute::Declaration({"filter", filter_node_options2}); - auto filter_declaration3 = compute::Declaration({"filter", filter_node_options3}); - auto sink_declaration = compute::Declaration({"sink", sink_node_options}); - - auto scan_declarations1 = - compute::Declaration::Sequence({scan_declaration1, filter_declaration1}); - auto scan_declarations2 = - compute::Declaration::Sequence({scan_declaration2, filter_declaration2}); - compute::HashJoinNodeOptions join_node_options{arrow::compute::JoinType::INNER, - /*in_left_keys=*/{"lkey"}, - /*in_right_keys=*/{"rkey"}, - /*filter*/ arrow::compute::literal(true), - /*output_suffix_for_left*/ "_l1", - /*output_suffix_for_right*/ "_r1"}; - - auto join_declaration = compute::Declaration({"hashjoin", join_node_options}); - join_declaration.inputs.emplace_back(scan_declarations1); - join_declaration.inputs.emplace_back(scan_declarations2); + auto scan_declaration = compute::Declaration({"scan", scan_node_options, "s"}); + auto filter_declaration = compute::Declaration({"filter", filter_node_options, "f"}); + auto sink_declaration = compute::Declaration({"sink", sink_node_options, "e"}); auto declarations = compute::Declaration::Sequence( - {join_declaration, filter_declaration3, sink_declaration}); + {scan_declaration, filter_declaration, sink_declaration}); ASSERT_OK_AND_ASSIGN(auto plan, compute::ExecPlan::Make(&exec_context)); - ASSERT_OK_AND_ASSIGN(auto decl, declarations.AddToPlan(plan.get())); - - ASSERT_OK(decl->Validate()); - - ASSERT_OK_AND_ASSIGN(auto serialized_plan, SerializePlan(*plan, &ext_set)); - // ASSERT_OK_AND_ASSIGN(auto deserialized_plan, - // DeserializeRelation(*serialized_plan, ext_set)); + ASSERT_OK_AND_ASSIGN(auto serialized_plan, + SerializePlan(plan.get(), declarations, &ext_set)); #endif } From 00549f755999c379619c955badeec89e85641790 Mon Sep 17 00:00:00 2001 From: Vibhatha Abeykoon Date: Mon, 1 Aug 2022 20:22:29 +0530 Subject: [PATCH 05/30] adding initial testing on registry-based roundtrip tests for plan/relations --- .../arrow/engine/substrait/plan_internal.cc | 22 +++++++------ .../arrow/engine/substrait/plan_internal.h | 2 +- cpp/src/arrow/engine/substrait/serde_test.cc | 32 +++++++++++++++++-- 3 files changed, 44 insertions(+), 12 deletions(-) diff --git a/cpp/src/arrow/engine/substrait/plan_internal.cc b/cpp/src/arrow/engine/substrait/plan_internal.cc index c6c3878c714..fdce5eeba2b 100644 --- a/cpp/src/arrow/engine/substrait/plan_internal.cc +++ b/cpp/src/arrow/engine/substrait/plan_internal.cc @@ -170,31 +170,35 @@ Result> ExtractSchemaToBind(const compute::Declaration& return bind_schema; } -Status TraverseDeclarations(const compute::Declaration& declaration, - ExtensionSet* ext_set, std::unique_ptr& rel) { +Status SerializeRelations(const compute::Declaration& declaration, ExtensionSet* ext_set, + std::unique_ptr& rel) { std::vector inputs = declaration.inputs; for (auto& input : inputs) { auto input_decl = util::get(input); - RETURN_NOT_OK(TraverseDeclarations(input_decl, ext_set, rel)); + RETURN_NOT_OK(SerializeRelations(input_decl, ext_set, rel)); } const auto& factory_name = declaration.factory_name; - std::cout << factory_name << std::endl; ARROW_ASSIGN_OR_RAISE(auto schema, ExtractSchemaToBind(declaration)); SubstraitConversionRegistry* registry = default_substrait_conversion_registry(); if (factory_name != "sink") { ARROW_ASSIGN_OR_RAISE(auto factory, registry->GetConverter(factory_name)); ARROW_ASSIGN_OR_RAISE(auto factory_rel, factory(schema, declaration, ext_set)); + RETURN_NOT_OK(SetRelation(rel, factory_rel, factory_name)); } return Status::OK(); } -Result> ToProto(compute::ExecPlan* plan, - const compute::Declaration& declr, - ExtensionSet* ext_set) { +Result> ToProto(compute::ExecPlan* plan, + const compute::Declaration& declr, + ExtensionSet* ext_set) { + auto subs_plan = internal::make_unique(); auto plan_rel = internal::make_unique(); auto rel = internal::make_unique(); - RETURN_NOT_OK(TraverseDeclarations(declr, ext_set, rel)); - return plan_rel; + RETURN_NOT_OK(SerializeRelations(declr, ext_set, rel)); + plan_rel->set_allocated_rel(rel.release()); + subs_plan->mutable_relations()->AddAllocated(plan_rel.release()); + RETURN_NOT_OK(AddExtensionSetToPlan(*ext_set, subs_plan.get())); + return std::move(subs_plan); } } // namespace engine diff --git a/cpp/src/arrow/engine/substrait/plan_internal.h b/cpp/src/arrow/engine/substrait/plan_internal.h index 686d10b2dbf..81aaabad029 100644 --- a/cpp/src/arrow/engine/substrait/plan_internal.h +++ b/cpp/src/arrow/engine/substrait/plan_internal.h @@ -53,7 +53,7 @@ Result GetExtensionSetFromPlan( const substrait::Plan& plan, const ExtensionIdRegistry* registry = default_extension_id_registry()); -ARROW_ENGINE_EXPORT Result> ToProto( +ARROW_ENGINE_EXPORT Result> ToProto( compute::ExecPlan* plan, const compute::Declaration& declr, ExtensionSet* ext_set); // ARROW_ENGINE_EXPORT Result> ToProto( diff --git a/cpp/src/arrow/engine/substrait/serde_test.cc b/cpp/src/arrow/engine/substrait/serde_test.cc index f66b0fc20ce..c040555b0d7 100644 --- a/cpp/src/arrow/engine/substrait/serde_test.cc +++ b/cpp/src/arrow/engine/substrait/serde_test.cc @@ -1887,8 +1887,8 @@ TEST(Substrait, SerializePlan) { auto scan_options = std::make_shared(); scan_options->projection = compute::project({}, {}); - - auto filter = compute::equal(compute::field_ref("key"), compute::literal(3)); + const std::string filter_col = "shared"; + auto filter = compute::equal(compute::field_ref(filter_col), compute::literal(3)); arrow::AsyncGenerator > sink_gen; @@ -1907,6 +1907,34 @@ TEST(Substrait, SerializePlan) { ASSERT_OK_AND_ASSIGN(auto serialized_plan, SerializePlan(plan.get(), declarations, &ext_set)); + + for (auto sp_ext_id_reg : + {std::shared_ptr(), substrait::MakeExtensionIdRegistry()}) { + ExtensionIdRegistry* ext_id_reg = sp_ext_id_reg.get(); + ExtensionSet ext_set(ext_id_reg); + ASSERT_OK_AND_ASSIGN( + auto sink_decls, + DeserializePlans( + *serialized_plan, [] { return kNullConsumer; }, ext_id_reg, &ext_set)); + // filter declaration + auto roundtripped_filter = sink_decls[0].inputs[0].get(); + const auto& filter_opts = + checked_cast(*(roundtripped_filter->options)); + auto roundtripped_expr = filter_opts.filter_expression; + + if (auto* call = roundtripped_expr.call()) { + EXPECT_EQ(call->function_name, "equal"); + auto args = call->arguments; + auto index = args[0].field_ref()->field_path()->indices()[0]; + EXPECT_EQ(dummy_schema->field_names()[index], filter_col); + EXPECT_EQ(args[1], compute::literal(3)); + } + // scan declaration + auto roundtripped_scan = roundtripped_filter->inputs[0].get(); + const auto& dataset_opts = + checked_cast(*(roundtripped_scan->options)); + EXPECT_TRUE(dataset_opts.dataset->schema()->Equals(*dummy_schema)); + } #endif } From 87c5328f029bd984740d7b157ac39633ce48238a Mon Sep 17 00:00:00 2001 From: Vibhatha Abeykoon Date: Tue, 2 Aug 2022 12:01:09 +0530 Subject: [PATCH 06/30] feat(test): adding end-to-end testing --- cpp/src/arrow/engine/substrait/serde_test.cc | 70 ++++++++++++++++++-- 1 file changed, 64 insertions(+), 6 deletions(-) diff --git a/cpp/src/arrow/engine/substrait/serde_test.cc b/cpp/src/arrow/engine/substrait/serde_test.cc index c040555b0d7..f88e78ba722 100644 --- a/cpp/src/arrow/engine/substrait/serde_test.cc +++ b/cpp/src/arrow/engine/substrait/serde_test.cc @@ -55,15 +55,26 @@ using internal::checked_cast; using internal::hash_combine; namespace engine { -bool WriteParquetData(const std::string& path, - const std::shared_ptr file_system, - const std::shared_ptr
input, const int64_t chunk_size = 3) { +Status WriteParquetData(const std::string& path, + const std::shared_ptr file_system, + const std::shared_ptr
input, + const int64_t chunk_size = 3) { EXPECT_OK_AND_ASSIGN(auto buffer_writer, file_system->OpenOutputStream(path)); PARQUET_THROW_NOT_OK(parquet::arrow::WriteTable(*input, arrow::default_memory_pool(), buffer_writer, chunk_size)); - return buffer_writer->Close().ok(); + return buffer_writer->Close(); } +// Result WriteTemporaryData(const std::string& file_name, const +// std::shared_ptr
& table, const std::shared_ptr& filesystem) +// { +// ARROW_ASSIGN_OR_RAISE(auto tempdir, +// arrow::internal::TemporaryDir::Make("substrait_tempdir")); ARROW_ASSIGN_OR_RAISE(auto +// file_path, tempdir->path().Join(file_name)); std::string file_path_str = +// file_path.ToString(); EXPECT_EQ(WriteParquetData(file_path_str, filesystem, table), +// true); return file_path_str; +// } + bool CompareDataset(std::shared_ptr ds_lhs, std::shared_ptr ds_rhs) { const auto& fsd_lhs = checked_cast(*ds_lhs); @@ -1870,11 +1881,44 @@ TEST(Substrait, SerializePlan) { auto dummy_schema = schema( {field("key", int32()), field("shared", int32()), field("distinct", int32())}); // creating a dummy dataset using a dummy table + + auto table = TableFromJSON(dummy_schema, {R"([ + [1, 1, 10], + [3, 4, 20] + ])", + R"([ + [0, 2, 1], + [1, 3, 2], + [4, 1, 3], + [3, 1, 3], + [1, 2, 5] + ])", + R"([ + [2, 2, 12], + [5, 3, 12], + [1, 3, 12] + ])"}); + auto format = std::make_shared(); auto filesystem = std::make_shared(); + const std::string file_name = "serde_test.parquet"; + + ASSERT_OK_AND_ASSIGN(auto tempdir, + arrow::internal::TemporaryDir::Make("substrait_tempdir")); + ASSERT_OK_AND_ASSIGN(auto file_path, tempdir->path().Join(file_name)); + std::string file_path_str = file_path.ToString(); + + // Note: there is an additional forward slash introduced by the tempdir + // it must be replaced to properly load into reading files + // TODO: (Review: Jira needs to be reported to handle this properly) + std::string toReplace("/T//"); + size_t pos = file_path_str.find(toReplace); + file_path_str.replace(pos, toReplace.length(), "/T/"); + + ARROW_EXPECT_OK(WriteParquetData(file_path_str, filesystem, table)); std::vector files; - const std::vector f_paths = {"/tmp/data1.parquet", "/tmp/data2.parquet"}; + const std::vector f_paths = {file_path_str}; for (const auto& f_path : f_paths) { ASSERT_OK_AND_ASSIGN(auto f_file, filesystem->GetFileInfo(f_path)); @@ -1933,7 +1977,21 @@ TEST(Substrait, SerializePlan) { auto roundtripped_scan = roundtripped_filter->inputs[0].get(); const auto& dataset_opts = checked_cast(*(roundtripped_scan->options)); - EXPECT_TRUE(dataset_opts.dataset->schema()->Equals(*dummy_schema)); + const auto& roundripped_ds = dataset_opts.dataset; + EXPECT_TRUE(roundripped_ds->schema()->Equals(*dummy_schema)); + ASSERT_OK_AND_ASSIGN(auto roundtripped_frgs, roundripped_ds->GetFragments()); + ASSERT_OK_AND_ASSIGN(auto expected_frgs, dataset->GetFragments()); + + auto roundtrip_frg_vec = IteratorToVector(std::move(roundtripped_frgs)); + auto expected_frg_vec = IteratorToVector(std::move(expected_frgs)); + EXPECT_EQ(expected_frg_vec.size(), roundtrip_frg_vec.size()); + int64_t idx = 0; + for (auto fragment : expected_frg_vec) { + const auto* l_frag = checked_cast(fragment.get()); + const auto* r_frag = + checked_cast(roundtrip_frg_vec[idx++].get()); + EXPECT_TRUE(l_frag->Equals(*r_frag)); + } } #endif } From 53e2245998260bd69ea2380274f500651707297e Mon Sep 17 00:00:00 2001 From: Vibhatha Abeykoon Date: Tue, 2 Aug 2022 12:32:19 +0530 Subject: [PATCH 07/30] feat(test-end-to-end): added an end-to-end test case --- cpp/src/arrow/engine/substrait/serde_test.cc | 303 +++++++++---------- 1 file changed, 137 insertions(+), 166 deletions(-) diff --git a/cpp/src/arrow/engine/substrait/serde_test.cc b/cpp/src/arrow/engine/substrait/serde_test.cc index f88e78ba722..f6bb351296f 100644 --- a/cpp/src/arrow/engine/substrait/serde_test.cc +++ b/cpp/src/arrow/engine/substrait/serde_test.cc @@ -65,6 +65,22 @@ Status WriteParquetData(const std::string& path, return buffer_writer->Close(); } +Result> GetTableFromPlan( + std::shared_ptr& plan, compute::Declaration& declarations, + arrow::AsyncGenerator>& sink_gen, + compute::ExecContext& exec_context, std::shared_ptr& output_schema) { + ARROW_ASSIGN_OR_RAISE(auto decl, declarations.AddToPlan(plan.get())); + + RETURN_NOT_OK(decl->Validate()); + + std::shared_ptr sink_reader = compute::MakeGeneratorReader( + output_schema, std::move(sink_gen), exec_context.memory_pool()); + + RETURN_NOT_OK(plan->Validate()); + RETURN_NOT_OK(plan->StartProducing()); + return arrow::Table::FromRecordBatchReader(sink_reader.get()); +} + // Result WriteTemporaryData(const std::string& file_name, const // std::shared_ptr
& table, const std::shared_ptr& filesystem) // { @@ -1872,7 +1888,7 @@ TEST(Substrait, AggregateBadPhase) { ASSERT_RAISES(NotImplemented, DeserializePlans(*buf, [] { return kNullConsumer; })); } -TEST(Substrait, SerializePlan) { +TEST(Substrait, BasicPlanRoundTripping) { #ifdef _WIN32 GTEST_SKIP() << "ARROW-16392: Substrait File URI not supported for Windows"; #else @@ -1880,8 +1896,8 @@ TEST(Substrait, SerializePlan) { ExtensionSet ext_set; auto dummy_schema = schema( {field("key", int32()), field("shared", int32()), field("distinct", int32())}); - // creating a dummy dataset using a dummy table + // creating a dummy dataset using a dummy table auto table = TableFromJSON(dummy_schema, {R"([ [1, 1, 10], [3, 4, 20] @@ -1934,7 +1950,7 @@ TEST(Substrait, SerializePlan) { const std::string filter_col = "shared"; auto filter = compute::equal(compute::field_ref(filter_col), compute::literal(3)); - arrow::AsyncGenerator > sink_gen; + arrow::AsyncGenerator> sink_gen; auto scan_node_options = dataset::ScanNodeOptions{dataset, scan_options}; auto filter_node_options = compute::FilterNodeOptions{filter}; @@ -1996,188 +2012,143 @@ TEST(Substrait, SerializePlan) { #endif } -// TEST(Substrait, SerializeRelation) { -// #ifdef _WIN32 -// GTEST_SKIP() << "ARROW-16392: Substrait File URI not supported for Windows"; -// #else -// ExtensionSet ext_set; -// auto dummy_schema = schema({field("a", int32()), field("b", int32())}); -// auto table = TableFromJSON(dummy_schema, {R"([ -// [1, 1], -// [3, 4] -// ])", -// R"([ -// [0, 2], -// [1, 3], -// [4, 1], -// [3, 1], -// [1, 2] -// ])", -// R"([ -// [2, 2], -// [5, 3], -// [1, 3] -// ])"}); -// const std::string path = "/testing.parquet"; - -// EXPECT_OK_AND_ASSIGN(auto filesystem, -// fs::internal::MockFileSystem::Make(fs::kNoTime, {})); - -// EXPECT_EQ(WriteParquetData(path, filesystem, table), true); -// // creating a dummy dataset using a dummy table -// auto format = std::make_shared(); - -// std::vector files; -// const std::vector f_paths = {path}; - -// for (const auto& f_path : f_paths) { -// ASSERT_OK_AND_ASSIGN(auto f_file, filesystem->GetFileInfo(f_path)); -// files.push_back(std::move(f_file)); -// } - -// ASSERT_OK_AND_ASSIGN(auto ds_factory, -// dataset::FileSystemDatasetFactory::Make( -// filesystem, std::move(files), std::move(format), {})); -// ASSERT_OK_AND_ASSIGN(auto dataset, ds_factory->Finish(dummy_schema)); - -// auto options = std::make_shared(); -// options->projection = compute::project({}, {}); -// auto scan_node_options = dataset::ScanNodeOptions{dataset, options}; - -// auto scan_declaration = compute::Declaration({"scan", scan_node_options}); - -// ASSERT_OK_AND_ASSIGN(auto serialized_rel, -// SerializeRelation(scan_declaration, &ext_set)); -// ASSERT_OK_AND_ASSIGN(auto deserialized_decl, -// DeserializeRelation(*serialized_rel, ext_set)); - -// auto& mfs = checked_cast(*filesystem); -// mfs.AllFiles(); - -// EXPECT_EQ(deserialized_decl.factory_name, scan_declaration.factory_name); -// const auto& lhs = -// checked_cast(*deserialized_decl.options); -// const auto& rhs = -// checked_cast(*scan_declaration.options); -// ASSERT_TRUE(CompareScanOptions(lhs, rhs)); -// #endif -// } - -// TEST(Substrait, SerializeRelationEndToEnd) { -// #ifdef _WIN32 -// GTEST_SKIP() << "ARROW-16392: Substrait File URI not supported for Windows"; -// #else -// ExtensionSet ext_set; -// compute::ExecContext exec_context; - -// auto dummy_schema = schema({field("a", int32()), field("b", int32())}); -// auto table = TableFromJSON(dummy_schema, {R"([ -// [1, 1], -// [3, 4] -// ])", -// R"([ -// [0, 2], -// [1, 3], -// [4, 1], -// [3, 1], -// [1, 2] -// ])", -// R"([ -// [2, 2], -// [5, 3], -// [1, 3] -// ])"}); -// const std::string path = "/testing.parquet"; - -// EXPECT_OK_AND_ASSIGN(auto filesystem, -// fs::internal::MockFileSystem::Make(fs::kNoTime, {})); - -// EXPECT_EQ(WriteParquetData(path, filesystem, table), true); - -// auto format = std::make_shared(); - -// std::vector files; -// ASSERT_OK_AND_ASSIGN(auto f_file, filesystem->GetFileInfo(path)); -// files.push_back(std::move(f_file)); - -// ASSERT_OK_AND_ASSIGN(auto ds_factory, dataset::FileSystemDatasetFactory::Make( -// filesystem, files, format, {})); -// ASSERT_OK_AND_ASSIGN(auto other_ds_factory, dataset::FileSystemDatasetFactory::Make( -// std::move(filesystem), -// std::move(files), std::move(format), -// {})); - -// ASSERT_OK_AND_ASSIGN(auto dataset, ds_factory->Finish(dummy_schema)); -// ASSERT_OK_AND_ASSIGN(auto other_dataset, other_ds_factory->Finish(dummy_schema)); - -// auto options = std::make_shared(); -// options->projection = compute::project({}, {}); - -// auto scan_node_options = dataset::ScanNodeOptions{dataset, options}; - -// arrow::AsyncGenerator > sink_gen; - -// auto sink_node_options = compute::SinkNodeOptions{&sink_gen}; - -// auto scan_declaration = compute::Declaration({"scan", scan_node_options}); -// auto sink_declaration = compute::Declaration({"sink", sink_node_options}); +TEST(Substrait, BasicPlanRoundTrippingEndToEnd) { +#ifdef _WIN32 + GTEST_SKIP() << "ARROW-16392: Substrait File URI not supported for Windows"; +#else + compute::ExecContext exec_context; + ExtensionSet ext_set; + auto dummy_schema = schema( + {field("key", int32()), field("shared", int32()), field("distinct", int32())}); -// auto declarations = -// compute::Declaration::Sequence({scan_declaration, sink_declaration}); + // creating a dummy dataset using a dummy table + auto table = TableFromJSON(dummy_schema, {R"([ + [1, 1, 10], + [3, 4, 20] + ])", + R"([ + [0, 2, 1], + [1, 3, 2], + [4, 1, 3], + [3, 1, 3], + [1, 2, 5] + ])", + R"([ + [2, 2, 12], + [5, 3, 12], + [1, 3, 12] + ])"}); -// ASSERT_OK_AND_ASSIGN(auto plan, compute::ExecPlan::Make(&exec_context)); -// ASSERT_OK_AND_ASSIGN(auto decl, declarations.AddToPlan(plan.get())); + auto format = std::make_shared(); + auto filesystem = std::make_shared(); + const std::string file_name = "serde_test.parquet"; -// ASSERT_OK(decl->Validate()); + ASSERT_OK_AND_ASSIGN(auto tempdir, + arrow::internal::TemporaryDir::Make("substrait_tempdir")); + ASSERT_OK_AND_ASSIGN(auto file_path, tempdir->path().Join(file_name)); + std::string file_path_str = file_path.ToString(); -// std::shared_ptr sink_reader = compute::MakeGeneratorReader( -// dummy_schema, std::move(sink_gen), exec_context.memory_pool()); + // Note: there is an additional forward slash introduced by the tempdir + // it must be replaced to properly load into reading files + // TODO: (Review: Jira needs to be reported to handle this properly) + std::string toReplace("/T//"); + size_t pos = file_path_str.find(toReplace); + file_path_str.replace(pos, toReplace.length(), "/T/"); -// ASSERT_OK(plan->Validate()); -// ASSERT_OK(plan->StartProducing()); + ARROW_EXPECT_OK(WriteParquetData(file_path_str, filesystem, table)); -// std::shared_ptr response_table; + std::vector files; + const std::vector f_paths = {file_path_str}; -// ASSERT_OK_AND_ASSIGN(response_table, -// arrow::Table::FromRecordBatchReader(sink_reader.get())); + for (const auto& f_path : f_paths) { + ASSERT_OK_AND_ASSIGN(auto f_file, filesystem->GetFileInfo(f_path)); + files.push_back(std::move(f_file)); + } -// auto other_scan_node_options = dataset::ScanNodeOptions{other_dataset, options}; -// auto other_scan_declaration = compute::Declaration({"scan", -// other_scan_node_options}); + ASSERT_OK_AND_ASSIGN(auto ds_factory, dataset::FileSystemDatasetFactory::Make( + filesystem, std::move(files), format, {})); + ASSERT_OK_AND_ASSIGN(auto dataset, ds_factory->Finish(dummy_schema)); -// ASSERT_OK_AND_ASSIGN(auto serialized_rel, -// SerializeRelation(other_scan_declaration, &ext_set)); -// ASSERT_OK_AND_ASSIGN(auto deserialized_decl, -// DeserializeRelation(*serialized_rel, ext_set)); + auto scan_options = std::make_shared(); + scan_options->projection = compute::project({}, {}); + const std::string filter_col = "shared"; + auto filter = compute::equal(compute::field_ref(filter_col), compute::literal(3)); -// // arrow::AsyncGenerator > des_sink_gen; -// // auto des_sink_node_options = compute::SinkNodeOptions{&des_sink_gen}; + arrow::AsyncGenerator> sink_gen; -// // auto des_sink_declaration = compute::Declaration({"sink", des_sink_node_options}); + auto scan_node_options = dataset::ScanNodeOptions{dataset, scan_options}; + auto filter_node_options = compute::FilterNodeOptions{filter}; + auto sink_node_options = compute::SinkNodeOptions{&sink_gen}; -// // auto t_decls = -// // compute::Declaration::Sequence({deserialized_decl, des_sink_declaration}); + auto scan_declaration = compute::Declaration({"scan", scan_node_options, "s"}); + auto filter_declaration = compute::Declaration({"filter", filter_node_options, "f"}); + auto sink_declaration = compute::Declaration({"sink", sink_node_options, "e"}); -// // ASSERT_OK_AND_ASSIGN(auto t_plan, compute::ExecPlan::Make()); -// // ASSERT_OK_AND_ASSIGN(auto t_decl, t_decls.AddToPlan(t_plan.get())); + auto declarations = compute::Declaration::Sequence( + {scan_declaration, filter_declaration, sink_declaration}); -// // ASSERT_OK(t_decl->Validate()); + ASSERT_OK_AND_ASSIGN(auto plan, compute::ExecPlan::Make(&exec_context)); -// // std::shared_ptr des_sink_reader = -// // compute::MakeGeneratorReader(dummy_schema, std::move(des_sink_gen), -// // exec_context.memory_pool()); + ASSERT_OK_AND_ASSIGN(auto serialized_plan, + SerializePlan(plan.get(), declarations, &ext_set)); -// // ASSERT_OK(t_plan->Validate()); -// // ASSERT_OK(t_plan->StartProducing()); + ASSERT_OK_AND_ASSIGN(auto expected_tb, GetTableFromPlan(plan, declarations, sink_gen, + exec_context, dummy_schema)); -// // std::shared_ptr des_response_table; + for (auto sp_ext_id_reg : + {std::shared_ptr(), substrait::MakeExtensionIdRegistry()}) { + ExtensionIdRegistry* ext_id_reg = sp_ext_id_reg.get(); + ExtensionSet ext_set(ext_id_reg); + ASSERT_OK_AND_ASSIGN( + auto sink_decls, + DeserializePlans( + *serialized_plan, [] { return kNullConsumer; }, ext_id_reg, &ext_set)); + // filter declaration + auto roundtripped_filter = sink_decls[0].inputs[0].get(); + const auto& filter_opts = + checked_cast(*(roundtripped_filter->options)); + auto roundtripped_expr = filter_opts.filter_expression; -// // ASSERT_OK_AND_ASSIGN(des_response_table, -// // arrow::Table::FromRecordBatchReader(des_sink_reader.get())); + if (auto* call = roundtripped_expr.call()) { + EXPECT_EQ(call->function_name, "equal"); + auto args = call->arguments; + auto index = args[0].field_ref()->field_path()->indices()[0]; + EXPECT_EQ(dummy_schema->field_names()[index], filter_col); + EXPECT_EQ(args[1], compute::literal(3)); + } + // scan declaration + auto roundtripped_scan = roundtripped_filter->inputs[0].get(); + const auto& dataset_opts = + checked_cast(*(roundtripped_scan->options)); + const auto& roundripped_ds = dataset_opts.dataset; + EXPECT_TRUE(roundripped_ds->schema()->Equals(*dummy_schema)); + ASSERT_OK_AND_ASSIGN(auto roundtripped_frgs, roundripped_ds->GetFragments()); + ASSERT_OK_AND_ASSIGN(auto expected_frgs, dataset->GetFragments()); -// // ASSERT_TRUE(response_table->Equals(*des_response_table, true)); -// #endif -// } + auto roundtrip_frg_vec = IteratorToVector(std::move(roundtripped_frgs)); + auto expected_frg_vec = IteratorToVector(std::move(expected_frgs)); + EXPECT_EQ(expected_frg_vec.size(), roundtrip_frg_vec.size()); + int64_t idx = 0; + for (auto fragment : expected_frg_vec) { + const auto* l_frag = checked_cast(fragment.get()); + const auto* r_frag = + checked_cast(roundtrip_frg_vec[idx++].get()); + EXPECT_TRUE(l_frag->Equals(*r_frag)); + } + arrow::AsyncGenerator> rnd_trp_sink_gen; + auto rnd_trp_sink_node_options = compute::SinkNodeOptions{&rnd_trp_sink_gen}; + auto rnd_trp_sink_declaration = + compute::Declaration({"sink", rnd_trp_sink_node_options, "e"}); + auto rnd_trp_declarations = + compute::Declaration::Sequence({*roundtripped_filter, rnd_trp_sink_declaration}); + ASSERT_OK_AND_ASSIGN(auto rnd_trp_plan, compute::ExecPlan::Make(&exec_context)); + ASSERT_OK_AND_ASSIGN(auto rnd_trp_table, + GetTableFromPlan(rnd_trp_plan, rnd_trp_declarations, + rnd_trp_sink_gen, exec_context, dummy_schema)); + EXPECT_TRUE(expected_tb->Equals(*rnd_trp_table)); + } +#endif +} } // namespace engine } // namespace arrow From 6907639f2da110ce1a4d263e731b3757e8d56511 Mon Sep 17 00:00:00 2001 From: Vibhatha Abeykoon Date: Tue, 2 Aug 2022 20:17:31 +0530 Subject: [PATCH 08/30] fix(refactor): refactor the relation and plan impl --- cpp/src/arrow/compute/exec/exec_plan.h | 10 -- cpp/src/arrow/compute/exec/plan_test.cc | 39 ------ cpp/src/arrow/compute/exec/source_node.cc | 2 +- cpp/src/arrow/dataset/scanner_test.cc | 18 --- .../arrow/engine/substrait/plan_internal.cc | 110 +++++++-------- .../arrow/engine/substrait/plan_internal.h | 8 +- cpp/src/arrow/engine/substrait/registry.h | 3 +- .../engine/substrait/relation_internal.cc | 130 ++++++++---------- .../engine/substrait/relation_internal.h | 13 +- cpp/src/arrow/engine/substrait/serde.cc | 17 +-- cpp/src/arrow/engine/substrait/serde.h | 13 +- cpp/src/arrow/engine/substrait/serde_test.cc | 11 +- 12 files changed, 147 insertions(+), 227 deletions(-) diff --git a/cpp/src/arrow/compute/exec/exec_plan.h b/cpp/src/arrow/compute/exec/exec_plan.h index 7be0feaa26d..263f3634a5a 100644 --- a/cpp/src/arrow/compute/exec/exec_plan.h +++ b/cpp/src/arrow/compute/exec/exec_plan.h @@ -199,14 +199,6 @@ class ARROW_EXPORT ExecNode { /// This node's exec plan ExecPlan* plan() { return plan_; } - /// Set this node's options - /// This is an optional method included to support Acero to Substrait - /// serialization. - void SetOptions(std::shared_ptr options) { options_ = options; } - - /// This node's options - std::shared_ptr options() { return options_; } - /// \brief An optional label, for display and debugging /// /// There is no guarantee that this value is non-empty or unique. @@ -375,8 +367,6 @@ class ARROW_EXPORT ExecNode { Future<> finished_ = Future<>::Make(); util::tracing::Span span_; - - std::shared_ptr options_; }; /// \brief MapNode is an ExecNode type class which process a task like filter/project diff --git a/cpp/src/arrow/compute/exec/plan_test.cc b/cpp/src/arrow/compute/exec/plan_test.cc index 906f3ffaa4d..e06c41c7489 100644 --- a/cpp/src/arrow/compute/exec/plan_test.cc +++ b/cpp/src/arrow/compute/exec/plan_test.cc @@ -279,26 +279,6 @@ TEST(ExecPlanExecution, TableSourceSink) { } } -TEST(ExecPlanExecution, TableSourceNodeOptions) { - ASSERT_OK_AND_ASSIGN(auto plan, ExecPlan::Make()); - AsyncGenerator> sink_gen; - - auto exp_batches = MakeBasicBatches(); - ASSERT_OK_AND_ASSIGN(auto table, - TableFromExecBatches(exp_batches.schema, exp_batches.batches)); - auto table_source_options = std::make_shared(table, 3); - - ASSERT_OK_AND_ASSIGN(ExecNode * table_source, MakeExecNode("table_source", plan.get(), - {}, *table_source_options)); - - table_source->SetOptions(table_source_options); - const auto& res_table_options = - static_cast(*table_source->options()); - - EXPECT_EQ(table_source_options->table, res_table_options.table); - EXPECT_EQ(table_source_options->max_batch_size, res_table_options.max_batch_size); -} - TEST(ExecPlanExecution, TableSourceSinkError) { ASSERT_OK_AND_ASSIGN(auto plan, ExecPlan::Make()); AsyncGenerator> sink_gen; @@ -1501,24 +1481,5 @@ TEST(ExecPlan, SourceEnforcesBatchLimit) { } } -TEST(ExecPlan, ExecNodeOption) { - auto input = MakeGroupableBatches(); - - auto exec_ctx = arrow::internal::make_unique( - default_memory_pool(), arrow::internal::GetCpuThreadPool()); - - ASSERT_OK_AND_ASSIGN(auto plan, ExecPlan::Make(exec_ctx.get())); - - AsyncGenerator> sink_gen; - - auto options = std::make_shared( - input.schema, input.gen(/*parallel*/ true, /*slow=*/false)); - - ASSERT_OK_AND_ASSIGN(auto* source, MakeExecNode("source", plan.get(), {}, *options)); - source->SetOptions(options); - const auto& opts = static_cast(*source->options()); - ASSERT_EQ(options->output_schema, opts.output_schema); -} - } // namespace compute } // namespace arrow diff --git a/cpp/src/arrow/compute/exec/source_node.cc b/cpp/src/arrow/compute/exec/source_node.cc index 783b7d33e8b..a640cf737ef 100644 --- a/cpp/src/arrow/compute/exec/source_node.cc +++ b/cpp/src/arrow/compute/exec/source_node.cc @@ -48,7 +48,7 @@ namespace { struct SourceNode : ExecNode { SourceNode(ExecPlan* plan, std::shared_ptr output_schema, AsyncGenerator> generator) - : ExecNode(plan, {}, {}, output_schema, + : ExecNode(plan, {}, {}, std::move(output_schema), /*num_outputs=*/1), generator_(std::move(generator)) {} diff --git a/cpp/src/arrow/dataset/scanner_test.cc b/cpp/src/arrow/dataset/scanner_test.cc index d9343578b44..804e82b57db 100644 --- a/cpp/src/arrow/dataset/scanner_test.cc +++ b/cpp/src/arrow/dataset/scanner_test.cc @@ -1991,23 +1991,5 @@ TEST(ScanNode, MinimalGroupedAggEndToEnd) { AssertTablesEqual(*expected, *sorted.table(), /*same_chunk_layout=*/false); } -TEST(ScanNode, NodeOptions) { - TestPlan plan; - - auto basic = MakeBasicDataset(); - - auto options = std::make_shared(); - options->projection = Materialize({}); // set an empty projection - auto scan_node_options = std::make_shared(basic.dataset, options); - ASSERT_OK_AND_ASSIGN(auto* scan, - compute::MakeExecNode("scan", plan.get(), {}, *scan_node_options)); - scan->SetOptions(scan_node_options); - const auto& res_scan_options = static_cast(*scan->options()); - - ASSERT_EQ(scan_node_options->dataset->schema(), res_scan_options.dataset->schema()); - ASSERT_EQ(scan_node_options->scan_options->projection, - res_scan_options.scan_options->projection); -} - } // namespace dataset } // namespace arrow diff --git a/cpp/src/arrow/engine/substrait/plan_internal.cc b/cpp/src/arrow/engine/substrait/plan_internal.cc index fdce5eeba2b..51c253df2ac 100644 --- a/cpp/src/arrow/engine/substrait/plan_internal.cc +++ b/cpp/src/arrow/engine/substrait/plan_internal.cc @@ -19,7 +19,7 @@ #include "arrow/dataset/plan.h" #include "arrow/dataset/scanner.h" -#include "arrow/engine/substrait/registry.h" +#include "arrow/engine/substrait/relation_internal.h" #include "arrow/result.h" #include "arrow/util/hash_util.h" #include "arrow/util/hashing.h" @@ -138,63 +138,63 @@ Result GetExtensionSetFromPlan(const substrait::Plan& plan, registry); } -Status SetRelation(const std::unique_ptr& plan, - const std::unique_ptr& partial_plan, - const std::string& factory_name) { - if (factory_name == "scan" && partial_plan->has_read()) { - plan->set_allocated_read(partial_plan->release_read()); - } else if (factory_name == "filter" && partial_plan->has_filter()) { - plan->set_allocated_filter(partial_plan->release_filter()); - } else { - return Status::NotImplemented("Substrait converter ", factory_name, - " not supported."); - } - return Status::OK(); -} - -Result> ExtractSchemaToBind(const compute::Declaration& declr) { - std::shared_ptr bind_schema; - if (declr.factory_name == "scan") { - const auto& opts = checked_cast(*(declr.options)); - bind_schema = opts.dataset->schema(); - } else if (declr.factory_name == "filter") { - auto input_declr = util::get(declr.inputs[0]); - ARROW_ASSIGN_OR_RAISE(bind_schema, ExtractSchemaToBind(input_declr)); - } else if (declr.factory_name == "hashjoin") { - } else if (declr.factory_name == "sink") { - return bind_schema; - } else { - return Status::Invalid("Schema extraction failed, unsupported factory ", - declr.factory_name); - } - return bind_schema; -} - -Status SerializeRelations(const compute::Declaration& declaration, ExtensionSet* ext_set, - std::unique_ptr& rel) { - std::vector inputs = declaration.inputs; - for (auto& input : inputs) { - auto input_decl = util::get(input); - RETURN_NOT_OK(SerializeRelations(input_decl, ext_set, rel)); - } - const auto& factory_name = declaration.factory_name; - ARROW_ASSIGN_OR_RAISE(auto schema, ExtractSchemaToBind(declaration)); - SubstraitConversionRegistry* registry = default_substrait_conversion_registry(); - if (factory_name != "sink") { - ARROW_ASSIGN_OR_RAISE(auto factory, registry->GetConverter(factory_name)); - ARROW_ASSIGN_OR_RAISE(auto factory_rel, factory(schema, declaration, ext_set)); - RETURN_NOT_OK(SetRelation(rel, factory_rel, factory_name)); - } - return Status::OK(); -} - -Result> ToProto(compute::ExecPlan* plan, - const compute::Declaration& declr, - ExtensionSet* ext_set) { +// Status SetRelation(const std::unique_ptr& plan, +// const std::unique_ptr& partial_plan, +// const std::string& factory_name) { +// if (factory_name == "scan" && partial_plan->has_read()) { +// plan->set_allocated_read(partial_plan->release_read()); +// } else if (factory_name == "filter" && partial_plan->has_filter()) { +// plan->set_allocated_filter(partial_plan->release_filter()); +// } else { +// return Status::NotImplemented("Substrait converter ", factory_name, +// " not supported."); +// } +// return Status::OK(); +// } + +// Result> ExtractSchemaToBind(const compute::Declaration& declr) { +// std::shared_ptr bind_schema; +// if (declr.factory_name == "scan") { +// const auto& opts = checked_cast(*(declr.options)); +// bind_schema = opts.dataset->schema(); +// } else if (declr.factory_name == "filter") { +// auto input_declr = util::get(declr.inputs[0]); +// ARROW_ASSIGN_OR_RAISE(bind_schema, ExtractSchemaToBind(input_declr)); +// } else if (declr.factory_name == "hashjoin") { +// } else if (declr.factory_name == "sink") { +// return bind_schema; +// } else { +// return Status::Invalid("Schema extraction failed, unsupported factory ", +// declr.factory_name); +// } +// return bind_schema; +// } + +// Status SerializeRelations(const compute::Declaration& declaration, ExtensionSet* ext_set, +// std::unique_ptr& rel, const ConversionOptions& conversion_options) { +// std::vector inputs = declaration.inputs; +// for (auto& input : inputs) { +// auto input_decl = util::get(input); +// RETURN_NOT_OK(SerializeRelations(input_decl, ext_set, rel, conversion_options)); +// } +// const auto& factory_name = declaration.factory_name; +// ARROW_ASSIGN_OR_RAISE(auto schema, ExtractSchemaToBind(declaration)); +// SubstraitConversionRegistry* registry = default_substrait_conversion_registry(); +// if (factory_name != "sink") { +// ARROW_ASSIGN_OR_RAISE(auto factory, registry->GetConverter(factory_name)); +// ARROW_ASSIGN_OR_RAISE(auto factory_rel, factory(schema, declaration, ext_set, conversion_options)); +// RETURN_NOT_OK(SetRelation(rel, factory_rel, factory_name)); +// } +// return Status::OK(); +// } + +Result> PlanToProto(const compute::Declaration& declr, + ExtensionSet* ext_set, + const ConversionOptions& conversion_options) { auto subs_plan = internal::make_unique(); auto plan_rel = internal::make_unique(); auto rel = internal::make_unique(); - RETURN_NOT_OK(SerializeRelations(declr, ext_set, rel)); + RETURN_NOT_OK(CombineRelations(declr, ext_set, rel, conversion_options)); plan_rel->set_allocated_rel(rel.release()); subs_plan->mutable_relations()->AddAllocated(plan_rel.release()); RETURN_NOT_OK(AddExtensionSetToPlan(*ext_set, subs_plan.get())); diff --git a/cpp/src/arrow/engine/substrait/plan_internal.h b/cpp/src/arrow/engine/substrait/plan_internal.h index 81aaabad029..bf47b4cb17b 100644 --- a/cpp/src/arrow/engine/substrait/plan_internal.h +++ b/cpp/src/arrow/engine/substrait/plan_internal.h @@ -22,6 +22,7 @@ #include "arrow/compute/exec/exec_plan.h" #include "arrow/engine/substrait/extension_set.h" +#include "arrow/engine/substrait/options.h" #include "arrow/engine/substrait/visibility.h" #include "arrow/type_fwd.h" @@ -53,11 +54,8 @@ Result GetExtensionSetFromPlan( const substrait::Plan& plan, const ExtensionIdRegistry* registry = default_extension_id_registry()); -ARROW_ENGINE_EXPORT Result> ToProto( - compute::ExecPlan* plan, const compute::Declaration& declr, ExtensionSet* ext_set); - -// ARROW_ENGINE_EXPORT Result> ToProto( -// const compute::Declaration& declaration, ExtensionSet* ext_set); +ARROW_ENGINE_EXPORT Result> PlanToProto(const compute::Declaration& declr, ExtensionSet* ext_set, + const ConversionOptions& conversion_options = {}); } // namespace engine } // namespace arrow diff --git a/cpp/src/arrow/engine/substrait/registry.h b/cpp/src/arrow/engine/substrait/registry.h index 94f46e1dc7c..96bc70bba87 100644 --- a/cpp/src/arrow/engine/substrait/registry.h +++ b/cpp/src/arrow/engine/substrait/registry.h @@ -31,6 +31,7 @@ #include "arrow/compute/exec/exec_plan.h" #include "arrow/engine/substrait/extension_set.h" #include "arrow/engine/substrait/extension_types.h" +#include "arrow/engine/substrait/options.h" #include "arrow/engine/substrait/relation_internal.h" #include "arrow/engine/substrait/serde.h" #include "arrow/engine/substrait/visibility.h" @@ -46,7 +47,7 @@ class ARROW_EXPORT SubstraitConversionRegistry { public: virtual ~SubstraitConversionRegistry() = default; using SubstraitConverter = std::function>( - const std::shared_ptr&, const compute::Declaration&, ExtensionSet*)>; + const std::shared_ptr&, const compute::Declaration&, ExtensionSet*, const ConversionOptions&)>; virtual Result GetConverter(const std::string& factory_name) = 0; diff --git a/cpp/src/arrow/engine/substrait/relation_internal.cc b/cpp/src/arrow/engine/substrait/relation_internal.cc index 055ee9f746a..8379e400754 100644 --- a/cpp/src/arrow/engine/substrait/relation_internal.cc +++ b/cpp/src/arrow/engine/substrait/relation_internal.cc @@ -25,6 +25,7 @@ #include "arrow/dataset/plan.h" #include "arrow/dataset/scanner.h" #include "arrow/engine/substrait/expression_internal.h" +#include "arrow/engine/substrait/registry.h" #include "arrow/engine/substrait/type_internal.h" #include "arrow/filesystem/localfs.h" #include "arrow/filesystem/path_util.h" @@ -427,95 +428,78 @@ Result FromProto(const substrait::Rel& rel, const ExtensionSet& rel.DebugString()); } -namespace { - -Result> MakeReadRelation( - const compute::Declaration& declaration, ExtensionSet* ext_set) { - auto read_rel = make_unique(); - const auto& scan_node_options = - checked_cast(*declaration.options); - - auto dataset = - dynamic_cast(scan_node_options.dataset.get()); - if (dataset == nullptr) { - return Status::Invalid("Can only convert file system datasets to a Substrait plan."); - } - // set schema - ARROW_ASSIGN_OR_RAISE(auto named_struct, ToProto(*dataset->schema(), ext_set)); - read_rel->set_allocated_base_schema(named_struct.release()); - - // set local files - auto read_rel_lfs = make_unique(); - for (const auto& file : dataset->files()) { - auto read_rel_lfs_ffs = make_unique(); - read_rel_lfs_ffs->set_uri_path("file://" + file); +Result> ToProto(const compute::Declaration& declr, + ExtensionSet* ext_set, + const ConversionOptions& conversion_options) { + auto rel = make_unique(); + RETURN_NOT_OK(CombineRelations(declr, ext_set, rel, conversion_options)); + return std::move(rel); +} - // set file format - // arrow and feather are temporarily handled via the Parquet format until - // upgraded to the latest Substrait version. - auto format_type_name = dataset->format()->type_name(); - if (format_type_name == "parquet") { - auto parquet_fmt = - make_unique(); - read_rel_lfs_ffs->set_allocated_parquet(parquet_fmt.release()); - } else if (format_type_name == "arrow") { - auto arrow_fmt = - make_unique(); - read_rel_lfs_ffs->set_allocated_arrow(arrow_fmt.release()); - } else if (format_type_name == "orc") { - auto orc_fmt = - make_unique(); - read_rel_lfs_ffs->set_allocated_orc(orc_fmt.release()); - } else { - return Status::Invalid("Unsupported file type : ", format_type_name); - } - read_rel_lfs->mutable_items()->AddAllocated(read_rel_lfs_ffs.release()); +Status SetRelation(const std::unique_ptr& plan, + const std::unique_ptr& partial_plan, + const std::string& factory_name) { + if (factory_name == "scan" && partial_plan->has_read()) { + plan->set_allocated_read(partial_plan->release_read()); + } else if (factory_name == "filter" && partial_plan->has_filter()) { + plan->set_allocated_filter(partial_plan->release_filter()); + } else { + return Status::NotImplemented("Substrait converter ", factory_name, + " not supported."); } - *read_rel->mutable_local_files() = *read_rel_lfs.get(); - - return read_rel; + return Status::OK(); } -Result> MakeRelation( - const compute::Declaration& declaration, ExtensionSet* ext_set) { - const std::string& rel_name = declaration.factory_name; - auto rel = make_unique(); - if (rel_name == "scan") { - rel->set_allocated_read(MakeReadRelation(declaration, ext_set)->release()); - } else if (rel_name == "filter") { - return Status::NotImplemented("Filter operator not supported."); - } else if (rel_name == "project") { - return Status::NotImplemented("Project operator not supported."); - } else if (rel_name == "hashjoin") { - return Status::NotImplemented("Join operator not supported."); - } else if (rel_name == "aggregate") { - return Status::NotImplemented("Aggregate operator not supported."); +Result> ExtractSchemaToBind(const compute::Declaration& declr) { + std::shared_ptr bind_schema; + if (declr.factory_name == "scan") { + const auto& opts = checked_cast(*(declr.options)); + bind_schema = opts.dataset->schema(); + } else if (declr.factory_name == "filter") { + auto input_declr = util::get(declr.inputs[0]); + ARROW_ASSIGN_OR_RAISE(bind_schema, ExtractSchemaToBind(input_declr)); + } else if (declr.factory_name == "hashjoin") { + } else if (declr.factory_name == "sink") { + return bind_schema; } else { - return Status::Invalid("Unsupported exec node factory name :", rel_name); + return Status::Invalid("Schema extraction failed, unsupported factory ", + declr.factory_name); } - return rel; + return bind_schema; } -} // namespace - -Result> ToProto(const compute::Declaration& declaration, - ExtensionSet* ext_set) { - return MakeRelation(declaration, ext_set); +Status CombineRelations(const compute::Declaration& declaration, ExtensionSet* ext_set, + std::unique_ptr& rel, const ConversionOptions& conversion_options) { + std::vector inputs = declaration.inputs; + for (auto& input : inputs) { + auto input_decl = util::get(input); + RETURN_NOT_OK(CombineRelations(input_decl, ext_set, rel, conversion_options)); + } + const auto& factory_name = declaration.factory_name; + ARROW_ASSIGN_OR_RAISE(auto schema, ExtractSchemaToBind(declaration)); + SubstraitConversionRegistry* registry = default_substrait_conversion_registry(); + if (factory_name != "sink") { + ARROW_ASSIGN_OR_RAISE(auto factory, registry->GetConverter(factory_name)); + ARROW_ASSIGN_OR_RAISE(auto factory_rel, factory(schema, declaration, ext_set, conversion_options)); + RETURN_NOT_OK(SetRelation(rel, factory_rel, factory_name)); + } + return Status::OK(); } Result> GetRelationFromDeclaration( - const compute::Declaration& declaration, ExtensionSet* ext_set) { + const compute::Declaration& declaration, ExtensionSet* ext_set, const ConversionOptions& conversion_options) { auto declr_input = declaration.inputs[0]; // TODO: figure out a better way if (util::get_if(&declr_input)) { return Status::NotImplemented("Only support Plans written in Declaration format."); } - return ToProto(util::get(declr_input), ext_set); + auto declr = util::get(declr_input); + return ToProto(declr, ext_set, conversion_options); } Result> ScanRelationConverter( const std::shared_ptr& schema, const compute::Declaration& declaration, - ExtensionSet* ext_set) { + ExtensionSet* ext_set, const ConversionOptions& conversion_options) { auto rel = make_unique(); auto read_rel = make_unique(); const auto& scan_node_options = @@ -526,7 +510,7 @@ Result> ScanRelationConverter( return Status::Invalid("Can only convert file system datasets to a Substrait plan."); } // set schema - ARROW_ASSIGN_OR_RAISE(auto named_struct, ToProto(*dataset->schema(), ext_set)); + ARROW_ASSIGN_OR_RAISE(auto named_struct, ToProto(*dataset->schema(), ext_set, conversion_options)); read_rel->set_allocated_base_schema(named_struct.release()); // set local files @@ -564,7 +548,7 @@ Result> ScanRelationConverter( Result> FilterRelationConverter( const std::shared_ptr& schema, const compute::Declaration& declaration, - ExtensionSet* ext_set) { + ExtensionSet* ext_set, const ConversionOptions& conversion_options) { auto rel = make_unique(); auto filter_rel = make_unique(); const auto& filter_node_options = @@ -580,11 +564,11 @@ Result> FilterRelationConverter( return Status::Invalid("Filter node doesn't have an input."); } - auto input_rel = GetRelationFromDeclaration(declaration, ext_set); + auto input_rel = GetRelationFromDeclaration(declaration, ext_set, conversion_options); filter_rel->set_allocated_input(input_rel->release()); - ARROW_ASSIGN_OR_RAISE(auto subs_expr, ToProto(bound_expression, ext_set)); + ARROW_ASSIGN_OR_RAISE(auto subs_expr, ToProto(bound_expression, ext_set, conversion_options)); *filter_rel->mutable_condition() = *subs_expr.get(); rel->set_allocated_filter(filter_rel.release()); return rel; diff --git a/cpp/src/arrow/engine/substrait/relation_internal.h b/cpp/src/arrow/engine/substrait/relation_internal.h index d1fec63abcc..588653997cd 100644 --- a/cpp/src/arrow/engine/substrait/relation_internal.h +++ b/cpp/src/arrow/engine/substrait/relation_internal.h @@ -44,15 +44,18 @@ ARROW_ENGINE_EXPORT Result FromProto(const substrait::Rel&, const ExtensionSet&, const ConversionOptions&); -ARROW_ENGINE_EXPORT -Result> ToProto(const compute::Declaration&, - ExtensionSet*); +ARROW_ENGINE_EXPORT Status CombineRelations(const compute::Declaration&, ExtensionSet*, + std::unique_ptr&, const ConversionOptions&); + +ARROW_ENGINE_EXPORT Result> ToProto(const compute::Declaration&, + ExtensionSet*, + const ConversionOptions&); ARROW_ENGINE_EXPORT Result> ScanRelationConverter( - const std::shared_ptr&, const compute::Declaration&, ExtensionSet* ext_set); + const std::shared_ptr&, const compute::Declaration&, ExtensionSet* ext_set, const ConversionOptions& conversion_options); ARROW_ENGINE_EXPORT Result> FilterRelationConverter( - const std::shared_ptr&, const compute::Declaration&, ExtensionSet* ext_set); + const std::shared_ptr&, const compute::Declaration&, ExtensionSet* ext_set, const ConversionOptions& conversion_options); } // namespace engine } // namespace arrow diff --git a/cpp/src/arrow/engine/substrait/serde.cc b/cpp/src/arrow/engine/substrait/serde.cc index 58d307e1201..74d706a4d3c 100644 --- a/cpp/src/arrow/engine/substrait/serde.cc +++ b/cpp/src/arrow/engine/substrait/serde.cc @@ -52,23 +52,24 @@ Result ParseFromBuffer(const Buffer& buf) { return message; } -Result> SerializePlan(compute::ExecPlan* plan, - const compute::Declaration& declr, - ExtensionSet* ext_set) { - ARROW_ASSIGN_OR_RAISE(auto subs_plan, ToProto(plan, declr, ext_set)); +Result> SerializePlan(const compute::Declaration& declr, + ExtensionSet* ext_set, + const ConversionOptions& conversion_options) { + ARROW_ASSIGN_OR_RAISE(auto subs_plan, PlanToProto(declr, ext_set, conversion_options)); std::string serialized = subs_plan->SerializeAsString(); return Buffer::FromString(std::move(serialized)); } Result> SerializeRelation(const compute::Declaration& declaration, - ExtensionSet* ext_set) { - ARROW_ASSIGN_OR_RAISE(auto relation, ToProto(declaration, ext_set)); + ExtensionSet* ext_set, + const ConversionOptions& conversion_options) { + ARROW_ASSIGN_OR_RAISE(auto relation, ToProto(declaration, ext_set, conversion_options)); std::string serialized = relation->SerializeAsString(); return Buffer::FromString(std::move(serialized)); } -Result DeserializeRelation(const Buffer& buf, - const ExtensionSet& ext_set) { +Result DeserializeRelation(const Buffer& buf, const ExtensionSet& ext_set, + const ConversionOptions& conversion_options) { ARROW_ASSIGN_OR_RAISE(auto rel, ParseFromBuffer(buf)); ARROW_ASSIGN_OR_RAISE(auto decl_info, FromProto(rel, ext_set, conversion_options)); return std::move(decl_info.declaration); diff --git a/cpp/src/arrow/engine/substrait/serde.h b/cpp/src/arrow/engine/substrait/serde.h index d573d55a7b1..01e7e26ba46 100644 --- a/cpp/src/arrow/engine/substrait/serde.h +++ b/cpp/src/arrow/engine/substrait/serde.h @@ -37,9 +37,9 @@ namespace arrow { namespace engine { ARROW_ENGINE_EXPORT -Result> SerializePlan(compute::ExecPlan* plan, - const compute::Declaration& declr, - ExtensionSet* ext_set); +Result> SerializePlan(const compute::Declaration& declr, + ExtensionSet* ext_set, + const ConversionOptions& conversion_options = {}); /// Factory function type for generating the node that consumes the batches produced by /// each toplevel Substrait relation when deserializing a Substrait Plan. @@ -129,10 +129,6 @@ ARROW_ENGINE_EXPORT Result> DeserializePlan( const ExtensionIdRegistry* registry = NULLPTR, ExtensionSet* ext_set_out = NULLPTR, const ConversionOptions& conversion_options = {}); -ARROW_ENGINE_EXPORT Result DeserializePlan( - const Buffer& buf, const ConsumerFactory& consumer_factory, - ExtensionSet* ext_set_out = NULLPTR); - /// \brief Deserializes a Substrait Type message to the corresponding Arrow type /// /// \param[in] buf a buffer containing the protobuf serialization of a Substrait Type @@ -219,7 +215,8 @@ Result> SerializeExpression( /// \return a buffer containing the protobuf serialization of the corresponding Substrait /// Relation message ARROW_ENGINE_EXPORT Result> SerializeRelation( - const compute::Declaration& declaration, ExtensionSet* ext_set); + const compute::Declaration& declaration, ExtensionSet* ext_set, + const ConversionOptions& conversion_options = {}); /// \brief Deserializes a Substrait Rel (relation) message to an ExecNode declaration /// diff --git a/cpp/src/arrow/engine/substrait/serde_test.cc b/cpp/src/arrow/engine/substrait/serde_test.cc index f6bb351296f..701ff1ecf3a 100644 --- a/cpp/src/arrow/engine/substrait/serde_test.cc +++ b/cpp/src/arrow/engine/substrait/serde_test.cc @@ -15,8 +15,6 @@ // specific language governing permissions and limitations // under the License. -#include "arrow/engine/substrait/serde.h" - #include #include #include @@ -24,11 +22,16 @@ #include "arrow/compute/exec/expression_internal.h" #include "arrow/dataset/file_base.h" +#include "arrow/dataset/file_ipc.h" #include "arrow/dataset/file_parquet.h" + #include "arrow/dataset/plan.h" #include "arrow/dataset/scanner.h" #include "arrow/engine/substrait/extension_types.h" +#include "arrow/engine/substrait/serde.h" + #include "arrow/engine/substrait/util.h" + #include "arrow/filesystem/localfs.h" #include "arrow/filesystem/mockfs.h" #include "arrow/filesystem/test_util.h" @@ -1966,7 +1969,7 @@ TEST(Substrait, BasicPlanRoundTripping) { ASSERT_OK_AND_ASSIGN(auto plan, compute::ExecPlan::Make(&exec_context)); ASSERT_OK_AND_ASSIGN(auto serialized_plan, - SerializePlan(plan.get(), declarations, &ext_set)); + SerializePlan(declarations, &ext_set)); for (auto sp_ext_id_reg : {std::shared_ptr(), substrait::MakeExtensionIdRegistry()}) { @@ -2090,7 +2093,7 @@ TEST(Substrait, BasicPlanRoundTrippingEndToEnd) { ASSERT_OK_AND_ASSIGN(auto plan, compute::ExecPlan::Make(&exec_context)); ASSERT_OK_AND_ASSIGN(auto serialized_plan, - SerializePlan(plan.get(), declarations, &ext_set)); + SerializePlan(declarations, &ext_set)); ASSERT_OK_AND_ASSIGN(auto expected_tb, GetTableFromPlan(plan, declarations, sink_gen, exec_context, dummy_schema)); From 5b5cc8307a2ec27f76e8d1dd0b6a09f238bc33a2 Mon Sep 17 00:00:00 2001 From: Vibhatha Abeykoon Date: Tue, 2 Aug 2022 20:23:33 +0530 Subject: [PATCH 09/30] fix(format): reformat code --- .../arrow/engine/substrait/plan_internal.cc | 60 ++----------------- .../arrow/engine/substrait/plan_internal.h | 3 +- cpp/src/arrow/engine/substrait/registry.h | 3 +- .../engine/substrait/relation_internal.cc | 29 +++++---- .../engine/substrait/relation_internal.h | 17 +++--- cpp/src/arrow/engine/substrait/serde.cc | 17 +++--- cpp/src/arrow/engine/substrait/serde.h | 6 +- cpp/src/arrow/engine/substrait/serde_test.cc | 6 +- 8 files changed, 50 insertions(+), 91 deletions(-) diff --git a/cpp/src/arrow/engine/substrait/plan_internal.cc b/cpp/src/arrow/engine/substrait/plan_internal.cc index 51c253df2ac..6caf0b69193 100644 --- a/cpp/src/arrow/engine/substrait/plan_internal.cc +++ b/cpp/src/arrow/engine/substrait/plan_internal.cc @@ -21,7 +21,6 @@ #include "arrow/dataset/scanner.h" #include "arrow/engine/substrait/relation_internal.h" #include "arrow/result.h" -#include "arrow/util/hash_util.h" #include "arrow/util/hashing.h" #include "arrow/util/logging.h" #include "arrow/util/make_unique.h" @@ -36,7 +35,6 @@ using internal::checked_cast; namespace engine { namespace internal { -using ::arrow::internal::hash_combine; using ::arrow::internal::make_unique; } // namespace internal @@ -138,63 +136,13 @@ Result GetExtensionSetFromPlan(const substrait::Plan& plan, registry); } -// Status SetRelation(const std::unique_ptr& plan, -// const std::unique_ptr& partial_plan, -// const std::string& factory_name) { -// if (factory_name == "scan" && partial_plan->has_read()) { -// plan->set_allocated_read(partial_plan->release_read()); -// } else if (factory_name == "filter" && partial_plan->has_filter()) { -// plan->set_allocated_filter(partial_plan->release_filter()); -// } else { -// return Status::NotImplemented("Substrait converter ", factory_name, -// " not supported."); -// } -// return Status::OK(); -// } - -// Result> ExtractSchemaToBind(const compute::Declaration& declr) { -// std::shared_ptr bind_schema; -// if (declr.factory_name == "scan") { -// const auto& opts = checked_cast(*(declr.options)); -// bind_schema = opts.dataset->schema(); -// } else if (declr.factory_name == "filter") { -// auto input_declr = util::get(declr.inputs[0]); -// ARROW_ASSIGN_OR_RAISE(bind_schema, ExtractSchemaToBind(input_declr)); -// } else if (declr.factory_name == "hashjoin") { -// } else if (declr.factory_name == "sink") { -// return bind_schema; -// } else { -// return Status::Invalid("Schema extraction failed, unsupported factory ", -// declr.factory_name); -// } -// return bind_schema; -// } - -// Status SerializeRelations(const compute::Declaration& declaration, ExtensionSet* ext_set, -// std::unique_ptr& rel, const ConversionOptions& conversion_options) { -// std::vector inputs = declaration.inputs; -// for (auto& input : inputs) { -// auto input_decl = util::get(input); -// RETURN_NOT_OK(SerializeRelations(input_decl, ext_set, rel, conversion_options)); -// } -// const auto& factory_name = declaration.factory_name; -// ARROW_ASSIGN_OR_RAISE(auto schema, ExtractSchemaToBind(declaration)); -// SubstraitConversionRegistry* registry = default_substrait_conversion_registry(); -// if (factory_name != "sink") { -// ARROW_ASSIGN_OR_RAISE(auto factory, registry->GetConverter(factory_name)); -// ARROW_ASSIGN_OR_RAISE(auto factory_rel, factory(schema, declaration, ext_set, conversion_options)); -// RETURN_NOT_OK(SetRelation(rel, factory_rel, factory_name)); -// } -// return Status::OK(); -// } - -Result> PlanToProto(const compute::Declaration& declr, - ExtensionSet* ext_set, - const ConversionOptions& conversion_options) { +Result> PlanToProto( + const compute::Declaration& declr, ExtensionSet* ext_set, + const ConversionOptions& conversion_options) { auto subs_plan = internal::make_unique(); auto plan_rel = internal::make_unique(); auto rel = internal::make_unique(); - RETURN_NOT_OK(CombineRelations(declr, ext_set, rel, conversion_options)); + RETURN_NOT_OK(SerializeAndCombineRelations(declr, ext_set, rel, conversion_options)); plan_rel->set_allocated_rel(rel.release()); subs_plan->mutable_relations()->AddAllocated(plan_rel.release()); RETURN_NOT_OK(AddExtensionSetToPlan(*ext_set, subs_plan.get())); diff --git a/cpp/src/arrow/engine/substrait/plan_internal.h b/cpp/src/arrow/engine/substrait/plan_internal.h index bf47b4cb17b..04ce7298d8d 100644 --- a/cpp/src/arrow/engine/substrait/plan_internal.h +++ b/cpp/src/arrow/engine/substrait/plan_internal.h @@ -54,7 +54,8 @@ Result GetExtensionSetFromPlan( const substrait::Plan& plan, const ExtensionIdRegistry* registry = default_extension_id_registry()); -ARROW_ENGINE_EXPORT Result> PlanToProto(const compute::Declaration& declr, ExtensionSet* ext_set, +ARROW_ENGINE_EXPORT Result> PlanToProto( + const compute::Declaration& declr, ExtensionSet* ext_set, const ConversionOptions& conversion_options = {}); } // namespace engine diff --git a/cpp/src/arrow/engine/substrait/registry.h b/cpp/src/arrow/engine/substrait/registry.h index 96bc70bba87..10ff1b45d84 100644 --- a/cpp/src/arrow/engine/substrait/registry.h +++ b/cpp/src/arrow/engine/substrait/registry.h @@ -47,7 +47,8 @@ class ARROW_EXPORT SubstraitConversionRegistry { public: virtual ~SubstraitConversionRegistry() = default; using SubstraitConverter = std::function>( - const std::shared_ptr&, const compute::Declaration&, ExtensionSet*, const ConversionOptions&)>; + const std::shared_ptr&, const compute::Declaration&, ExtensionSet*, + const ConversionOptions&)>; virtual Result GetConverter(const std::string& factory_name) = 0; diff --git a/cpp/src/arrow/engine/substrait/relation_internal.cc b/cpp/src/arrow/engine/substrait/relation_internal.cc index 8379e400754..b6a28298da2 100644 --- a/cpp/src/arrow/engine/substrait/relation_internal.cc +++ b/cpp/src/arrow/engine/substrait/relation_internal.cc @@ -428,11 +428,11 @@ Result FromProto(const substrait::Rel& rel, const ExtensionSet& rel.DebugString()); } -Result> ToProto(const compute::Declaration& declr, - ExtensionSet* ext_set, - const ConversionOptions& conversion_options) { +Result> ToProto( + const compute::Declaration& declr, ExtensionSet* ext_set, + const ConversionOptions& conversion_options) { auto rel = make_unique(); - RETURN_NOT_OK(CombineRelations(declr, ext_set, rel, conversion_options)); + RETURN_NOT_OK(SerializeAndCombineRelations(declr, ext_set, rel, conversion_options)); return std::move(rel); } @@ -468,26 +468,31 @@ Result> ExtractSchemaToBind(const compute::Declaration& return bind_schema; } -Status CombineRelations(const compute::Declaration& declaration, ExtensionSet* ext_set, - std::unique_ptr& rel, const ConversionOptions& conversion_options) { +Status SerializeAndCombineRelations(const compute::Declaration& declaration, + ExtensionSet* ext_set, + std::unique_ptr& rel, + const ConversionOptions& conversion_options) { std::vector inputs = declaration.inputs; for (auto& input : inputs) { auto input_decl = util::get(input); - RETURN_NOT_OK(CombineRelations(input_decl, ext_set, rel, conversion_options)); + RETURN_NOT_OK( + SerializeAndCombineRelations(input_decl, ext_set, rel, conversion_options)); } const auto& factory_name = declaration.factory_name; ARROW_ASSIGN_OR_RAISE(auto schema, ExtractSchemaToBind(declaration)); SubstraitConversionRegistry* registry = default_substrait_conversion_registry(); if (factory_name != "sink") { ARROW_ASSIGN_OR_RAISE(auto factory, registry->GetConverter(factory_name)); - ARROW_ASSIGN_OR_RAISE(auto factory_rel, factory(schema, declaration, ext_set, conversion_options)); + ARROW_ASSIGN_OR_RAISE(auto factory_rel, + factory(schema, declaration, ext_set, conversion_options)); RETURN_NOT_OK(SetRelation(rel, factory_rel, factory_name)); } return Status::OK(); } Result> GetRelationFromDeclaration( - const compute::Declaration& declaration, ExtensionSet* ext_set, const ConversionOptions& conversion_options) { + const compute::Declaration& declaration, ExtensionSet* ext_set, + const ConversionOptions& conversion_options) { auto declr_input = declaration.inputs[0]; // TODO: figure out a better way if (util::get_if(&declr_input)) { @@ -510,7 +515,8 @@ Result> ScanRelationConverter( return Status::Invalid("Can only convert file system datasets to a Substrait plan."); } // set schema - ARROW_ASSIGN_OR_RAISE(auto named_struct, ToProto(*dataset->schema(), ext_set, conversion_options)); + ARROW_ASSIGN_OR_RAISE(auto named_struct, + ToProto(*dataset->schema(), ext_set, conversion_options)); read_rel->set_allocated_base_schema(named_struct.release()); // set local files @@ -568,7 +574,8 @@ Result> FilterRelationConverter( filter_rel->set_allocated_input(input_rel->release()); - ARROW_ASSIGN_OR_RAISE(auto subs_expr, ToProto(bound_expression, ext_set, conversion_options)); + ARROW_ASSIGN_OR_RAISE(auto subs_expr, + ToProto(bound_expression, ext_set, conversion_options)); *filter_rel->mutable_condition() = *subs_expr.get(); rel->set_allocated_filter(filter_rel.release()); return rel; diff --git a/cpp/src/arrow/engine/substrait/relation_internal.h b/cpp/src/arrow/engine/substrait/relation_internal.h index 588653997cd..447831dfa47 100644 --- a/cpp/src/arrow/engine/substrait/relation_internal.h +++ b/cpp/src/arrow/engine/substrait/relation_internal.h @@ -44,18 +44,21 @@ ARROW_ENGINE_EXPORT Result FromProto(const substrait::Rel&, const ExtensionSet&, const ConversionOptions&); -ARROW_ENGINE_EXPORT Status CombineRelations(const compute::Declaration&, ExtensionSet*, - std::unique_ptr&, const ConversionOptions&); +ARROW_ENGINE_EXPORT Status SerializeAndCombineRelations(const compute::Declaration&, + ExtensionSet*, + std::unique_ptr&, + const ConversionOptions&); -ARROW_ENGINE_EXPORT Result> ToProto(const compute::Declaration&, - ExtensionSet*, - const ConversionOptions&); +ARROW_ENGINE_EXPORT Result> ToProto( + const compute::Declaration&, ExtensionSet*, const ConversionOptions&); ARROW_ENGINE_EXPORT Result> ScanRelationConverter( - const std::shared_ptr&, const compute::Declaration&, ExtensionSet* ext_set, const ConversionOptions& conversion_options); + const std::shared_ptr&, const compute::Declaration&, ExtensionSet* ext_set, + const ConversionOptions& conversion_options); ARROW_ENGINE_EXPORT Result> FilterRelationConverter( - const std::shared_ptr&, const compute::Declaration&, ExtensionSet* ext_set, const ConversionOptions& conversion_options); + const std::shared_ptr&, const compute::Declaration&, ExtensionSet* ext_set, + const ConversionOptions& conversion_options); } // namespace engine } // namespace arrow diff --git a/cpp/src/arrow/engine/substrait/serde.cc b/cpp/src/arrow/engine/substrait/serde.cc index 74d706a4d3c..a311a4763ba 100644 --- a/cpp/src/arrow/engine/substrait/serde.cc +++ b/cpp/src/arrow/engine/substrait/serde.cc @@ -52,24 +52,25 @@ Result ParseFromBuffer(const Buffer& buf) { return message; } -Result> SerializePlan(const compute::Declaration& declr, - ExtensionSet* ext_set, - const ConversionOptions& conversion_options) { +Result> SerializePlan( + const compute::Declaration& declr, ExtensionSet* ext_set, + const ConversionOptions& conversion_options) { ARROW_ASSIGN_OR_RAISE(auto subs_plan, PlanToProto(declr, ext_set, conversion_options)); std::string serialized = subs_plan->SerializeAsString(); return Buffer::FromString(std::move(serialized)); } -Result> SerializeRelation(const compute::Declaration& declaration, - ExtensionSet* ext_set, - const ConversionOptions& conversion_options) { +Result> SerializeRelation( + const compute::Declaration& declaration, ExtensionSet* ext_set, + const ConversionOptions& conversion_options) { ARROW_ASSIGN_OR_RAISE(auto relation, ToProto(declaration, ext_set, conversion_options)); std::string serialized = relation->SerializeAsString(); return Buffer::FromString(std::move(serialized)); } -Result DeserializeRelation(const Buffer& buf, const ExtensionSet& ext_set, - const ConversionOptions& conversion_options) { +Result DeserializeRelation( + const Buffer& buf, const ExtensionSet& ext_set, + const ConversionOptions& conversion_options) { ARROW_ASSIGN_OR_RAISE(auto rel, ParseFromBuffer(buf)); ARROW_ASSIGN_OR_RAISE(auto decl_info, FromProto(rel, ext_set, conversion_options)); return std::move(decl_info.declaration); diff --git a/cpp/src/arrow/engine/substrait/serde.h b/cpp/src/arrow/engine/substrait/serde.h index 01e7e26ba46..7209cea2ba2 100644 --- a/cpp/src/arrow/engine/substrait/serde.h +++ b/cpp/src/arrow/engine/substrait/serde.h @@ -37,9 +37,9 @@ namespace arrow { namespace engine { ARROW_ENGINE_EXPORT -Result> SerializePlan(const compute::Declaration& declr, - ExtensionSet* ext_set, - const ConversionOptions& conversion_options = {}); +Result> SerializePlan( + const compute::Declaration& declr, ExtensionSet* ext_set, + const ConversionOptions& conversion_options = {}); /// Factory function type for generating the node that consumes the batches produced by /// each toplevel Substrait relation when deserializing a Substrait Plan. diff --git a/cpp/src/arrow/engine/substrait/serde_test.cc b/cpp/src/arrow/engine/substrait/serde_test.cc index 701ff1ecf3a..f35a56b4e72 100644 --- a/cpp/src/arrow/engine/substrait/serde_test.cc +++ b/cpp/src/arrow/engine/substrait/serde_test.cc @@ -1968,8 +1968,7 @@ TEST(Substrait, BasicPlanRoundTripping) { ASSERT_OK_AND_ASSIGN(auto plan, compute::ExecPlan::Make(&exec_context)); - ASSERT_OK_AND_ASSIGN(auto serialized_plan, - SerializePlan(declarations, &ext_set)); + ASSERT_OK_AND_ASSIGN(auto serialized_plan, SerializePlan(declarations, &ext_set)); for (auto sp_ext_id_reg : {std::shared_ptr(), substrait::MakeExtensionIdRegistry()}) { @@ -2092,8 +2091,7 @@ TEST(Substrait, BasicPlanRoundTrippingEndToEnd) { ASSERT_OK_AND_ASSIGN(auto plan, compute::ExecPlan::Make(&exec_context)); - ASSERT_OK_AND_ASSIGN(auto serialized_plan, - SerializePlan(declarations, &ext_set)); + ASSERT_OK_AND_ASSIGN(auto serialized_plan, SerializePlan(declarations, &ext_set)); ASSERT_OK_AND_ASSIGN(auto expected_tb, GetTableFromPlan(plan, declarations, sink_gen, exec_context, dummy_schema)); From a4510439e10257e135fcd5a52fb67bfc181b427c Mon Sep 17 00:00:00 2001 From: Vibhatha Abeykoon Date: Wed, 3 Aug 2022 17:51:40 +0530 Subject: [PATCH 10/30] fix(export): updating export declaration for registry --- cpp/src/arrow/engine/substrait/registry.h | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/cpp/src/arrow/engine/substrait/registry.h b/cpp/src/arrow/engine/substrait/registry.h index 10ff1b45d84..fb887630c12 100644 --- a/cpp/src/arrow/engine/substrait/registry.h +++ b/cpp/src/arrow/engine/substrait/registry.h @@ -43,7 +43,7 @@ namespace arrow { namespace engine { -class ARROW_EXPORT SubstraitConversionRegistry { +class ARROW_ENGINE_EXPORT SubstraitConversionRegistry { public: virtual ~SubstraitConversionRegistry() = default; using SubstraitConverter = std::function>( @@ -56,7 +56,7 @@ class ARROW_EXPORT SubstraitConversionRegistry { SubstraitConverter converter) = 0; }; -ARROW_EXPORT SubstraitConversionRegistry* default_substrait_conversion_registry(); +ARROW_ENGINE_EXPORT SubstraitConversionRegistry* default_substrait_conversion_registry(); } // namespace engine } // namespace arrow From 05c2d7a152ef1543f0c6483452fe8adad3db654a Mon Sep 17 00:00:00 2001 From: Vibhatha Abeykoon Date: Wed, 3 Aug 2022 18:48:09 +0530 Subject: [PATCH 11/30] fix(format): formatting the code and fixing unresolved minor issues --- .../arrow/engine/substrait/plan_internal.h | 9 ++++++- cpp/src/arrow/engine/substrait/registry.h | 24 +++++++++++++++++++ .../engine/substrait/relation_internal.cc | 15 ++++++------ 3 files changed, 39 insertions(+), 9 deletions(-) diff --git a/cpp/src/arrow/engine/substrait/plan_internal.h b/cpp/src/arrow/engine/substrait/plan_internal.h index 04ce7298d8d..e8a07ad666f 100644 --- a/cpp/src/arrow/engine/substrait/plan_internal.h +++ b/cpp/src/arrow/engine/substrait/plan_internal.h @@ -20,7 +20,6 @@ #pragma once #include "arrow/compute/exec/exec_plan.h" - #include "arrow/engine/substrait/extension_set.h" #include "arrow/engine/substrait/options.h" #include "arrow/engine/substrait/visibility.h" @@ -54,6 +53,14 @@ Result GetExtensionSetFromPlan( const substrait::Plan& plan, const ExtensionIdRegistry* registry = default_extension_id_registry()); +/// \brief Serializes Declaration and produces a substrait::Plan. +/// +/// Note that, this is a part of roundtripping test API and not +/// designed to use in production +/// \param[in] declr the sequence of declarations +/// \param[in, out] ext_set the extension set to be updated +/// \param[in] conversion_options the conversion options useful for the serialization +/// \return serialized Acero plan ARROW_ENGINE_EXPORT Result> PlanToProto( const compute::Declaration& declr, ExtensionSet* ext_set, const ConversionOptions& conversion_options = {}); diff --git a/cpp/src/arrow/engine/substrait/registry.h b/cpp/src/arrow/engine/substrait/registry.h index fb887630c12..82c3bc2a7f8 100644 --- a/cpp/src/arrow/engine/substrait/registry.h +++ b/cpp/src/arrow/engine/substrait/registry.h @@ -43,19 +43,43 @@ namespace arrow { namespace engine { +/// \brief Acero-Substrait integration contains converters which enables +/// converting Acero ExecPlan related entities to the corresponding Substrait +/// entities. +/// +/// Note that the current registry definition only holds converters to convert +/// an Acero plan to Substrait plan. class ARROW_ENGINE_EXPORT SubstraitConversionRegistry { public: virtual ~SubstraitConversionRegistry() = default; + + /// \brief Alias for Acero-to-Substrait converter using SubstraitConverter = std::function>( const std::shared_ptr&, const compute::Declaration&, ExtensionSet*, const ConversionOptions&)>; + /// \brief Retrieve a SubstraitConverter from the registry by factory name + /// + /// \param[in] factory_name name of the converter (aligned with Acero ExecNode kind + /// name) \return the matching SubstraitConverter virtual Result GetConverter(const std::string& factory_name) = 0; + /// \brief Register a converter by factory + /// + /// \param[in] factory_name name of the converter + /// \param[in] converter the std::function encapsulating the converter logic + /// \return Status of the registration virtual Status RegisterConverter(std::string factory_name, SubstraitConverter converter) = 0; }; +/// \brief Retrive the default Acero-to-Substrait conversion registry +/// The default registry contains the converters corresponding to mapping +/// the core ExecNodes in Acero. +/// +/// The default registry can be represented as a parent registry if a non-Acero +/// converters are required to be used with it. It must be separately implemented +/// by using the default input as the parent. ARROW_ENGINE_EXPORT SubstraitConversionRegistry* default_substrait_conversion_registry(); } // namespace engine diff --git a/cpp/src/arrow/engine/substrait/relation_internal.cc b/cpp/src/arrow/engine/substrait/relation_internal.cc index b6a28298da2..9785390c798 100644 --- a/cpp/src/arrow/engine/substrait/relation_internal.cc +++ b/cpp/src/arrow/engine/substrait/relation_internal.cc @@ -458,8 +458,8 @@ Result> ExtractSchemaToBind(const compute::Declaration& } else if (declr.factory_name == "filter") { auto input_declr = util::get(declr.inputs[0]); ARROW_ASSIGN_OR_RAISE(bind_schema, ExtractSchemaToBind(input_declr)); - } else if (declr.factory_name == "hashjoin") { } else if (declr.factory_name == "sink") { + // Note that the sink has no output_schema return bind_schema; } else { return Status::Invalid("Schema extraction failed, unsupported factory ", @@ -494,12 +494,13 @@ Result> GetRelationFromDeclaration( const compute::Declaration& declaration, ExtensionSet* ext_set, const ConversionOptions& conversion_options) { auto declr_input = declaration.inputs[0]; - // TODO: figure out a better way + // Note that the input is expected in declaration. + // ExecNode inputs are not accepted if (util::get_if(&declr_input)) { return Status::NotImplemented("Only support Plans written in Declaration format."); } - auto declr = util::get(declr_input); - return ToProto(declr, ext_set, conversion_options); + return ToProto(util::get(declr_input), ext_set, + conversion_options); } Result> ScanRelationConverter( @@ -524,7 +525,6 @@ Result> ScanRelationConverter( for (const auto& file : dataset->files()) { auto read_rel_lfs_ffs = make_unique(); read_rel_lfs_ffs->set_uri_path("file://" + file); - // set file format // arrow and feather are temporarily handled via the Parquet format until // upgraded to the latest Substrait version. @@ -546,8 +546,7 @@ Result> ScanRelationConverter( } read_rel_lfs->mutable_items()->AddAllocated(read_rel_lfs_ffs.release()); } - // TODO(Before PR Merge) : evaluate better hand-off of pointers - *read_rel->mutable_local_files() = *read_rel_lfs.get(); + read_rel->set_allocated_local_files(read_rel_lfs.release()); rel->set_allocated_read(read_rel.release()); return std::move(rel); } @@ -576,7 +575,7 @@ Result> FilterRelationConverter( ARROW_ASSIGN_OR_RAISE(auto subs_expr, ToProto(bound_expression, ext_set, conversion_options)); - *filter_rel->mutable_condition() = *subs_expr.get(); + filter_rel->set_allocated_condition(subs_expr.release()); rel->set_allocated_filter(filter_rel.release()); return rel; } From 3ce1c185cf33fc1c6375aeda0cea83a06a884ae8 Mon Sep 17 00:00:00 2001 From: Vibhatha Abeykoon Date: Wed, 3 Aug 2022 19:14:29 +0530 Subject: [PATCH 12/30] fix(format): included docstrings and clode clean up --- .../engine/substrait/relation_internal.h | 29 +++++++++++++-- cpp/src/arrow/engine/substrait/serde_test.cc | 37 ------------------- 2 files changed, 25 insertions(+), 41 deletions(-) diff --git a/cpp/src/arrow/engine/substrait/relation_internal.h b/cpp/src/arrow/engine/substrait/relation_internal.h index 447831dfa47..1cb2f3673ef 100644 --- a/cpp/src/arrow/engine/substrait/relation_internal.h +++ b/cpp/src/arrow/engine/substrait/relation_internal.h @@ -40,25 +40,46 @@ struct DeclarationInfo { int num_columns; }; +/// \brief A function to extract Acero Declaration from a Substrait Rel object ARROW_ENGINE_EXPORT Result FromProto(const substrait::Rel&, const ExtensionSet&, const ConversionOptions&); +/// \brief Serializes a Declaration, produce a Substrait Rel and update the global +/// Substrait plan. A Substrait Rel is passed as a the plan and it is updated with +/// corresponding Declaration passed for serialization. +/// +/// Note that this is a rather a helper method useful to fuse a partially serialized +/// plan with another plan. The reason for having a partially serialized plan is to +/// avoid unnecessary complication and enable partial plan serialization without +/// affecting a global plan. Since kept as unique_ptr resources are relased efficiently +/// upon releasing for the global plan. ARROW_ENGINE_EXPORT Status SerializeAndCombineRelations(const compute::Declaration&, ExtensionSet*, std::unique_ptr&, const ConversionOptions&); +/// \brief Serialize a Declaration and produces a Substrait Rel. +/// +/// Note that in order to provide a generic interface for ToProto for +/// declaration it is not specialized for each relation within the Substrait Rel. +/// Rather a serialized relation is set as a member for the Substrait Rel +/// (partial Relation) which is later on extracted to update a Substrait Rel +/// which would be included in the fully serialized Acero Exec Plan. +/// The ExecNode or ExecPlan is not used in this context as Declaration is preferred +/// in the Substrait space rather than internal components of Acero execution engine. ARROW_ENGINE_EXPORT Result> ToProto( const compute::Declaration&, ExtensionSet*, const ConversionOptions&); +/// \brief Acero to Substrait converter for Acero scan relation. ARROW_ENGINE_EXPORT Result> ScanRelationConverter( - const std::shared_ptr&, const compute::Declaration&, ExtensionSet* ext_set, - const ConversionOptions& conversion_options); + const std::shared_ptr&, const compute::Declaration&, ExtensionSet*, + const ConversionOptions&); +/// \brief Acero to Substrait converter for Acero filter relation. ARROW_ENGINE_EXPORT Result> FilterRelationConverter( - const std::shared_ptr&, const compute::Declaration&, ExtensionSet* ext_set, - const ConversionOptions& conversion_options); + const std::shared_ptr&, const compute::Declaration&, ExtensionSet*, + const ConversionOptions&); } // namespace engine } // namespace arrow diff --git a/cpp/src/arrow/engine/substrait/serde_test.cc b/cpp/src/arrow/engine/substrait/serde_test.cc index f35a56b4e72..5f0c337f887 100644 --- a/cpp/src/arrow/engine/substrait/serde_test.cc +++ b/cpp/src/arrow/engine/substrait/serde_test.cc @@ -84,43 +84,6 @@ Result> GetTableFromPlan( return arrow::Table::FromRecordBatchReader(sink_reader.get()); } -// Result WriteTemporaryData(const std::string& file_name, const -// std::shared_ptr
& table, const std::shared_ptr& filesystem) -// { -// ARROW_ASSIGN_OR_RAISE(auto tempdir, -// arrow::internal::TemporaryDir::Make("substrait_tempdir")); ARROW_ASSIGN_OR_RAISE(auto -// file_path, tempdir->path().Join(file_name)); std::string file_path_str = -// file_path.ToString(); EXPECT_EQ(WriteParquetData(file_path_str, filesystem, table), -// true); return file_path_str; -// } - -bool CompareDataset(std::shared_ptr ds_lhs, - std::shared_ptr ds_rhs) { - const auto& fsd_lhs = checked_cast(*ds_lhs); - const auto& fsd_rhs = checked_cast(*ds_rhs); - const auto& files_lhs = fsd_lhs.files(); - const auto& files_rhs = fsd_rhs.files(); - - if (files_lhs.size() != files_rhs.size()) { - return false; - } - uint64_t fidx = 0; - for (const auto& l_file : files_lhs) { - if (l_file != files_rhs[fidx++]) { - return false; - } - } - bool cmp_file_format = fsd_lhs.format()->Equals(*fsd_rhs.format()); - bool cmp_file_system = fsd_lhs.filesystem()->Equals(fsd_rhs.filesystem()); - return cmp_file_format && cmp_file_system; -} - -bool CompareScanOptions(const dataset::ScanNodeOptions& lhs, - const dataset::ScanNodeOptions& rhs) { - return lhs.require_sequenced_output == rhs.require_sequenced_output && - CompareDataset(lhs.dataset, rhs.dataset); -} - class NullSinkNodeConsumer : public compute::SinkNodeConsumer { public: Status Init(const std::shared_ptr&, compute::BackpressureControl*) override { From 1a179a10ad38b55fbfba5080199b1cc0cb37e87e Mon Sep 17 00:00:00 2001 From: Vibhatha Abeykoon Date: Wed, 3 Aug 2022 19:20:43 +0530 Subject: [PATCH 13/30] fix(review): addressing a previous review comment --- cpp/src/arrow/engine/substrait/serde_test.cc | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/cpp/src/arrow/engine/substrait/serde_test.cc b/cpp/src/arrow/engine/substrait/serde_test.cc index 5f0c337f887..ccc18a6834c 100644 --- a/cpp/src/arrow/engine/substrait/serde_test.cc +++ b/cpp/src/arrow/engine/substrait/serde_test.cc @@ -60,11 +60,10 @@ namespace engine { Status WriteParquetData(const std::string& path, const std::shared_ptr file_system, - const std::shared_ptr
input, - const int64_t chunk_size = 3) { + const std::shared_ptr
input) { EXPECT_OK_AND_ASSIGN(auto buffer_writer, file_system->OpenOutputStream(path)); PARQUET_THROW_NOT_OK(parquet::arrow::WriteTable(*input, arrow::default_memory_pool(), - buffer_writer, chunk_size)); + buffer_writer, /*chunk_size*/ 1)); return buffer_writer->Close(); } From 630524a1ac920f3541c61dd6f6b1538c8f9b9fff Mon Sep 17 00:00:00 2001 From: Vibhatha Abeykoon Date: Wed, 3 Aug 2022 19:25:47 +0530 Subject: [PATCH 14/30] fix(review): addressing review comment --- cpp/src/arrow/engine/substrait/relation_internal.cc | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/cpp/src/arrow/engine/substrait/relation_internal.cc b/cpp/src/arrow/engine/substrait/relation_internal.cc index 9785390c798..21d29f05067 100644 --- a/cpp/src/arrow/engine/substrait/relation_internal.cc +++ b/cpp/src/arrow/engine/substrait/relation_internal.cc @@ -542,7 +542,7 @@ Result> ScanRelationConverter( make_unique(); read_rel_lfs_ffs->set_allocated_orc(orc_fmt.release()); } else { - return Status::Invalid("Unsupported file type : ", format_type_name); + return Status::NotImplemented("Unsupported file type: ", format_type_name); } read_rel_lfs->mutable_items()->AddAllocated(read_rel_lfs_ffs.release()); } From ce13740b5120015de135298db9e15b13162d2b1c Mon Sep 17 00:00:00 2001 From: Vibhatha Abeykoon Date: Wed, 3 Aug 2022 19:42:08 +0530 Subject: [PATCH 15/30] fix(code): missed move op added --- cpp/src/arrow/engine/substrait/relation_internal.cc | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/cpp/src/arrow/engine/substrait/relation_internal.cc b/cpp/src/arrow/engine/substrait/relation_internal.cc index 21d29f05067..3e6f3d31720 100644 --- a/cpp/src/arrow/engine/substrait/relation_internal.cc +++ b/cpp/src/arrow/engine/substrait/relation_internal.cc @@ -577,7 +577,7 @@ Result> FilterRelationConverter( ToProto(bound_expression, ext_set, conversion_options)); filter_rel->set_allocated_condition(subs_expr.release()); rel->set_allocated_filter(filter_rel.release()); - return rel; + return std::move(rel); } } // namespace engine From e6abfc9a45f7c6618f29623b7c4065ff220de6ab Mon Sep 17 00:00:00 2001 From: Vibhatha Abeykoon Date: Thu, 4 Aug 2022 12:48:51 +0530 Subject: [PATCH 16/30] fix(path): using ToNative instead of ToString --- cpp/src/arrow/engine/substrait/serde_test.cc | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/cpp/src/arrow/engine/substrait/serde_test.cc b/cpp/src/arrow/engine/substrait/serde_test.cc index ccc18a6834c..88c85c385c1 100644 --- a/cpp/src/arrow/engine/substrait/serde_test.cc +++ b/cpp/src/arrow/engine/substrait/serde_test.cc @@ -2010,7 +2010,7 @@ TEST(Substrait, BasicPlanRoundTrippingEndToEnd) { ASSERT_OK_AND_ASSIGN(auto tempdir, arrow::internal::TemporaryDir::Make("substrait_tempdir")); ASSERT_OK_AND_ASSIGN(auto file_path, tempdir->path().Join(file_name)); - std::string file_path_str = file_path.ToString(); + std::string file_path_str = file_path.ToNative(); // Note: there is an additional forward slash introduced by the tempdir // it must be replaced to properly load into reading files From f07de57ef49fb91c6b38d44db4b59a793d540596 Mon Sep 17 00:00:00 2001 From: Vibhatha Abeykoon Date: Thu, 4 Aug 2022 13:57:09 +0530 Subject: [PATCH 17/30] fix(docs): added conversion_options to docstring --- cpp/src/arrow/engine/substrait/serde.h | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/cpp/src/arrow/engine/substrait/serde.h b/cpp/src/arrow/engine/substrait/serde.h index 7209cea2ba2..a8cb6e20d3a 100644 --- a/cpp/src/arrow/engine/substrait/serde.h +++ b/cpp/src/arrow/engine/substrait/serde.h @@ -211,9 +211,10 @@ Result> SerializeExpression( /// /// \param[in] declaration the Arrow compute declaration to serialize /// \param[in,out] ext_set the extension mapping to use; may be updated to add +/// \param[in] conversion_options options to control how the conversion is done /// mappings for the components in the used declaration /// \return a buffer containing the protobuf serialization of the corresponding Substrait -/// Relation message +/// relation message ARROW_ENGINE_EXPORT Result> SerializeRelation( const compute::Declaration& declaration, ExtensionSet* ext_set, const ConversionOptions& conversion_options = {}); From ea878ea29d5d3eaeeee48d577e35925da995ee11 Mon Sep 17 00:00:00 2001 From: Vibhatha Abeykoon Date: Mon, 22 Aug 2022 14:53:27 +0530 Subject: [PATCH 18/30] fix(rebase): rebasing with Substrait changes --- cpp/src/arrow/engine/CMakeLists.txt | 2 + .../arrow/engine/substrait/extension_set.cc | 19 ++++++ cpp/src/arrow/engine/substrait/serde_test.cc | 62 ++++++++++--------- 3 files changed, 54 insertions(+), 29 deletions(-) diff --git a/cpp/src/arrow/engine/CMakeLists.txt b/cpp/src/arrow/engine/CMakeLists.txt index f3e39a37c53..5153137c3af 100644 --- a/cpp/src/arrow/engine/CMakeLists.txt +++ b/cpp/src/arrow/engine/CMakeLists.txt @@ -26,6 +26,8 @@ set(ARROW_SUBSTRAIT_SRCS substrait/plan_internal.cc substrait/relation_internal.cc substrait/registry.cc + substrait/serde.cc + substrait/test_plan_builder.cc substrait/type_internal.cc substrait/util.cc) diff --git a/cpp/src/arrow/engine/substrait/extension_set.cc b/cpp/src/arrow/engine/substrait/extension_set.cc index 6e8522897ee..0e1f5ebc664 100644 --- a/cpp/src/arrow/engine/substrait/extension_set.cc +++ b/cpp/src/arrow/engine/substrait/extension_set.cc @@ -698,6 +698,20 @@ ExtensionIdRegistry::ArrowToSubstraitCall EncodeOptionlessOverflowableArithmetic }; } +ExtensionIdRegistry::ArrowToSubstraitCall EncodeOptionlessComparison(Id substrait_fn_id) { + return + [substrait_fn_id](const compute::Expression::Call& call) -> Result { + // nullable=true isn't quite correct but we don't know the nullability of + // the inputs + SubstraitCall substrait_call(substrait_fn_id, call.type.GetSharedPtr(), + /*nullable=*/true); + for (std::size_t i = 0; i < call.arguments.size(); i++) { + substrait_call.SetValueArg(static_cast(i), call.arguments[i]); + } + return std::move(substrait_call); + }; +} + ExtensionIdRegistry::SubstraitCallToArrow DecodeOptionlessBasicMapping( const std::string& function_name, uint32_t max_args) { return [function_name, @@ -873,6 +887,11 @@ struct DefaultExtensionIdRegistry : ExtensionIdRegistryImpl { AddArrowToSubstraitCall(std::string(fn_name) + "_checked", EncodeOptionlessOverflowableArithmetic(fn_id))); } + // Comparison operators + for (const auto& fn_name : {"equal", "is_not_distinct_from"}) { + Id fn_id{kSubstraitComparisonFunctionsUri, fn_name}; + DCHECK_OK(AddArrowToSubstraitCall(fn_name, EncodeOptionlessComparison(fn_id))); + } } }; diff --git a/cpp/src/arrow/engine/substrait/serde_test.cc b/cpp/src/arrow/engine/substrait/serde_test.cc index 88c85c385c1..7c8d79db555 100644 --- a/cpp/src/arrow/engine/substrait/serde_test.cc +++ b/cpp/src/arrow/engine/substrait/serde_test.cc @@ -1858,7 +1858,6 @@ TEST(Substrait, BasicPlanRoundTripping) { GTEST_SKIP() << "ARROW-16392: Substrait File URI not supported for Windows"; #else compute::ExecContext exec_context; - ExtensionSet ext_set; auto dummy_schema = schema( {field("key", int32()), field("shared", int32()), field("distinct", int32())}); @@ -1912,8 +1911,11 @@ TEST(Substrait, BasicPlanRoundTripping) { auto scan_options = std::make_shared(); scan_options->projection = compute::project({}, {}); - const std::string filter_col = "shared"; - auto filter = compute::equal(compute::field_ref(filter_col), compute::literal(3)); + const std::string filter_col_left = "shared"; + const std::string filter_col_right = "distinct"; + auto comp_left_value = compute::field_ref(filter_col_left); + auto comp_right_value = compute::field_ref(filter_col_right); + auto filter = compute::equal(comp_left_value, comp_right_value); arrow::AsyncGenerator> sink_gen; @@ -1927,15 +1929,14 @@ TEST(Substrait, BasicPlanRoundTripping) { auto declarations = compute::Declaration::Sequence( {scan_declaration, filter_declaration, sink_declaration}); - ASSERT_OK_AND_ASSIGN(auto plan, compute::ExecPlan::Make(&exec_context)); - ASSERT_OK_AND_ASSIGN(auto serialized_plan, SerializePlan(declarations, &ext_set)); - - for (auto sp_ext_id_reg : - {std::shared_ptr(), substrait::MakeExtensionIdRegistry()}) { + for (auto sp_ext_id_reg : {MakeExtensionIdRegistry()}) { ExtensionIdRegistry* ext_id_reg = sp_ext_id_reg.get(); ExtensionSet ext_set(ext_id_reg); + + ASSERT_OK_AND_ASSIGN(auto serialized_plan, SerializePlan(declarations, &ext_set)); + ASSERT_OK_AND_ASSIGN( auto sink_decls, DeserializePlans( @@ -1949,9 +1950,10 @@ TEST(Substrait, BasicPlanRoundTripping) { if (auto* call = roundtripped_expr.call()) { EXPECT_EQ(call->function_name, "equal"); auto args = call->arguments; - auto index = args[0].field_ref()->field_path()->indices()[0]; - EXPECT_EQ(dummy_schema->field_names()[index], filter_col); - EXPECT_EQ(args[1], compute::literal(3)); + auto left_index = args[0].field_ref()->field_path()->indices()[0]; + EXPECT_EQ(dummy_schema->field_names()[left_index], filter_col_left); + auto right_index = args[1].field_ref()->field_path()->indices()[0]; + EXPECT_EQ(dummy_schema->field_names()[right_index], filter_col_right); } // scan declaration auto roundtripped_scan = roundtripped_filter->inputs[0].get(); @@ -1988,19 +1990,19 @@ TEST(Substrait, BasicPlanRoundTrippingEndToEnd) { // creating a dummy dataset using a dummy table auto table = TableFromJSON(dummy_schema, {R"([ [1, 1, 10], - [3, 4, 20] + [3, 4, 4] ])", R"([ [0, 2, 1], [1, 3, 2], - [4, 1, 3], + [4, 1, 1], [3, 1, 3], - [1, 2, 5] + [1, 2, 2] ])", R"([ [2, 2, 12], [5, 3, 12], - [1, 3, 12] + [1, 3, 3] ])"}); auto format = std::make_shared(); @@ -2035,8 +2037,11 @@ TEST(Substrait, BasicPlanRoundTrippingEndToEnd) { auto scan_options = std::make_shared(); scan_options->projection = compute::project({}, {}); - const std::string filter_col = "shared"; - auto filter = compute::equal(compute::field_ref(filter_col), compute::literal(3)); + const std::string filter_col_left = "shared"; + const std::string filter_col_right = "distinct"; + auto comp_left_value = compute::field_ref(filter_col_left); + auto comp_right_value = compute::field_ref(filter_col_right); + auto filter = compute::equal(comp_left_value, comp_right_value); arrow::AsyncGenerator> sink_gen; @@ -2050,18 +2055,16 @@ TEST(Substrait, BasicPlanRoundTrippingEndToEnd) { auto declarations = compute::Declaration::Sequence( {scan_declaration, filter_declaration, sink_declaration}); - ASSERT_OK_AND_ASSIGN(auto plan, compute::ExecPlan::Make(&exec_context)); + ASSERT_OK_AND_ASSIGN(auto expected_table, GetTableFromPlan(plan, declarations, sink_gen, + exec_context, dummy_schema)); - ASSERT_OK_AND_ASSIGN(auto serialized_plan, SerializePlan(declarations, &ext_set)); - - ASSERT_OK_AND_ASSIGN(auto expected_tb, GetTableFromPlan(plan, declarations, sink_gen, - exec_context, dummy_schema)); - - for (auto sp_ext_id_reg : - {std::shared_ptr(), substrait::MakeExtensionIdRegistry()}) { + for (auto sp_ext_id_reg : {MakeExtensionIdRegistry()}) { ExtensionIdRegistry* ext_id_reg = sp_ext_id_reg.get(); ExtensionSet ext_set(ext_id_reg); + + ASSERT_OK_AND_ASSIGN(auto serialized_plan, SerializePlan(declarations, &ext_set)); + ASSERT_OK_AND_ASSIGN( auto sink_decls, DeserializePlans( @@ -2075,9 +2078,10 @@ TEST(Substrait, BasicPlanRoundTrippingEndToEnd) { if (auto* call = roundtripped_expr.call()) { EXPECT_EQ(call->function_name, "equal"); auto args = call->arguments; - auto index = args[0].field_ref()->field_path()->indices()[0]; - EXPECT_EQ(dummy_schema->field_names()[index], filter_col); - EXPECT_EQ(args[1], compute::literal(3)); + auto left_index = args[0].field_ref()->field_path()->indices()[0]; + EXPECT_EQ(dummy_schema->field_names()[left_index], filter_col_left); + auto right_index = args[1].field_ref()->field_path()->indices()[0]; + EXPECT_EQ(dummy_schema->field_names()[right_index], filter_col_right); } // scan declaration auto roundtripped_scan = roundtripped_filter->inputs[0].get(); @@ -2108,7 +2112,7 @@ TEST(Substrait, BasicPlanRoundTrippingEndToEnd) { ASSERT_OK_AND_ASSIGN(auto rnd_trp_table, GetTableFromPlan(rnd_trp_plan, rnd_trp_declarations, rnd_trp_sink_gen, exec_context, dummy_schema)); - EXPECT_TRUE(expected_tb->Equals(*rnd_trp_table)); + EXPECT_TRUE(expected_table->Equals(*rnd_trp_table)); } #endif } From 1daecba2f98b8f18f935376defb7d50b7afb8e64 Mon Sep 17 00:00:00 2001 From: Vibhatha Abeykoon Date: Mon, 22 Aug 2022 15:08:55 +0530 Subject: [PATCH 19/30] fix(address_review): refactor --- cpp/src/arrow/engine/substrait/serde_test.cc | 45 +++++++------------- 1 file changed, 16 insertions(+), 29 deletions(-) diff --git a/cpp/src/arrow/engine/substrait/serde_test.cc b/cpp/src/arrow/engine/substrait/serde_test.cc index 7c8d79db555..a615d5a9aff 100644 --- a/cpp/src/arrow/engine/substrait/serde_test.cc +++ b/cpp/src/arrow/engine/substrait/serde_test.cc @@ -68,9 +68,10 @@ Status WriteParquetData(const std::string& path, } Result> GetTableFromPlan( - std::shared_ptr& plan, compute::Declaration& declarations, + compute::Declaration& declarations, arrow::AsyncGenerator>& sink_gen, compute::ExecContext& exec_context, std::shared_ptr& output_schema) { + ARROW_ASSIGN_OR_RAISE(auto plan, compute::ExecPlan::Make(&exec_context)); ARROW_ASSIGN_OR_RAISE(auto decl, declarations.AddToPlan(plan.get())); RETURN_NOT_OK(decl->Validate()); @@ -1856,7 +1857,7 @@ TEST(Substrait, AggregateBadPhase) { TEST(Substrait, BasicPlanRoundTripping) { #ifdef _WIN32 GTEST_SKIP() << "ARROW-16392: Substrait File URI not supported for Windows"; -#else +#endif compute::ExecContext exec_context; auto dummy_schema = schema( {field("key", int32()), field("shared", int32()), field("distinct", int32())}); @@ -1919,17 +1920,11 @@ TEST(Substrait, BasicPlanRoundTripping) { arrow::AsyncGenerator> sink_gen; - auto scan_node_options = dataset::ScanNodeOptions{dataset, scan_options}; - auto filter_node_options = compute::FilterNodeOptions{filter}; - auto sink_node_options = compute::SinkNodeOptions{&sink_gen}; - - auto scan_declaration = compute::Declaration({"scan", scan_node_options, "s"}); - auto filter_declaration = compute::Declaration({"filter", filter_node_options, "f"}); - auto sink_declaration = compute::Declaration({"sink", sink_node_options, "e"}); - auto declarations = compute::Declaration::Sequence( - {scan_declaration, filter_declaration, sink_declaration}); - ASSERT_OK_AND_ASSIGN(auto plan, compute::ExecPlan::Make(&exec_context)); + {compute::Declaration( + {"scan", dataset::ScanNodeOptions{dataset, scan_options}, "s"}), + compute::Declaration({"filter", compute::FilterNodeOptions{filter}, "f"}), + compute::Declaration({"sink", compute::SinkNodeOptions{&sink_gen}, "e"})}); for (auto sp_ext_id_reg : {MakeExtensionIdRegistry()}) { ExtensionIdRegistry* ext_id_reg = sp_ext_id_reg.get(); @@ -1975,13 +1970,12 @@ TEST(Substrait, BasicPlanRoundTripping) { EXPECT_TRUE(l_frag->Equals(*r_frag)); } } -#endif } TEST(Substrait, BasicPlanRoundTrippingEndToEnd) { #ifdef _WIN32 GTEST_SKIP() << "ARROW-16392: Substrait File URI not supported for Windows"; -#else +#endif compute::ExecContext exec_context; ExtensionSet ext_set; auto dummy_schema = schema( @@ -2045,18 +2039,13 @@ TEST(Substrait, BasicPlanRoundTrippingEndToEnd) { arrow::AsyncGenerator> sink_gen; - auto scan_node_options = dataset::ScanNodeOptions{dataset, scan_options}; - auto filter_node_options = compute::FilterNodeOptions{filter}; - auto sink_node_options = compute::SinkNodeOptions{&sink_gen}; - - auto scan_declaration = compute::Declaration({"scan", scan_node_options, "s"}); - auto filter_declaration = compute::Declaration({"filter", filter_node_options, "f"}); - auto sink_declaration = compute::Declaration({"sink", sink_node_options, "e"}); - auto declarations = compute::Declaration::Sequence( - {scan_declaration, filter_declaration, sink_declaration}); - ASSERT_OK_AND_ASSIGN(auto plan, compute::ExecPlan::Make(&exec_context)); - ASSERT_OK_AND_ASSIGN(auto expected_table, GetTableFromPlan(plan, declarations, sink_gen, + {compute::Declaration( + {"scan", dataset::ScanNodeOptions{dataset, scan_options}, "s"}), + compute::Declaration({"filter", compute::FilterNodeOptions{filter}, "f"}), + compute::Declaration({"sink", compute::SinkNodeOptions{&sink_gen}, "e"})}); + + ASSERT_OK_AND_ASSIGN(auto expected_table, GetTableFromPlan(declarations, sink_gen, exec_context, dummy_schema)); for (auto sp_ext_id_reg : {MakeExtensionIdRegistry()}) { @@ -2108,13 +2097,11 @@ TEST(Substrait, BasicPlanRoundTrippingEndToEnd) { compute::Declaration({"sink", rnd_trp_sink_node_options, "e"}); auto rnd_trp_declarations = compute::Declaration::Sequence({*roundtripped_filter, rnd_trp_sink_declaration}); - ASSERT_OK_AND_ASSIGN(auto rnd_trp_plan, compute::ExecPlan::Make(&exec_context)); ASSERT_OK_AND_ASSIGN(auto rnd_trp_table, - GetTableFromPlan(rnd_trp_plan, rnd_trp_declarations, - rnd_trp_sink_gen, exec_context, dummy_schema)); + GetTableFromPlan(rnd_trp_declarations, rnd_trp_sink_gen, + exec_context, dummy_schema)); EXPECT_TRUE(expected_table->Equals(*rnd_trp_table)); } -#endif } } // namespace engine From ea8c5576ae299441e8e83c6e4e289e293e20a565 Mon Sep 17 00:00:00 2001 From: Vibhatha Abeykoon Date: Wed, 24 Aug 2022 09:35:56 +0530 Subject: [PATCH 20/30] fix(registry): cleaning up registry --- cpp/src/arrow/engine/CMakeLists.txt | 1 - cpp/src/arrow/engine/substrait/registry.cc | 71 --------------- cpp/src/arrow/engine/substrait/registry.h | 86 ------------------- .../engine/substrait/relation_internal.cc | 29 +++++-- 4 files changed, 23 insertions(+), 164 deletions(-) delete mode 100644 cpp/src/arrow/engine/substrait/registry.cc delete mode 100644 cpp/src/arrow/engine/substrait/registry.h diff --git a/cpp/src/arrow/engine/CMakeLists.txt b/cpp/src/arrow/engine/CMakeLists.txt index 5153137c3af..a8d5be90af8 100644 --- a/cpp/src/arrow/engine/CMakeLists.txt +++ b/cpp/src/arrow/engine/CMakeLists.txt @@ -25,7 +25,6 @@ set(ARROW_SUBSTRAIT_SRCS substrait/extension_types.cc substrait/plan_internal.cc substrait/relation_internal.cc - substrait/registry.cc substrait/serde.cc substrait/test_plan_builder.cc substrait/type_internal.cc diff --git a/cpp/src/arrow/engine/substrait/registry.cc b/cpp/src/arrow/engine/substrait/registry.cc deleted file mode 100644 index 237a689b5ac..00000000000 --- a/cpp/src/arrow/engine/substrait/registry.cc +++ /dev/null @@ -1,71 +0,0 @@ -// Licensed to the Apache Software Foundation (ASF) under one -// or more contributor license agreements. See the NOTICE file -// distributed with this work for additional information -// regarding copyright ownership. The ASF licenses this file -// to you under the Apache License, Version 2.0 (the -// "License"); you may not use this file except in compliance -// with the License. You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, -// software distributed under the License is distributed on an -// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -// KIND, either express or implied. See the License for the -// specific language governing permissions and limitations -// under the License. - -// NOTE: API is EXPERIMENTAL and will change without going through a -// deprecation cycle - -#include "arrow/engine/substrait/registry.h" -#include "arrow/engine/substrait/relation_internal.h" - -namespace arrow { - -namespace engine { - -class SubstraitConversionRegistryImpl : public SubstraitConversionRegistry { - public: - virtual ~SubstraitConversionRegistryImpl() {} - - Result GetConverter(const std::string& factory_name) override { - auto it = name_to_converter_.find(factory_name); - if (it == name_to_converter_.end()) { - return Status::KeyError("SubstraitConverter named ", factory_name, - " not present in registry."); - } - return it->second; - } - - Status RegisterConverter(std::string factory_name, - SubstraitConverter converter) override { - auto it_success = - name_to_converter_.emplace(std::move(factory_name), std::move(converter)); - - if (!it_success.second) { - const auto& factory_name = it_success.first->first; - return Status::KeyError("SubstraitConverter named ", factory_name, - " already registered."); - } - return Status::OK(); - } - - private: - std::unordered_map name_to_converter_; -}; - -struct DefaultSubstraitConversionRegistry : SubstraitConversionRegistryImpl { - DefaultSubstraitConversionRegistry() { - DCHECK_OK(RegisterConverter("scan", ScanRelationConverter)); - DCHECK_OK(RegisterConverter("filter", FilterRelationConverter)); - } -}; - -SubstraitConversionRegistry* default_substrait_conversion_registry() { - static DefaultSubstraitConversionRegistry impl_; - return &impl_; -} - -} // namespace engine -} // namespace arrow diff --git a/cpp/src/arrow/engine/substrait/registry.h b/cpp/src/arrow/engine/substrait/registry.h deleted file mode 100644 index 82c3bc2a7f8..00000000000 --- a/cpp/src/arrow/engine/substrait/registry.h +++ /dev/null @@ -1,86 +0,0 @@ -// Licensed to the Apache Software Foundation (ASF) under one -// or more contributor license agreements. See the NOTICE file -// distributed with this work for additional information -// regarding copyright ownership. The ASF licenses this file -// to you under the Apache License, Version 2.0 (the -// "License"); you may not use this file except in compliance -// with the License. You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, -// software distributed under the License is distributed on an -// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -// KIND, either express or implied. See the License for the -// specific language governing permissions and limitations -// under the License. - -// NOTE: API is EXPERIMENTAL and will change without going through a -// deprecation cycle - -#pragma once - -#include -#include -#include - -#include "arrow/result.h" -#include "arrow/status.h" -#include "arrow/util/visibility.h" - -#include "arrow/compute/exec/exec_plan.h" -#include "arrow/engine/substrait/extension_set.h" -#include "arrow/engine/substrait/extension_types.h" -#include "arrow/engine/substrait/options.h" -#include "arrow/engine/substrait/relation_internal.h" -#include "arrow/engine/substrait/serde.h" -#include "arrow/engine/substrait/visibility.h" -#include "arrow/type_fwd.h" - -#include "substrait/algebra.pb.h" // IWYU pragma: export - -namespace arrow { - -namespace engine { - -/// \brief Acero-Substrait integration contains converters which enables -/// converting Acero ExecPlan related entities to the corresponding Substrait -/// entities. -/// -/// Note that the current registry definition only holds converters to convert -/// an Acero plan to Substrait plan. -class ARROW_ENGINE_EXPORT SubstraitConversionRegistry { - public: - virtual ~SubstraitConversionRegistry() = default; - - /// \brief Alias for Acero-to-Substrait converter - using SubstraitConverter = std::function>( - const std::shared_ptr&, const compute::Declaration&, ExtensionSet*, - const ConversionOptions&)>; - - /// \brief Retrieve a SubstraitConverter from the registry by factory name - /// - /// \param[in] factory_name name of the converter (aligned with Acero ExecNode kind - /// name) \return the matching SubstraitConverter - virtual Result GetConverter(const std::string& factory_name) = 0; - - /// \brief Register a converter by factory - /// - /// \param[in] factory_name name of the converter - /// \param[in] converter the std::function encapsulating the converter logic - /// \return Status of the registration - virtual Status RegisterConverter(std::string factory_name, - SubstraitConverter converter) = 0; -}; - -/// \brief Retrive the default Acero-to-Substrait conversion registry -/// The default registry contains the converters corresponding to mapping -/// the core ExecNodes in Acero. -/// -/// The default registry can be represented as a parent registry if a non-Acero -/// converters are required to be used with it. It must be separately implemented -/// by using the default input as the parent. -ARROW_ENGINE_EXPORT SubstraitConversionRegistry* default_substrait_conversion_registry(); - -} // namespace engine -} // namespace arrow diff --git a/cpp/src/arrow/engine/substrait/relation_internal.cc b/cpp/src/arrow/engine/substrait/relation_internal.cc index 3e6f3d31720..c2d767cbe53 100644 --- a/cpp/src/arrow/engine/substrait/relation_internal.cc +++ b/cpp/src/arrow/engine/substrait/relation_internal.cc @@ -25,7 +25,6 @@ #include "arrow/dataset/plan.h" #include "arrow/dataset/scanner.h" #include "arrow/engine/substrait/expression_internal.h" -#include "arrow/engine/substrait/registry.h" #include "arrow/engine/substrait/type_internal.h" #include "arrow/filesystem/localfs.h" #include "arrow/filesystem/path_util.h" @@ -40,6 +39,10 @@ using internal::make_unique; namespace engine { +using SubstraitConverter = std::function>( + const std::shared_ptr&, const compute::Declaration&, ExtensionSet*, + const ConversionOptions&)>; + template Status CheckRelCommon(const RelMessage& rel) { if (rel.has_common()) { @@ -480,12 +483,26 @@ Status SerializeAndCombineRelations(const compute::Declaration& declaration, } const auto& factory_name = declaration.factory_name; ARROW_ASSIGN_OR_RAISE(auto schema, ExtractSchemaToBind(declaration)); - SubstraitConversionRegistry* registry = default_substrait_conversion_registry(); - if (factory_name != "sink") { - ARROW_ASSIGN_OR_RAISE(auto factory, registry->GetConverter(factory_name)); - ARROW_ASSIGN_OR_RAISE(auto factory_rel, - factory(schema, declaration, ext_set, conversion_options)); + // Note that the sink declaration factory doesn't exist for serialization as + // Substrait doesn't deal with a sink node definition + std::unique_ptr factory_rel; + if (factory_name == "scan") { + ARROW_ASSIGN_OR_RAISE(factory_rel, ScanRelationConverter(schema, declaration, ext_set, + conversion_options)); + } else if (factory_name == "filter") { + ARROW_ASSIGN_OR_RAISE( + factory_rel, + FilterRelationConverter(schema, declaration, ext_set, conversion_options)); + } else { + return Status::NotImplemented("Factory ", factory_name, + " not implemented for roundtripping."); + } + + if (factory_rel != nullptr) { RETURN_NOT_OK(SetRelation(rel, factory_rel, factory_name)); + } else { + return Status::Invalid("Conversion on factory ", factory_name, + " returned an invalid relation"); } return Status::OK(); } From ef407b09a6f651ec13f95e564de900c67a2c9581 Mon Sep 17 00:00:00 2001 From: Vibhatha Abeykoon Date: Mon, 29 Aug 2022 10:27:43 +0530 Subject: [PATCH 21/30] fix(reviews): uri fix, remove SetRelation, simplify code --- .../engine/substrait/relation_internal.cc | 129 ++++++++---------- .../engine/substrait/relation_internal.h | 10 -- cpp/src/arrow/engine/substrait/serde.cc | 5 +- cpp/src/arrow/engine/substrait/serde.h | 19 ++- cpp/src/arrow/filesystem/localfs_test.cc | 14 +- cpp/src/arrow/util/uri.cc | 10 ++ cpp/src/arrow/util/uri.h | 5 + 7 files changed, 90 insertions(+), 102 deletions(-) diff --git a/cpp/src/arrow/engine/substrait/relation_internal.cc b/cpp/src/arrow/engine/substrait/relation_internal.cc index c2d767cbe53..2cd2bf079a9 100644 --- a/cpp/src/arrow/engine/substrait/relation_internal.cc +++ b/cpp/src/arrow/engine/substrait/relation_internal.cc @@ -31,18 +31,16 @@ #include "arrow/filesystem/util_internal.h" #include "arrow/util/checked_cast.h" #include "arrow/util/make_unique.h" +#include "arrow/util/uri.h" namespace arrow { +using ::arrow::internal::UriFromAbsolutePath; using internal::checked_cast; using internal::make_unique; namespace engine { -using SubstraitConverter = std::function>( - const std::shared_ptr&, const compute::Declaration&, ExtensionSet*, - const ConversionOptions&)>; - template Status CheckRelCommon(const RelMessage& rel) { if (rel.has_common()) { @@ -439,18 +437,19 @@ Result> ToProto( return std::move(rel); } -Status SetRelation(const std::unique_ptr& plan, - const std::unique_ptr& partial_plan, - const std::string& factory_name) { - if (factory_name == "scan" && partial_plan->has_read()) { - plan->set_allocated_read(partial_plan->release_read()); - } else if (factory_name == "filter" && partial_plan->has_filter()) { - plan->set_allocated_filter(partial_plan->release_filter()); - } else { - return Status::NotImplemented("Substrait converter ", factory_name, - " not supported."); +namespace { + +Result> GetRelationFromDeclaration( + const compute::Declaration& declaration, ExtensionSet* ext_set, + const ConversionOptions& conversion_options) { + auto declr_input = declaration.inputs[0]; + // Note that the input is expected in declaration. + // ExecNode inputs are not accepted + if (util::get_if(&declr_input)) { + return Status::NotImplemented("Only support Plans written in Declaration format."); } - return Status::OK(); + return ToProto(util::get(declr_input), ext_set, + conversion_options); } Result> ExtractSchemaToBind(const compute::Declaration& declr) { @@ -471,59 +470,9 @@ Result> ExtractSchemaToBind(const compute::Declaration& return bind_schema; } -Status SerializeAndCombineRelations(const compute::Declaration& declaration, - ExtensionSet* ext_set, - std::unique_ptr& rel, - const ConversionOptions& conversion_options) { - std::vector inputs = declaration.inputs; - for (auto& input : inputs) { - auto input_decl = util::get(input); - RETURN_NOT_OK( - SerializeAndCombineRelations(input_decl, ext_set, rel, conversion_options)); - } - const auto& factory_name = declaration.factory_name; - ARROW_ASSIGN_OR_RAISE(auto schema, ExtractSchemaToBind(declaration)); - // Note that the sink declaration factory doesn't exist for serialization as - // Substrait doesn't deal with a sink node definition - std::unique_ptr factory_rel; - if (factory_name == "scan") { - ARROW_ASSIGN_OR_RAISE(factory_rel, ScanRelationConverter(schema, declaration, ext_set, - conversion_options)); - } else if (factory_name == "filter") { - ARROW_ASSIGN_OR_RAISE( - factory_rel, - FilterRelationConverter(schema, declaration, ext_set, conversion_options)); - } else { - return Status::NotImplemented("Factory ", factory_name, - " not implemented for roundtripping."); - } - - if (factory_rel != nullptr) { - RETURN_NOT_OK(SetRelation(rel, factory_rel, factory_name)); - } else { - return Status::Invalid("Conversion on factory ", factory_name, - " returned an invalid relation"); - } - return Status::OK(); -} - -Result> GetRelationFromDeclaration( - const compute::Declaration& declaration, ExtensionSet* ext_set, - const ConversionOptions& conversion_options) { - auto declr_input = declaration.inputs[0]; - // Note that the input is expected in declaration. - // ExecNode inputs are not accepted - if (util::get_if(&declr_input)) { - return Status::NotImplemented("Only support Plans written in Declaration format."); - } - return ToProto(util::get(declr_input), ext_set, - conversion_options); -} - -Result> ScanRelationConverter( +Result> ScanRelationConverter( const std::shared_ptr& schema, const compute::Declaration& declaration, ExtensionSet* ext_set, const ConversionOptions& conversion_options) { - auto rel = make_unique(); auto read_rel = make_unique(); const auto& scan_node_options = checked_cast(*declaration.options); @@ -541,7 +490,7 @@ Result> ScanRelationConverter( auto read_rel_lfs = make_unique(); for (const auto& file : dataset->files()) { auto read_rel_lfs_ffs = make_unique(); - read_rel_lfs_ffs->set_uri_path("file://" + file); + read_rel_lfs_ffs->set_uri_path(UriFromAbsolutePath(file)); // set file format // arrow and feather are temporarily handled via the Parquet format until // upgraded to the latest Substrait version. @@ -564,14 +513,12 @@ Result> ScanRelationConverter( read_rel_lfs->mutable_items()->AddAllocated(read_rel_lfs_ffs.release()); } read_rel->set_allocated_local_files(read_rel_lfs.release()); - rel->set_allocated_read(read_rel.release()); - return std::move(rel); + return std::move(read_rel); } -Result> FilterRelationConverter( +Result> FilterRelationConverter( const std::shared_ptr& schema, const compute::Declaration& declaration, ExtensionSet* ext_set, const ConversionOptions& conversion_options) { - auto rel = make_unique(); auto filter_rel = make_unique(); const auto& filter_node_options = checked_cast(*(declaration.options)); @@ -593,8 +540,44 @@ Result> FilterRelationConverter( ARROW_ASSIGN_OR_RAISE(auto subs_expr, ToProto(bound_expression, ext_set, conversion_options)); filter_rel->set_allocated_condition(subs_expr.release()); - rel->set_allocated_filter(filter_rel.release()); - return std::move(rel); + return std::move(filter_rel); +} + +} // namespace + +Status SerializeAndCombineRelations(const compute::Declaration& declaration, + ExtensionSet* ext_set, + std::unique_ptr& rel, + const ConversionOptions& conversion_options) { + std::vector inputs = declaration.inputs; + for (auto& input : inputs) { + auto input_decl = util::get(input); + RETURN_NOT_OK( + SerializeAndCombineRelations(input_decl, ext_set, rel, conversion_options)); + } + const auto& factory_name = declaration.factory_name; + ARROW_ASSIGN_OR_RAISE(auto schema, ExtractSchemaToBind(declaration)); + // Note that the sink declaration factory doesn't exist for serialization as + // Substrait doesn't deal with a sink node definition + + // ignore the sink relation in a plan, since sink is implicitly added + if (factory_name != "sink") { + if (factory_name == "scan") { + ARROW_ASSIGN_OR_RAISE( + auto read_rel, + ScanRelationConverter(schema, declaration, ext_set, conversion_options)); + rel->set_allocated_read(read_rel.release()); + } else if (factory_name == "filter") { + ARROW_ASSIGN_OR_RAISE( + auto filter_rel, + FilterRelationConverter(schema, declaration, ext_set, conversion_options)); + rel->set_allocated_filter(filter_rel.release()); + } else { + return Status::NotImplemented("Factory ", factory_name, + " not implemented for roundtripping."); + } + } + return Status::OK(); } } // namespace engine diff --git a/cpp/src/arrow/engine/substrait/relation_internal.h b/cpp/src/arrow/engine/substrait/relation_internal.h index 1cb2f3673ef..79c87315f70 100644 --- a/cpp/src/arrow/engine/substrait/relation_internal.h +++ b/cpp/src/arrow/engine/substrait/relation_internal.h @@ -71,15 +71,5 @@ ARROW_ENGINE_EXPORT Status SerializeAndCombineRelations(const compute::Declarati ARROW_ENGINE_EXPORT Result> ToProto( const compute::Declaration&, ExtensionSet*, const ConversionOptions&); -/// \brief Acero to Substrait converter for Acero scan relation. -ARROW_ENGINE_EXPORT Result> ScanRelationConverter( - const std::shared_ptr&, const compute::Declaration&, ExtensionSet*, - const ConversionOptions&); - -/// \brief Acero to Substrait converter for Acero filter relation. -ARROW_ENGINE_EXPORT Result> FilterRelationConverter( - const std::shared_ptr&, const compute::Declaration&, ExtensionSet*, - const ConversionOptions&); - } // namespace engine } // namespace arrow diff --git a/cpp/src/arrow/engine/substrait/serde.cc b/cpp/src/arrow/engine/substrait/serde.cc index a311a4763ba..c6297675492 100644 --- a/cpp/src/arrow/engine/substrait/serde.cc +++ b/cpp/src/arrow/engine/substrait/serde.cc @@ -53,9 +53,10 @@ Result ParseFromBuffer(const Buffer& buf) { } Result> SerializePlan( - const compute::Declaration& declr, ExtensionSet* ext_set, + const compute::Declaration& declaration, ExtensionSet* ext_set, const ConversionOptions& conversion_options) { - ARROW_ASSIGN_OR_RAISE(auto subs_plan, PlanToProto(declr, ext_set, conversion_options)); + ARROW_ASSIGN_OR_RAISE(auto subs_plan, + PlanToProto(declaration, ext_set, conversion_options)); std::string serialized = subs_plan->SerializeAsString(); return Buffer::FromString(std::move(serialized)); } diff --git a/cpp/src/arrow/engine/substrait/serde.h b/cpp/src/arrow/engine/substrait/serde.h index a8cb6e20d3a..2a14ca67570 100644 --- a/cpp/src/arrow/engine/substrait/serde.h +++ b/cpp/src/arrow/engine/substrait/serde.h @@ -36,9 +36,17 @@ namespace arrow { namespace engine { +/// \brief Serialize an Acero Plan to a binary protobuf Substrait message +/// +/// \param[in] declaration the Acero declaration to serialize. +/// This declaration is the sink relation of the Acero plan. +/// \param[in,out] ext_set the extension mapping to use; may be updated to add +/// \param[in] conversion_options options to control how the conversion is done +/// +/// \return a buffer containing the protobuf serialization of the Acero relation ARROW_ENGINE_EXPORT Result> SerializePlan( - const compute::Declaration& declr, ExtensionSet* ext_set, + const compute::Declaration& declaration, ExtensionSet* ext_set, const ConversionOptions& conversion_options = {}); /// Factory function type for generating the node that consumes the batches produced by @@ -207,14 +215,13 @@ Result> SerializeExpression( const compute::Expression& expr, ExtensionSet* ext_set, const ConversionOptions& conversion_options = {}); -/// \brief Serializes an Arrow compute Declaration to a Substrait Relation message +/// \brief Serialize an Acero Declaration to a binary protobuf Substrait message /// -/// \param[in] declaration the Arrow compute declaration to serialize +/// \param[in] declaration the Acero declaration to serialize /// \param[in,out] ext_set the extension mapping to use; may be updated to add /// \param[in] conversion_options options to control how the conversion is done -/// mappings for the components in the used declaration -/// \return a buffer containing the protobuf serialization of the corresponding Substrait -/// relation message +/// +/// \return a buffer containing the protobuf serialization of the Acero relation ARROW_ENGINE_EXPORT Result> SerializeRelation( const compute::Declaration& declaration, ExtensionSet* ext_set, const ConversionOptions& conversion_options = {}); diff --git a/cpp/src/arrow/filesystem/localfs_test.cc b/cpp/src/arrow/filesystem/localfs_test.cc index 0078a593938..fd36faf30fa 100644 --- a/cpp/src/arrow/filesystem/localfs_test.cc +++ b/cpp/src/arrow/filesystem/localfs_test.cc @@ -32,6 +32,7 @@ #include "arrow/filesystem/util_internal.h" #include "arrow/testing/gtest_util.h" #include "arrow/util/io_util.h" +#include "arrow/util/uri.h" namespace arrow { namespace fs { @@ -40,6 +41,7 @@ namespace internal { using ::arrow::internal::FileDescriptor; using ::arrow::internal::PlatformFilename; using ::arrow::internal::TemporaryDir; +using ::arrow::internal::UriFromAbsolutePath; class LocalFSTestMixin : public ::testing::Test { public: @@ -173,16 +175,6 @@ class TestLocalFS : public LocalFSTestMixin { fs_ = std::make_shared(local_path_, local_fs_); } - std::string UriFromAbsolutePath(const std::string& path) { -#ifdef _WIN32 - // Path is supposed to start with "X:/..." - return "file:///" + path; -#else - // Path is supposed to start with "/..." - return "file://" + path; -#endif - } - template void CheckFileSystemFromUriFunc(const std::string& uri, FileSystemFromUriFunc&& fs_from_uri) { @@ -307,7 +299,7 @@ TYPED_TEST(TestLocalFS, NormalizePathThroughSubtreeFS) { TYPED_TEST(TestLocalFS, FileSystemFromUriFile) { // Concrete test with actual file - const auto uri_string = this->UriFromAbsolutePath(this->local_path_); + const auto uri_string = UriFromAbsolutePath(this->local_path_); this->TestFileSystemFromUri(uri_string); this->TestFileSystemFromUriOrPath(uri_string); diff --git a/cpp/src/arrow/util/uri.cc b/cpp/src/arrow/util/uri.cc index 7a8484ce51a..abfc9de8b49 100644 --- a/cpp/src/arrow/util/uri.cc +++ b/cpp/src/arrow/util/uri.cc @@ -304,5 +304,15 @@ Status Uri::Parse(const std::string& uri_string) { return Status::OK(); } +std::string UriFromAbsolutePath(const std::string& path) { +#ifdef _WIN32 + // Path is supposed to start with "X:/..." + return "file:///" + path; +#else + // Path is supposed to start with "/..." + return "file://" + path; +#endif +} + } // namespace internal } // namespace arrow diff --git a/cpp/src/arrow/util/uri.h b/cpp/src/arrow/util/uri.h index eae1956eafc..7ea82d33c5d 100644 --- a/cpp/src/arrow/util/uri.h +++ b/cpp/src/arrow/util/uri.h @@ -104,5 +104,10 @@ std::string UriEncodeHost(const std::string& host); ARROW_EXPORT bool IsValidUriScheme(const arrow::util::string_view s); +/// Create a file uri from a given URI +/// file:/// +ARROW_EXPORT +std::string UriFromAbsolutePath(const std::string& path); + } // namespace internal } // namespace arrow From 6571de2129a337de21a82dba998364e98cfd4b41 Mon Sep 17 00:00:00 2001 From: Vibhatha Abeykoon Date: Mon, 29 Aug 2022 11:09:22 +0530 Subject: [PATCH 22/30] fix(cleanup): raddressing reviews --- .../arrow/engine/substrait/plan_internal.cc | 2 +- .../engine/substrait/relation_internal.cc | 8 +++--- .../engine/substrait/relation_internal.h | 27 +++++++++---------- 3 files changed, 17 insertions(+), 20 deletions(-) diff --git a/cpp/src/arrow/engine/substrait/plan_internal.cc b/cpp/src/arrow/engine/substrait/plan_internal.cc index 6caf0b69193..27c3ae0f5d8 100644 --- a/cpp/src/arrow/engine/substrait/plan_internal.cc +++ b/cpp/src/arrow/engine/substrait/plan_internal.cc @@ -142,7 +142,7 @@ Result> PlanToProto( auto subs_plan = internal::make_unique(); auto plan_rel = internal::make_unique(); auto rel = internal::make_unique(); - RETURN_NOT_OK(SerializeAndCombineRelations(declr, ext_set, rel, conversion_options)); + RETURN_NOT_OK(SerializeAndCombineRelations(declr, ext_set, &rel, conversion_options)); plan_rel->set_allocated_rel(rel.release()); subs_plan->mutable_relations()->AddAllocated(plan_rel.release()); RETURN_NOT_OK(AddExtensionSetToPlan(*ext_set, subs_plan.get())); diff --git a/cpp/src/arrow/engine/substrait/relation_internal.cc b/cpp/src/arrow/engine/substrait/relation_internal.cc index 2cd2bf079a9..aeaed44aeec 100644 --- a/cpp/src/arrow/engine/substrait/relation_internal.cc +++ b/cpp/src/arrow/engine/substrait/relation_internal.cc @@ -433,7 +433,7 @@ Result> ToProto( const compute::Declaration& declr, ExtensionSet* ext_set, const ConversionOptions& conversion_options) { auto rel = make_unique(); - RETURN_NOT_OK(SerializeAndCombineRelations(declr, ext_set, rel, conversion_options)); + RETURN_NOT_OK(SerializeAndCombineRelations(declr, ext_set, &rel, conversion_options)); return std::move(rel); } @@ -547,7 +547,7 @@ Result> FilterRelationConverter( Status SerializeAndCombineRelations(const compute::Declaration& declaration, ExtensionSet* ext_set, - std::unique_ptr& rel, + std::unique_ptr* rel, const ConversionOptions& conversion_options) { std::vector inputs = declaration.inputs; for (auto& input : inputs) { @@ -566,12 +566,12 @@ Status SerializeAndCombineRelations(const compute::Declaration& declaration, ARROW_ASSIGN_OR_RAISE( auto read_rel, ScanRelationConverter(schema, declaration, ext_set, conversion_options)); - rel->set_allocated_read(read_rel.release()); + (*rel)->set_allocated_read(read_rel.release()); } else if (factory_name == "filter") { ARROW_ASSIGN_OR_RAISE( auto filter_rel, FilterRelationConverter(schema, declaration, ext_set, conversion_options)); - rel->set_allocated_filter(filter_rel.release()); + (*rel)->set_allocated_filter(filter_rel.release()); } else { return Status::NotImplemented("Factory ", factory_name, " not implemented for roundtripping."); diff --git a/cpp/src/arrow/engine/substrait/relation_internal.h b/cpp/src/arrow/engine/substrait/relation_internal.h index 79c87315f70..225b82ef76d 100644 --- a/cpp/src/arrow/engine/substrait/relation_internal.h +++ b/cpp/src/arrow/engine/substrait/relation_internal.h @@ -45,29 +45,26 @@ ARROW_ENGINE_EXPORT Result FromProto(const substrait::Rel&, const ExtensionSet&, const ConversionOptions&); -/// \brief Serializes a Declaration, produce a Substrait Rel and update the global -/// Substrait plan. A Substrait Rel is passed as a the plan and it is updated with +/// \brief Convert a Declaration (and its inputs) to a Substrait Rel +/// +/// A Substrait Rel is passed as a the plan and it is updated with /// corresponding Declaration passed for serialization. /// -/// Note that this is a rather a helper method useful to fuse a partially serialized -/// plan with another plan. The reason for having a partially serialized plan is to -/// avoid unnecessary complication and enable partial plan serialization without -/// affecting a global plan. Since kept as unique_ptr resources are relased efficiently -/// upon releasing for the global plan. +/// Note that this used to fuse a partially serialized plan with another plan. +/// Partially serialized plan is recursively being used to generate global plan. +/// Since kept as unique_ptr resources are relased efficiently upon releasing for +/// the global plan. ARROW_ENGINE_EXPORT Status SerializeAndCombineRelations(const compute::Declaration&, ExtensionSet*, - std::unique_ptr&, + std::unique_ptr*, const ConversionOptions&); -/// \brief Serialize a Declaration and produces a Substrait Rel. +/// \brief Convert an Acero Declaration to a Substrait Rel /// /// Note that in order to provide a generic interface for ToProto for -/// declaration it is not specialized for each relation within the Substrait Rel. -/// Rather a serialized relation is set as a member for the Substrait Rel -/// (partial Relation) which is later on extracted to update a Substrait Rel -/// which would be included in the fully serialized Acero Exec Plan. -/// The ExecNode or ExecPlan is not used in this context as Declaration is preferred -/// in the Substrait space rather than internal components of Acero execution engine. +/// declaration. The ExecNode or ExecPlan is not used in this context as Declaration +/// is preferred in the Substrait space rather than internal components of +/// Acero execution engine. ARROW_ENGINE_EXPORT Result> ToProto( const compute::Declaration&, ExtensionSet*, const ConversionOptions&); From 3841bfb352be1bc7fb792a6446a355e700c927ba Mon Sep 17 00:00:00 2001 From: Vibhatha Abeykoon Date: Mon, 29 Aug 2022 13:34:05 +0530 Subject: [PATCH 23/30] fix(reviews): updated input handling --- .../arrow/engine/substrait/plan_internal.cc | 8 +- .../arrow/engine/substrait/plan_internal.h | 8 +- .../engine/substrait/relation_internal.cc | 86 ++++++++----------- .../engine/substrait/relation_internal.h | 16 +--- 4 files changed, 46 insertions(+), 72 deletions(-) diff --git a/cpp/src/arrow/engine/substrait/plan_internal.cc b/cpp/src/arrow/engine/substrait/plan_internal.cc index 27c3ae0f5d8..1efd4e1a0a9 100644 --- a/cpp/src/arrow/engine/substrait/plan_internal.cc +++ b/cpp/src/arrow/engine/substrait/plan_internal.cc @@ -18,7 +18,6 @@ #include "arrow/engine/substrait/plan_internal.h" #include "arrow/dataset/plan.h" -#include "arrow/dataset/scanner.h" #include "arrow/engine/substrait/relation_internal.h" #include "arrow/result.h" #include "arrow/util/hashing.h" @@ -141,9 +140,10 @@ Result> PlanToProto( const ConversionOptions& conversion_options) { auto subs_plan = internal::make_unique(); auto plan_rel = internal::make_unique(); - auto rel = internal::make_unique(); - RETURN_NOT_OK(SerializeAndCombineRelations(declr, ext_set, &rel, conversion_options)); - plan_rel->set_allocated_rel(rel.release()); + auto rel_root = internal::make_unique(); + ARROW_ASSIGN_OR_RAISE(auto rel, ToProto(declr, ext_set, conversion_options)); + rel_root->set_allocated_input(rel.release()); + plan_rel->set_allocated_root(rel_root.release()); subs_plan->mutable_relations()->AddAllocated(plan_rel.release()); RETURN_NOT_OK(AddExtensionSetToPlan(*ext_set, subs_plan.get())); return std::move(subs_plan); diff --git a/cpp/src/arrow/engine/substrait/plan_internal.h b/cpp/src/arrow/engine/substrait/plan_internal.h index e8a07ad666f..9ebb629db1d 100644 --- a/cpp/src/arrow/engine/substrait/plan_internal.h +++ b/cpp/src/arrow/engine/substrait/plan_internal.h @@ -53,14 +53,14 @@ Result GetExtensionSetFromPlan( const substrait::Plan& plan, const ExtensionIdRegistry* registry = default_extension_id_registry()); -/// \brief Serializes Declaration and produces a substrait::Plan. +/// \brief Serialize a declaration and into a substrait::Plan. /// /// Note that, this is a part of roundtripping test API and not /// designed to use in production -/// \param[in] declr the sequence of declarations +/// \param[in] declr the sequence of declarations to be serialized /// \param[in, out] ext_set the extension set to be updated -/// \param[in] conversion_options the conversion options useful for the serialization -/// \return serialized Acero plan +/// \param[in] conversion_options options to control serialization behavior +/// \return the serialized plan ARROW_ENGINE_EXPORT Result> PlanToProto( const compute::Declaration& declr, ExtensionSet* ext_set, const ConversionOptions& conversion_options = {}); diff --git a/cpp/src/arrow/engine/substrait/relation_internal.cc b/cpp/src/arrow/engine/substrait/relation_internal.cc index aeaed44aeec..da5b2349ccd 100644 --- a/cpp/src/arrow/engine/substrait/relation_internal.cc +++ b/cpp/src/arrow/engine/substrait/relation_internal.cc @@ -429,29 +429,8 @@ Result FromProto(const substrait::Rel& rel, const ExtensionSet& rel.DebugString()); } -Result> ToProto( - const compute::Declaration& declr, ExtensionSet* ext_set, - const ConversionOptions& conversion_options) { - auto rel = make_unique(); - RETURN_NOT_OK(SerializeAndCombineRelations(declr, ext_set, &rel, conversion_options)); - return std::move(rel); -} - namespace { -Result> GetRelationFromDeclaration( - const compute::Declaration& declaration, ExtensionSet* ext_set, - const ConversionOptions& conversion_options) { - auto declr_input = declaration.inputs[0]; - // Note that the input is expected in declaration. - // ExecNode inputs are not accepted - if (util::get_if(&declr_input)) { - return Status::NotImplemented("Only support Plans written in Declaration format."); - } - return ToProto(util::get(declr_input), ext_set, - conversion_options); -} - Result> ExtractSchemaToBind(const compute::Declaration& declr) { std::shared_ptr bind_schema; if (declr.factory_name == "scan") { @@ -479,7 +458,8 @@ Result> ScanRelationConverter( auto dataset = dynamic_cast(scan_node_options.dataset.get()); if (dataset == nullptr) { - return Status::Invalid("Can only convert file system datasets to a Substrait plan."); + return Status::Invalid( + "Can only convert scan node with FileSystemDataset to a Substrait plan."); } // set schema ARROW_ASSIGN_OR_RAISE(auto named_struct, @@ -492,8 +472,6 @@ Result> ScanRelationConverter( auto read_rel_lfs_ffs = make_unique(); read_rel_lfs_ffs->set_uri_path(UriFromAbsolutePath(file)); // set file format - // arrow and feather are temporarily handled via the Parquet format until - // upgraded to the latest Substrait version. auto format_type_name = dataset->format()->type_name(); if (format_type_name == "parquet") { auto parquet_fmt = @@ -533,9 +511,12 @@ Result> FilterRelationConverter( return Status::Invalid("Filter node doesn't have an input."); } - auto input_rel = GetRelationFromDeclaration(declaration, ext_set, conversion_options); - - filter_rel->set_allocated_input(input_rel->release()); + // handling input + auto declr_input = declaration.inputs[0]; + ARROW_ASSIGN_OR_RAISE( + auto input_rel, + ToProto(util::get(declr_input), ext_set, conversion_options)); + filter_rel->set_allocated_input(input_rel.release()); ARROW_ASSIGN_OR_RAISE(auto subs_expr, ToProto(bound_expression, ext_set, conversion_options)); @@ -549,36 +530,43 @@ Status SerializeAndCombineRelations(const compute::Declaration& declaration, ExtensionSet* ext_set, std::unique_ptr* rel, const ConversionOptions& conversion_options) { - std::vector inputs = declaration.inputs; - for (auto& input : inputs) { - auto input_decl = util::get(input); - RETURN_NOT_OK( - SerializeAndCombineRelations(input_decl, ext_set, rel, conversion_options)); - } const auto& factory_name = declaration.factory_name; ARROW_ASSIGN_OR_RAISE(auto schema, ExtractSchemaToBind(declaration)); // Note that the sink declaration factory doesn't exist for serialization as // Substrait doesn't deal with a sink node definition - // ignore the sink relation in a plan, since sink is implicitly added - if (factory_name != "sink") { - if (factory_name == "scan") { - ARROW_ASSIGN_OR_RAISE( - auto read_rel, - ScanRelationConverter(schema, declaration, ext_set, conversion_options)); - (*rel)->set_allocated_read(read_rel.release()); - } else if (factory_name == "filter") { - ARROW_ASSIGN_OR_RAISE( - auto filter_rel, - FilterRelationConverter(schema, declaration, ext_set, conversion_options)); - (*rel)->set_allocated_filter(filter_rel.release()); - } else { - return Status::NotImplemented("Factory ", factory_name, - " not implemented for roundtripping."); - } + if (factory_name == "scan") { + ARROW_ASSIGN_OR_RAISE( + auto read_rel, + ScanRelationConverter(schema, declaration, ext_set, conversion_options)); + (*rel)->set_allocated_read(read_rel.release()); + } else if (factory_name == "filter") { + ARROW_ASSIGN_OR_RAISE( + auto filter_rel, + FilterRelationConverter(schema, declaration, ext_set, conversion_options)); + (*rel)->set_allocated_filter(filter_rel.release()); + } else if (factory_name == "sink") { + // Generally when a plan is deserialized the declaration will be a sink declaration. + // Since there is no Sink relation in substrait, this function would be recursively + // called on the input of the Sink declaration. + auto sink_input_decl = util::get(declaration.inputs[0]); + RETURN_NOT_OK( + SerializeAndCombineRelations(sink_input_decl, ext_set, rel, conversion_options)); + } else { + return Status::NotImplemented("Factory ", factory_name, + " not implemented for roundtripping."); } + return Status::OK(); } +Result> ToProto( + const compute::Declaration& declr, ExtensionSet* ext_set, + const ConversionOptions& conversion_options) { + auto rel = make_unique(); + RETURN_NOT_OK(SerializeAndCombineRelations(declr, ext_set, &rel, conversion_options)); + return std::move(rel); +} + } // namespace engine } // namespace arrow diff --git a/cpp/src/arrow/engine/substrait/relation_internal.h b/cpp/src/arrow/engine/substrait/relation_internal.h index 225b82ef76d..8c1367b6d85 100644 --- a/cpp/src/arrow/engine/substrait/relation_internal.h +++ b/cpp/src/arrow/engine/substrait/relation_internal.h @@ -40,25 +40,11 @@ struct DeclarationInfo { int num_columns; }; -/// \brief A function to extract Acero Declaration from a Substrait Rel object +/// \brief Convert a Substrait Rel object to an Acero declaration ARROW_ENGINE_EXPORT Result FromProto(const substrait::Rel&, const ExtensionSet&, const ConversionOptions&); -/// \brief Convert a Declaration (and its inputs) to a Substrait Rel -/// -/// A Substrait Rel is passed as a the plan and it is updated with -/// corresponding Declaration passed for serialization. -/// -/// Note that this used to fuse a partially serialized plan with another plan. -/// Partially serialized plan is recursively being used to generate global plan. -/// Since kept as unique_ptr resources are relased efficiently upon releasing for -/// the global plan. -ARROW_ENGINE_EXPORT Status SerializeAndCombineRelations(const compute::Declaration&, - ExtensionSet*, - std::unique_ptr*, - const ConversionOptions&); - /// \brief Convert an Acero Declaration to a Substrait Rel /// /// Note that in order to provide a generic interface for ToProto for From b9d6f073a6cad65b7cd316115e14c24459fd97c2 Mon Sep 17 00:00:00 2001 From: Vibhatha Abeykoon Date: Mon, 29 Aug 2022 13:49:54 +0530 Subject: [PATCH 24/30] fix(native): updated the file_path method to check CI failure --- cpp/src/arrow/engine/substrait/serde_test.cc | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/cpp/src/arrow/engine/substrait/serde_test.cc b/cpp/src/arrow/engine/substrait/serde_test.cc index a615d5a9aff..79985e8bb8d 100644 --- a/cpp/src/arrow/engine/substrait/serde_test.cc +++ b/cpp/src/arrow/engine/substrait/serde_test.cc @@ -2006,7 +2006,7 @@ TEST(Substrait, BasicPlanRoundTrippingEndToEnd) { ASSERT_OK_AND_ASSIGN(auto tempdir, arrow::internal::TemporaryDir::Make("substrait_tempdir")); ASSERT_OK_AND_ASSIGN(auto file_path, tempdir->path().Join(file_name)); - std::string file_path_str = file_path.ToNative(); + std::string file_path_str = file_path.ToString(); // Note: there is an additional forward slash introduced by the tempdir // it must be replaced to properly load into reading files From 0479dacff207757dcbf954ff00127a99ad631eb4 Mon Sep 17 00:00:00 2001 From: Vibhatha Abeykoon Date: Wed, 31 Aug 2022 09:44:12 +0530 Subject: [PATCH 25/30] fix(ipc): adding ipc write replacing parquet --- .../engine/substrait/relation_internal.cc | 2 +- cpp/src/arrow/engine/substrait/serde_test.cc | 40 ++++++++++++------- 2 files changed, 27 insertions(+), 15 deletions(-) diff --git a/cpp/src/arrow/engine/substrait/relation_internal.cc b/cpp/src/arrow/engine/substrait/relation_internal.cc index da5b2349ccd..6b3789da956 100644 --- a/cpp/src/arrow/engine/substrait/relation_internal.cc +++ b/cpp/src/arrow/engine/substrait/relation_internal.cc @@ -477,7 +477,7 @@ Result> ScanRelationConverter( auto parquet_fmt = make_unique(); read_rel_lfs_ffs->set_allocated_parquet(parquet_fmt.release()); - } else if (format_type_name == "arrow") { + } else if (format_type_name == "ipc") { auto arrow_fmt = make_unique(); read_rel_lfs_ffs->set_allocated_arrow(arrow_fmt.release()); diff --git a/cpp/src/arrow/engine/substrait/serde_test.cc b/cpp/src/arrow/engine/substrait/serde_test.cc index 79985e8bb8d..31a92238453 100644 --- a/cpp/src/arrow/engine/substrait/serde_test.cc +++ b/cpp/src/arrow/engine/substrait/serde_test.cc @@ -58,13 +58,24 @@ using internal::checked_cast; using internal::hash_combine; namespace engine { -Status WriteParquetData(const std::string& path, - const std::shared_ptr file_system, - const std::shared_ptr
input) { - EXPECT_OK_AND_ASSIGN(auto buffer_writer, file_system->OpenOutputStream(path)); - PARQUET_THROW_NOT_OK(parquet::arrow::WriteTable(*input, arrow::default_memory_pool(), - buffer_writer, /*chunk_size*/ 1)); - return buffer_writer->Close(); +Status WriteIpcData(const std::string& path, + const std::shared_ptr file_system, + const std::shared_ptr
input) { + EXPECT_OK_AND_ASSIGN(auto mmap, file_system->OpenOutputStream(path)); + ARROW_ASSIGN_OR_RAISE( + auto file_writer, + MakeFileWriter(mmap, input->schema(), ipc::IpcWriteOptions::Defaults())); + TableBatchReader reader(input); + std::shared_ptr batch; + while (true) { + RETURN_NOT_OK(reader.ReadNext(&batch)); + if (batch == nullptr) { + break; + } + RETURN_NOT_OK(file_writer->WriteRecordBatch(*batch)); + } + RETURN_NOT_OK(file_writer->Close()); + return Status::OK(); } Result> GetTableFromPlan( @@ -1880,15 +1891,16 @@ TEST(Substrait, BasicPlanRoundTripping) { [1, 3, 12] ])"}); - auto format = std::make_shared(); + auto format = std::make_shared(); auto filesystem = std::make_shared(); - const std::string file_name = "serde_test.parquet"; + const std::string file_name = "serde_test.arrow"; ASSERT_OK_AND_ASSIGN(auto tempdir, arrow::internal::TemporaryDir::Make("substrait_tempdir")); + std::cout << "file_path_str " << tempdir->path().ToString() << std::endl; ASSERT_OK_AND_ASSIGN(auto file_path, tempdir->path().Join(file_name)); std::string file_path_str = file_path.ToString(); - + std::cout << "file_path_str " << file_path_str << std::endl; // Note: there is an additional forward slash introduced by the tempdir // it must be replaced to properly load into reading files // TODO: (Review: Jira needs to be reported to handle this properly) @@ -1896,7 +1908,7 @@ TEST(Substrait, BasicPlanRoundTripping) { size_t pos = file_path_str.find(toReplace); file_path_str.replace(pos, toReplace.length(), "/T/"); - ARROW_EXPECT_OK(WriteParquetData(file_path_str, filesystem, table)); + ARROW_EXPECT_OK(WriteIpcData(file_path_str, filesystem, table)); std::vector files; const std::vector f_paths = {file_path_str}; @@ -1999,9 +2011,9 @@ TEST(Substrait, BasicPlanRoundTrippingEndToEnd) { [1, 3, 3] ])"}); - auto format = std::make_shared(); + auto format = std::make_shared(); auto filesystem = std::make_shared(); - const std::string file_name = "serde_test.parquet"; + const std::string file_name = "serde_test.arrow"; ASSERT_OK_AND_ASSIGN(auto tempdir, arrow::internal::TemporaryDir::Make("substrait_tempdir")); @@ -2015,7 +2027,7 @@ TEST(Substrait, BasicPlanRoundTrippingEndToEnd) { size_t pos = file_path_str.find(toReplace); file_path_str.replace(pos, toReplace.length(), "/T/"); - ARROW_EXPECT_OK(WriteParquetData(file_path_str, filesystem, table)); + ARROW_EXPECT_OK(WriteIpcData(file_path_str, filesystem, table)); std::vector files; const std::vector f_paths = {file_path_str}; From c1de2b8824702c20d4d8e744ff9f5928b129be18 Mon Sep 17 00:00:00 2001 From: Vibhatha Abeykoon Date: Tue, 6 Sep 2022 11:26:18 +0530 Subject: [PATCH 26/30] fix(file_path_issue): temp commit --- cpp/src/arrow/engine/substrait/serde_test.cc | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/cpp/src/arrow/engine/substrait/serde_test.cc b/cpp/src/arrow/engine/substrait/serde_test.cc index 31a92238453..483730872da 100644 --- a/cpp/src/arrow/engine/substrait/serde_test.cc +++ b/cpp/src/arrow/engine/substrait/serde_test.cc @@ -1904,9 +1904,9 @@ TEST(Substrait, BasicPlanRoundTripping) { // Note: there is an additional forward slash introduced by the tempdir // it must be replaced to properly load into reading files // TODO: (Review: Jira needs to be reported to handle this properly) - std::string toReplace("/T//"); - size_t pos = file_path_str.find(toReplace); - file_path_str.replace(pos, toReplace.length(), "/T/"); + // std::string toReplace("/T//"); + // size_t pos = file_path_str.find(toReplace); + // file_path_str.replace(pos, toReplace.length(), "/T/"); ARROW_EXPECT_OK(WriteIpcData(file_path_str, filesystem, table)); @@ -2023,9 +2023,9 @@ TEST(Substrait, BasicPlanRoundTrippingEndToEnd) { // Note: there is an additional forward slash introduced by the tempdir // it must be replaced to properly load into reading files // TODO: (Review: Jira needs to be reported to handle this properly) - std::string toReplace("/T//"); - size_t pos = file_path_str.find(toReplace); - file_path_str.replace(pos, toReplace.length(), "/T/"); + // std::string toReplace("/T//"); + // size_t pos = file_path_str.find(toReplace); + // file_path_str.replace(pos, toReplace.length(), "/T/"); ARROW_EXPECT_OK(WriteIpcData(file_path_str, filesystem, table)); From 3156fd29027e0678ab40f3baee526bcbcadd7db4 Mon Sep 17 00:00:00 2001 From: Vibhatha Abeykoon Date: Tue, 6 Sep 2022 18:10:41 +0530 Subject: [PATCH 27/30] fix(temp): testing a fix for additional slash in file handling --- cpp/src/arrow/engine/substrait/serde_test.cc | 8 ++++++-- cpp/src/arrow/util/io_util.cc | 2 +- 2 files changed, 7 insertions(+), 3 deletions(-) diff --git a/cpp/src/arrow/engine/substrait/serde_test.cc b/cpp/src/arrow/engine/substrait/serde_test.cc index 483730872da..ca25f743a2f 100644 --- a/cpp/src/arrow/engine/substrait/serde_test.cc +++ b/cpp/src/arrow/engine/substrait/serde_test.cc @@ -1906,7 +1906,9 @@ TEST(Substrait, BasicPlanRoundTripping) { // TODO: (Review: Jira needs to be reported to handle this properly) // std::string toReplace("/T//"); // size_t pos = file_path_str.find(toReplace); - // file_path_str.replace(pos, toReplace.length(), "/T/"); + // if (pos >= 0) { + // file_path_str.replace(pos, toReplace.length(), "/T/"); + // } ARROW_EXPECT_OK(WriteIpcData(file_path_str, filesystem, table)); @@ -2025,7 +2027,9 @@ TEST(Substrait, BasicPlanRoundTrippingEndToEnd) { // TODO: (Review: Jira needs to be reported to handle this properly) // std::string toReplace("/T//"); // size_t pos = file_path_str.find(toReplace); - // file_path_str.replace(pos, toReplace.length(), "/T/"); + // if (pos >= 0) { + // file_path_str.replace(pos, toReplace.length(), "/T/"); + // } ARROW_EXPECT_OK(WriteIpcData(file_path_str, filesystem, table)); diff --git a/cpp/src/arrow/util/io_util.cc b/cpp/src/arrow/util/io_util.cc index 11ae80d03e2..bf0ac26dc3d 100644 --- a/cpp/src/arrow/util/io_util.cc +++ b/cpp/src/arrow/util/io_util.cc @@ -1867,7 +1867,7 @@ Result> TemporaryDir::Make(const std::string& pref [&](const NativePathString& base_dir) -> Result> { Status st; for (int attempt = 0; attempt < 3; ++attempt) { - PlatformFilename fn(base_dir + kNativeSep + base_name + kNativeSep); + PlatformFilename fn(base_dir + base_name + kNativeSep); auto result = CreateDir(fn); if (!result.ok()) { // Probably a permissions error or a non-existing base_dir From 491a985e6c5005eb082f8ec94b8120c4f4eeb760 Mon Sep 17 00:00:00 2001 From: Vibhatha Abeykoon Date: Wed, 7 Sep 2022 09:45:39 +0530 Subject: [PATCH 28/30] fix(path-issue): fixed the path issue and updated the test cases --- .../engine/substrait/relation_internal.cc | 52 +++-- cpp/src/arrow/engine/substrait/serde_test.cc | 209 ++++++++---------- cpp/src/arrow/filesystem/util_internal.cc | 1 + cpp/src/arrow/util/io_util.cc | 11 +- 4 files changed, 135 insertions(+), 138 deletions(-) diff --git a/cpp/src/arrow/engine/substrait/relation_internal.cc b/cpp/src/arrow/engine/substrait/relation_internal.cc index 6b3789da956..941d9030aa4 100644 --- a/cpp/src/arrow/engine/substrait/relation_internal.cc +++ b/cpp/src/arrow/engine/substrait/relation_internal.cc @@ -170,36 +170,40 @@ Result FromProto(const substrait::Rel& rel, const ExtensionSet& } path = path.substr(7); - if (item.path_type_case() == - substrait::ReadRel_LocalFiles_FileOrFiles::kUriPath) { - ARROW_ASSIGN_OR_RAISE(auto file, filesystem->GetFileInfo(path)); - if (file.type() == fs::FileType::File) { - files.push_back(std::move(file)); - } else if (file.type() == fs::FileType::Directory) { + switch (item.path_type_case()) { + case substrait::ReadRel_LocalFiles_FileOrFiles::kUriPath: { + ARROW_ASSIGN_OR_RAISE(auto file, filesystem->GetFileInfo(path)); + if (file.type() == fs::FileType::File) { + files.push_back(std::move(file)); + } else if (file.type() == fs::FileType::Directory) { + fs::FileSelector selector; + selector.base_dir = path; + selector.recursive = true; + ARROW_ASSIGN_OR_RAISE(auto discovered_files, + filesystem->GetFileInfo(selector)); + std::move(files.begin(), files.end(), std::back_inserter(discovered_files)); + } + break; + } + case substrait::ReadRel_LocalFiles_FileOrFiles::kUriFile: { + files.emplace_back(path, fs::FileType::File); + break; + } + case substrait::ReadRel_LocalFiles_FileOrFiles::kUriFolder: { fs::FileSelector selector; selector.base_dir = path; selector.recursive = true; ARROW_ASSIGN_OR_RAISE(auto discovered_files, filesystem->GetFileInfo(selector)); - std::move(files.begin(), files.end(), std::back_inserter(discovered_files)); + std::move(discovered_files.begin(), discovered_files.end(), + std::back_inserter(files)); + break; } - } - if (item.path_type_case() == - substrait::ReadRel_LocalFiles_FileOrFiles::kUriFile) { - files.emplace_back(path, fs::FileType::File); - } else if (item.path_type_case() == - substrait::ReadRel_LocalFiles_FileOrFiles::kUriFolder) { - fs::FileSelector selector; - selector.base_dir = path; - selector.recursive = true; - ARROW_ASSIGN_OR_RAISE(auto discovered_files, filesystem->GetFileInfo(selector)); - std::move(discovered_files.begin(), discovered_files.end(), - std::back_inserter(files)); - } else { - ARROW_ASSIGN_OR_RAISE(auto discovered_files, - fs::internal::GlobFiles(filesystem, path)); - std::move(discovered_files.begin(), discovered_files.end(), - std::back_inserter(files)); + default: + ARROW_ASSIGN_OR_RAISE(auto discovered_files, + fs::internal::GlobFiles(filesystem, path)); + std::move(discovered_files.begin(), discovered_files.end(), + std::back_inserter(files)); } } diff --git a/cpp/src/arrow/engine/substrait/serde_test.cc b/cpp/src/arrow/engine/substrait/serde_test.cc index ca25f743a2f..9b6c3f715f7 100644 --- a/cpp/src/arrow/engine/substrait/serde_test.cc +++ b/cpp/src/arrow/engine/substrait/serde_test.cc @@ -1870,6 +1870,8 @@ TEST(Substrait, BasicPlanRoundTripping) { GTEST_SKIP() << "ARROW-16392: Substrait File URI not supported for Windows"; #endif compute::ExecContext exec_context; + arrow::dataset::internal::Initialize(); + auto dummy_schema = schema( {field("key", int32()), field("shared", int32()), field("distinct", int32())}); @@ -1896,19 +1898,10 @@ TEST(Substrait, BasicPlanRoundTripping) { const std::string file_name = "serde_test.arrow"; ASSERT_OK_AND_ASSIGN(auto tempdir, - arrow::internal::TemporaryDir::Make("substrait_tempdir")); + arrow::internal::TemporaryDir::Make("substrait-tempdir-")); std::cout << "file_path_str " << tempdir->path().ToString() << std::endl; ASSERT_OK_AND_ASSIGN(auto file_path, tempdir->path().Join(file_name)); std::string file_path_str = file_path.ToString(); - std::cout << "file_path_str " << file_path_str << std::endl; - // Note: there is an additional forward slash introduced by the tempdir - // it must be replaced to properly load into reading files - // TODO: (Review: Jira needs to be reported to handle this properly) - // std::string toReplace("/T//"); - // size_t pos = file_path_str.find(toReplace); - // if (pos >= 0) { - // file_path_str.replace(pos, toReplace.length(), "/T/"); - // } ARROW_EXPECT_OK(WriteIpcData(file_path_str, filesystem, table)); @@ -1940,49 +1933,48 @@ TEST(Substrait, BasicPlanRoundTripping) { compute::Declaration({"filter", compute::FilterNodeOptions{filter}, "f"}), compute::Declaration({"sink", compute::SinkNodeOptions{&sink_gen}, "e"})}); - for (auto sp_ext_id_reg : {MakeExtensionIdRegistry()}) { - ExtensionIdRegistry* ext_id_reg = sp_ext_id_reg.get(); - ExtensionSet ext_set(ext_id_reg); + std::shared_ptr sp_ext_id_reg = MakeExtensionIdRegistry(); + ExtensionIdRegistry* ext_id_reg = sp_ext_id_reg.get(); + ExtensionSet ext_set(ext_id_reg); - ASSERT_OK_AND_ASSIGN(auto serialized_plan, SerializePlan(declarations, &ext_set)); + ASSERT_OK_AND_ASSIGN(auto serialized_plan, SerializePlan(declarations, &ext_set)); - ASSERT_OK_AND_ASSIGN( - auto sink_decls, - DeserializePlans( - *serialized_plan, [] { return kNullConsumer; }, ext_id_reg, &ext_set)); - // filter declaration - auto roundtripped_filter = sink_decls[0].inputs[0].get(); - const auto& filter_opts = - checked_cast(*(roundtripped_filter->options)); - auto roundtripped_expr = filter_opts.filter_expression; - - if (auto* call = roundtripped_expr.call()) { - EXPECT_EQ(call->function_name, "equal"); - auto args = call->arguments; - auto left_index = args[0].field_ref()->field_path()->indices()[0]; - EXPECT_EQ(dummy_schema->field_names()[left_index], filter_col_left); - auto right_index = args[1].field_ref()->field_path()->indices()[0]; - EXPECT_EQ(dummy_schema->field_names()[right_index], filter_col_right); - } - // scan declaration - auto roundtripped_scan = roundtripped_filter->inputs[0].get(); - const auto& dataset_opts = - checked_cast(*(roundtripped_scan->options)); - const auto& roundripped_ds = dataset_opts.dataset; - EXPECT_TRUE(roundripped_ds->schema()->Equals(*dummy_schema)); - ASSERT_OK_AND_ASSIGN(auto roundtripped_frgs, roundripped_ds->GetFragments()); - ASSERT_OK_AND_ASSIGN(auto expected_frgs, dataset->GetFragments()); - - auto roundtrip_frg_vec = IteratorToVector(std::move(roundtripped_frgs)); - auto expected_frg_vec = IteratorToVector(std::move(expected_frgs)); - EXPECT_EQ(expected_frg_vec.size(), roundtrip_frg_vec.size()); - int64_t idx = 0; - for (auto fragment : expected_frg_vec) { - const auto* l_frag = checked_cast(fragment.get()); - const auto* r_frag = - checked_cast(roundtrip_frg_vec[idx++].get()); - EXPECT_TRUE(l_frag->Equals(*r_frag)); - } + ASSERT_OK_AND_ASSIGN( + auto sink_decls, + DeserializePlans( + *serialized_plan, [] { return kNullConsumer; }, ext_id_reg, &ext_set)); + // filter declaration + auto roundtripped_filter = sink_decls[0].inputs[0].get(); + const auto& filter_opts = + checked_cast(*(roundtripped_filter->options)); + auto roundtripped_expr = filter_opts.filter_expression; + + if (auto* call = roundtripped_expr.call()) { + EXPECT_EQ(call->function_name, "equal"); + auto args = call->arguments; + auto left_index = args[0].field_ref()->field_path()->indices()[0]; + EXPECT_EQ(dummy_schema->field_names()[left_index], filter_col_left); + auto right_index = args[1].field_ref()->field_path()->indices()[0]; + EXPECT_EQ(dummy_schema->field_names()[right_index], filter_col_right); + } + // scan declaration + auto roundtripped_scan = roundtripped_filter->inputs[0].get(); + const auto& dataset_opts = + checked_cast(*(roundtripped_scan->options)); + const auto& roundripped_ds = dataset_opts.dataset; + EXPECT_TRUE(roundripped_ds->schema()->Equals(*dummy_schema)); + ASSERT_OK_AND_ASSIGN(auto roundtripped_frgs, roundripped_ds->GetFragments()); + ASSERT_OK_AND_ASSIGN(auto expected_frgs, dataset->GetFragments()); + + auto roundtrip_frg_vec = IteratorToVector(std::move(roundtripped_frgs)); + auto expected_frg_vec = IteratorToVector(std::move(expected_frgs)); + EXPECT_EQ(expected_frg_vec.size(), roundtrip_frg_vec.size()); + int64_t idx = 0; + for (auto fragment : expected_frg_vec) { + const auto* l_frag = checked_cast(fragment.get()); + const auto* r_frag = + checked_cast(roundtrip_frg_vec[idx++].get()); + EXPECT_TRUE(l_frag->Equals(*r_frag)); } } @@ -1991,7 +1983,8 @@ TEST(Substrait, BasicPlanRoundTrippingEndToEnd) { GTEST_SKIP() << "ARROW-16392: Substrait File URI not supported for Windows"; #endif compute::ExecContext exec_context; - ExtensionSet ext_set; + arrow::dataset::internal::Initialize(); + auto dummy_schema = schema( {field("key", int32()), field("shared", int32()), field("distinct", int32())}); @@ -2018,19 +2011,10 @@ TEST(Substrait, BasicPlanRoundTrippingEndToEnd) { const std::string file_name = "serde_test.arrow"; ASSERT_OK_AND_ASSIGN(auto tempdir, - arrow::internal::TemporaryDir::Make("substrait_tempdir")); + arrow::internal::TemporaryDir::Make("substrait-tempdir-")); ASSERT_OK_AND_ASSIGN(auto file_path, tempdir->path().Join(file_name)); std::string file_path_str = file_path.ToString(); - // Note: there is an additional forward slash introduced by the tempdir - // it must be replaced to properly load into reading files - // TODO: (Review: Jira needs to be reported to handle this properly) - // std::string toReplace("/T//"); - // size_t pos = file_path_str.find(toReplace); - // if (pos >= 0) { - // file_path_str.replace(pos, toReplace.length(), "/T/"); - // } - ARROW_EXPECT_OK(WriteIpcData(file_path_str, filesystem, table)); std::vector files; @@ -2064,60 +2048,59 @@ TEST(Substrait, BasicPlanRoundTrippingEndToEnd) { ASSERT_OK_AND_ASSIGN(auto expected_table, GetTableFromPlan(declarations, sink_gen, exec_context, dummy_schema)); - for (auto sp_ext_id_reg : {MakeExtensionIdRegistry()}) { - ExtensionIdRegistry* ext_id_reg = sp_ext_id_reg.get(); - ExtensionSet ext_set(ext_id_reg); + std::shared_ptr sp_ext_id_reg = MakeExtensionIdRegistry(); + ExtensionIdRegistry* ext_id_reg = sp_ext_id_reg.get(); + ExtensionSet ext_set(ext_id_reg); - ASSERT_OK_AND_ASSIGN(auto serialized_plan, SerializePlan(declarations, &ext_set)); + ASSERT_OK_AND_ASSIGN(auto serialized_plan, SerializePlan(declarations, &ext_set)); - ASSERT_OK_AND_ASSIGN( - auto sink_decls, - DeserializePlans( - *serialized_plan, [] { return kNullConsumer; }, ext_id_reg, &ext_set)); - // filter declaration - auto roundtripped_filter = sink_decls[0].inputs[0].get(); - const auto& filter_opts = - checked_cast(*(roundtripped_filter->options)); - auto roundtripped_expr = filter_opts.filter_expression; - - if (auto* call = roundtripped_expr.call()) { - EXPECT_EQ(call->function_name, "equal"); - auto args = call->arguments; - auto left_index = args[0].field_ref()->field_path()->indices()[0]; - EXPECT_EQ(dummy_schema->field_names()[left_index], filter_col_left); - auto right_index = args[1].field_ref()->field_path()->indices()[0]; - EXPECT_EQ(dummy_schema->field_names()[right_index], filter_col_right); - } - // scan declaration - auto roundtripped_scan = roundtripped_filter->inputs[0].get(); - const auto& dataset_opts = - checked_cast(*(roundtripped_scan->options)); - const auto& roundripped_ds = dataset_opts.dataset; - EXPECT_TRUE(roundripped_ds->schema()->Equals(*dummy_schema)); - ASSERT_OK_AND_ASSIGN(auto roundtripped_frgs, roundripped_ds->GetFragments()); - ASSERT_OK_AND_ASSIGN(auto expected_frgs, dataset->GetFragments()); - - auto roundtrip_frg_vec = IteratorToVector(std::move(roundtripped_frgs)); - auto expected_frg_vec = IteratorToVector(std::move(expected_frgs)); - EXPECT_EQ(expected_frg_vec.size(), roundtrip_frg_vec.size()); - int64_t idx = 0; - for (auto fragment : expected_frg_vec) { - const auto* l_frag = checked_cast(fragment.get()); - const auto* r_frag = - checked_cast(roundtrip_frg_vec[idx++].get()); - EXPECT_TRUE(l_frag->Equals(*r_frag)); - } - arrow::AsyncGenerator> rnd_trp_sink_gen; - auto rnd_trp_sink_node_options = compute::SinkNodeOptions{&rnd_trp_sink_gen}; - auto rnd_trp_sink_declaration = - compute::Declaration({"sink", rnd_trp_sink_node_options, "e"}); - auto rnd_trp_declarations = - compute::Declaration::Sequence({*roundtripped_filter, rnd_trp_sink_declaration}); - ASSERT_OK_AND_ASSIGN(auto rnd_trp_table, - GetTableFromPlan(rnd_trp_declarations, rnd_trp_sink_gen, - exec_context, dummy_schema)); - EXPECT_TRUE(expected_table->Equals(*rnd_trp_table)); + ASSERT_OK_AND_ASSIGN( + auto sink_decls, + DeserializePlans( + *serialized_plan, [] { return kNullConsumer; }, ext_id_reg, &ext_set)); + // filter declaration + auto roundtripped_filter = sink_decls[0].inputs[0].get(); + const auto& filter_opts = + checked_cast(*(roundtripped_filter->options)); + auto roundtripped_expr = filter_opts.filter_expression; + + if (auto* call = roundtripped_expr.call()) { + EXPECT_EQ(call->function_name, "equal"); + auto args = call->arguments; + auto left_index = args[0].field_ref()->field_path()->indices()[0]; + EXPECT_EQ(dummy_schema->field_names()[left_index], filter_col_left); + auto right_index = args[1].field_ref()->field_path()->indices()[0]; + EXPECT_EQ(dummy_schema->field_names()[right_index], filter_col_right); + } + // scan declaration + auto roundtripped_scan = roundtripped_filter->inputs[0].get(); + const auto& dataset_opts = + checked_cast(*(roundtripped_scan->options)); + const auto& roundripped_ds = dataset_opts.dataset; + EXPECT_TRUE(roundripped_ds->schema()->Equals(*dummy_schema)); + ASSERT_OK_AND_ASSIGN(auto roundtripped_frgs, roundripped_ds->GetFragments()); + ASSERT_OK_AND_ASSIGN(auto expected_frgs, dataset->GetFragments()); + + auto roundtrip_frg_vec = IteratorToVector(std::move(roundtripped_frgs)); + auto expected_frg_vec = IteratorToVector(std::move(expected_frgs)); + EXPECT_EQ(expected_frg_vec.size(), roundtrip_frg_vec.size()); + int64_t idx = 0; + for (auto fragment : expected_frg_vec) { + const auto* l_frag = checked_cast(fragment.get()); + const auto* r_frag = + checked_cast(roundtrip_frg_vec[idx++].get()); + EXPECT_TRUE(l_frag->Equals(*r_frag)); } + arrow::AsyncGenerator> rnd_trp_sink_gen; + auto rnd_trp_sink_node_options = compute::SinkNodeOptions{&rnd_trp_sink_gen}; + auto rnd_trp_sink_declaration = + compute::Declaration({"sink", rnd_trp_sink_node_options, "e"}); + auto rnd_trp_declarations = + compute::Declaration::Sequence({*roundtripped_filter, rnd_trp_sink_declaration}); + ASSERT_OK_AND_ASSIGN(auto rnd_trp_table, + GetTableFromPlan(rnd_trp_declarations, rnd_trp_sink_gen, + exec_context, dummy_schema)); + EXPECT_TRUE(expected_table->Equals(*rnd_trp_table)); } } // namespace engine diff --git a/cpp/src/arrow/filesystem/util_internal.cc b/cpp/src/arrow/filesystem/util_internal.cc index 0d2ad709026..e6f301bdbf1 100644 --- a/cpp/src/arrow/filesystem/util_internal.cc +++ b/cpp/src/arrow/filesystem/util_internal.cc @@ -78,6 +78,7 @@ Status InvalidDeleteDirContents(util::string_view path) { Result GlobFiles(const std::shared_ptr& filesystem, const std::string& glob) { + // TODO: ARROW-17640 // The candidate entries at the current depth level. // We start with the filesystem root. FileInfoVector results{FileInfo("", FileType::Directory)}; diff --git a/cpp/src/arrow/util/io_util.cc b/cpp/src/arrow/util/io_util.cc index bf0ac26dc3d..38bd4457cf2 100644 --- a/cpp/src/arrow/util/io_util.cc +++ b/cpp/src/arrow/util/io_util.cc @@ -1867,7 +1867,16 @@ Result> TemporaryDir::Make(const std::string& pref [&](const NativePathString& base_dir) -> Result> { Status st; for (int attempt = 0; attempt < 3; ++attempt) { - PlatformFilename fn(base_dir + base_name + kNativeSep); + // Note: certain temporary directories of MacOS contains a trailing slash + // Handling the base_dir with trailing slash + PlatformFilename fn; + if (base_dir.back() == kNativeSep) { + PlatformFilename fn_base_dir(base_dir); + PlatformFilename fn_base_name(base_name); + ARROW_ASSIGN_OR_RAISE(fn, fn_base_dir.Join(base_name + kNativeSep)); + } else { + fn = PlatformFilename(base_dir + kNativeSep + base_name + kNativeSep); + } auto result = CreateDir(fn); if (!result.ok()) { // Probably a permissions error or a non-existing base_dir From 33d77531981d87af592cbf7de8e515799ce83de1 Mon Sep 17 00:00:00 2001 From: Vibhatha Abeykoon Date: Wed, 7 Sep 2022 16:47:38 +0530 Subject: [PATCH 29/30] fix(path): windows issue fixing --- cpp/src/arrow/util/io_util.cc | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/cpp/src/arrow/util/io_util.cc b/cpp/src/arrow/util/io_util.cc index 38bd4457cf2..ba8613b6b06 100644 --- a/cpp/src/arrow/util/io_util.cc +++ b/cpp/src/arrow/util/io_util.cc @@ -1872,8 +1872,8 @@ Result> TemporaryDir::Make(const std::string& pref PlatformFilename fn; if (base_dir.back() == kNativeSep) { PlatformFilename fn_base_dir(base_dir); - PlatformFilename fn_base_name(base_name); - ARROW_ASSIGN_OR_RAISE(fn, fn_base_dir.Join(base_name + kNativeSep)); + PlatformFilename fn_base_name(base_name + kNativeSep); + fn = fn_base_dir.Join(fn_base_name); } else { fn = PlatformFilename(base_dir + kNativeSep + base_name + kNativeSep); } From 616d6e56e889a021d6dc89f4b3cb1abb2036aab1 Mon Sep 17 00:00:00 2001 From: Vibhatha Abeykoon Date: Thu, 8 Sep 2022 13:31:09 +0530 Subject: [PATCH 30/30] fix(reviews): address reviews --- .../arrow/engine/substrait/plan_internal.h | 6 ++--- .../engine/substrait/relation_internal.cc | 22 ++++++++++--------- .../engine/substrait/relation_internal.h | 4 ++-- cpp/src/arrow/util/io_util.cc | 13 +++-------- cpp/src/arrow/util/uri.h | 3 +-- 5 files changed, 21 insertions(+), 27 deletions(-) diff --git a/cpp/src/arrow/engine/substrait/plan_internal.h b/cpp/src/arrow/engine/substrait/plan_internal.h index 9ebb629db1d..e1ced549ce1 100644 --- a/cpp/src/arrow/engine/substrait/plan_internal.h +++ b/cpp/src/arrow/engine/substrait/plan_internal.h @@ -53,10 +53,10 @@ Result GetExtensionSetFromPlan( const substrait::Plan& plan, const ExtensionIdRegistry* registry = default_extension_id_registry()); -/// \brief Serialize a declaration and into a substrait::Plan. +/// \brief Serialize a declaration into a substrait::Plan. /// -/// Note that, this is a part of roundtripping test API and not -/// designed to use in production +/// Note that, this is a part of a roundtripping test API and not +/// designed for use in production /// \param[in] declr the sequence of declarations to be serialized /// \param[in, out] ext_set the extension set to be updated /// \param[in] conversion_options options to control serialization behavior diff --git a/cpp/src/arrow/engine/substrait/relation_internal.cc b/cpp/src/arrow/engine/substrait/relation_internal.cc index 941d9030aa4..c5d212c8c2f 100644 --- a/cpp/src/arrow/engine/substrait/relation_internal.cc +++ b/cpp/src/arrow/engine/substrait/relation_internal.cc @@ -199,11 +199,16 @@ Result FromProto(const substrait::Rel& rel, const ExtensionSet& std::back_inserter(files)); break; } - default: + case substrait::ReadRel_LocalFiles_FileOrFiles::kUriPathGlob: { ARROW_ASSIGN_OR_RAISE(auto discovered_files, fs::internal::GlobFiles(filesystem, path)); std::move(discovered_files.begin(), discovered_files.end(), std::back_inserter(files)); + break; + } + default: { + return Status::Invalid("Unrecognized file type in LocalFiles"); + } } } @@ -478,17 +483,14 @@ Result> ScanRelationConverter( // set file format auto format_type_name = dataset->format()->type_name(); if (format_type_name == "parquet") { - auto parquet_fmt = - make_unique(); - read_rel_lfs_ffs->set_allocated_parquet(parquet_fmt.release()); + read_rel_lfs_ffs->set_allocated_parquet( + new substrait::ReadRel::LocalFiles::FileOrFiles::ParquetReadOptions()); } else if (format_type_name == "ipc") { - auto arrow_fmt = - make_unique(); - read_rel_lfs_ffs->set_allocated_arrow(arrow_fmt.release()); + read_rel_lfs_ffs->set_allocated_arrow( + new substrait::ReadRel::LocalFiles::FileOrFiles::ArrowReadOptions()); } else if (format_type_name == "orc") { - auto orc_fmt = - make_unique(); - read_rel_lfs_ffs->set_allocated_orc(orc_fmt.release()); + read_rel_lfs_ffs->set_allocated_orc( + new substrait::ReadRel::LocalFiles::FileOrFiles::OrcReadOptions()); } else { return Status::NotImplemented("Unsupported file type: ", format_type_name); } diff --git a/cpp/src/arrow/engine/substrait/relation_internal.h b/cpp/src/arrow/engine/substrait/relation_internal.h index 8c1367b6d85..778d1e5bc01 100644 --- a/cpp/src/arrow/engine/substrait/relation_internal.h +++ b/cpp/src/arrow/engine/substrait/relation_internal.h @@ -47,8 +47,8 @@ Result FromProto(const substrait::Rel&, const ExtensionSet&, /// \brief Convert an Acero Declaration to a Substrait Rel /// -/// Note that in order to provide a generic interface for ToProto for -/// declaration. The ExecNode or ExecPlan is not used in this context as Declaration +/// Note that, in order to provide a generic interface for ToProto, +/// the ExecNode or ExecPlan are not used in this context as Declaration /// is preferred in the Substrait space rather than internal components of /// Acero execution engine. ARROW_ENGINE_EXPORT Result> ToProto( diff --git a/cpp/src/arrow/util/io_util.cc b/cpp/src/arrow/util/io_util.cc index ba8613b6b06..a62040f3a70 100644 --- a/cpp/src/arrow/util/io_util.cc +++ b/cpp/src/arrow/util/io_util.cc @@ -1867,16 +1867,9 @@ Result> TemporaryDir::Make(const std::string& pref [&](const NativePathString& base_dir) -> Result> { Status st; for (int attempt = 0; attempt < 3; ++attempt) { - // Note: certain temporary directories of MacOS contains a trailing slash - // Handling the base_dir with trailing slash - PlatformFilename fn; - if (base_dir.back() == kNativeSep) { - PlatformFilename fn_base_dir(base_dir); - PlatformFilename fn_base_name(base_name + kNativeSep); - fn = fn_base_dir.Join(fn_base_name); - } else { - fn = PlatformFilename(base_dir + kNativeSep + base_name + kNativeSep); - } + PlatformFilename fn_base_dir(base_dir); + PlatformFilename fn_base_name(base_name + kNativeSep); + PlatformFilename fn = fn_base_dir.Join(fn_base_name); auto result = CreateDir(fn); if (!result.ok()) { // Probably a permissions error or a non-existing base_dir diff --git a/cpp/src/arrow/util/uri.h b/cpp/src/arrow/util/uri.h index 7ea82d33c5d..50d9eccf82f 100644 --- a/cpp/src/arrow/util/uri.h +++ b/cpp/src/arrow/util/uri.h @@ -104,8 +104,7 @@ std::string UriEncodeHost(const std::string& host); ARROW_EXPORT bool IsValidUriScheme(const arrow::util::string_view s); -/// Create a file uri from a given URI -/// file:/// +/// Create a file uri from a given absolute path ARROW_EXPORT std::string UriFromAbsolutePath(const std::string& path);