diff --git a/cpp/src/arrow/engine/substrait/extension_set.cc b/cpp/src/arrow/engine/substrait/extension_set.cc index 6e8522897ee..0e1f5ebc664 100644 --- a/cpp/src/arrow/engine/substrait/extension_set.cc +++ b/cpp/src/arrow/engine/substrait/extension_set.cc @@ -698,6 +698,20 @@ ExtensionIdRegistry::ArrowToSubstraitCall EncodeOptionlessOverflowableArithmetic }; } +ExtensionIdRegistry::ArrowToSubstraitCall EncodeOptionlessComparison(Id substrait_fn_id) { + return + [substrait_fn_id](const compute::Expression::Call& call) -> Result { + // nullable=true isn't quite correct but we don't know the nullability of + // the inputs + SubstraitCall substrait_call(substrait_fn_id, call.type.GetSharedPtr(), + /*nullable=*/true); + for (std::size_t i = 0; i < call.arguments.size(); i++) { + substrait_call.SetValueArg(static_cast(i), call.arguments[i]); + } + return std::move(substrait_call); + }; +} + ExtensionIdRegistry::SubstraitCallToArrow DecodeOptionlessBasicMapping( const std::string& function_name, uint32_t max_args) { return [function_name, @@ -873,6 +887,11 @@ struct DefaultExtensionIdRegistry : ExtensionIdRegistryImpl { AddArrowToSubstraitCall(std::string(fn_name) + "_checked", EncodeOptionlessOverflowableArithmetic(fn_id))); } + // Comparison operators + for (const auto& fn_name : {"equal", "is_not_distinct_from"}) { + Id fn_id{kSubstraitComparisonFunctionsUri, fn_name}; + DCHECK_OK(AddArrowToSubstraitCall(fn_name, EncodeOptionlessComparison(fn_id))); + } } }; diff --git a/cpp/src/arrow/engine/substrait/plan_internal.cc b/cpp/src/arrow/engine/substrait/plan_internal.cc index b0fdb9bdc2f..1efd4e1a0a9 100644 --- a/cpp/src/arrow/engine/substrait/plan_internal.cc +++ b/cpp/src/arrow/engine/substrait/plan_internal.cc @@ -17,6 +17,8 @@ #include "arrow/engine/substrait/plan_internal.h" +#include "arrow/dataset/plan.h" +#include "arrow/engine/substrait/relation_internal.h" #include "arrow/result.h" #include "arrow/util/hashing.h" #include "arrow/util/logging.h" @@ -133,5 +135,19 @@ Result GetExtensionSetFromPlan(const substrait::Plan& plan, registry); } +Result> PlanToProto( + const compute::Declaration& declr, ExtensionSet* ext_set, + const ConversionOptions& conversion_options) { + auto subs_plan = internal::make_unique(); + auto plan_rel = internal::make_unique(); + auto rel_root = internal::make_unique(); + ARROW_ASSIGN_OR_RAISE(auto rel, ToProto(declr, ext_set, conversion_options)); + rel_root->set_allocated_input(rel.release()); + plan_rel->set_allocated_root(rel_root.release()); + subs_plan->mutable_relations()->AddAllocated(plan_rel.release()); + RETURN_NOT_OK(AddExtensionSetToPlan(*ext_set, subs_plan.get())); + return std::move(subs_plan); +} + } // namespace engine } // namespace arrow diff --git a/cpp/src/arrow/engine/substrait/plan_internal.h b/cpp/src/arrow/engine/substrait/plan_internal.h index dce23cdceba..e1ced549ce1 100644 --- a/cpp/src/arrow/engine/substrait/plan_internal.h +++ b/cpp/src/arrow/engine/substrait/plan_internal.h @@ -19,7 +19,9 @@ #pragma once +#include "arrow/compute/exec/exec_plan.h" #include "arrow/engine/substrait/extension_set.h" +#include "arrow/engine/substrait/options.h" #include "arrow/engine/substrait/visibility.h" #include "arrow/type_fwd.h" @@ -51,5 +53,17 @@ Result GetExtensionSetFromPlan( const substrait::Plan& plan, const ExtensionIdRegistry* registry = default_extension_id_registry()); +/// \brief Serialize a declaration into a substrait::Plan. +/// +/// Note that, this is a part of a roundtripping test API and not +/// designed for use in production +/// \param[in] declr the sequence of declarations to be serialized +/// \param[in, out] ext_set the extension set to be updated +/// \param[in] conversion_options options to control serialization behavior +/// \return the serialized plan +ARROW_ENGINE_EXPORT Result> PlanToProto( + const compute::Declaration& declr, ExtensionSet* ext_set, + const ConversionOptions& conversion_options = {}); + } // namespace engine } // namespace arrow diff --git a/cpp/src/arrow/engine/substrait/relation_internal.cc b/cpp/src/arrow/engine/substrait/relation_internal.cc index c5c02f51558..c5d212c8c2f 100644 --- a/cpp/src/arrow/engine/substrait/relation_internal.cc +++ b/cpp/src/arrow/engine/substrait/relation_internal.cc @@ -29,8 +29,16 @@ #include "arrow/filesystem/localfs.h" #include "arrow/filesystem/path_util.h" #include "arrow/filesystem/util_internal.h" +#include "arrow/util/checked_cast.h" +#include "arrow/util/make_unique.h" +#include "arrow/util/uri.h" namespace arrow { + +using ::arrow::internal::UriFromAbsolutePath; +using internal::checked_cast; +using internal::make_unique; + namespace engine { template @@ -162,36 +170,45 @@ Result FromProto(const substrait::Rel& rel, const ExtensionSet& } path = path.substr(7); - if (item.path_type_case() == - substrait::ReadRel_LocalFiles_FileOrFiles::kUriPath) { - ARROW_ASSIGN_OR_RAISE(auto file, filesystem->GetFileInfo(path)); - if (file.type() == fs::FileType::File) { - files.push_back(std::move(file)); - } else if (file.type() == fs::FileType::Directory) { + switch (item.path_type_case()) { + case substrait::ReadRel_LocalFiles_FileOrFiles::kUriPath: { + ARROW_ASSIGN_OR_RAISE(auto file, filesystem->GetFileInfo(path)); + if (file.type() == fs::FileType::File) { + files.push_back(std::move(file)); + } else if (file.type() == fs::FileType::Directory) { + fs::FileSelector selector; + selector.base_dir = path; + selector.recursive = true; + ARROW_ASSIGN_OR_RAISE(auto discovered_files, + filesystem->GetFileInfo(selector)); + std::move(files.begin(), files.end(), std::back_inserter(discovered_files)); + } + break; + } + case substrait::ReadRel_LocalFiles_FileOrFiles::kUriFile: { + files.emplace_back(path, fs::FileType::File); + break; + } + case substrait::ReadRel_LocalFiles_FileOrFiles::kUriFolder: { fs::FileSelector selector; selector.base_dir = path; selector.recursive = true; ARROW_ASSIGN_OR_RAISE(auto discovered_files, filesystem->GetFileInfo(selector)); - std::move(files.begin(), files.end(), std::back_inserter(discovered_files)); + std::move(discovered_files.begin(), discovered_files.end(), + std::back_inserter(files)); + break; + } + case substrait::ReadRel_LocalFiles_FileOrFiles::kUriPathGlob: { + ARROW_ASSIGN_OR_RAISE(auto discovered_files, + fs::internal::GlobFiles(filesystem, path)); + std::move(discovered_files.begin(), discovered_files.end(), + std::back_inserter(files)); + break; + } + default: { + return Status::Invalid("Unrecognized file type in LocalFiles"); } - } - if (item.path_type_case() == - substrait::ReadRel_LocalFiles_FileOrFiles::kUriFile) { - files.emplace_back(path, fs::FileType::File); - } else if (item.path_type_case() == - substrait::ReadRel_LocalFiles_FileOrFiles::kUriFolder) { - fs::FileSelector selector; - selector.base_dir = path; - selector.recursive = true; - ARROW_ASSIGN_OR_RAISE(auto discovered_files, filesystem->GetFileInfo(selector)); - std::move(discovered_files.begin(), discovered_files.end(), - std::back_inserter(files)); - } else { - ARROW_ASSIGN_OR_RAISE(auto discovered_files, - fs::internal::GlobFiles(filesystem, path)); - std::move(discovered_files.begin(), discovered_files.end(), - std::back_inserter(files)); } } @@ -421,5 +438,141 @@ Result FromProto(const substrait::Rel& rel, const ExtensionSet& rel.DebugString()); } +namespace { + +Result> ExtractSchemaToBind(const compute::Declaration& declr) { + std::shared_ptr bind_schema; + if (declr.factory_name == "scan") { + const auto& opts = checked_cast(*(declr.options)); + bind_schema = opts.dataset->schema(); + } else if (declr.factory_name == "filter") { + auto input_declr = util::get(declr.inputs[0]); + ARROW_ASSIGN_OR_RAISE(bind_schema, ExtractSchemaToBind(input_declr)); + } else if (declr.factory_name == "sink") { + // Note that the sink has no output_schema + return bind_schema; + } else { + return Status::Invalid("Schema extraction failed, unsupported factory ", + declr.factory_name); + } + return bind_schema; +} + +Result> ScanRelationConverter( + const std::shared_ptr& schema, const compute::Declaration& declaration, + ExtensionSet* ext_set, const ConversionOptions& conversion_options) { + auto read_rel = make_unique(); + const auto& scan_node_options = + checked_cast(*declaration.options); + auto dataset = + dynamic_cast(scan_node_options.dataset.get()); + if (dataset == nullptr) { + return Status::Invalid( + "Can only convert scan node with FileSystemDataset to a Substrait plan."); + } + // set schema + ARROW_ASSIGN_OR_RAISE(auto named_struct, + ToProto(*dataset->schema(), ext_set, conversion_options)); + read_rel->set_allocated_base_schema(named_struct.release()); + + // set local files + auto read_rel_lfs = make_unique(); + for (const auto& file : dataset->files()) { + auto read_rel_lfs_ffs = make_unique(); + read_rel_lfs_ffs->set_uri_path(UriFromAbsolutePath(file)); + // set file format + auto format_type_name = dataset->format()->type_name(); + if (format_type_name == "parquet") { + read_rel_lfs_ffs->set_allocated_parquet( + new substrait::ReadRel::LocalFiles::FileOrFiles::ParquetReadOptions()); + } else if (format_type_name == "ipc") { + read_rel_lfs_ffs->set_allocated_arrow( + new substrait::ReadRel::LocalFiles::FileOrFiles::ArrowReadOptions()); + } else if (format_type_name == "orc") { + read_rel_lfs_ffs->set_allocated_orc( + new substrait::ReadRel::LocalFiles::FileOrFiles::OrcReadOptions()); + } else { + return Status::NotImplemented("Unsupported file type: ", format_type_name); + } + read_rel_lfs->mutable_items()->AddAllocated(read_rel_lfs_ffs.release()); + } + read_rel->set_allocated_local_files(read_rel_lfs.release()); + return std::move(read_rel); +} + +Result> FilterRelationConverter( + const std::shared_ptr& schema, const compute::Declaration& declaration, + ExtensionSet* ext_set, const ConversionOptions& conversion_options) { + auto filter_rel = make_unique(); + const auto& filter_node_options = + checked_cast(*(declaration.options)); + + auto filter_expr = filter_node_options.filter_expression; + compute::Expression bound_expression; + if (!filter_expr.IsBound()) { + ARROW_ASSIGN_OR_RAISE(bound_expression, filter_expr.Bind(*schema)); + } + + if (declaration.inputs.size() == 0) { + return Status::Invalid("Filter node doesn't have an input."); + } + + // handling input + auto declr_input = declaration.inputs[0]; + ARROW_ASSIGN_OR_RAISE( + auto input_rel, + ToProto(util::get(declr_input), ext_set, conversion_options)); + filter_rel->set_allocated_input(input_rel.release()); + + ARROW_ASSIGN_OR_RAISE(auto subs_expr, + ToProto(bound_expression, ext_set, conversion_options)); + filter_rel->set_allocated_condition(subs_expr.release()); + return std::move(filter_rel); +} + +} // namespace + +Status SerializeAndCombineRelations(const compute::Declaration& declaration, + ExtensionSet* ext_set, + std::unique_ptr* rel, + const ConversionOptions& conversion_options) { + const auto& factory_name = declaration.factory_name; + ARROW_ASSIGN_OR_RAISE(auto schema, ExtractSchemaToBind(declaration)); + // Note that the sink declaration factory doesn't exist for serialization as + // Substrait doesn't deal with a sink node definition + + if (factory_name == "scan") { + ARROW_ASSIGN_OR_RAISE( + auto read_rel, + ScanRelationConverter(schema, declaration, ext_set, conversion_options)); + (*rel)->set_allocated_read(read_rel.release()); + } else if (factory_name == "filter") { + ARROW_ASSIGN_OR_RAISE( + auto filter_rel, + FilterRelationConverter(schema, declaration, ext_set, conversion_options)); + (*rel)->set_allocated_filter(filter_rel.release()); + } else if (factory_name == "sink") { + // Generally when a plan is deserialized the declaration will be a sink declaration. + // Since there is no Sink relation in substrait, this function would be recursively + // called on the input of the Sink declaration. + auto sink_input_decl = util::get(declaration.inputs[0]); + RETURN_NOT_OK( + SerializeAndCombineRelations(sink_input_decl, ext_set, rel, conversion_options)); + } else { + return Status::NotImplemented("Factory ", factory_name, + " not implemented for roundtripping."); + } + + return Status::OK(); +} + +Result> ToProto( + const compute::Declaration& declr, ExtensionSet* ext_set, + const ConversionOptions& conversion_options) { + auto rel = make_unique(); + RETURN_NOT_OK(SerializeAndCombineRelations(declr, ext_set, &rel, conversion_options)); + return std::move(rel); +} + } // namespace engine } // namespace arrow diff --git a/cpp/src/arrow/engine/substrait/relation_internal.h b/cpp/src/arrow/engine/substrait/relation_internal.h index 3699d1f6577..778d1e5bc01 100644 --- a/cpp/src/arrow/engine/substrait/relation_internal.h +++ b/cpp/src/arrow/engine/substrait/relation_internal.h @@ -40,9 +40,19 @@ struct DeclarationInfo { int num_columns; }; +/// \brief Convert a Substrait Rel object to an Acero declaration ARROW_ENGINE_EXPORT Result FromProto(const substrait::Rel&, const ExtensionSet&, const ConversionOptions&); +/// \brief Convert an Acero Declaration to a Substrait Rel +/// +/// Note that, in order to provide a generic interface for ToProto, +/// the ExecNode or ExecPlan are not used in this context as Declaration +/// is preferred in the Substrait space rather than internal components of +/// Acero execution engine. +ARROW_ENGINE_EXPORT Result> ToProto( + const compute::Declaration&, ExtensionSet*, const ConversionOptions&); + } // namespace engine } // namespace arrow diff --git a/cpp/src/arrow/engine/substrait/serde.cc b/cpp/src/arrow/engine/substrait/serde.cc index 9f7d979e2f0..c6297675492 100644 --- a/cpp/src/arrow/engine/substrait/serde.cc +++ b/cpp/src/arrow/engine/substrait/serde.cc @@ -52,6 +52,23 @@ Result ParseFromBuffer(const Buffer& buf) { return message; } +Result> SerializePlan( + const compute::Declaration& declaration, ExtensionSet* ext_set, + const ConversionOptions& conversion_options) { + ARROW_ASSIGN_OR_RAISE(auto subs_plan, + PlanToProto(declaration, ext_set, conversion_options)); + std::string serialized = subs_plan->SerializeAsString(); + return Buffer::FromString(std::move(serialized)); +} + +Result> SerializeRelation( + const compute::Declaration& declaration, ExtensionSet* ext_set, + const ConversionOptions& conversion_options) { + ARROW_ASSIGN_OR_RAISE(auto relation, ToProto(declaration, ext_set, conversion_options)); + std::string serialized = relation->SerializeAsString(); + return Buffer::FromString(std::move(serialized)); +} + Result DeserializeRelation( const Buffer& buf, const ExtensionSet& ext_set, const ConversionOptions& conversion_options) { diff --git a/cpp/src/arrow/engine/substrait/serde.h b/cpp/src/arrow/engine/substrait/serde.h index 6c2083fb56a..2a14ca67570 100644 --- a/cpp/src/arrow/engine/substrait/serde.h +++ b/cpp/src/arrow/engine/substrait/serde.h @@ -36,6 +36,19 @@ namespace arrow { namespace engine { +/// \brief Serialize an Acero Plan to a binary protobuf Substrait message +/// +/// \param[in] declaration the Acero declaration to serialize. +/// This declaration is the sink relation of the Acero plan. +/// \param[in,out] ext_set the extension mapping to use; may be updated to add +/// \param[in] conversion_options options to control how the conversion is done +/// +/// \return a buffer containing the protobuf serialization of the Acero relation +ARROW_ENGINE_EXPORT +Result> SerializePlan( + const compute::Declaration& declaration, ExtensionSet* ext_set, + const ConversionOptions& conversion_options = {}); + /// Factory function type for generating the node that consumes the batches produced by /// each toplevel Substrait relation when deserializing a Substrait Plan. using ConsumerFactory = std::function()>; @@ -202,6 +215,17 @@ Result> SerializeExpression( const compute::Expression& expr, ExtensionSet* ext_set, const ConversionOptions& conversion_options = {}); +/// \brief Serialize an Acero Declaration to a binary protobuf Substrait message +/// +/// \param[in] declaration the Acero declaration to serialize +/// \param[in,out] ext_set the extension mapping to use; may be updated to add +/// \param[in] conversion_options options to control how the conversion is done +/// +/// \return a buffer containing the protobuf serialization of the Acero relation +ARROW_ENGINE_EXPORT Result> SerializeRelation( + const compute::Declaration& declaration, ExtensionSet* ext_set, + const ConversionOptions& conversion_options = {}); + /// \brief Deserializes a Substrait Rel (relation) message to an ExecNode declaration /// /// \param[in] buf a buffer containing the protobuf serialization of a Substrait diff --git a/cpp/src/arrow/engine/substrait/serde_test.cc b/cpp/src/arrow/engine/substrait/serde_test.cc index 04405b31680..9b6c3f715f7 100644 --- a/cpp/src/arrow/engine/substrait/serde_test.cc +++ b/cpp/src/arrow/engine/substrait/serde_test.cc @@ -23,17 +23,30 @@ #include "arrow/compute/exec/expression_internal.h" #include "arrow/dataset/file_base.h" #include "arrow/dataset/file_ipc.h" +#include "arrow/dataset/file_parquet.h" + #include "arrow/dataset/plan.h" #include "arrow/dataset/scanner.h" #include "arrow/engine/substrait/extension_types.h" #include "arrow/engine/substrait/serde.h" + #include "arrow/engine/substrait/util.h" + +#include "arrow/filesystem/localfs.h" #include "arrow/filesystem/mockfs.h" #include "arrow/filesystem/test_util.h" +#include "arrow/io/compressed.h" +#include "arrow/io/memory.h" +#include "arrow/ipc/writer.h" #include "arrow/testing/gtest_util.h" #include "arrow/testing/matchers.h" #include "arrow/util/key_value_metadata.h" +#include "parquet/arrow/writer.h" + +#include "arrow/util/hash_util.h" +#include "arrow/util/hashing.h" + using testing::ElementsAre; using testing::Eq; using testing::HasSubstr; @@ -42,9 +55,46 @@ using testing::UnorderedElementsAre; namespace arrow { using internal::checked_cast; - +using internal::hash_combine; namespace engine { +Status WriteIpcData(const std::string& path, + const std::shared_ptr file_system, + const std::shared_ptr input) { + EXPECT_OK_AND_ASSIGN(auto mmap, file_system->OpenOutputStream(path)); + ARROW_ASSIGN_OR_RAISE( + auto file_writer, + MakeFileWriter(mmap, input->schema(), ipc::IpcWriteOptions::Defaults())); + TableBatchReader reader(input); + std::shared_ptr batch; + while (true) { + RETURN_NOT_OK(reader.ReadNext(&batch)); + if (batch == nullptr) { + break; + } + RETURN_NOT_OK(file_writer->WriteRecordBatch(*batch)); + } + RETURN_NOT_OK(file_writer->Close()); + return Status::OK(); +} + +Result> GetTableFromPlan( + compute::Declaration& declarations, + arrow::AsyncGenerator>& sink_gen, + compute::ExecContext& exec_context, std::shared_ptr& output_schema) { + ARROW_ASSIGN_OR_RAISE(auto plan, compute::ExecPlan::Make(&exec_context)); + ARROW_ASSIGN_OR_RAISE(auto decl, declarations.AddToPlan(plan.get())); + + RETURN_NOT_OK(decl->Validate()); + + std::shared_ptr sink_reader = compute::MakeGeneratorReader( + output_schema, std::move(sink_gen), exec_context.memory_pool()); + + RETURN_NOT_OK(plan->Validate()); + RETURN_NOT_OK(plan->StartProducing()); + return arrow::Table::FromRecordBatchReader(sink_reader.get()); +} + class NullSinkNodeConsumer : public compute::SinkNodeConsumer { public: Status Init(const std::shared_ptr&, compute::BackpressureControl*) override { @@ -866,6 +916,7 @@ Result GetSubstraitJSON() { auto file_name = arrow::internal::PlatformFilename::FromString(dir_string)->Join("binary.parquet"); auto file_path = file_name->ToString(); + std::string substrait_json = R"({ "relations": [ {"rel": { @@ -1814,5 +1865,243 @@ TEST(Substrait, AggregateBadPhase) { ASSERT_RAISES(NotImplemented, DeserializePlans(*buf, [] { return kNullConsumer; })); } +TEST(Substrait, BasicPlanRoundTripping) { +#ifdef _WIN32 + GTEST_SKIP() << "ARROW-16392: Substrait File URI not supported for Windows"; +#endif + compute::ExecContext exec_context; + arrow::dataset::internal::Initialize(); + + auto dummy_schema = schema( + {field("key", int32()), field("shared", int32()), field("distinct", int32())}); + + // creating a dummy dataset using a dummy table + auto table = TableFromJSON(dummy_schema, {R"([ + [1, 1, 10], + [3, 4, 20] + ])", + R"([ + [0, 2, 1], + [1, 3, 2], + [4, 1, 3], + [3, 1, 3], + [1, 2, 5] + ])", + R"([ + [2, 2, 12], + [5, 3, 12], + [1, 3, 12] + ])"}); + + auto format = std::make_shared(); + auto filesystem = std::make_shared(); + const std::string file_name = "serde_test.arrow"; + + ASSERT_OK_AND_ASSIGN(auto tempdir, + arrow::internal::TemporaryDir::Make("substrait-tempdir-")); + std::cout << "file_path_str " << tempdir->path().ToString() << std::endl; + ASSERT_OK_AND_ASSIGN(auto file_path, tempdir->path().Join(file_name)); + std::string file_path_str = file_path.ToString(); + + ARROW_EXPECT_OK(WriteIpcData(file_path_str, filesystem, table)); + + std::vector files; + const std::vector f_paths = {file_path_str}; + + for (const auto& f_path : f_paths) { + ASSERT_OK_AND_ASSIGN(auto f_file, filesystem->GetFileInfo(f_path)); + files.push_back(std::move(f_file)); + } + + ASSERT_OK_AND_ASSIGN(auto ds_factory, dataset::FileSystemDatasetFactory::Make( + filesystem, std::move(files), format, {})); + ASSERT_OK_AND_ASSIGN(auto dataset, ds_factory->Finish(dummy_schema)); + + auto scan_options = std::make_shared(); + scan_options->projection = compute::project({}, {}); + const std::string filter_col_left = "shared"; + const std::string filter_col_right = "distinct"; + auto comp_left_value = compute::field_ref(filter_col_left); + auto comp_right_value = compute::field_ref(filter_col_right); + auto filter = compute::equal(comp_left_value, comp_right_value); + + arrow::AsyncGenerator> sink_gen; + + auto declarations = compute::Declaration::Sequence( + {compute::Declaration( + {"scan", dataset::ScanNodeOptions{dataset, scan_options}, "s"}), + compute::Declaration({"filter", compute::FilterNodeOptions{filter}, "f"}), + compute::Declaration({"sink", compute::SinkNodeOptions{&sink_gen}, "e"})}); + + std::shared_ptr sp_ext_id_reg = MakeExtensionIdRegistry(); + ExtensionIdRegistry* ext_id_reg = sp_ext_id_reg.get(); + ExtensionSet ext_set(ext_id_reg); + + ASSERT_OK_AND_ASSIGN(auto serialized_plan, SerializePlan(declarations, &ext_set)); + + ASSERT_OK_AND_ASSIGN( + auto sink_decls, + DeserializePlans( + *serialized_plan, [] { return kNullConsumer; }, ext_id_reg, &ext_set)); + // filter declaration + auto roundtripped_filter = sink_decls[0].inputs[0].get(); + const auto& filter_opts = + checked_cast(*(roundtripped_filter->options)); + auto roundtripped_expr = filter_opts.filter_expression; + + if (auto* call = roundtripped_expr.call()) { + EXPECT_EQ(call->function_name, "equal"); + auto args = call->arguments; + auto left_index = args[0].field_ref()->field_path()->indices()[0]; + EXPECT_EQ(dummy_schema->field_names()[left_index], filter_col_left); + auto right_index = args[1].field_ref()->field_path()->indices()[0]; + EXPECT_EQ(dummy_schema->field_names()[right_index], filter_col_right); + } + // scan declaration + auto roundtripped_scan = roundtripped_filter->inputs[0].get(); + const auto& dataset_opts = + checked_cast(*(roundtripped_scan->options)); + const auto& roundripped_ds = dataset_opts.dataset; + EXPECT_TRUE(roundripped_ds->schema()->Equals(*dummy_schema)); + ASSERT_OK_AND_ASSIGN(auto roundtripped_frgs, roundripped_ds->GetFragments()); + ASSERT_OK_AND_ASSIGN(auto expected_frgs, dataset->GetFragments()); + + auto roundtrip_frg_vec = IteratorToVector(std::move(roundtripped_frgs)); + auto expected_frg_vec = IteratorToVector(std::move(expected_frgs)); + EXPECT_EQ(expected_frg_vec.size(), roundtrip_frg_vec.size()); + int64_t idx = 0; + for (auto fragment : expected_frg_vec) { + const auto* l_frag = checked_cast(fragment.get()); + const auto* r_frag = + checked_cast(roundtrip_frg_vec[idx++].get()); + EXPECT_TRUE(l_frag->Equals(*r_frag)); + } +} + +TEST(Substrait, BasicPlanRoundTrippingEndToEnd) { +#ifdef _WIN32 + GTEST_SKIP() << "ARROW-16392: Substrait File URI not supported for Windows"; +#endif + compute::ExecContext exec_context; + arrow::dataset::internal::Initialize(); + + auto dummy_schema = schema( + {field("key", int32()), field("shared", int32()), field("distinct", int32())}); + + // creating a dummy dataset using a dummy table + auto table = TableFromJSON(dummy_schema, {R"([ + [1, 1, 10], + [3, 4, 4] + ])", + R"([ + [0, 2, 1], + [1, 3, 2], + [4, 1, 1], + [3, 1, 3], + [1, 2, 2] + ])", + R"([ + [2, 2, 12], + [5, 3, 12], + [1, 3, 3] + ])"}); + + auto format = std::make_shared(); + auto filesystem = std::make_shared(); + const std::string file_name = "serde_test.arrow"; + + ASSERT_OK_AND_ASSIGN(auto tempdir, + arrow::internal::TemporaryDir::Make("substrait-tempdir-")); + ASSERT_OK_AND_ASSIGN(auto file_path, tempdir->path().Join(file_name)); + std::string file_path_str = file_path.ToString(); + + ARROW_EXPECT_OK(WriteIpcData(file_path_str, filesystem, table)); + + std::vector files; + const std::vector f_paths = {file_path_str}; + + for (const auto& f_path : f_paths) { + ASSERT_OK_AND_ASSIGN(auto f_file, filesystem->GetFileInfo(f_path)); + files.push_back(std::move(f_file)); + } + + ASSERT_OK_AND_ASSIGN(auto ds_factory, dataset::FileSystemDatasetFactory::Make( + filesystem, std::move(files), format, {})); + ASSERT_OK_AND_ASSIGN(auto dataset, ds_factory->Finish(dummy_schema)); + + auto scan_options = std::make_shared(); + scan_options->projection = compute::project({}, {}); + const std::string filter_col_left = "shared"; + const std::string filter_col_right = "distinct"; + auto comp_left_value = compute::field_ref(filter_col_left); + auto comp_right_value = compute::field_ref(filter_col_right); + auto filter = compute::equal(comp_left_value, comp_right_value); + + arrow::AsyncGenerator> sink_gen; + + auto declarations = compute::Declaration::Sequence( + {compute::Declaration( + {"scan", dataset::ScanNodeOptions{dataset, scan_options}, "s"}), + compute::Declaration({"filter", compute::FilterNodeOptions{filter}, "f"}), + compute::Declaration({"sink", compute::SinkNodeOptions{&sink_gen}, "e"})}); + + ASSERT_OK_AND_ASSIGN(auto expected_table, GetTableFromPlan(declarations, sink_gen, + exec_context, dummy_schema)); + + std::shared_ptr sp_ext_id_reg = MakeExtensionIdRegistry(); + ExtensionIdRegistry* ext_id_reg = sp_ext_id_reg.get(); + ExtensionSet ext_set(ext_id_reg); + + ASSERT_OK_AND_ASSIGN(auto serialized_plan, SerializePlan(declarations, &ext_set)); + + ASSERT_OK_AND_ASSIGN( + auto sink_decls, + DeserializePlans( + *serialized_plan, [] { return kNullConsumer; }, ext_id_reg, &ext_set)); + // filter declaration + auto roundtripped_filter = sink_decls[0].inputs[0].get(); + const auto& filter_opts = + checked_cast(*(roundtripped_filter->options)); + auto roundtripped_expr = filter_opts.filter_expression; + + if (auto* call = roundtripped_expr.call()) { + EXPECT_EQ(call->function_name, "equal"); + auto args = call->arguments; + auto left_index = args[0].field_ref()->field_path()->indices()[0]; + EXPECT_EQ(dummy_schema->field_names()[left_index], filter_col_left); + auto right_index = args[1].field_ref()->field_path()->indices()[0]; + EXPECT_EQ(dummy_schema->field_names()[right_index], filter_col_right); + } + // scan declaration + auto roundtripped_scan = roundtripped_filter->inputs[0].get(); + const auto& dataset_opts = + checked_cast(*(roundtripped_scan->options)); + const auto& roundripped_ds = dataset_opts.dataset; + EXPECT_TRUE(roundripped_ds->schema()->Equals(*dummy_schema)); + ASSERT_OK_AND_ASSIGN(auto roundtripped_frgs, roundripped_ds->GetFragments()); + ASSERT_OK_AND_ASSIGN(auto expected_frgs, dataset->GetFragments()); + + auto roundtrip_frg_vec = IteratorToVector(std::move(roundtripped_frgs)); + auto expected_frg_vec = IteratorToVector(std::move(expected_frgs)); + EXPECT_EQ(expected_frg_vec.size(), roundtrip_frg_vec.size()); + int64_t idx = 0; + for (auto fragment : expected_frg_vec) { + const auto* l_frag = checked_cast(fragment.get()); + const auto* r_frag = + checked_cast(roundtrip_frg_vec[idx++].get()); + EXPECT_TRUE(l_frag->Equals(*r_frag)); + } + arrow::AsyncGenerator> rnd_trp_sink_gen; + auto rnd_trp_sink_node_options = compute::SinkNodeOptions{&rnd_trp_sink_gen}; + auto rnd_trp_sink_declaration = + compute::Declaration({"sink", rnd_trp_sink_node_options, "e"}); + auto rnd_trp_declarations = + compute::Declaration::Sequence({*roundtripped_filter, rnd_trp_sink_declaration}); + ASSERT_OK_AND_ASSIGN(auto rnd_trp_table, + GetTableFromPlan(rnd_trp_declarations, rnd_trp_sink_gen, + exec_context, dummy_schema)); + EXPECT_TRUE(expected_table->Equals(*rnd_trp_table)); +} + } // namespace engine } // namespace arrow diff --git a/cpp/src/arrow/filesystem/localfs_test.cc b/cpp/src/arrow/filesystem/localfs_test.cc index 0078a593938..fd36faf30fa 100644 --- a/cpp/src/arrow/filesystem/localfs_test.cc +++ b/cpp/src/arrow/filesystem/localfs_test.cc @@ -32,6 +32,7 @@ #include "arrow/filesystem/util_internal.h" #include "arrow/testing/gtest_util.h" #include "arrow/util/io_util.h" +#include "arrow/util/uri.h" namespace arrow { namespace fs { @@ -40,6 +41,7 @@ namespace internal { using ::arrow::internal::FileDescriptor; using ::arrow::internal::PlatformFilename; using ::arrow::internal::TemporaryDir; +using ::arrow::internal::UriFromAbsolutePath; class LocalFSTestMixin : public ::testing::Test { public: @@ -173,16 +175,6 @@ class TestLocalFS : public LocalFSTestMixin { fs_ = std::make_shared(local_path_, local_fs_); } - std::string UriFromAbsolutePath(const std::string& path) { -#ifdef _WIN32 - // Path is supposed to start with "X:/..." - return "file:///" + path; -#else - // Path is supposed to start with "/..." - return "file://" + path; -#endif - } - template void CheckFileSystemFromUriFunc(const std::string& uri, FileSystemFromUriFunc&& fs_from_uri) { @@ -307,7 +299,7 @@ TYPED_TEST(TestLocalFS, NormalizePathThroughSubtreeFS) { TYPED_TEST(TestLocalFS, FileSystemFromUriFile) { // Concrete test with actual file - const auto uri_string = this->UriFromAbsolutePath(this->local_path_); + const auto uri_string = UriFromAbsolutePath(this->local_path_); this->TestFileSystemFromUri(uri_string); this->TestFileSystemFromUriOrPath(uri_string); diff --git a/cpp/src/arrow/filesystem/util_internal.cc b/cpp/src/arrow/filesystem/util_internal.cc index 0d2ad709026..e6f301bdbf1 100644 --- a/cpp/src/arrow/filesystem/util_internal.cc +++ b/cpp/src/arrow/filesystem/util_internal.cc @@ -78,6 +78,7 @@ Status InvalidDeleteDirContents(util::string_view path) { Result GlobFiles(const std::shared_ptr& filesystem, const std::string& glob) { + // TODO: ARROW-17640 // The candidate entries at the current depth level. // We start with the filesystem root. FileInfoVector results{FileInfo("", FileType::Directory)}; diff --git a/cpp/src/arrow/util/io_util.cc b/cpp/src/arrow/util/io_util.cc index 11ae80d03e2..a62040f3a70 100644 --- a/cpp/src/arrow/util/io_util.cc +++ b/cpp/src/arrow/util/io_util.cc @@ -1867,7 +1867,9 @@ Result> TemporaryDir::Make(const std::string& pref [&](const NativePathString& base_dir) -> Result> { Status st; for (int attempt = 0; attempt < 3; ++attempt) { - PlatformFilename fn(base_dir + kNativeSep + base_name + kNativeSep); + PlatformFilename fn_base_dir(base_dir); + PlatformFilename fn_base_name(base_name + kNativeSep); + PlatformFilename fn = fn_base_dir.Join(fn_base_name); auto result = CreateDir(fn); if (!result.ok()) { // Probably a permissions error or a non-existing base_dir diff --git a/cpp/src/arrow/util/uri.cc b/cpp/src/arrow/util/uri.cc index 7a8484ce51a..abfc9de8b49 100644 --- a/cpp/src/arrow/util/uri.cc +++ b/cpp/src/arrow/util/uri.cc @@ -304,5 +304,15 @@ Status Uri::Parse(const std::string& uri_string) { return Status::OK(); } +std::string UriFromAbsolutePath(const std::string& path) { +#ifdef _WIN32 + // Path is supposed to start with "X:/..." + return "file:///" + path; +#else + // Path is supposed to start with "/..." + return "file://" + path; +#endif +} + } // namespace internal } // namespace arrow diff --git a/cpp/src/arrow/util/uri.h b/cpp/src/arrow/util/uri.h index eae1956eafc..50d9eccf82f 100644 --- a/cpp/src/arrow/util/uri.h +++ b/cpp/src/arrow/util/uri.h @@ -104,5 +104,9 @@ std::string UriEncodeHost(const std::string& host); ARROW_EXPORT bool IsValidUriScheme(const arrow::util::string_view s); +/// Create a file uri from a given absolute path +ARROW_EXPORT +std::string UriFromAbsolutePath(const std::string& path); + } // namespace internal } // namespace arrow