diff --git a/cpp/src/arrow/engine/substrait/relation_internal.cc b/cpp/src/arrow/engine/substrait/relation_internal.cc index 89ab7ca4dc3..1ea03b08b5a 100644 --- a/cpp/src/arrow/engine/substrait/relation_internal.cc +++ b/cpp/src/arrow/engine/substrait/relation_internal.cc @@ -27,10 +27,17 @@ #include "arrow/engine/substrait/type_internal.h" #include "arrow/filesystem/localfs.h" #include "arrow/filesystem/util_internal.h" +#include "arrow/util/checked_cast.h" +#include "arrow/util/make_unique.h" namespace arrow { namespace engine { +namespace internal { +using ::arrow::internal::checked_cast; +using ::arrow::internal::make_unique; +} // namespace internal + template Status CheckRelCommon(const RelMessage& rel) { if (rel.has_common()) { @@ -316,5 +323,133 @@ Result FromProto(const substrait::Rel& rel, rel.DebugString()); } +namespace { +// TODO: add other types +enum ArrowRelationType : uint8_t { + SCAN, + FILTER, + PROJECT, + JOIN, + AGGREGATE, +}; + +const std::map enum_map{ + {"scan", ArrowRelationType::SCAN}, {"filter", ArrowRelationType::FILTER}, + {"project", ArrowRelationType::PROJECT}, {"join", ArrowRelationType::JOIN}, + {"aggregate", ArrowRelationType::AGGREGATE}, +}; + +struct ExtractRelation { + explicit ExtractRelation(substrait::Rel* rel, ExtensionSet* ext_set) + : rel_(rel), ext_set_(ext_set) {} + + Status AddRelation(const compute::Declaration& declaration) { + const std::string& rel_name = declaration.factory_name; + switch (enum_map.find(rel_name)->second) { + case ArrowRelationType::SCAN: + return AddReadRelation(declaration); + case ArrowRelationType::FILTER: + return AddFilterRelation(declaration); + case ArrowRelationType::PROJECT: + return Status::NotImplemented("Project operator not supported."); + case ArrowRelationType::JOIN: + return Status::NotImplemented("Join operator not supported."); + case ArrowRelationType::AGGREGATE: + return Status::NotImplemented("Aggregate operator not supported."); + default: + return Status::Invalid("Unsupported exec node factory name :", rel_name); + } + } + + Status AddReadRelation(const compute::Declaration& declaration) { + auto read_rel = internal::make_unique(); + const auto& scan_node_options = + internal::checked_cast(*declaration.options); + + auto dataset = + dynamic_cast(scan_node_options.dataset.get()); + if (dataset == nullptr) { + return Status::Invalid( + "Can only convert file system datasets to a Substrait plan."); + } + // set schema + ARROW_ASSIGN_OR_RAISE(auto named_struct, ToProto(*dataset->schema(), ext_set_)); + read_rel->set_allocated_base_schema(named_struct.release()); + + // set local files + auto read_rel_lfs = internal::make_unique(); + for (const auto& file : dataset->files()) { + auto read_rel_lfs_ffs = + internal::make_unique(); + read_rel_lfs_ffs->set_uri_path("file://" + file); + + // set file format + // arrow and feather are temporarily handled via the Parquet format until + // upgraded to the latest Substrait version. + auto format_type_name = dataset->format()->type_name(); + if (format_type_name == "parquet" || format_type_name == "arrow" || + format_type_name == "feather") { + read_rel_lfs_ffs->set_format( + substrait::ReadRel::LocalFiles::FileOrFiles::FILE_FORMAT_PARQUET); + } else { + return Status::Invalid("Unsupported file type : ", format_type_name); + } + read_rel_lfs->mutable_items()->AddAllocated(read_rel_lfs_ffs.release()); + } + *read_rel->mutable_local_files() = *read_rel_lfs.get(); + + rel_->set_allocated_read(read_rel.release()); + return Status::OK(); + } + + Status AddFilterRelation(const compute::Declaration& declaration) { + auto filter_rel = internal::make_unique(); + const auto& filter_node_options = + internal::checked_cast(*declaration.options); + + if (declaration.inputs.size() == 0) { + return Status::Invalid("Filter node doesn't have an input."); + } + + auto input_rel = GetRelationFromDeclaration(declaration, ext_set_); + + filter_rel->set_allocated_input(input_rel->release()); + + ARROW_ASSIGN_OR_RAISE(auto subs_expr, + ToProto(filter_node_options.filter_expression, ext_set_)); + *filter_rel->mutable_condition() = *subs_expr.get(); + + rel_->set_allocated_filter(filter_rel.release()); + + return Status::OK(); + } + + Status operator()(const compute::Declaration& declaration) { + return AddRelation(declaration); + } + + private: + Result> GetRelationFromDeclaration( + const compute::Declaration declaration, ExtensionSet* ext_set) { + auto declr_input = declaration.inputs[0]; + // TODO: figure out a better way + if (util::get_if(&declr_input)) { + return Status::NotImplemented("Only support Plans written in Declaration format."); + } + return ToProto(util::get(declr_input), ext_set); + } + substrait::Rel* rel_; + ExtensionSet* ext_set_; +}; + +} // namespace + +Result> ToProto(const compute::Declaration& declaration, + ExtensionSet* ext_set) { + auto out = internal::make_unique(); + RETURN_NOT_OK(ExtractRelation(out.get(), ext_set)(declaration)); + return std::move(out); +} + } // namespace engine } // namespace arrow diff --git a/cpp/src/arrow/engine/substrait/relation_internal.h b/cpp/src/arrow/engine/substrait/relation_internal.h index 77d47c586b4..c40ecd87390 100644 --- a/cpp/src/arrow/engine/substrait/relation_internal.h +++ b/cpp/src/arrow/engine/substrait/relation_internal.h @@ -33,5 +33,9 @@ namespace engine { ARROW_ENGINE_EXPORT Result FromProto(const substrait::Rel&, const ExtensionSet&); +ARROW_ENGINE_EXPORT +Result> ToProto(const compute::Declaration&, + ExtensionSet*); + } // namespace engine } // namespace arrow diff --git a/cpp/src/arrow/engine/substrait/serde.cc b/cpp/src/arrow/engine/substrait/serde.cc index 2012f1fc26a..89792bb459d 100644 --- a/cpp/src/arrow/engine/substrait/serde.cc +++ b/cpp/src/arrow/engine/substrait/serde.cc @@ -52,6 +52,13 @@ Result ParseFromBuffer(const Buffer& buf) { return message; } +Result> SerializeRelation(const compute::Declaration& declaration, + ExtensionSet* ext_set) { + ARROW_ASSIGN_OR_RAISE(auto relation, ToProto(declaration, ext_set)); + std::string serialized = relation->SerializeAsString(); + return Buffer::FromString(std::move(serialized)); +} + Result DeserializeRelation(const Buffer& buf, const ExtensionSet& ext_set) { ARROW_ASSIGN_OR_RAISE(auto rel, ParseFromBuffer(buf)); diff --git a/cpp/src/arrow/engine/substrait/serde.h b/cpp/src/arrow/engine/substrait/serde.h index 4af9f89ac87..c5d0eaa51be 100644 --- a/cpp/src/arrow/engine/substrait/serde.h +++ b/cpp/src/arrow/engine/substrait/serde.h @@ -52,9 +52,9 @@ ARROW_ENGINE_EXPORT Result> DeserializePlans( const Buffer& buf, const ConsumerFactory& consumer_factory, ExtensionSet* ext_set_out = NULLPTR); -Result DeserializePlan(const Buffer& buf, - const ConsumerFactory& consumer_factory, - ExtensionSet* ext_set_out = NULLPTR); +ARROW_ENGINE_EXPORT Result DeserializePlan( + const Buffer& buf, const ConsumerFactory& consumer_factory, + ExtensionSet* ext_set_out = NULLPTR); /// \brief Deserializes a Substrait Type message to the corresponding Arrow type /// @@ -122,6 +122,16 @@ ARROW_ENGINE_EXPORT Result> SerializeExpression(const compute::Expression& expr, ExtensionSet* ext_set); +/// \brief Serializes an Arrow compute Declaration to a Substrait Relation message +/// +/// \param[in] declaration the Arrow compute declaration to serialize +/// \param[in,out] ext_set the extension mapping to use; may be updated to add +/// mappings for the components in the used declaration +/// \return a buffer containing the protobuf serialization of the corresponding Substrait +/// Relation message +ARROW_ENGINE_EXPORT Result> SerializeRelation( + const compute::Declaration& declaration, ExtensionSet* ext_set); + /// \brief Deserializes a Substrait Rel (relation) message to an ExecNode declaration /// /// \param[in] buf a buffer containing the protobuf serialization of a Substrait diff --git a/cpp/src/arrow/engine/substrait/serde_test.cc b/cpp/src/arrow/engine/substrait/serde_test.cc index 9a0e93fc7a0..2d5da991567 100644 --- a/cpp/src/arrow/engine/substrait/serde_test.cc +++ b/cpp/src/arrow/engine/substrait/serde_test.cc @@ -25,8 +25,10 @@ #include "arrow/compute/exec/expression_internal.h" #include "arrow/dataset/file_base.h" +#include "arrow/dataset/file_parquet.h" #include "arrow/dataset/scanner.h" #include "arrow/engine/substrait/extension_types.h" +#include "arrow/filesystem/localfs.h" #include "arrow/testing/gtest_util.h" #include "arrow/testing/matchers.h" #include "arrow/util/key_value_metadata.h" @@ -1173,5 +1175,293 @@ TEST(Substrait, JoinPlanInvalidKeys) { &ext_set)); } +TEST(Substrait, SerializeRelation) { +#ifdef _WIN32 + GTEST_SKIP() << "ARROW-16392: Substrait File URI not supported for Windows"; +#else + ExtensionSet ext_set; + auto dummy_schema = schema({field("foo", binary())}); + // creating a dummy dataset using a dummy table + auto format = std::make_shared(); + auto filesystem = std::make_shared(); + + ASSERT_OK_AND_ASSIGN(std::string dir_string, + arrow::internal::GetEnvVar("PARQUET_TEST_DATA")); + auto file_name = + arrow::internal::PlatformFilename::FromString(dir_string)->Join("binary.parquet"); + + std::vector files; + const std::vector f_paths = {file_name->ToString()}; + + for (const auto& f_path : f_paths) { + ASSERT_OK_AND_ASSIGN(auto f_file, filesystem->GetFileInfo(f_path)); + files.push_back(std::move(f_file)); + } + + ASSERT_OK_AND_ASSIGN(auto ds_factory, dataset::FileSystemDatasetFactory::Make( + std::move(filesystem), std::move(files), + std::move(format), {})); + ASSERT_OK_AND_ASSIGN(auto dataset, ds_factory->Finish(dummy_schema)); + + auto options = std::make_shared(); + options->projection = compute::project({}, {}); + auto scan_node_options = dataset::ScanNodeOptions{dataset, options}; + + auto scan_declaration = compute::Declaration({"scan", scan_node_options}); + + ASSERT_OK_AND_ASSIGN(auto serialized_rel, + SerializeRelation(scan_declaration, &ext_set)); + ASSERT_OK_AND_ASSIGN(auto deserialized_decl, + DeserializeRelation(*serialized_rel, ext_set)); + + auto dataset_comparator = [](std::shared_ptr ds_lhs, + std::shared_ptr ds_rhs) -> bool { + const auto& fsd_lhs = checked_cast(*ds_lhs); + const auto& fsd_rhs = checked_cast(*ds_lhs); + const auto& files_lhs = fsd_lhs.files(); + const auto& files_rhs = fsd_rhs.files(); + + if (files_lhs.size() != files_rhs.size()) { + return false; + } + uint64_t fidx = 0; + for (const auto& l_file : files_lhs) { + if (l_file != files_rhs[fidx]) { + return false; + } + fidx++; + } + bool cmp_file_format = fsd_lhs.format()->Equals(*fsd_lhs.format()); + bool cmp_file_system = fsd_lhs.filesystem()->Equals(fsd_rhs.filesystem()); + return cmp_file_format && cmp_file_system; + }; + + auto scan_option_comparator = [dataset_comparator]( + const dataset::ScanNodeOptions& lhs, + const dataset::ScanNodeOptions& rhs) -> bool { + bool cmp_rso = lhs.require_sequenced_output == rhs.require_sequenced_output; + bool cmp_ds = dataset_comparator(lhs.dataset, rhs.dataset); + return cmp_rso && cmp_ds; + }; + + EXPECT_EQ(deserialized_decl.factory_name, scan_declaration.factory_name); + const auto& lhs = + checked_cast(*deserialized_decl.options); + const auto& rhs = + checked_cast(*scan_declaration.options); + ASSERT_TRUE(scan_option_comparator(lhs, rhs)); +#endif +} + +TEST(Substrait, SerializeRelationEndToEnd) { +#ifdef _WIN32 + GTEST_SKIP() << "ARROW-16392: Substrait File URI not supported for Windows"; +#else + ExtensionSet ext_set; + compute::ExecContext exec_context; + + ASSERT_OK_AND_ASSIGN(std::string dir_string, + arrow::internal::GetEnvVar("PARQUET_TEST_DATA")); + auto file_name = + arrow::internal::PlatformFilename::FromString(dir_string)->Join("binary.parquet"); + + auto dummy_schema = schema({field("foo", binary())}); + auto format = std::make_shared(); + auto filesystem = std::make_shared(); + + std::vector files; + const std::string f_path = file_name->ToString(); + ASSERT_OK_AND_ASSIGN(auto f_file, filesystem->GetFileInfo(f_path)); + files.push_back(std::move(f_file)); + + ASSERT_OK_AND_ASSIGN(auto ds_factory, dataset::FileSystemDatasetFactory::Make( + std::move(filesystem), std::move(files), + std::move(format), {})); + + ASSERT_OK_AND_ASSIGN(auto dataset, ds_factory->Finish(dummy_schema)); + + auto options = std::make_shared(); + options->projection = compute::project({}, {}); + + auto scan_node_options = dataset::ScanNodeOptions{dataset, options}; + + arrow::AsyncGenerator > sink_gen; + + auto sink_node_options = compute::SinkNodeOptions{&sink_gen}; + + auto scan_declaration = compute::Declaration({"scan", scan_node_options}); + auto sink_declaration = compute::Declaration({"sink", sink_node_options}); + + auto declarations = + compute::Declaration::Sequence({scan_declaration, sink_declaration}); + + ASSERT_OK_AND_ASSIGN(auto plan, compute::ExecPlan::Make(&exec_context)); + ASSERT_OK_AND_ASSIGN(auto decl, declarations.AddToPlan(plan.get())); + + ASSERT_OK(decl->Validate()); + + std::shared_ptr sink_reader = compute::MakeGeneratorReader( + dummy_schema, std::move(sink_gen), exec_context.memory_pool()); + + ASSERT_OK(plan->Validate()); + ASSERT_OK(plan->StartProducing()); + + std::shared_ptr response_table; + + ASSERT_OK_AND_ASSIGN(response_table, + arrow::Table::FromRecordBatchReader(sink_reader.get())); + + ASSERT_OK_AND_ASSIGN(auto serialized_rel, + SerializeRelation(scan_declaration, &ext_set)); + ASSERT_OK_AND_ASSIGN(auto deserialized_decl, + DeserializeRelation(*serialized_rel, ext_set)); + + arrow::AsyncGenerator > des_sink_gen; + auto des_sink_node_options = compute::SinkNodeOptions{&des_sink_gen}; + + auto des_sink_declaration = compute::Declaration({"sink", des_sink_node_options}); + + auto t_decls = + compute::Declaration::Sequence({deserialized_decl, des_sink_declaration}); + + ASSERT_OK_AND_ASSIGN(auto t_plan, compute::ExecPlan::Make()); + ASSERT_OK_AND_ASSIGN(auto t_decl, t_decls.AddToPlan(t_plan.get())); + + ASSERT_OK(t_decl->Validate()); + + std::shared_ptr des_sink_reader = + compute::MakeGeneratorReader(dummy_schema, std::move(des_sink_gen), + exec_context.memory_pool()); + + ASSERT_OK(t_plan->Validate()); + ASSERT_OK(t_plan->StartProducing()); + + std::shared_ptr des_response_table; + + ASSERT_OK_AND_ASSIGN(des_response_table, + arrow::Table::FromRecordBatchReader(des_sink_reader.get())); + + ASSERT_TRUE(response_table->Equals(*des_response_table, true)); +#endif +} + +TEST(Substrait, SerializeFilterRelation) { +#ifdef _WIN32 + GTEST_SKIP() << "ARROW-16392: Substrait File URI not supported for Windows"; +#else + ExtensionSet ext_set; + compute::ExecContext exec_context; + + ASSERT_OK_AND_ASSIGN(std::string dir_string, + arrow::internal::GetEnvVar("PARQUET_TEST_DATA")); + auto file_name = arrow::internal::PlatformFilename::FromString(dir_string) + ->Join("alltypes_plain.parquet"); + + // Note: left the timestamp field since it is not supported. + // Add it back once it is added. + auto dummy_schema = schema({ + field("id", int32()), + field("bool_col", boolean()), + field("tinyint_col", int32()), + field("smallint_col", int32()), + field("int_col", int32()), + field("bigint_col", int64()), + field("float_col", float32()), + field("double_col", float64()), + field("date_string_col", binary()), + field("string_col", binary()), + }); + auto format = std::make_shared(); + auto filesystem = std::make_shared(); + + std::vector files; + const std::string f_path = file_name->ToString(); + ASSERT_OK_AND_ASSIGN(auto f_file, filesystem->GetFileInfo(f_path)); + files.push_back(std::move(f_file)); + + ASSERT_OK_AND_ASSIGN(auto ds_factory, dataset::FileSystemDatasetFactory::Make( + std::move(filesystem), std::move(files), + std::move(format), {})); + + ASSERT_OK_AND_ASSIGN(auto dataset, ds_factory->Finish(dummy_schema)); + + auto options = std::make_shared(); + + options->projection = compute::project({}, {}); + + auto scan_node_options = dataset::ScanNodeOptions{dataset, options}; + + arrow::AsyncGenerator > sink_gen; + + auto sink_node_options = compute::SinkNodeOptions{&sink_gen}; + + compute::Expression filter_expr = + compute::equal(compute::field_ref("bigint_col"), compute::literal(10)); + // TODO: evaluate this + const std::shared_ptr kBoringSchema = schema({field("bigint_col", int32())}); + ASSERT_OK_AND_ASSIGN(filter_expr, filter_expr.Bind(*kBoringSchema)); + auto filter_node_options = compute::FilterNodeOptions{{filter_expr}}; + + auto scan_declaration = compute::Declaration({"scan", scan_node_options}); + auto filter_declaration = compute::Declaration({"filter", filter_node_options}); + auto sink_declaration = compute::Declaration({"sink", sink_node_options}); + + auto declarations = compute::Declaration::Sequence( + {scan_declaration, filter_declaration, sink_declaration}); + + ASSERT_OK_AND_ASSIGN(auto plan, compute::ExecPlan::Make(&exec_context)); + ASSERT_OK_AND_ASSIGN(auto decl, declarations.AddToPlan(plan.get())); + + ASSERT_OK(decl->Validate()); + + auto out_schema = schema({field("bigint_col", int64())}); + std::shared_ptr sink_reader = compute::MakeGeneratorReader( + out_schema, std::move(sink_gen), exec_context.memory_pool()); + + ASSERT_OK(plan->Validate()); + ASSERT_OK(plan->StartProducing()); + + std::shared_ptr response_table; + + ASSERT_OK_AND_ASSIGN(response_table, + arrow::Table::FromRecordBatchReader(sink_reader.get())); + + auto scan_filter_declr = + compute::Declaration::Sequence({scan_declaration, filter_declaration}); + + ASSERT_OK_AND_ASSIGN(auto serialized_filter_rel, + SerializeRelation(scan_filter_declr, &ext_set)); + ASSERT_OK_AND_ASSIGN(auto deserialized_filter_decl, + DeserializeRelation(*serialized_filter_rel, ext_set)); + + arrow::AsyncGenerator > des_sink_gen; + auto des_sink_node_options = compute::SinkNodeOptions{&des_sink_gen}; + + auto des_sink_declaration = compute::Declaration({"sink", des_sink_node_options}); + + auto t_decls = + compute::Declaration::Sequence({deserialized_filter_decl, des_sink_declaration}); + + ASSERT_OK_AND_ASSIGN(auto t_plan, compute::ExecPlan::Make()); + ASSERT_OK_AND_ASSIGN(auto t_decl, t_decls.AddToPlan(t_plan.get())); + + ASSERT_OK(t_decl->Validate()); + + std::shared_ptr des_sink_reader = + compute::MakeGeneratorReader(out_schema, std::move(des_sink_gen), + exec_context.memory_pool()); + + ASSERT_OK(t_plan->Validate()); + ASSERT_OK(t_plan->StartProducing()); + + std::shared_ptr des_response_table; + + ASSERT_OK_AND_ASSIGN(des_response_table, + arrow::Table::FromRecordBatchReader(des_sink_reader.get())); + + ASSERT_TRUE(response_table->Equals(*des_response_table, true)); +#endif +} + } // namespace engine } // namespace arrow