From 716a5b944a8ba1bef057be677a5ab3bb7f5eea6d Mon Sep 17 00:00:00 2001 From: Yaron Gvili Date: Fri, 11 Mar 2022 10:10:58 -0500 Subject: [PATCH 01/19] Substrait integrations --- .../compute/kernels/scalar_arithmetic.cc | 142 ++++++++++++++++++ .../arrow/engine/substrait/extension_set.cc | 11 +- .../engine/substrait/relation_internal.cc | 11 +- cpp/src/arrow/engine/substrait/serde.cc | 14 +- 4 files changed, 171 insertions(+), 7 deletions(-) diff --git a/cpp/src/arrow/compute/kernels/scalar_arithmetic.cc b/cpp/src/arrow/compute/kernels/scalar_arithmetic.cc index 2d543e23266..db6dda5b1a6 100644 --- a/cpp/src/arrow/compute/kernels/scalar_arithmetic.cc +++ b/cpp/src/arrow/compute/kernels/scalar_arithmetic.cc @@ -495,6 +495,102 @@ struct DivideChecked { } }; +// if at least one argument is NaN, returns the first one that is NaN +struct Minimum { + template + static constexpr enable_if_floating_value Call(KernelContext*, Arg0 left, Arg1 right, + Status*) { + return std::isnan(left) ? left : left < right ? left : right; + } + + template + static constexpr enable_if_unsigned_integer_value Call(KernelContext*, Arg0 left, + Arg1 right, Status*) { + return std::isnan(left) ? left : left < right ? left : right; + } + + template + static constexpr enable_if_signed_integer_value Call(KernelContext*, Arg0 left, + Arg1 right, Status*) { + return std::isnan(left) ? left : left < right ? left : right; + } + + template + static enable_if_decimal_value Call(KernelContext*, Arg0 left, Arg1 right, Status*) { + return left < right ? left : right; + } +}; + +// if both arguments are NaN, returns the first one +struct MinimumChecked { + template + static enable_if_integer_value Call(KernelContext*, Arg0 left, Arg1 right, + Status* st) { + static_assert(std::is_same::value && std::is_same::value, ""); + return std::isnan(left) && std::isnan(right) ? left : left < right ? left : right; + } + + template + static enable_if_floating_value Call(KernelContext*, Arg0 left, Arg1 right, + Status*) { + static_assert(std::is_same::value && std::is_same::value, ""); + return std::isnan(left) && std::isnan(right) ? left : left < right ? left : right; + } + + template + static enable_if_decimal_value Call(KernelContext*, Arg0 left, Arg1 right, Status*) { + return left < right ? left : right; + } +}; + +// if at least one argument is NaN, returns the first one that is NaN +struct Maximum { + template + static constexpr enable_if_floating_value Call(KernelContext*, Arg0 left, Arg1 right, + Status*) { + return std::isnan(left) ? left : left > right ? left : right; + } + + template + static constexpr enable_if_unsigned_integer_value Call(KernelContext*, Arg0 left, + Arg1 right, Status*) { + return std::isnan(left) ? left : left > right ? left : right; + } + + template + static constexpr enable_if_signed_integer_value Call(KernelContext*, Arg0 left, + Arg1 right, Status*) { + return std::isnan(left) ? left : left > right ? left : right; + } + + template + static enable_if_decimal_value Call(KernelContext*, Arg0 left, Arg1 right, Status*) { + return left > right ? left : right; + } +}; + +// if both arguments are NaN, returns the first one +struct MaximumChecked { + template + static enable_if_integer_value Call(KernelContext*, Arg0 left, Arg1 right, + Status* st) { + static_assert(std::is_same::value && std::is_same::value, ""); + return std::isnan(left) && std::isnan(right) ? left : left > right ? left : right; + } + + template + static enable_if_floating_value Call(KernelContext*, Arg0 left, Arg1 right, + Status*) { + static_assert(std::is_same::value && std::is_same::value, ""); + return std::isnan(left) && std::isnan(right) ? left : left > right ? left : right; + } + + template + static enable_if_decimal_value Call(KernelContext*, Arg0 left, Arg1 right, Status*) { + return left > right ? left : right; + } +}; + struct Negate { template static constexpr enable_if_floating_value Call(KernelContext*, Arg arg, Status*) { @@ -2372,6 +2468,30 @@ const FunctionDoc negate_checked_doc{ "doesn't fail on overflow, use function \"negate\"."), {"x"}}; +const FunctionDoc min_doc{"Take the minimum of the arguments element-wise", + ("Results will take the first NaN if at least one is input.\n" + "Use function \"minimum_checked\" if you want a number\n" + "to be returned if one is input."), + {"x", "y"}}; + +const FunctionDoc min_checked_doc{ + "Take the minimum of the argumentss element-wise", + ("This function returns a number if one is input. For a variant that\n" + "returns a NaN if one is input, use function \"minimum\"."), + {"x"}}; + +const FunctionDoc max_doc{"Take the maximum of the arguments element-wise", + ("Results will take the first NaN if at least one is input.\n" + "Use function \"maximum_checked\" if you want a number\n" + "to be returned if one is input."), + {"x", "y"}}; + +const FunctionDoc max_checked_doc{ + "Take the maximum of the argumentss element-wise", + ("This function returns a number if one is input. For a variant that\n" + "returns a NaN if one is input, use function \"maximum\"."), + {"x"}}; + const FunctionDoc pow_doc{ "Raise arguments to power element-wise", ("Integer to negative integer power returns an error. However, integer overflow\n" @@ -2874,6 +2994,28 @@ void RegisterScalarArithmetic(FunctionRegistry* registry) { AddDecimalUnaryKernels(negate_checked.get()); DCHECK_OK(registry->AddFunction(std::move(negate_checked))); + // ---------------------------------------------------------------------- + auto minimum = MakeArithmeticFunction("minimum", &min_doc); + AddDecimalBinaryKernels("minimum", minimum.get()); + DCHECK_OK(registry->AddFunction(std::move(minimum))); + + // ---------------------------------------------------------------------- + auto minimum_checked = MakeArithmeticFunctionNotNull( + "minimum_checked", &min_checked_doc); + AddDecimalBinaryKernels("minimum_checked", minimum_checked.get()); + DCHECK_OK(registry->AddFunction(std::move(minimum_checked))); + + // ---------------------------------------------------------------------- + auto maximum = MakeArithmeticFunction("maximum", &max_doc); + AddDecimalBinaryKernels("maximum", maximum.get()); + DCHECK_OK(registry->AddFunction(std::move(maximum))); + + // ---------------------------------------------------------------------- + auto maximum_checked = MakeArithmeticFunctionNotNull( + "maximum_checked", &max_checked_doc); + AddDecimalBinaryKernels("maximum_checked", maximum_checked.get()); + DCHECK_OK(registry->AddFunction(std::move(maximum_checked))); + // ---------------------------------------------------------------------- auto power = MakeArithmeticFunction( "power", &pow_doc); diff --git a/cpp/src/arrow/engine/substrait/extension_set.cc b/cpp/src/arrow/engine/substrait/extension_set.cc index fe43ab28799..b2d75faa38c 100644 --- a/cpp/src/arrow/engine/substrait/extension_set.cc +++ b/cpp/src/arrow/engine/substrait/extension_set.cc @@ -246,10 +246,15 @@ ExtensionIdRegistry* default_extension_id_registry() { // all functions (and prototypes) that Arrow provides that are relevant // for Substrait, and include mappings for all of them here. See // ARROW-15535. - for (util::string_view name : { - "add", + for (std::pair name_pair : { + std::make_pair("add", "add"), + std::make_pair("divide", "divide"), + std::make_pair("power", "power"), + std::make_pair("clip_lower", "maximum"), + std::make_pair("clip_upper", "minimum"), }) { - DCHECK_OK(RegisterFunction({kArrowExtTypesUri, name}, name.to_string())); + DCHECK_OK(RegisterFunction( + {kArrowExtTypesUri, name_pair.first}, name_pair.second.to_string())); } } diff --git a/cpp/src/arrow/engine/substrait/relation_internal.cc b/cpp/src/arrow/engine/substrait/relation_internal.cc index 4ef19349a8d..4e4a4f51265 100644 --- a/cpp/src/arrow/engine/substrait/relation_internal.cc +++ b/cpp/src/arrow/engine/substrait/relation_internal.cc @@ -19,6 +19,7 @@ #include "arrow/compute/api_scalar.h" #include "arrow/compute/exec/options.h" +#include "arrow/dataset/file_ipc.h" #include "arrow/dataset/file_parquet.h" #include "arrow/dataset/plan.h" #include "arrow/dataset/scanner.h" @@ -85,7 +86,7 @@ Result FromProto(const substrait::Rel& rel, "substrait::ReadRel::LocalFiles::advanced_extension"); } - auto format = std::make_shared(); + std::shared_ptr format; auto filesystem = std::make_shared(); std::vector> fragments; @@ -97,8 +98,14 @@ Result FromProto(const substrait::Rel& rel, "path_type other than uri_file"); } - if (item.format() != + if (item.format() == substrait::ReadRel::LocalFiles::FileOrFiles::FILE_FORMAT_PARQUET) { + format = std::make_shared(); + } else if (util::string_view{item.uri_file()}.ends_with(".arrow")) { + format = std::make_shared(); + } else if (util::string_view{item.uri_file()}.ends_with(".feather")) { + format = std::make_shared(); + } else { return Status::NotImplemented( "substrait::ReadRel::LocalFiles::FileOrFiles::format " "other than FILE_FORMAT_PARQUET"); diff --git a/cpp/src/arrow/engine/substrait/serde.cc b/cpp/src/arrow/engine/substrait/serde.cc index ea916d86757..9e6e64910bd 100644 --- a/cpp/src/arrow/engine/substrait/serde.cc +++ b/cpp/src/arrow/engine/substrait/serde.cc @@ -67,10 +67,20 @@ Result> DeserializePlan( std::vector sink_decls; for (const substrait::PlanRel& plan_rel : plan.relations()) { + ARROW_ASSIGN_OR_RAISE( + auto decl, + FromProto(plan_rel.has_root() ? plan_rel.root().input() : plan_rel.rel(), + ext_set)); if (plan_rel.has_root()) { - return Status::NotImplemented("substrait::PlanRel with custom output field names"); + compute::ProjectNodeOptions* options_with_names = + dynamic_cast(decl.options.get()); + if (options_with_names == nullptr) { + return Status::NotImplemented( + "substrait::PlanRel with custom output field names"); + } + auto names = plan_rel.root().names(); + options_with_names->names = {names.begin(), names.end()}; } - ARROW_ASSIGN_OR_RAISE(auto decl, FromProto(plan_rel.rel(), ext_set)); // pipe each relation into a consuming_sink node auto sink_decl = compute::Declaration::Sequence({ From 8327b6778116a31f674f349bc99843d4fc541861 Mon Sep 17 00:00:00 2001 From: Yaron Gvili Date: Fri, 11 Mar 2022 10:10:58 -0500 Subject: [PATCH 02/19] Substrait integrations --- .../compute/kernels/scalar_arithmetic.cc | 142 ++++++++++++++++++ .../arrow/engine/substrait/extension_set.cc | 11 +- cpp/src/arrow/engine/substrait/serde.cc | 14 +- 3 files changed, 162 insertions(+), 5 deletions(-) diff --git a/cpp/src/arrow/compute/kernels/scalar_arithmetic.cc b/cpp/src/arrow/compute/kernels/scalar_arithmetic.cc index 2d543e23266..db6dda5b1a6 100644 --- a/cpp/src/arrow/compute/kernels/scalar_arithmetic.cc +++ b/cpp/src/arrow/compute/kernels/scalar_arithmetic.cc @@ -495,6 +495,102 @@ struct DivideChecked { } }; +// if at least one argument is NaN, returns the first one that is NaN +struct Minimum { + template + static constexpr enable_if_floating_value Call(KernelContext*, Arg0 left, Arg1 right, + Status*) { + return std::isnan(left) ? left : left < right ? left : right; + } + + template + static constexpr enable_if_unsigned_integer_value Call(KernelContext*, Arg0 left, + Arg1 right, Status*) { + return std::isnan(left) ? left : left < right ? left : right; + } + + template + static constexpr enable_if_signed_integer_value Call(KernelContext*, Arg0 left, + Arg1 right, Status*) { + return std::isnan(left) ? left : left < right ? left : right; + } + + template + static enable_if_decimal_value Call(KernelContext*, Arg0 left, Arg1 right, Status*) { + return left < right ? left : right; + } +}; + +// if both arguments are NaN, returns the first one +struct MinimumChecked { + template + static enable_if_integer_value Call(KernelContext*, Arg0 left, Arg1 right, + Status* st) { + static_assert(std::is_same::value && std::is_same::value, ""); + return std::isnan(left) && std::isnan(right) ? left : left < right ? left : right; + } + + template + static enable_if_floating_value Call(KernelContext*, Arg0 left, Arg1 right, + Status*) { + static_assert(std::is_same::value && std::is_same::value, ""); + return std::isnan(left) && std::isnan(right) ? left : left < right ? left : right; + } + + template + static enable_if_decimal_value Call(KernelContext*, Arg0 left, Arg1 right, Status*) { + return left < right ? left : right; + } +}; + +// if at least one argument is NaN, returns the first one that is NaN +struct Maximum { + template + static constexpr enable_if_floating_value Call(KernelContext*, Arg0 left, Arg1 right, + Status*) { + return std::isnan(left) ? left : left > right ? left : right; + } + + template + static constexpr enable_if_unsigned_integer_value Call(KernelContext*, Arg0 left, + Arg1 right, Status*) { + return std::isnan(left) ? left : left > right ? left : right; + } + + template + static constexpr enable_if_signed_integer_value Call(KernelContext*, Arg0 left, + Arg1 right, Status*) { + return std::isnan(left) ? left : left > right ? left : right; + } + + template + static enable_if_decimal_value Call(KernelContext*, Arg0 left, Arg1 right, Status*) { + return left > right ? left : right; + } +}; + +// if both arguments are NaN, returns the first one +struct MaximumChecked { + template + static enable_if_integer_value Call(KernelContext*, Arg0 left, Arg1 right, + Status* st) { + static_assert(std::is_same::value && std::is_same::value, ""); + return std::isnan(left) && std::isnan(right) ? left : left > right ? left : right; + } + + template + static enable_if_floating_value Call(KernelContext*, Arg0 left, Arg1 right, + Status*) { + static_assert(std::is_same::value && std::is_same::value, ""); + return std::isnan(left) && std::isnan(right) ? left : left > right ? left : right; + } + + template + static enable_if_decimal_value Call(KernelContext*, Arg0 left, Arg1 right, Status*) { + return left > right ? left : right; + } +}; + struct Negate { template static constexpr enable_if_floating_value Call(KernelContext*, Arg arg, Status*) { @@ -2372,6 +2468,30 @@ const FunctionDoc negate_checked_doc{ "doesn't fail on overflow, use function \"negate\"."), {"x"}}; +const FunctionDoc min_doc{"Take the minimum of the arguments element-wise", + ("Results will take the first NaN if at least one is input.\n" + "Use function \"minimum_checked\" if you want a number\n" + "to be returned if one is input."), + {"x", "y"}}; + +const FunctionDoc min_checked_doc{ + "Take the minimum of the argumentss element-wise", + ("This function returns a number if one is input. For a variant that\n" + "returns a NaN if one is input, use function \"minimum\"."), + {"x"}}; + +const FunctionDoc max_doc{"Take the maximum of the arguments element-wise", + ("Results will take the first NaN if at least one is input.\n" + "Use function \"maximum_checked\" if you want a number\n" + "to be returned if one is input."), + {"x", "y"}}; + +const FunctionDoc max_checked_doc{ + "Take the maximum of the argumentss element-wise", + ("This function returns a number if one is input. For a variant that\n" + "returns a NaN if one is input, use function \"maximum\"."), + {"x"}}; + const FunctionDoc pow_doc{ "Raise arguments to power element-wise", ("Integer to negative integer power returns an error. However, integer overflow\n" @@ -2874,6 +2994,28 @@ void RegisterScalarArithmetic(FunctionRegistry* registry) { AddDecimalUnaryKernels(negate_checked.get()); DCHECK_OK(registry->AddFunction(std::move(negate_checked))); + // ---------------------------------------------------------------------- + auto minimum = MakeArithmeticFunction("minimum", &min_doc); + AddDecimalBinaryKernels("minimum", minimum.get()); + DCHECK_OK(registry->AddFunction(std::move(minimum))); + + // ---------------------------------------------------------------------- + auto minimum_checked = MakeArithmeticFunctionNotNull( + "minimum_checked", &min_checked_doc); + AddDecimalBinaryKernels("minimum_checked", minimum_checked.get()); + DCHECK_OK(registry->AddFunction(std::move(minimum_checked))); + + // ---------------------------------------------------------------------- + auto maximum = MakeArithmeticFunction("maximum", &max_doc); + AddDecimalBinaryKernels("maximum", maximum.get()); + DCHECK_OK(registry->AddFunction(std::move(maximum))); + + // ---------------------------------------------------------------------- + auto maximum_checked = MakeArithmeticFunctionNotNull( + "maximum_checked", &max_checked_doc); + AddDecimalBinaryKernels("maximum_checked", maximum_checked.get()); + DCHECK_OK(registry->AddFunction(std::move(maximum_checked))); + // ---------------------------------------------------------------------- auto power = MakeArithmeticFunction( "power", &pow_doc); diff --git a/cpp/src/arrow/engine/substrait/extension_set.cc b/cpp/src/arrow/engine/substrait/extension_set.cc index fe43ab28799..b2d75faa38c 100644 --- a/cpp/src/arrow/engine/substrait/extension_set.cc +++ b/cpp/src/arrow/engine/substrait/extension_set.cc @@ -246,10 +246,15 @@ ExtensionIdRegistry* default_extension_id_registry() { // all functions (and prototypes) that Arrow provides that are relevant // for Substrait, and include mappings for all of them here. See // ARROW-15535. - for (util::string_view name : { - "add", + for (std::pair name_pair : { + std::make_pair("add", "add"), + std::make_pair("divide", "divide"), + std::make_pair("power", "power"), + std::make_pair("clip_lower", "maximum"), + std::make_pair("clip_upper", "minimum"), }) { - DCHECK_OK(RegisterFunction({kArrowExtTypesUri, name}, name.to_string())); + DCHECK_OK(RegisterFunction( + {kArrowExtTypesUri, name_pair.first}, name_pair.second.to_string())); } } diff --git a/cpp/src/arrow/engine/substrait/serde.cc b/cpp/src/arrow/engine/substrait/serde.cc index ea916d86757..9e6e64910bd 100644 --- a/cpp/src/arrow/engine/substrait/serde.cc +++ b/cpp/src/arrow/engine/substrait/serde.cc @@ -67,10 +67,20 @@ Result> DeserializePlan( std::vector sink_decls; for (const substrait::PlanRel& plan_rel : plan.relations()) { + ARROW_ASSIGN_OR_RAISE( + auto decl, + FromProto(plan_rel.has_root() ? plan_rel.root().input() : plan_rel.rel(), + ext_set)); if (plan_rel.has_root()) { - return Status::NotImplemented("substrait::PlanRel with custom output field names"); + compute::ProjectNodeOptions* options_with_names = + dynamic_cast(decl.options.get()); + if (options_with_names == nullptr) { + return Status::NotImplemented( + "substrait::PlanRel with custom output field names"); + } + auto names = plan_rel.root().names(); + options_with_names->names = {names.begin(), names.end()}; } - ARROW_ASSIGN_OR_RAISE(auto decl, FromProto(plan_rel.rel(), ext_set)); // pipe each relation into a consuming_sink node auto sink_decl = compute::Declaration::Sequence({ From 330ae6639aa5ae01dc4b9c39e3930d069b1bf04c Mon Sep 17 00:00:00 2001 From: Yaron Gvili Date: Fri, 25 Mar 2022 09:16:32 -0400 Subject: [PATCH 03/19] Added end-to-end Substrait-to-Arrow enhancements --- cpp/cmake_modules/ThirdpartyToolchain.cmake | 3 +- .../compute/kernels/scalar_arithmetic.cc | 18 +- cpp/src/arrow/dataset/file_base.cc | 181 +++++++++++++++--- .../engine/substrait/expression_internal.cc | 17 +- .../engine/substrait/expression_internal.h | 3 +- .../arrow/engine/substrait/extension_set.cc | 5 +- .../engine/substrait/relation_internal.cc | 153 ++++++++++++++- .../engine/substrait/relation_internal.h | 5 +- cpp/src/arrow/engine/substrait/serde.cc | 18 +- 9 files changed, 344 insertions(+), 59 deletions(-) diff --git a/cpp/cmake_modules/ThirdpartyToolchain.cmake b/cpp/cmake_modules/ThirdpartyToolchain.cmake index 35bd80be3e0..afa8ab8d2cd 100644 --- a/cpp/cmake_modules/ThirdpartyToolchain.cmake +++ b/cpp/cmake_modules/ThirdpartyToolchain.cmake @@ -1636,13 +1636,12 @@ macro(build_substrait) message("Building Substrait from source") set(SUBSTRAIT_PROTOS + algebra capabilities - expression extensions/extensions function parameterized_types plan - relations type type_expressions) diff --git a/cpp/src/arrow/compute/kernels/scalar_arithmetic.cc b/cpp/src/arrow/compute/kernels/scalar_arithmetic.cc index db6dda5b1a6..bf2d82d4448 100644 --- a/cpp/src/arrow/compute/kernels/scalar_arithmetic.cc +++ b/cpp/src/arrow/compute/kernels/scalar_arithmetic.cc @@ -1869,6 +1869,17 @@ Result ResolveDecimalDivisionOutput(KernelContext*, }); } +Result ResolveDecimalMinimumOrMaximumOutput( + KernelContext*, const std::vector& args) { + return ResolveDecimalBinaryOperationOutput( + args, [](int32_t p1, int32_t s1, int32_t p2, int32_t s2) { + DCHECK_EQ(s1, s2); + const int32_t scale = s1; + const int32_t precision = std::max(p1, p2); + return std::make_pair(precision, scale); + }); +} + Result ResolveTemporalOutput(KernelContext*, const std::vector& args) { DCHECK_EQ(args[0].type->id(), args[1].type->id()); @@ -1907,7 +1918,10 @@ void AddDecimalBinaryKernels(const std::string& name, ScalarFunction* func) { out_type = OutputType(ResolveDecimalMultiplicationOutput); } else if (op == "divide") { out_type = OutputType(ResolveDecimalDivisionOutput); + } else if (op == "minimum" || op == "maximum") { + out_type = OutputType(ResolveDecimalMinimumOrMaximumOutput); } else { + ARROW_LOG(FATAL) << "AddDecimalBinaryKernels failed: name is " << name; DCHECK(false); } @@ -2478,7 +2492,7 @@ const FunctionDoc min_checked_doc{ "Take the minimum of the argumentss element-wise", ("This function returns a number if one is input. For a variant that\n" "returns a NaN if one is input, use function \"minimum\"."), - {"x"}}; + {"x", "y"}}; const FunctionDoc max_doc{"Take the maximum of the arguments element-wise", ("Results will take the first NaN if at least one is input.\n" @@ -2490,7 +2504,7 @@ const FunctionDoc max_checked_doc{ "Take the maximum of the argumentss element-wise", ("This function returns a number if one is input. For a variant that\n" "returns a NaN if one is input, use function \"maximum\"."), - {"x"}}; + {"x", "y"}}; const FunctionDoc pow_doc{ "Raise arguments to power element-wise", diff --git a/cpp/src/arrow/dataset/file_base.cc b/cpp/src/arrow/dataset/file_base.cc index f4551c27590..04080a17a68 100644 --- a/cpp/src/arrow/dataset/file_base.cc +++ b/cpp/src/arrow/dataset/file_base.cc @@ -42,6 +42,7 @@ #include "arrow/util/map.h" #include "arrow/util/string.h" #include "arrow/util/task_group.h" +#include "arrow/util/tracing_internal.h" #include "arrow/util/variant.h" namespace arrow { @@ -269,6 +270,35 @@ Future<> FileWriter::Finish() { namespace { +Status WriteBatch( + std::shared_ptr batch, + compute::Expression guarantee, + FileSystemDatasetWriteOptions write_options, + std::function, std::string)> write) { + ARROW_ASSIGN_OR_RAISE(auto groups, write_options.partitioning->Partition(batch)); + batch.reset(); // drop to hopefully conserve memory + + if (write_options.max_partitions <= 0) { + return Status::Invalid("max_partitions must be positive (was ", + write_options.max_partitions, ")"); + } + + if (groups.batches.size() > static_cast(write_options.max_partitions)) { + return Status::Invalid("Fragment would be written into ", groups.batches.size(), + " partitions. This exceeds the maximum of ", + write_options.max_partitions); + } + + for (std::size_t index = 0; index < groups.batches.size(); index++) { + auto partition_expression = and_(groups.expressions[index], guarantee); + auto next_batch = groups.batches[index]; + ARROW_ASSIGN_OR_RAISE(std::string destination, + write_options.partitioning->Format(partition_expression)); + RETURN_NOT_OK(write(next_batch, destination)); + } + return Status::OK(); +} + class DatasetWritingSinkNodeConsumer : public compute::SinkNodeConsumer { public: DatasetWritingSinkNodeConsumer(std::shared_ptr schema, @@ -294,35 +324,20 @@ class DatasetWritingSinkNodeConsumer : public compute::SinkNodeConsumer { private: Status WriteNextBatch(std::shared_ptr batch, compute::Expression guarantee) { - ARROW_ASSIGN_OR_RAISE(auto groups, write_options_.partitioning->Partition(batch)); - batch.reset(); // drop to hopefully conserve memory - - if (write_options_.max_partitions <= 0) { - return Status::Invalid("max_partitions must be positive (was ", - write_options_.max_partitions, ")"); - } - - if (groups.batches.size() > static_cast(write_options_.max_partitions)) { - return Status::Invalid("Fragment would be written into ", groups.batches.size(), - " partitions. This exceeds the maximum of ", - write_options_.max_partitions); - } - - for (std::size_t index = 0; index < groups.batches.size(); index++) { - auto partition_expression = and_(groups.expressions[index], guarantee); - auto next_batch = groups.batches[index]; - ARROW_ASSIGN_OR_RAISE(std::string destination, - write_options_.partitioning->Format(partition_expression)); - RETURN_NOT_OK(task_group_.AddTask([this, next_batch, destination] { - Future<> has_room = dataset_writer_->WriteRecordBatch(next_batch, destination); - if (!has_room.is_finished() && backpressure_toggle_) { - backpressure_toggle_->Close(); - return has_room.Then([this] { backpressure_toggle_->Open(); }); - } - return has_room; - })); - } - return Status::OK(); + return WriteBatch(batch, + guarantee, + write_options_, + [this](std::shared_ptr next_batch, + std::string destination) { + return task_group_.AddTask([this, next_batch, destination] { + Future<> has_room = dataset_writer_->WriteRecordBatch(next_batch, destination); + if (!has_room.is_finished() && backpressure_toggle_) { + backpressure_toggle_->Close(); + return has_room.Then([this] { backpressure_toggle_->Open(); }); + } + return has_room; + }); + }); } std::shared_ptr schema_; @@ -399,9 +414,117 @@ Result MakeWriteNode(compute::ExecPlan* plan, return node; } +namespace { + +class TeeNode : public compute::MapNode { + public: + TeeNode(compute::ExecPlan* plan, std::vector inputs, + std::shared_ptr output_schema, + std::unique_ptr dataset_writer, + FileSystemDatasetWriteOptions write_options, + std::shared_ptr backpressure_toggle, + bool async_mode) + : MapNode(plan, std::move(inputs), std::move(output_schema), async_mode), + dataset_writer_(std::move(dataset_writer)), + write_options_(std::move(write_options)), + backpressure_toggle_(std::move(backpressure_toggle)) { + } + + static Result Make(compute::ExecPlan* plan, + std::vector inputs, + const compute::ExecNodeOptions& options) { + RETURN_NOT_OK(ValidateExecNodeInputs(plan, inputs, 1, "TeeNode")); + + const WriteNodeOptions write_node_options = + checked_cast(options); + const FileSystemDatasetWriteOptions& write_options = + write_node_options.write_options; + const std::shared_ptr& schema = + write_node_options.schema + ? write_node_options.schema + : inputs[0]->output_schema(); + const std::shared_ptr& backpressure_toggle = + write_node_options.backpressure_toggle; + + if (schema.get() != inputs[0]->output_schema().get()) { + return Status::Invalid("input schema does not match one in options"); + } + + ARROW_ASSIGN_OR_RAISE(auto dataset_writer, + internal::DatasetWriter::Make(write_options)); + + return plan->EmplaceNode(plan, std::move(inputs), std::move(schema), + std::move(dataset_writer), + std::move(write_options), + std::move(backpressure_toggle), + /*async_mode=*/true); + } + + const char* kind_name() const override { return "TeeNode"; } + + Result DoTee(const compute::ExecBatch& batch) { + ARROW_ASSIGN_OR_RAISE(std::shared_ptr record_batch, + batch.ToRecordBatch(output_schema())); + ARROW_RETURN_NOT_OK(WriteNextBatch(std::move(record_batch), batch.guarantee)); + return batch; + } + + Status WriteNextBatch(std::shared_ptr batch, + compute::Expression guarantee) { + return WriteBatch(batch, + guarantee, + write_options_, + [this](std::shared_ptr next_batch, + std::string destination) { + return task_group_.AddTask([this, next_batch, destination] { + util::tracing::Span span; + START_SPAN(span, "Tee", + {{"tee.base_dir", ToStringExtra()}, + {"tee.length", next_batch.length}}); + Future<> has_room = dataset_writer_->WriteRecordBatch(next_batch, destination); + if (!has_room.is_finished() && backpressure_toggle_) { + backpressure_toggle_->Close(); + return has_room.Then([this] { backpressure_toggle_->Open(); }); + } + return has_room; + }); + }); + } + + void InputReceived(compute::ExecNode* input, compute::ExecBatch batch) override { + EVENT(span_, "InputReceived", {{"batch.length", batch.length}}); + DCHECK_EQ(input, inputs_[0]); + auto func = [this](compute::ExecBatch batch) { + util::tracing::Span span; + START_SPAN_WITH_PARENT(span, span_, "InputReceived", + {{"tee", ToStringExtra()}, + {"node.label", label()}, + {"batch.length", batch.length}}); + auto result = DoTee(std::move(batch)); + MARK_SPAN(span, result.status()); + END_SPAN(span); + return result; + }; + this->SubmitTask(std::move(func), std::move(batch)); + } + + protected: + std::string ToStringExtra(int indent = 0) const override { + return "base_dir=" + write_options_.base_dir; + } + + private: + std::unique_ptr dataset_writer_; + FileSystemDatasetWriteOptions write_options_; + std::shared_ptr backpressure_toggle_; +}; + +} // namespace + namespace internal { void InitializeDatasetWriter(arrow::compute::ExecFactoryRegistry* registry) { DCHECK_OK(registry->AddFactory("write", MakeWriteNode)); + DCHECK_OK(registry->AddFactory("tee", TeeNode::Make)); } } // namespace internal diff --git a/cpp/src/arrow/engine/substrait/expression_internal.cc b/cpp/src/arrow/engine/substrait/expression_internal.cc index d18ae4dcb41..db67a8ae5d6 100644 --- a/cpp/src/arrow/engine/substrait/expression_internal.cc +++ b/cpp/src/arrow/engine/substrait/expression_internal.cc @@ -165,7 +165,22 @@ Result FromProto(const substrait::Expression& expr, ARROW_ASSIGN_OR_RAISE(arguments[i], FromProto(scalar_fn.args(i), ext_set)); } - return compute::call(decoded_function.name.to_string(), std::move(arguments)); + auto func_name = decoded_function.name.to_string(); + if (func_name != "cast") { + return compute::call(func_name, std::move(arguments)); + } else { + switch (scalar_fn.output_type().kind_case()) { + case substrait::Type::kTimestamp: // fall through + case substrait::Type::kTimestampTz: { + auto cast_options = + compute::CastOptions::Safe(arrow::timestamp(TimeUnit::NANO, "utc")); + return compute::call(func_name, std::move(arguments), cast_options); + } + + default: + break; + } + } } default: diff --git a/cpp/src/arrow/engine/substrait/expression_internal.h b/cpp/src/arrow/engine/substrait/expression_internal.h index e491fa674cf..67d53725433 100644 --- a/cpp/src/arrow/engine/substrait/expression_internal.h +++ b/cpp/src/arrow/engine/substrait/expression_internal.h @@ -26,7 +26,8 @@ #include "arrow/engine/visibility.h" #include "arrow/type_fwd.h" -#include "substrait/expression.pb.h" // IWYU pragma: export +#include "substrait/algebra.pb.h" // IWYU pragma: export +#include "substrait/type_expressions.pb.h" // IWYU pragma: export namespace arrow { namespace engine { diff --git a/cpp/src/arrow/engine/substrait/extension_set.cc b/cpp/src/arrow/engine/substrait/extension_set.cc index b2d75faa38c..587de5e923e 100644 --- a/cpp/src/arrow/engine/substrait/extension_set.cc +++ b/cpp/src/arrow/engine/substrait/extension_set.cc @@ -164,7 +164,7 @@ Result ExtensionSet::Make(std::vector uris, set.functions_[i] = {rec->id, rec->function_name}; continue; } - return Status::Invalid("Function ", function_ids[i].uri, "#", type_ids[i].name, + return Status::Invalid("Function ", function_ids[i].uri, "#", function_ids[i].name, " not found"); } @@ -252,6 +252,9 @@ ExtensionIdRegistry* default_extension_id_registry() { std::make_pair("power", "power"), std::make_pair("clip_lower", "maximum"), std::make_pair("clip_upper", "minimum"), + std::make_pair("equals", "equal"), + std::make_pair("cast", "cast"), + std::make_pair("negate", "negate"), }) { DCHECK_OK(RegisterFunction( {kArrowExtTypesUri, name_pair.first}, name_pair.second.to_string())); diff --git a/cpp/src/arrow/engine/substrait/relation_internal.cc b/cpp/src/arrow/engine/substrait/relation_internal.cc index 4e4a4f51265..1faaf157bfc 100644 --- a/cpp/src/arrow/engine/substrait/relation_internal.cc +++ b/cpp/src/arrow/engine/substrait/relation_internal.cc @@ -19,6 +19,7 @@ #include "arrow/compute/api_scalar.h" #include "arrow/compute/exec/options.h" +#include "arrow/dataset/file_base.h" #include "arrow/dataset/file_ipc.h" #include "arrow/dataset/file_parquet.h" #include "arrow/dataset/plan.h" @@ -26,10 +27,35 @@ #include "arrow/engine/substrait/expression_internal.h" #include "arrow/engine/substrait/type_internal.h" #include "arrow/filesystem/localfs.h" +#include "arrow/filesystem/path_util.h" namespace arrow { namespace engine { +static std::shared_ptr<::arrow::dataset::Partitioning> EmptyPartitioning() { + class EmptyPartitioning : public ::arrow::dataset::Partitioning { + public: + EmptyPartitioning() : ::arrow::dataset::Partitioning(::arrow::schema({})) {} + + std::string type_name() const override { return "empty"; } + + Result Parse(const std::string& path) const override { + return compute::literal(true); + } + + Result Format(const compute::Expression& expr) const override { + return ""; + } + + Result Partition( + const std::shared_ptr& batch) const override { + return PartitionedBatches{{batch}, {compute::literal(true)}}; + } + }; + + return std::make_shared(); +} + template Status CheckRelCommon(const RelMessage& rel) { if (rel.has_common()) { @@ -49,8 +75,9 @@ Status CheckRelCommon(const RelMessage& rel) { return Status::OK(); } -Result FromProto(const substrait::Rel& rel, - const ExtensionSet& ext_set) { +Result FromProtoInternal(const substrait::Rel& rel, + const ExtensionSet& ext_set, + std::vector& names) { static bool dataset_init = false; if (!dataset_init) { dataset_init = true; @@ -98,20 +125,22 @@ Result FromProto(const substrait::Rel& rel, "path_type other than uri_file"); } + util::string_view uri_file{item.uri_file()}; + if (item.format() == substrait::ReadRel::LocalFiles::FileOrFiles::FILE_FORMAT_PARQUET) { format = std::make_shared(); - } else if (util::string_view{item.uri_file()}.ends_with(".arrow")) { + } else if (uri_file.ends_with(".arrow")) { format = std::make_shared(); - } else if (util::string_view{item.uri_file()}.ends_with(".feather")) { + } else if (uri_file.ends_with(".feather")) { format = std::make_shared(); } else { return Status::NotImplemented( "substrait::ReadRel::LocalFiles::FileOrFiles::format " - "other than FILE_FORMAT_PARQUET"); + "other than FILE_FORMAT_PARQUET and not recognized"); } - if (!util::string_view{item.uri_file()}.starts_with("file:///")) { + if (!uri_file.starts_with("file:///")) { return Status::NotImplemented( "substrait::ReadRel::LocalFiles::FileOrFiles::uri_file " "with other than local filesystem (file:///)"); @@ -147,6 +176,95 @@ Result FromProto(const substrait::Rel& rel, "scan", dataset::ScanNodeOptions{std::move(ds), std::move(scan_options)}}; } + case substrait::Rel::RelTypeCase::kWrite: { + const auto& write = rel.write(); + RETURN_NOT_OK(CheckRelCommon(write)); + + if (!write.has_input()) { + return Status::Invalid("substrait::WriteRel with no input relation"); + } + ARROW_ASSIGN_OR_RAISE(auto input, FromProto(write.input(), ext_set, names)); + + if (!write.has_local_files()) { + return Status::NotImplemented( + "substrait::WriteRel with write_type other than LocalFiles"); + } + + if (write.local_files().has_advanced_extension()) { + return Status::NotImplemented( + "substrait::WriteRel::LocalFiles::advanced_extension"); + } + + std::shared_ptr format; + auto filesystem = std::make_shared(); + + if (write.local_files().items().size() != 1) { + return Status::NotImplemented( + "substrait::WriteRel with non-single LocalFiles items"); + } + + dataset::FileSystemDatasetWriteOptions write_options; + write_options.filesystem = filesystem; + write_options.partitioning = EmptyPartitioning(); + + for (const auto& item : write.local_files().items()) { + if (item.path_type_case() != + substrait::WriteRel_LocalFiles_FileOrFiles::kUriFile) { + return Status::NotImplemented( + "substrait::WriteRel::LocalFiles::FileOrFiles with " + "path_type other than uri_file"); + } + + util::string_view uri_file{item.uri_file()}; + + if (item.format() == + substrait::WriteRel::LocalFiles::FileOrFiles::FILE_FORMAT_PARQUET) { + format = std::make_shared(); + } else if (uri_file.ends_with(".arrow")) { + format = std::make_shared(); + } else if (uri_file.ends_with(".feather")) { + format = std::make_shared(); + } else { + return Status::NotImplemented( + "substrait::WriteRel::LocalFiles::FileOrFiles::format " + "other than FILE_FORMAT_PARQUET and not recognized"); + } + write_options.file_write_options = format->DefaultWriteOptions(); + + if (!uri_file.starts_with("file:///")) { + return Status::NotImplemented( + "substrait::WriteRel::LocalFiles::FileOrFiles::uri_file " + "with other than local filesystem (file:///)"); + } + auto path = item.uri_file().substr(7); + + if (item.partition_index() != 0) { + return Status::NotImplemented( + "non-default " + "substrait::WriteRel::LocalFiles::FileOrFiles::partition_index"); + } + + if (item.start_row() != 0) { + return Status::NotImplemented( + "non-default substrait::ReadRel::LocalFiles::FileOrFiles::start_row"); + } + + if (item.number_of_rows() != 0) { + return Status::NotImplemented( + "non-default substrait::ReadRel::LocalFiles::FileOrFiles::number_of_rows"); + } + + auto path_pair = fs::internal::GetAbstractPathParent(path); + write_options.basename_template = path_pair.second; + write_options.base_dir = path_pair.first; + } + + return compute::Declaration::Sequence({ + std::move(input), + {"tee", dataset::WriteNodeOptions{std::move(write_options), nullptr}}, + }); + } + case substrait::Rel::RelTypeCase::kFilter: { const auto& filter = rel.filter(); RETURN_NOT_OK(CheckRelCommon(filter)); @@ -154,7 +272,7 @@ Result FromProto(const substrait::Rel& rel, if (!filter.has_input()) { return Status::Invalid("substrait::FilterRel with no input relation"); } - ARROW_ASSIGN_OR_RAISE(auto input, FromProto(filter.input(), ext_set)); + ARROW_ASSIGN_OR_RAISE(auto input, FromProto(filter.input(), ext_set, names)); if (!filter.has_condition()) { return Status::Invalid("substrait::FilterRel with no condition expression"); @@ -174,17 +292,27 @@ Result FromProto(const substrait::Rel& rel, if (!project.has_input()) { return Status::Invalid("substrait::ProjectRel with no input relation"); } - ARROW_ASSIGN_OR_RAISE(auto input, FromProto(project.input(), ext_set)); + ARROW_ASSIGN_OR_RAISE(auto input, + FromProtoInternal(project.input(), ext_set, names)); + size_t expr_size = static_cast(project.expressions_size()); + auto names_begin = names.end() - std::min(expr_size, names.size()); + auto names_iter = names_begin; + std::vector project_names; std::vector expressions; for (const auto& expr : project.expressions()) { expressions.emplace_back(); ARROW_ASSIGN_OR_RAISE(expressions.back(), FromProto(expr, ext_set)); + project_names.push_back( + names_iter != names.end() ? *names_iter++ : expressions.back().ToString()); } + names.erase(names_begin, names.end()); return compute::Declaration::Sequence({ std::move(input), - {"project", compute::ProjectNodeOptions{std::move(expressions)}}, + {"project", + compute::ProjectNodeOptions{std::move(expressions), std::move(project_names)} + }, }); } @@ -197,5 +325,12 @@ Result FromProto(const substrait::Rel& rel, rel.DebugString()); } +Result FromProto(const substrait::Rel& rel, + const ExtensionSet& ext_set, + std::vector names) { + std::vector copy_names(names.begin(), names.end()); + return FromProtoInternal(rel, ext_set, copy_names); +} + } // namespace engine } // namespace arrow diff --git a/cpp/src/arrow/engine/substrait/relation_internal.h b/cpp/src/arrow/engine/substrait/relation_internal.h index d9b90f50779..2462dd4d9ed 100644 --- a/cpp/src/arrow/engine/substrait/relation_internal.h +++ b/cpp/src/arrow/engine/substrait/relation_internal.h @@ -25,13 +25,14 @@ #include "arrow/engine/visibility.h" #include "arrow/type_fwd.h" -#include "substrait/relations.pb.h" // IWYU pragma: export +#include "substrait/algebra.pb.h" // IWYU pragma: export namespace arrow { namespace engine { ARROW_ENGINE_EXPORT -Result FromProto(const substrait::Rel&, const ExtensionSet&); +Result FromProto(const substrait::Rel&, const ExtensionSet&, + std::vector = {}); } // namespace engine } // namespace arrow diff --git a/cpp/src/arrow/engine/substrait/serde.cc b/cpp/src/arrow/engine/substrait/serde.cc index 9e6e64910bd..91818836853 100644 --- a/cpp/src/arrow/engine/substrait/serde.cc +++ b/cpp/src/arrow/engine/substrait/serde.cc @@ -67,20 +67,14 @@ Result> DeserializePlan( std::vector sink_decls; for (const substrait::PlanRel& plan_rel : plan.relations()) { - ARROW_ASSIGN_OR_RAISE( - auto decl, - FromProto(plan_rel.has_root() ? plan_rel.root().input() : plan_rel.rel(), - ext_set)); + const substrait::Rel& rel = + plan_rel.has_root() ? plan_rel.root().input() : plan_rel.rel(); + std::vector names; if (plan_rel.has_root()) { - compute::ProjectNodeOptions* options_with_names = - dynamic_cast(decl.options.get()); - if (options_with_names == nullptr) { - return Status::NotImplemented( - "substrait::PlanRel with custom output field names"); - } - auto names = plan_rel.root().names(); - options_with_names->names = {names.begin(), names.end()}; + const auto& root_names = plan_rel.root().names(); + names.assign(root_names.begin(), root_names.end()); } + ARROW_ASSIGN_OR_RAISE(auto decl, FromProto(rel, ext_set, names)); // pipe each relation into a consuming_sink node auto sink_decl = compute::Declaration::Sequence({ From abee905903b6901232a7cab5827c3e908d3e7f13 Mon Sep 17 00:00:00 2001 From: Yaron Gvili Date: Sun, 27 Mar 2022 10:27:06 -0400 Subject: [PATCH 04/19] Added logical comparison operators to Substrait registry --- cpp/src/arrow/engine/substrait/extension_set.cc | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/cpp/src/arrow/engine/substrait/extension_set.cc b/cpp/src/arrow/engine/substrait/extension_set.cc index 587de5e923e..c74066ab140 100644 --- a/cpp/src/arrow/engine/substrait/extension_set.cc +++ b/cpp/src/arrow/engine/substrait/extension_set.cc @@ -253,6 +253,11 @@ ExtensionIdRegistry* default_extension_id_registry() { std::make_pair("clip_lower", "maximum"), std::make_pair("clip_upper", "minimum"), std::make_pair("equals", "equal"), + std::make_pair("not_equals", "not_equal"), + std::make_pair("less", "less"), + std::make_pair("greater", "greater"), + std::make_pair("less_equal", "less_equal"), + std::make_pair("greater_equal", "greater_equal"), std::make_pair("cast", "cast"), std::make_pair("negate", "negate"), }) { From 3f3f3eff05b7789c655cf5abf8dd64abfd1d79f4 Mon Sep 17 00:00:00 2001 From: Yaron Gvili Date: Thu, 31 Mar 2022 07:54:45 -0400 Subject: [PATCH 05/19] Added as-of-merge execution --- cpp/src/arrow/compute/exec/options.h | 17 ++++ .../engine/substrait/expression_internal.cc | 15 +--- .../engine/substrait/relation_internal.cc | 77 ++++++++++++++----- .../engine/substrait/relation_internal.h | 2 +- .../arrow/engine/substrait/type_internal.cc | 2 +- 5 files changed, 81 insertions(+), 32 deletions(-) diff --git a/cpp/src/arrow/compute/exec/options.h b/cpp/src/arrow/compute/exec/options.h index d2ad45d37b9..26554885124 100644 --- a/cpp/src/arrow/compute/exec/options.h +++ b/cpp/src/arrow/compute/exec/options.h @@ -310,5 +310,22 @@ class ARROW_EXPORT TableSinkNodeOptions : public ExecNodeOptions { std::shared_ptr output_schema; }; +/// \addtogroup execnode-options +/// @{ + +/// \brief Make a node which implements as-of-merge (v1) operation. +class ARROW_EXPORT AsOfMergeV1NodeOptions : public ExecNodeOptions { + public: + AsOfMergeV1NodeOptions(std::string key_column, std::string time_column, + int64_t tolerance) + : key_column(key_column), time_column(time_column), tolerance(tolerance) {} + + std::string key_column; + std::string time_column; + int64_t tolerance; +}; + +/// @} + } // namespace compute } // namespace arrow diff --git a/cpp/src/arrow/engine/substrait/expression_internal.cc b/cpp/src/arrow/engine/substrait/expression_internal.cc index db67a8ae5d6..5f4a7e3e163 100644 --- a/cpp/src/arrow/engine/substrait/expression_internal.cc +++ b/cpp/src/arrow/engine/substrait/expression_internal.cc @@ -169,17 +169,10 @@ Result FromProto(const substrait::Expression& expr, if (func_name != "cast") { return compute::call(func_name, std::move(arguments)); } else { - switch (scalar_fn.output_type().kind_case()) { - case substrait::Type::kTimestamp: // fall through - case substrait::Type::kTimestampTz: { - auto cast_options = - compute::CastOptions::Safe(arrow::timestamp(TimeUnit::NANO, "utc")); - return compute::call(func_name, std::move(arguments), cast_options); - } - - default: - break; - } + ARROW_ASSIGN_OR_RAISE(auto output_type_desc, + FromProto(scalar_fn.output_type(), ext_set)); + auto cast_options = compute::CastOptions::Safe(output_type_desc.first); + return compute::call(func_name, std::move(arguments), cast_options); } } diff --git a/cpp/src/arrow/engine/substrait/relation_internal.cc b/cpp/src/arrow/engine/substrait/relation_internal.cc index 1faaf157bfc..60043245a75 100644 --- a/cpp/src/arrow/engine/substrait/relation_internal.cc +++ b/cpp/src/arrow/engine/substrait/relation_internal.cc @@ -75,9 +75,11 @@ Status CheckRelCommon(const RelMessage& rel) { return Status::OK(); } -Result FromProtoInternal(const substrait::Rel& rel, - const ExtensionSet& ext_set, - std::vector& names) { +Result FromProtoInternal( + const substrait::Rel& rel, + const ExtensionSet& ext_set, + std::vector::const_iterator& names_begin, + std::vector::const_iterator& names_end) { static bool dataset_init = false; if (!dataset_init) { dataset_init = true; @@ -183,7 +185,8 @@ Result FromProtoInternal(const substrait::Rel& rel, if (!write.has_input()) { return Status::Invalid("substrait::WriteRel with no input relation"); } - ARROW_ASSIGN_OR_RAISE(auto input, FromProto(write.input(), ext_set, names)); + ARROW_ASSIGN_OR_RAISE(auto input, FromProtoInternal(write.input(), ext_set, + names_begin, names_end)); if (!write.has_local_files()) { return Status::NotImplemented( @@ -254,14 +257,14 @@ Result FromProtoInternal(const substrait::Rel& rel, "non-default substrait::ReadRel::LocalFiles::FileOrFiles::number_of_rows"); } - auto path_pair = fs::internal::GetAbstractPathParent(path); + auto path_pair = fs::internal::GetAbstractPathParent(path); write_options.basename_template = path_pair.second; - write_options.base_dir = path_pair.first; + write_options.base_dir = path_pair.first; } return compute::Declaration::Sequence({ std::move(input), - {"tee", dataset::WriteNodeOptions{std::move(write_options), nullptr}}, + {"tee", dataset::WriteNodeOptions{std::move(write_options), nullptr}}, }); } @@ -272,7 +275,8 @@ Result FromProtoInternal(const substrait::Rel& rel, if (!filter.has_input()) { return Status::Invalid("substrait::FilterRel with no input relation"); } - ARROW_ASSIGN_OR_RAISE(auto input, FromProto(filter.input(), ext_set, names)); + ARROW_ASSIGN_OR_RAISE(auto input, FromProtoInternal(filter.input(), ext_set, + names_begin, names_end)); if (!filter.has_condition()) { return Status::Invalid("substrait::FilterRel with no condition expression"); @@ -292,30 +296,64 @@ Result FromProtoInternal(const substrait::Rel& rel, if (!project.has_input()) { return Status::Invalid("substrait::ProjectRel with no input relation"); } - ARROW_ASSIGN_OR_RAISE(auto input, - FromProtoInternal(project.input(), ext_set, names)); + ARROW_ASSIGN_OR_RAISE(auto input, FromProtoInternal(project.input(), ext_set, + names_begin, names_end)); - size_t expr_size = static_cast(project.expressions_size()); - auto names_begin = names.end() - std::min(expr_size, names.size()); - auto names_iter = names_begin; + auto expr_size = + static_cast(project.expressions_size()); + auto names_mid = names_end - std::min(expr_size, names_end - names_begin); + auto names_iter = names_mid; std::vector project_names; std::vector expressions; for (const auto& expr : project.expressions()) { expressions.emplace_back(); ARROW_ASSIGN_OR_RAISE(expressions.back(), FromProto(expr, ext_set)); project_names.push_back( - names_iter != names.end() ? *names_iter++ : expressions.back().ToString()); + names_iter != names_end ? *names_iter++ : expressions.back().ToString()); } - names.erase(names_begin, names.end()); + names_end = names_mid; return compute::Declaration::Sequence({ std::move(input), {"project", compute::ProjectNodeOptions{std::move(expressions), std::move(project_names)} - }, + }, }); } + case substrait::Rel::RelTypeCase::kAsOfMerge: { + const auto& as_of_merge = rel.as_of_merge(); + RETURN_NOT_OK(CheckRelCommon(as_of_merge)); + + auto inputs_size = as_of_merge.inputs_size(); + if (inputs_size < 2) { + return Status::Invalid("substrait::AsOfMergeRel with fewer than 2 inputs"); + } + if (inputs_size > 6) { + return Status::Invalid("substrait::AsOfMergeRel with more than 6 inputs"); + } + if (as_of_merge.version_case() != substrait::AsOfMergeRel::VersionCase::kV1) { + return Status::Invalid("substrait::AsOfMergeRel with unsupported version"); + } + std::vector inputs; + inputs.reserve(inputs_size); + for (auto input_rel : as_of_merge.inputs()) { + ARROW_ASSIGN_OR_RAISE(auto decl, FromProtoInternal(input_rel, ext_set, + names_begin, names_end)); + auto input = compute::Declaration::Input(decl); + inputs.push_back(input); + } + return compute::Declaration{ + "as_of_merge", + inputs, + compute::AsOfMergeV1NodeOptions{ + as_of_merge.v1().key_column(), + as_of_merge.v1().time_column(), + as_of_merge.v1().tolerance(), + } + }; + } + default: break; } @@ -327,9 +365,10 @@ Result FromProtoInternal(const substrait::Rel& rel, Result FromProto(const substrait::Rel& rel, const ExtensionSet& ext_set, - std::vector names) { - std::vector copy_names(names.begin(), names.end()); - return FromProtoInternal(rel, ext_set, copy_names); + const std::vector& names) { + std::vector::const_iterator names_begin = names.begin(); + std::vector::const_iterator names_end = names.end(); + return FromProtoInternal(rel, ext_set, names_begin, names_end); } } // namespace engine diff --git a/cpp/src/arrow/engine/substrait/relation_internal.h b/cpp/src/arrow/engine/substrait/relation_internal.h index 2462dd4d9ed..ebc6a100c0d 100644 --- a/cpp/src/arrow/engine/substrait/relation_internal.h +++ b/cpp/src/arrow/engine/substrait/relation_internal.h @@ -32,7 +32,7 @@ namespace engine { ARROW_ENGINE_EXPORT Result FromProto(const substrait::Rel&, const ExtensionSet&, - std::vector = {}); + const std::vector& = {}); } // namespace engine } // namespace arrow diff --git a/cpp/src/arrow/engine/substrait/type_internal.cc b/cpp/src/arrow/engine/substrait/type_internal.cc index c1dac97b682..d425f85a983 100644 --- a/cpp/src/arrow/engine/substrait/type_internal.cc +++ b/cpp/src/arrow/engine/substrait/type_internal.cc @@ -127,7 +127,7 @@ Result, bool>> FromProto( return FromProtoImpl(type.timestamp(), TimeUnit::MICRO); case substrait::Type::kTimestampTz: return FromProtoImpl(type.timestamp_tz(), TimeUnit::MICRO, - TimestampTzTimezoneString()); + type.timestamp_tz().tz()); case substrait::Type::kDate: return FromProtoImpl(type.date()); From 98d2663c7c3e3d100992115a875935513cf5d1cc Mon Sep 17 00:00:00 2001 From: Yaron Gvili Date: Fri, 8 Apr 2022 06:03:28 -0400 Subject: [PATCH 06/19] Added Substrait deserialization of flat field references for AsOfMerge --- cpp/src/arrow/array/builder_binary.h | 4 + cpp/src/arrow/compute/exec/options.h | 11 +- .../compute/kernels/scalar_compare_test.cc | 46 +++++++ .../arrow/engine/substrait/extension_set.cc | 1 + .../engine/substrait/relation_internal.cc | 128 +++++++++++++----- 5 files changed, 153 insertions(+), 37 deletions(-) diff --git a/cpp/src/arrow/array/builder_binary.h b/cpp/src/arrow/array/builder_binary.h index 703355bf278..95470b542fd 100644 --- a/cpp/src/arrow/array/builder_binary.h +++ b/cpp/src/arrow/array/builder_binary.h @@ -424,6 +424,10 @@ class ARROW_EXPORT StringBuilder : public BinaryBuilder { Status Finish(std::shared_ptr* out) { return FinishTyped(out); } std::shared_ptr type() const override { return utf8(); } + + util::string_view operator[](int64_t i) const { + return GetView(i); + } }; /// \class LargeBinaryBuilder diff --git a/cpp/src/arrow/compute/exec/options.h b/cpp/src/arrow/compute/exec/options.h index 26554885124..bc08630b791 100644 --- a/cpp/src/arrow/compute/exec/options.h +++ b/cpp/src/arrow/compute/exec/options.h @@ -316,12 +316,15 @@ class ARROW_EXPORT TableSinkNodeOptions : public ExecNodeOptions { /// \brief Make a node which implements as-of-merge (v1) operation. class ARROW_EXPORT AsOfMergeV1NodeOptions : public ExecNodeOptions { public: - AsOfMergeV1NodeOptions(std::string key_column, std::string time_column, + AsOfMergeV1NodeOptions(std::vector key_fields, + std::vector time_fields, int64_t tolerance) - : key_column(key_column), time_column(time_column), tolerance(tolerance) {} + : key_fields(key_fields), + time_fields(time_fields), + tolerance(tolerance) {} - std::string key_column; - std::string time_column; + std::vector key_fields; + std::vector time_fields; int64_t tolerance; }; diff --git a/cpp/src/arrow/compute/kernels/scalar_compare_test.cc b/cpp/src/arrow/compute/kernels/scalar_compare_test.cc index f0f2d7e3679..64ebf0b44f7 100644 --- a/cpp/src/arrow/compute/kernels/scalar_compare_test.cc +++ b/cpp/src/arrow/compute/kernels/scalar_compare_test.cc @@ -502,6 +502,52 @@ TEST(TestCompareTimestamps, DifferentParameters) { } } +TEST(TestCompareTimestamps, ScalarArray) { + const char* scalar_json = R"("1970-01-02")"; + const char* array_json = R"(["1970-01-02","2000-02-01","1900-02-28"])"; + + auto CheckArrayCase = [&](std::shared_ptr scalar_type, + std::shared_ptr array_type, + CompareOperator op, const char* expected_json) { + auto lhs = ScalarFromJSON(scalar_type, scalar_json); + auto rhs = ArrayFromJSON(array_type, array_json); + auto expected = ArrayFromJSON(boolean(), expected_json); + if (scalar_type->Equals(array_type)) { + ASSERT_OK_AND_ASSIGN(Datum result, + CallFunction(CompareOperatorToFunctionName(op), {lhs, rhs})); + AssertArraysEqual(*expected, *result.make_array(), /*verbose=*/true); + } else { + EXPECT_RAISES_WITH_MESSAGE_THAT( + Invalid, + ::testing::HasSubstr( + "Cannot compare timestamp with timezone to timestamp without timezone"), + CallFunction(CompareOperatorToFunctionName(op), {lhs, rhs})); + } + }; + + for (auto unit : {TimeUnit::SECOND, + TimeUnit::MILLI, + TimeUnit::MICRO, + TimeUnit::NANO, + }) { + for (auto types : + std::vector, std::shared_ptr>> { + {timestamp(unit), timestamp(unit)}, + {timestamp(unit), timestamp(unit, "utc")}, + {timestamp(unit, "utc"), timestamp(unit)}, + {timestamp(unit, "utc"), timestamp(unit, "utc")}, + }) { + auto t0 = types.first, t1 = types.second; + CheckArrayCase(t0, t1, CompareOperator::EQUAL, "[true, false, false]"); + CheckArrayCase(t0, t1, CompareOperator::NOT_EQUAL, "[false, true, true]"); + CheckArrayCase(t0, t1, CompareOperator::LESS, "[false, true, false]"); + CheckArrayCase(t0, t1, CompareOperator::LESS_EQUAL, "[true, true, false]"); + CheckArrayCase(t0, t1, CompareOperator::GREATER, "[false, false, true]"); + CheckArrayCase(t0, t1, CompareOperator::GREATER_EQUAL, "[true, false, true]"); + } + } +} + template class TestCompareDecimal : public ::testing::Test {}; TYPED_TEST_SUITE(TestCompareDecimal, DecimalArrowTypes); diff --git a/cpp/src/arrow/engine/substrait/extension_set.cc b/cpp/src/arrow/engine/substrait/extension_set.cc index c74066ab140..31671569aff 100644 --- a/cpp/src/arrow/engine/substrait/extension_set.cc +++ b/cpp/src/arrow/engine/substrait/extension_set.cc @@ -247,6 +247,7 @@ ExtensionIdRegistry* default_extension_id_registry() { // for Substrait, and include mappings for all of them here. See // ARROW-15535. for (std::pair name_pair : { + std::make_pair("and", "and"), std::make_pair("add", "add"), std::make_pair("divide", "divide"), std::make_pair("power", "power"), diff --git a/cpp/src/arrow/engine/substrait/relation_internal.cc b/cpp/src/arrow/engine/substrait/relation_internal.cc index 60043245a75..e89cb0c9452 100644 --- a/cpp/src/arrow/engine/substrait/relation_internal.cc +++ b/cpp/src/arrow/engine/substrait/relation_internal.cc @@ -14,7 +14,7 @@ // KIND, either express or implied. See the License for the // specific language governing permissions and limitations // under the License. - +#include #include "arrow/engine/substrait/relation_internal.h" #include "arrow/compute/api_scalar.h" @@ -75,11 +75,72 @@ Status CheckRelCommon(const RelMessage& rel) { return Status::OK(); } +Result FromProto(const substrait::Expression& expr, const std::string& what) { + int32_t index; + switch (expr.rex_type_case()) { + case substrait::Expression::RexTypeCase::kSelection: { + const auto& selection = expr.selection(); + switch (selection.root_type_case()) { + case substrait::Expression_FieldReference::RootTypeCase::kRootReference: { + break; + } + default: { + return Status::NotImplemented(std::string( + "substrait::Expression with non-root-reference for ") + what); + } + } + switch (selection.reference_type_case()) { + case substrait::Expression_FieldReference::ReferenceTypeCase::kDirectReference: { + const auto& direct_reference = selection.direct_reference(); + switch (direct_reference.reference_type_case()) { + case substrait::Expression_ReferenceSegment::ReferenceTypeCase::kStructField: + { + break; + } + default: { + return Status::NotImplemented(std::string( + "substrait::Expression with non-struct-field for ") + what); + } + } + const auto& struct_field = direct_reference.struct_field(); + if (struct_field.has_child()) { + return Status::NotImplemented(std::string( + "substrait::Expression with non-flat struct-field for ") + what); + } + index = struct_field.field(); + break; + } + default: { + return Status::NotImplemented(std::string( + "substrait::Expression with non-direct reference for ") + what); + } + } + break; + } + default: { + return Status::NotImplemented(std::string( + "substrait::AsOfMergeRel with non-selection for ") + what); + } + } + return FieldRef(FieldPath({index})); +} + +Result> FromProto( + const google::protobuf::RepeatedPtrField& exprs, + const std::string& what) { + std::vector fields; + int size = exprs.size(); + for (int i = 0; i < size; i++) { + ARROW_ASSIGN_OR_RAISE( + FieldRef field, FromProto(exprs[i], what)); + fields.push_back(field); + } + return fields; +} + Result FromProtoInternal( const substrait::Rel& rel, - const ExtensionSet& ext_set, - std::vector::const_iterator& names_begin, - std::vector::const_iterator& names_end) { + const ExtensionSet& ext_set) { static bool dataset_init = false; if (!dataset_init) { dataset_init = true; @@ -185,8 +246,7 @@ Result FromProtoInternal( if (!write.has_input()) { return Status::Invalid("substrait::WriteRel with no input relation"); } - ARROW_ASSIGN_OR_RAISE(auto input, FromProtoInternal(write.input(), ext_set, - names_begin, names_end)); + ARROW_ASSIGN_OR_RAISE(auto input, FromProtoInternal(write.input(), ext_set)); if (!write.has_local_files()) { return Status::NotImplemented( @@ -275,8 +335,7 @@ Result FromProtoInternal( if (!filter.has_input()) { return Status::Invalid("substrait::FilterRel with no input relation"); } - ARROW_ASSIGN_OR_RAISE(auto input, FromProtoInternal(filter.input(), ext_set, - names_begin, names_end)); + ARROW_ASSIGN_OR_RAISE(auto input, FromProtoInternal(filter.input(), ext_set)); if (!filter.has_condition()) { return Status::Invalid("substrait::FilterRel with no condition expression"); @@ -296,27 +355,18 @@ Result FromProtoInternal( if (!project.has_input()) { return Status::Invalid("substrait::ProjectRel with no input relation"); } - ARROW_ASSIGN_OR_RAISE(auto input, FromProtoInternal(project.input(), ext_set, - names_begin, names_end)); - - auto expr_size = - static_cast(project.expressions_size()); - auto names_mid = names_end - std::min(expr_size, names_end - names_begin); - auto names_iter = names_mid; - std::vector project_names; + ARROW_ASSIGN_OR_RAISE(auto input, FromProtoInternal(project.input(), ext_set)); + std::vector expressions; for (const auto& expr : project.expressions()) { expressions.emplace_back(); ARROW_ASSIGN_OR_RAISE(expressions.back(), FromProto(expr, ext_set)); - project_names.push_back( - names_iter != names_end ? *names_iter++ : expressions.back().ToString()); } - names_end = names_mid; return compute::Declaration::Sequence({ std::move(input), {"project", - compute::ProjectNodeOptions{std::move(expressions), std::move(project_names)} + compute::ProjectNodeOptions{std::move(expressions)} }, }); } @@ -335,22 +385,25 @@ Result FromProtoInternal( if (as_of_merge.version_case() != substrait::AsOfMergeRel::VersionCase::kV1) { return Status::Invalid("substrait::AsOfMergeRel with unsupported version"); } + + const auto& v1 = as_of_merge.v1(); + ARROW_ASSIGN_OR_RAISE( + auto key_fields, FromProto(v1.key_fields(), "AsOfMerge key field")); + ARROW_ASSIGN_OR_RAISE( + auto time_fields, FromProto(v1.time_fields(), "AsOfMerge time field")); + int64_t tolerance = as_of_merge.v1().tolerance(); + std::vector inputs; inputs.reserve(inputs_size); for (auto input_rel : as_of_merge.inputs()) { - ARROW_ASSIGN_OR_RAISE(auto decl, FromProtoInternal(input_rel, ext_set, - names_begin, names_end)); - auto input = compute::Declaration::Input(decl); - inputs.push_back(input); + ARROW_ASSIGN_OR_RAISE(auto decl, FromProtoInternal(input_rel, ext_set)); + auto input = compute::Declaration::Input(decl); + inputs.push_back(input); } return compute::Declaration{ "as_of_merge", - inputs, - compute::AsOfMergeV1NodeOptions{ - as_of_merge.v1().key_column(), - as_of_merge.v1().time_column(), - as_of_merge.v1().tolerance(), - } + inputs, + compute::AsOfMergeV1NodeOptions{key_fields, time_fields, tolerance} }; } @@ -366,9 +419,18 @@ Result FromProtoInternal( Result FromProto(const substrait::Rel& rel, const ExtensionSet& ext_set, const std::vector& names) { - std::vector::const_iterator names_begin = names.begin(); - std::vector::const_iterator names_end = names.end(); - return FromProtoInternal(rel, ext_set, names_begin, names_end); + ARROW_ASSIGN_OR_RAISE(auto input, FromProtoInternal(rel, ext_set)); + int names_size = names.size(); + std::vector expressions; + for (int i = 0; i < names_size; i++) { + expressions.push_back(compute::field_ref(FieldRef(i))); + } + return compute::Declaration::Sequence({ + std::move(input), + {"project", + compute::ProjectNodeOptions{std::move(expressions), std::move(names)} + }, + }); } } // namespace engine From 5aa7ede884f16595bee86fbb7512eb078e104266 Mon Sep 17 00:00:00 2001 From: Yaron Gvili Date: Sun, 17 Apr 2022 05:16:08 -0400 Subject: [PATCH 07/19] Support write-consumer of Arrow Substrait plan --- cpp/src/arrow/dataset/file_base.cc | 20 ++++++++++++- cpp/src/arrow/dataset/partition.h | 20 +++++++++++++ .../engine/substrait/relation_internal.cc | 26 +---------------- cpp/src/arrow/engine/substrait/serde.cc | 28 +++++++++++++++++-- cpp/src/arrow/engine/substrait/serde.h | 7 +++++ 5 files changed, 72 insertions(+), 29 deletions(-) diff --git a/cpp/src/arrow/dataset/file_base.cc b/cpp/src/arrow/dataset/file_base.cc index 04080a17a68..1517fc6edbc 100644 --- a/cpp/src/arrow/dataset/file_base.cc +++ b/cpp/src/arrow/dataset/file_base.cc @@ -310,6 +310,22 @@ class DatasetWritingSinkNodeConsumer : public compute::SinkNodeConsumer { write_options_(std::move(write_options)), backpressure_toggle_(std::move(backpressure_toggle)) {} + Status Init(compute::ExecNode* node) { + if (node == nullptr) { + return Status::Invalid("internal error - null node"); + } + auto schema = node->inputs()[0]->output_schema(); + if (schema.get() == nullptr) { + return Status::Invalid("internal error - null schema"); + } + if (schema_.get() == nullptr) { + schema_ = schema; + } else if (schema_.get() != schema.get()) { + return Status::Invalid("internal error - inconsistent schemata"); + } + return Status::OK(); + } + Status Consume(compute::ExecBatch batch) { ARROW_ASSIGN_OR_RAISE(std::shared_ptr record_batch, batch.ToRecordBatch(schema_)); @@ -409,7 +425,9 @@ Result MakeWriteNode(compute::ExecPlan* plan, ARROW_ASSIGN_OR_RAISE( auto node, compute::MakeExecNode("consuming_sink", plan, std::move(inputs), - compute::ConsumingSinkNodeOptions{std::move(consumer)})); + compute::ConsumingSinkNodeOptions{consumer})); + + ARROW_RETURN_NOT_OK(consumer->Init(node)); return node; } diff --git a/cpp/src/arrow/dataset/partition.h b/cpp/src/arrow/dataset/partition.h index aa6958ed1e8..551b482e9e4 100644 --- a/cpp/src/arrow/dataset/partition.h +++ b/cpp/src/arrow/dataset/partition.h @@ -90,6 +90,26 @@ class ARROW_DS_EXPORT Partitioning { std::shared_ptr schema_; }; +class ARROW_DS_EXPORT EmptyPartitioning : public Partitioning { +public: + EmptyPartitioning() : Partitioning(::arrow::schema({})) {} + + std::string type_name() const override { return "empty"; } + + Result Parse(const std::string& path) const override { + return compute::literal(true); + } + + Result Format(const compute::Expression& expr) const override { + return ""; + } + + Result Partition( + const std::shared_ptr& batch) const override { + return PartitionedBatches{{batch}, {compute::literal(true)}}; + } +}; + /// \brief The encoding of partition segments. enum class SegmentEncoding : int8_t { /// No encoding. diff --git a/cpp/src/arrow/engine/substrait/relation_internal.cc b/cpp/src/arrow/engine/substrait/relation_internal.cc index e89cb0c9452..521ce28babc 100644 --- a/cpp/src/arrow/engine/substrait/relation_internal.cc +++ b/cpp/src/arrow/engine/substrait/relation_internal.cc @@ -32,30 +32,6 @@ namespace arrow { namespace engine { -static std::shared_ptr<::arrow::dataset::Partitioning> EmptyPartitioning() { - class EmptyPartitioning : public ::arrow::dataset::Partitioning { - public: - EmptyPartitioning() : ::arrow::dataset::Partitioning(::arrow::schema({})) {} - - std::string type_name() const override { return "empty"; } - - Result Parse(const std::string& path) const override { - return compute::literal(true); - } - - Result Format(const compute::Expression& expr) const override { - return ""; - } - - Result Partition( - const std::shared_ptr& batch) const override { - return PartitionedBatches{{batch}, {compute::literal(true)}}; - } - }; - - return std::make_shared(); -} - template Status CheckRelCommon(const RelMessage& rel) { if (rel.has_common()) { @@ -268,7 +244,7 @@ Result FromProtoInternal( dataset::FileSystemDatasetWriteOptions write_options; write_options.filesystem = filesystem; - write_options.partitioning = EmptyPartitioning(); + write_options.partitioning = std::make_shared(); for (const auto& item : write.local_files().items()) { if (item.path_type_case() != diff --git a/cpp/src/arrow/engine/substrait/serde.cc b/cpp/src/arrow/engine/substrait/serde.cc index 91818836853..69147bc73bf 100644 --- a/cpp/src/arrow/engine/substrait/serde.cc +++ b/cpp/src/arrow/engine/substrait/serde.cc @@ -58,8 +58,9 @@ Result DeserializeRelation(const Buffer& buf, return FromProto(rel, ext_set); } -Result> DeserializePlan( - const Buffer& buf, const ConsumerFactory& consumer_factory, +static Result> DeserializePlan( + const Buffer& buf, const std::string& factory_name, + std::function()> options_factory, ExtensionSet* ext_set_out) { ARROW_ASSIGN_OR_RAISE(auto plan, ParseFromBuffer(buf)); @@ -79,7 +80,7 @@ Result> DeserializePlan( // pipe each relation into a consuming_sink node auto sink_decl = compute::Declaration::Sequence({ std::move(decl), - {"consuming_sink", compute::ConsumingSinkNodeOptions{consumer_factory()}}, + {factory_name, options_factory()}, }); sink_decls.push_back(std::move(sink_decl)); } @@ -90,6 +91,27 @@ Result> DeserializePlan( return sink_decls; } +Result> DeserializePlan( + const Buffer& buf, const ConsumerFactory& consumer_factory, + ExtensionSet* ext_set_out) { + return DeserializePlan( + buf, + "consuming_sink", + [&consumer_factory]() { + return std::make_shared( + compute::ConsumingSinkNodeOptions{consumer_factory()} + ); + }, + ext_set_out + ); +} + +Result> DeserializePlan( + const Buffer& buf, const WriteOptionsFactory& write_options_factory, + ExtensionSet* ext_set_out) { + return DeserializePlan(buf, "write", write_options_factory, ext_set_out); +} + Result> DeserializeSchema(const Buffer& buf, const ExtensionSet& ext_set) { ARROW_ASSIGN_OR_RAISE(auto named_struct, ParseFromBuffer(buf)); diff --git a/cpp/src/arrow/engine/substrait/serde.h b/cpp/src/arrow/engine/substrait/serde.h index 9e63a1befb5..4b3e9bf2a24 100644 --- a/cpp/src/arrow/engine/substrait/serde.h +++ b/cpp/src/arrow/engine/substrait/serde.h @@ -24,6 +24,7 @@ #include #include "arrow/buffer.h" +#include "arrow/dataset/file_base.h" #include "arrow/compute/exec/exec_plan.h" #include "arrow/compute/exec/options.h" #include "arrow/engine/substrait/extension_set.h" @@ -52,6 +53,12 @@ ARROW_ENGINE_EXPORT Result> DeserializePlan( const Buffer& buf, const ConsumerFactory& consumer_factory, ExtensionSet* ext_set = NULLPTR); +using WriteOptionsFactory = std::function()>; + +ARROW_ENGINE_EXPORT Result> DeserializePlan( + const Buffer& buf, const WriteOptionsFactory& write_options_factory, + ExtensionSet* ext_set = NULLPTR); + /// \brief Deserializes a Substrait Type message to the corresponding Arrow type /// /// \param[in] buf a buffer containing the protobuf serialization of a Substrait Type From c0c0d08cafece80c7ddbd5f04c0640b7dd6afe80 Mon Sep 17 00:00:00 2001 From: Yaron Gvili Date: Thu, 28 Apr 2022 10:15:19 -0400 Subject: [PATCH 08/19] Added explanation comment on MakeWriteNode --- cpp/src/arrow/dataset/file_base.cc | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/cpp/src/arrow/dataset/file_base.cc b/cpp/src/arrow/dataset/file_base.cc index 1517fc6edbc..6b5d711410a 100644 --- a/cpp/src/arrow/dataset/file_base.cc +++ b/cpp/src/arrow/dataset/file_base.cc @@ -427,6 +427,11 @@ Result MakeWriteNode(compute::ExecPlan* plan, compute::MakeExecNode("consuming_sink", plan, std::move(inputs), compute::ConsumingSinkNodeOptions{consumer})); + // this is a workaround specific for Arrow Substrait code paths + // Arrow Substrait creates ExecNodeOptions instances within a Declaration + // at this stage, schemata have not yet been created since nodes haven't + // thus, the ConsumingSinkNodeOptions passed to consumer has a null schema + // the following call to Init fills in the schema using the node just created ARROW_RETURN_NOT_OK(consumer->Init(node)); return node; From f202dc55e1eea75b2e7e4805dfa44be623ee3052 Mon Sep 17 00:00:00 2001 From: Yaron Gvili Date: Thu, 28 Apr 2022 10:16:27 -0400 Subject: [PATCH 09/19] Set use_threads on scan options of Arrow Substrait --- cpp/src/arrow/engine/substrait/relation_internal.cc | 1 + 1 file changed, 1 insertion(+) diff --git a/cpp/src/arrow/engine/substrait/relation_internal.cc b/cpp/src/arrow/engine/substrait/relation_internal.cc index 521ce28babc..c765d3e574a 100644 --- a/cpp/src/arrow/engine/substrait/relation_internal.cc +++ b/cpp/src/arrow/engine/substrait/relation_internal.cc @@ -131,6 +131,7 @@ Result FromProtoInternal( ARROW_ASSIGN_OR_RAISE(auto base_schema, FromProto(read.base_schema(), ext_set)); auto scan_options = std::make_shared(); + scan_options->use_threads = true; if (read.has_filter()) { ARROW_ASSIGN_OR_RAISE(scan_options->filter, FromProto(read.filter(), ext_set)); From a912ea576f8053d7c327a020ab755a90d67d5f90 Mon Sep 17 00:00:00 2001 From: Yaron Gvili Date: Fri, 13 May 2022 03:12:21 -0400 Subject: [PATCH 10/19] try --- cpp/cmake_modules/ThirdpartyToolchain.cmake | 3 +-- cpp/src/arrow/engine/substrait/expression_internal.h | 2 +- cpp/src/arrow/engine/substrait/relation_internal.h | 2 +- 3 files changed, 3 insertions(+), 4 deletions(-) diff --git a/cpp/cmake_modules/ThirdpartyToolchain.cmake b/cpp/cmake_modules/ThirdpartyToolchain.cmake index 32669b2c072..760f06dde90 100644 --- a/cpp/cmake_modules/ThirdpartyToolchain.cmake +++ b/cpp/cmake_modules/ThirdpartyToolchain.cmake @@ -1691,13 +1691,12 @@ macro(build_substrait) message("Building Substrait from source") set(SUBSTRAIT_PROTOS + algebra capabilities - expression extensions/extensions function parameterized_types plan - relations type type_expressions) diff --git a/cpp/src/arrow/engine/substrait/expression_internal.h b/cpp/src/arrow/engine/substrait/expression_internal.h index 6bbc2d8c767..4e23dc8f708 100644 --- a/cpp/src/arrow/engine/substrait/expression_internal.h +++ b/cpp/src/arrow/engine/substrait/expression_internal.h @@ -26,7 +26,7 @@ #include "arrow/engine/substrait/visibility.h" #include "arrow/type_fwd.h" -#include "substrait/expression.pb.h" // IWYU pragma: export +#include "substrait/algebra.pb.h" // IWYU pragma: export namespace arrow { namespace engine { diff --git a/cpp/src/arrow/engine/substrait/relation_internal.h b/cpp/src/arrow/engine/substrait/relation_internal.h index 77d47c586b4..ec56a2d3597 100644 --- a/cpp/src/arrow/engine/substrait/relation_internal.h +++ b/cpp/src/arrow/engine/substrait/relation_internal.h @@ -25,7 +25,7 @@ #include "arrow/engine/substrait/visibility.h" #include "arrow/type_fwd.h" -#include "substrait/relations.pb.h" // IWYU pragma: export +#include "substrait/algebra.pb.h" // IWYU pragma: export namespace arrow { namespace engine { From dbacb0aa383814ac9ace88bf2f9e759253cdb083 Mon Sep 17 00:00:00 2001 From: Yaron Gvili Date: Sun, 22 May 2022 04:35:28 -0400 Subject: [PATCH 11/19] integrated and tested --- cpp/src/arrow/compute/exec/options.h | 18 + .../arrow/engine/substrait/extension_set.cc | 361 +++++++++++------- .../arrow/engine/substrait/extension_set.h | 22 +- .../arrow/engine/substrait/plan_internal.cc | 2 +- .../arrow/engine/substrait/plan_internal.h | 2 +- .../engine/substrait/relation_internal.cc | 3 + cpp/src/arrow/engine/substrait/serde.h | 2 +- cpp/src/arrow/engine/substrait/util.cc | 16 + cpp/src/arrow/engine/substrait/util.h | 10 + cpp/src/arrow/python/pyarrow.h | 8 + python/pyarrow/__init__.pxd | 4 +- python/pyarrow/_exec_plan.pxd | 25 ++ python/pyarrow/_exec_plan.pyx | 14 +- python/pyarrow/_substrait.pyx | 54 ++- python/pyarrow/includes/libarrow.pxd | 5 + .../pyarrow/includes/libarrow_substrait.pxd | 7 + python/pyarrow/lib.pxd | 12 + python/pyarrow/public-api.pxi | 20 + python/pyarrow/substrait.py | 3 + python/pyarrow/table.pxi | 14 + 20 files changed, 458 insertions(+), 144 deletions(-) create mode 100644 python/pyarrow/_exec_plan.pxd diff --git a/cpp/src/arrow/compute/exec/options.h b/cpp/src/arrow/compute/exec/options.h index f5fa639c242..7e057fd9b2f 100644 --- a/cpp/src/arrow/compute/exec/options.h +++ b/cpp/src/arrow/compute/exec/options.h @@ -229,6 +229,24 @@ class ARROW_EXPORT SinkNodeConsumer { virtual Future<> Finish() = 0; }; +class ARROW_EXPORT NullSinkNodeConsumer : public SinkNodeConsumer { +public: + virtual Status Init(const std::shared_ptr&, + BackpressureControl*) override { + return Status::OK(); + } + virtual Status Consume(ExecBatch exec_batch) override { + return Status::OK(); + } + virtual Future<> Finish() override { + return Status::OK(); + } +public: + static std::shared_ptr Make() { + return std::make_shared(); + } +}; + /// \brief Add a sink node which consumes data within the exec plan run class ARROW_EXPORT ConsumingSinkNodeOptions : public ExecNodeOptions { public: diff --git a/cpp/src/arrow/engine/substrait/extension_set.cc b/cpp/src/arrow/engine/substrait/extension_set.cc index c163b03ce5d..3d4aa93ec9e 100644 --- a/cpp/src/arrow/engine/substrait/extension_set.cc +++ b/cpp/src/arrow/engine/substrait/extension_set.cc @@ -107,14 +107,14 @@ struct ExtensionSet::Impl { std::unordered_map types_, functions_; }; -ExtensionSet::ExtensionSet(ExtensionIdRegistry* registry) +ExtensionSet::ExtensionSet(const ExtensionIdRegistry* registry) : registry_(registry), impl_(new Impl(), [](Impl* impl) { delete impl; }) {} Result ExtensionSet::Make(std::vector uris, std::vector type_ids, std::vector type_is_variation, std::vector function_ids, - ExtensionIdRegistry* registry) { + const ExtensionIdRegistry* registry) { ExtensionSet set; set.registry_ = registry; @@ -210,172 +210,263 @@ const int* GetIndex(const KeyToIndex& key_to_index, const Key& key) { return &it->second; } -ExtensionIdRegistry* default_extension_id_registry() { - static struct Impl : ExtensionIdRegistry { - Impl() { - struct TypeName { - std::shared_ptr type; - util::string_view name; - }; - - // The type (variation) mappings listed below need to be kept in sync - // with the YAML at substrait/format/extension_types.yaml manually; - // see ARROW-15535. - for (TypeName e : { - TypeName{uint8(), "u8"}, - TypeName{uint16(), "u16"}, - TypeName{uint32(), "u32"}, - TypeName{uint64(), "u64"}, - TypeName{float16(), "fp16"}, - }) { - DCHECK_OK(RegisterType({kArrowExtTypesUri, e.name}, std::move(e.type), - /*is_variation=*/true)); - } - - for (TypeName e : { - TypeName{null(), "null"}, - TypeName{month_interval(), "interval_month"}, - TypeName{day_time_interval(), "interval_day_milli"}, - TypeName{month_day_nano_interval(), "interval_month_day_nano"}, - }) { - DCHECK_OK(RegisterType({kArrowExtTypesUri, e.name}, std::move(e.type), - /*is_variation=*/false)); - } - - // TODO: this is just a placeholder right now. We'll need a YAML file for - // all functions (and prototypes) that Arrow provides that are relevant - // for Substrait, and include mappings for all of them here. See - // ARROW-15535. - for (std::pair name_pair : { - std::make_pair("and", "and"), - std::make_pair("+", "add"), - std::make_pair("/", "divide"), - std::make_pair("power", "power"), - std::make_pair("clip_lower", "maximum"), - std::make_pair("clip_upper", "minimum"), - std::make_pair("equal", "equal"), - std::make_pair("not_equal", "not_equal"), - std::make_pair("lt", "less"), - std::make_pair("gt", "greater"), - std::make_pair("lte", "less_equal"), - std::make_pair("gte", "greater_equal"), - std::make_pair("cast", "cast"), - std::make_pair("negate", "negate"), - }) { - DCHECK_OK(RegisterFunction( - {kArrowExtTypesUri, name_pair.first}, name_pair.second.to_string())); - } - } +namespace { - std::vector Uris() const override { - return {uris_.begin(), uris_.end()}; +struct ExtensionIdRegistryImpl : ExtensionIdRegistry { + std::vector Uris() const override { + return {uris_.begin(), uris_.end()}; + } + + util::optional GetType(const DataType& type) const override { + if (auto index = GetIndex(type_to_index_, &type)) { + return TypeRecord{type_ids_[*index], types_[*index], type_is_variation_[*index]}; } + return {}; + } - util::optional GetType(const DataType& type) const override { - if (auto index = GetIndex(type_to_index_, &type)) { - return TypeRecord{type_ids_[*index], types_[*index], type_is_variation_[*index]}; - } - return {}; + util::optional GetType(Id id, bool is_variation) const override { + if (auto index = GetIndex(is_variation ? variation_id_to_index_ : id_to_index_, id)) { + return TypeRecord{type_ids_[*index], types_[*index], type_is_variation_[*index]}; } + return {}; + } - util::optional GetType(Id id, bool is_variation) const override { - if (auto index = - GetIndex(is_variation ? variation_id_to_index_ : id_to_index_, id)) { - return TypeRecord{type_ids_[*index], types_[*index], type_is_variation_[*index]}; - } - return {}; + virtual Status CanRegisterType(Id id, std::shared_ptr type, + bool is_variation) const { + auto& id_to_index = is_variation ? variation_id_to_index_ : id_to_index_; + if (id_to_index.find(id) != id_to_index.end()) { + return Status::Invalid("Type id was already registered"); } + if (type_to_index_.find(&*type) != type_to_index_.end()) { + return Status::Invalid("Type was already registered"); + } + return Status::OK(); + } + + Status RegisterType(Id id, std::shared_ptr type, bool is_variation) override { + DCHECK_EQ(type_ids_.size(), types_.size()); + DCHECK_EQ(type_ids_.size(), type_is_variation_.size()); - Status RegisterType(Id id, std::shared_ptr type, - bool is_variation) override { - DCHECK_EQ(type_ids_.size(), types_.size()); - DCHECK_EQ(type_ids_.size(), type_is_variation_.size()); + Id copied_id{*uris_.emplace(id.uri.to_string()).first, + *names_.emplace(id.name.to_string()).first}; - Id copied_id{*uris_.emplace(id.uri.to_string()).first, - *names_.emplace(id.name.to_string()).first}; + auto index = static_cast(type_ids_.size()); + + auto* id_to_index = is_variation ? &variation_id_to_index_ : &id_to_index_; + auto it_success = id_to_index->emplace(copied_id, index); + + if (!it_success.second) { + return Status::Invalid("Type id was already registered"); + } - auto index = static_cast(type_ids_.size()); + if (!type_to_index_.emplace(type.get(), index).second) { + id_to_index->erase(it_success.first); + return Status::Invalid("Type was already registered"); + } - auto* id_to_index = is_variation ? &variation_id_to_index_ : &id_to_index_; - auto it_success = id_to_index->emplace(copied_id, index); + type_ids_.push_back(copied_id); + types_.push_back(std::move(type)); + type_is_variation_.push_back(is_variation); + return Status::OK(); + } - if (!it_success.second) { - return Status::Invalid("Type id was already registered"); - } + util::optional GetFunction( + util::string_view arrow_function_name) const override { + if (auto index = GetIndex(function_name_to_index_, arrow_function_name)) { + return FunctionRecord{function_ids_[*index], *function_name_ptrs_[*index]}; + } + return {}; + } - if (!type_to_index_.emplace(type.get(), index).second) { - id_to_index->erase(it_success.first); - return Status::Invalid("Type was already registered"); - } + util::optional GetFunction(Id id) const override { + if (auto index = GetIndex(function_id_to_index_, id)) { + return FunctionRecord{function_ids_[*index], *function_name_ptrs_[*index]}; + } + return {}; + } - type_ids_.push_back(copied_id); - types_.push_back(std::move(type)); - type_is_variation_.push_back(is_variation); - return Status::OK(); + virtual Status CanRegisterFunction(Id id, std::string arrow_function_name) const { + if (function_id_to_index_.find(id) == function_id_to_index_.end()) { + return Status::Invalid("Function id was already registered"); + } + if (function_name_to_index_.find(arrow_function_name) == + function_name_to_index_.end()) { + return Status::Invalid("Function name was already registered"); } + return Status::OK(); + } + + Status RegisterFunction(Id id, std::string arrow_function_name) override { + DCHECK_EQ(function_ids_.size(), function_name_ptrs_.size()); + + Id copied_id{*uris_.emplace(id.uri.to_string()).first, + *names_.emplace(id.name.to_string()).first}; + + const std::string& copied_function_name{ + *function_names_.emplace(std::move(arrow_function_name)).first}; + + auto index = static_cast(function_ids_.size()); + + auto it_success = function_id_to_index_.emplace(copied_id, index); - util::optional GetFunction( - util::string_view arrow_function_name) const override { - if (auto index = GetIndex(function_name_to_index_, arrow_function_name)) { - return FunctionRecord{function_ids_[*index], *function_name_ptrs_[*index]}; - } - return {}; + if (!it_success.second) { + return Status::Invalid("Function id was already registered"); } - util::optional GetFunction(Id id) const override { - if (auto index = GetIndex(function_id_to_index_, id)) { - return FunctionRecord{function_ids_[*index], *function_name_ptrs_[*index]}; - } - return {}; + if (!function_name_to_index_.emplace(copied_function_name, index).second) { + function_id_to_index_.erase(it_success.first); + return Status::Invalid("Function name was already registered"); } - Status RegisterFunction(Id id, std::string arrow_function_name) override { - DCHECK_EQ(function_ids_.size(), function_name_ptrs_.size()); + function_name_ptrs_.push_back(&copied_function_name); + function_ids_.push_back(copied_id); + return Status::OK(); + } + + // owning storage of uris, names, (arrow::)function_names, types + // note that storing strings like this is safe since references into an + // unordered_set are not invalidated on insertion + std::unordered_set uris_, names_, function_names_; + DataTypeVector types_; + std::vector type_is_variation_; + + // non-owning lookup helpers + std::vector type_ids_, function_ids_; + std::unordered_map id_to_index_, variation_id_to_index_; + std::unordered_map type_to_index_; + + std::vector function_name_ptrs_; + std::unordered_map function_id_to_index_; + std::unordered_map + function_name_to_index_; +}; - Id copied_id{*uris_.emplace(id.uri.to_string()).first, - *names_.emplace(id.name.to_string()).first}; +struct NestedExtensionIdRegistryImpl : ExtensionIdRegistryImpl { + NestedExtensionIdRegistryImpl(const ExtensionIdRegistry* parent) : parent_(parent) {} - const std::string& copied_function_name{ - *function_names_.emplace(std::move(arrow_function_name)).first}; + std::vector Uris() const override { + std::vector uris = parent_->Uris(); + std::unordered_set uri_set; + uri_set.insert(uris.begin(), uris.end()); + uri_set.insert(uris_.begin(), uris_.end()); + return std::vector(uris); + } + + util::optional GetType(const DataType& type) const override { + auto type_opt = ExtensionIdRegistryImpl::GetType(type); + if (type_opt) { + return type_opt; + } + return parent_->GetType(type); + } - auto index = static_cast(function_ids_.size()); + util::optional GetType(Id id, bool is_variation) const override { + auto type_opt = ExtensionIdRegistryImpl::GetType(id, is_variation); + if (type_opt) { + return type_opt; + } + return parent_->GetType(id, is_variation); + } + + Status RegisterType(Id id, std::shared_ptr type, bool is_variation) override { + return parent_->CanRegisterType(id, type, is_variation) & + ExtensionIdRegistryImpl::RegisterType(id, type, is_variation); + } + + util::optional GetFunction( + util::string_view arrow_function_name) const override { + auto func_opt = ExtensionIdRegistryImpl::GetFunction(arrow_function_name); + if (func_opt) { + return func_opt; + } + return parent_->GetFunction(arrow_function_name); + } - auto it_success = function_id_to_index_.emplace(copied_id, index); + util::optional GetFunction(Id id) const override { + auto func_opt = ExtensionIdRegistryImpl::GetFunction(id); + if (func_opt) { + return func_opt; + } + return parent_->GetFunction(id); + } - if (!it_success.second) { - return Status::Invalid("Function id was already registered"); - } + Status RegisterFunction(Id id, std::string arrow_function_name) override { + return parent_->CanRegisterFunction(id, arrow_function_name) & + ExtensionIdRegistryImpl::RegisterFunction(id, arrow_function_name); + } - if (!function_name_to_index_.emplace(copied_function_name, index).second) { - function_id_to_index_.erase(it_success.first); - return Status::Invalid("Function name was already registered"); - } + const ExtensionIdRegistry* parent_; +}; - function_name_ptrs_.push_back(&copied_function_name); - function_ids_.push_back(copied_id); - return Status::OK(); +struct DefaultExtensionIdRegistry : ExtensionIdRegistryImpl { + DefaultExtensionIdRegistry() { + struct TypeName { + std::shared_ptr type; + util::string_view name; + }; + + // The type (variation) mappings listed below need to be kept in sync + // with the YAML at substrait/format/extension_types.yaml manually; + // see ARROW-15535. + for (TypeName e : { + TypeName{uint8(), "u8"}, + TypeName{uint16(), "u16"}, + TypeName{uint32(), "u32"}, + TypeName{uint64(), "u64"}, + TypeName{float16(), "fp16"}, + }) { + DCHECK_OK(RegisterType({kArrowExtTypesUri, e.name}, std::move(e.type), + /*is_variation=*/true)); } - // owning storage of uris, names, (arrow::)function_names, types - // note that storing strings like this is safe since references into an - // unordered_set are not invalidated on insertion - std::unordered_set uris_, names_, function_names_; - DataTypeVector types_; - std::vector type_is_variation_; + for (TypeName e : { + TypeName{null(), "null"}, + TypeName{month_interval(), "interval_month"}, + TypeName{day_time_interval(), "interval_day_milli"}, + TypeName{month_day_nano_interval(), "interval_month_day_nano"}, + }) { + DCHECK_OK(RegisterType({kArrowExtTypesUri, e.name}, std::move(e.type), + /*is_variation=*/false)); + } - // non-owning lookup helpers - std::vector type_ids_, function_ids_; - std::unordered_map id_to_index_, variation_id_to_index_; - std::unordered_map type_to_index_; + // TODO: this is just a placeholder right now. We'll need a YAML file for + // all functions (and prototypes) that Arrow provides that are relevant + // for Substrait, and include mappings for all of them here. See + // ARROW-15535. + for (std::pair name_pair : { + std::make_pair("add", "add"), + std::make_pair("and", "and"), + std::make_pair("+", "add"), + std::make_pair("/", "divide"), + std::make_pair("power", "power"), + std::make_pair("clip_lower", "maximum"), + std::make_pair("clip_upper", "minimum"), + std::make_pair("equal", "equal"), + std::make_pair("not_equal", "not_equal"), + std::make_pair("lt", "less"), + std::make_pair("gt", "greater"), + std::make_pair("lte", "less_equal"), + std::make_pair("gte", "greater_equal"), + std::make_pair("cast", "cast"), + std::make_pair("negate", "negate"), + }) { + DCHECK_OK(RegisterFunction( + {kArrowExtTypesUri, name_pair.first}, name_pair.second.to_string())); + } + } +}; - std::vector function_name_ptrs_; - std::unordered_map function_id_to_index_; - std::unordered_map - function_name_to_index_; - } impl_; +} // namespace +ExtensionIdRegistry* default_extension_id_registry() { + static DefaultExtensionIdRegistry impl_; return &impl_; } +std::shared_ptr nested_extension_id_registry( + const ExtensionIdRegistry* parent) { + return std::make_shared(parent); +} + } // namespace engine } // namespace arrow diff --git a/cpp/src/arrow/engine/substrait/extension_set.h b/cpp/src/arrow/engine/substrait/extension_set.h index 951f7ffa3a1..a6019333fc0 100644 --- a/cpp/src/arrow/engine/substrait/extension_set.h +++ b/cpp/src/arrow/engine/substrait/extension_set.h @@ -63,6 +63,8 @@ class ARROW_ENGINE_EXPORT ExtensionIdRegistry { }; virtual util::optional GetType(const DataType&) const = 0; virtual util::optional GetType(Id, bool is_variation) const = 0; + virtual Status CanRegisterType(Id, std::shared_ptr type, + bool is_variation) const = 0; virtual Status RegisterType(Id, std::shared_ptr, bool is_variation) = 0; /// \brief A mapping between a Substrait ID and an Arrow function @@ -84,6 +86,7 @@ class ARROW_ENGINE_EXPORT ExtensionIdRegistry { virtual util::optional GetFunction(Id) const = 0; virtual util::optional GetFunction( util::string_view arrow_function_name) const = 0; + virtual Status CanRegisterFunction(Id, std::string arrow_function_name) const = 0; virtual Status RegisterFunction(Id, std::string arrow_function_name) = 0; }; @@ -96,6 +99,19 @@ constexpr util::string_view kArrowExtTypesUri = /// Note: Function support is currently very minimal, see ARROW-15538 ARROW_ENGINE_EXPORT ExtensionIdRegistry* default_extension_id_registry(); +/// \brief Makes a nested registry with a given parent. +/// +/// A nested registry supports registering types and functions other and on top of those +/// already registered in its parent registry. No conflicts in IDs and names used for +/// lookup are allowed. Normally, the given parent is the default registry. +/// +/// One use case for a nested registry is for dynamic registration of functions defined +/// within a Substrait plan while keeping these registrations specific to the plan. When +/// the Substrait plan is disposed of, normally after its execution, the nested registry +/// can be disposed of as well. +ARROW_ENGINE_EXPORT std::shared_ptr nested_extension_id_registry( + const ExtensionIdRegistry* parent); + /// \brief A set of extensions used within a plan /// /// Each time an extension is used within a Substrait plan the extension @@ -140,7 +156,7 @@ class ARROW_ENGINE_EXPORT ExtensionSet { }; /// Construct an empty ExtensionSet to be populated during serialization. - explicit ExtensionSet(ExtensionIdRegistry* = default_extension_id_registry()); + explicit ExtensionSet(const ExtensionIdRegistry* = default_extension_id_registry()); ARROW_DEFAULT_MOVE_AND_ASSIGN(ExtensionSet); /// Construct an ExtensionSet with explicit extension ids for efficient referencing @@ -160,7 +176,7 @@ class ARROW_ENGINE_EXPORT ExtensionSet { static Result Make( std::vector uris, std::vector type_ids, std::vector type_is_variation, std::vector function_ids, - ExtensionIdRegistry* = default_extension_id_registry()); + const ExtensionIdRegistry* = default_extension_id_registry()); // index in these vectors == value of _anchor/_reference fields /// TODO(ARROW-15583) this assumes that _anchor/_references won't be huge, which is not @@ -224,7 +240,7 @@ class ARROW_ENGINE_EXPORT ExtensionSet { std::size_t num_functions() const { return functions_.size(); } private: - ExtensionIdRegistry* registry_; + const ExtensionIdRegistry* registry_; /// The subset of extension registry URIs referenced by this extension set std::vector uris_; std::vector types_; diff --git a/cpp/src/arrow/engine/substrait/plan_internal.cc b/cpp/src/arrow/engine/substrait/plan_internal.cc index 5813dcde24c..8b3d53ae794 100644 --- a/cpp/src/arrow/engine/substrait/plan_internal.cc +++ b/cpp/src/arrow/engine/substrait/plan_internal.cc @@ -111,7 +111,7 @@ void SetElement(size_t i, const Element& element, std::vector* vector) { } // namespace Result GetExtensionSetFromPlan(const substrait::Plan& plan, - ExtensionIdRegistry* registry) { + const ExtensionIdRegistry* registry) { std::vector uris; for (const auto& uri : plan.extension_uris()) { SetElement(uri.extension_uri_anchor(), uri.uri(), &uris); diff --git a/cpp/src/arrow/engine/substrait/plan_internal.h b/cpp/src/arrow/engine/substrait/plan_internal.h index 281cab0c0f3..dce23cdceba 100644 --- a/cpp/src/arrow/engine/substrait/plan_internal.h +++ b/cpp/src/arrow/engine/substrait/plan_internal.h @@ -49,7 +49,7 @@ Status AddExtensionSetToPlan(const ExtensionSet& ext_set, substrait::Plan* plan) ARROW_ENGINE_EXPORT Result GetExtensionSetFromPlan( const substrait::Plan& plan, - ExtensionIdRegistry* registry = default_extension_id_registry()); + const ExtensionIdRegistry* registry = default_extension_id_registry()); } // namespace engine } // namespace arrow diff --git a/cpp/src/arrow/engine/substrait/relation_internal.cc b/cpp/src/arrow/engine/substrait/relation_internal.cc index 13f7e447b37..0ef9f225e64 100644 --- a/cpp/src/arrow/engine/substrait/relation_internal.cc +++ b/cpp/src/arrow/engine/substrait/relation_internal.cc @@ -435,6 +435,9 @@ Result FromProto(const substrait::Rel& rel, const std::vector& names) { ARROW_ASSIGN_OR_RAISE(auto input, FromProtoInternal(rel, ext_set)); int names_size = names.size(); + if (names.size() == 0) { + return input; + } std::vector expressions; for (int i = 0; i < names_size; i++) { expressions.push_back(compute::field_ref(FieldRef(i))); diff --git a/cpp/src/arrow/engine/substrait/serde.h b/cpp/src/arrow/engine/substrait/serde.h index 8cb61c0d9b2..227c505c0e8 100644 --- a/cpp/src/arrow/engine/substrait/serde.h +++ b/cpp/src/arrow/engine/substrait/serde.h @@ -59,7 +59,7 @@ Result DeserializePlan(const Buffer& buf, using WriteOptionsFactory = std::function()>; -ARROW_ENGINE_EXPORT Result> DeserializePlan( +ARROW_ENGINE_EXPORT Result> DeserializePlans( const Buffer& buf, const WriteOptionsFactory& write_options_factory, ExtensionSet* ext_set = NULLPTR); diff --git a/cpp/src/arrow/engine/substrait/util.cc b/cpp/src/arrow/engine/substrait/util.cc index bc2aa36856e..74ec8ff33b2 100644 --- a/cpp/src/arrow/engine/substrait/util.cc +++ b/cpp/src/arrow/engine/substrait/util.cc @@ -123,6 +123,22 @@ Result> SerializeJsonPlan(const std::string& substrait_j return engine::internal::SubstraitFromJSON("Plan", substrait_json); } +Result> DeserializePlans(const Buffer& buffer) { + return engine::DeserializePlans( + buffer, []() { return std::make_shared(); } + ); +} + +std::shared_ptr MakeExtensionIdRegistry() { + return nested_extension_id_registry(default_extension_id_registry()); +} + +Status RegisterFunction(ExtensionIdRegistry& registry, const std::string& id_uri, + const std::string& id_name, + const std::string& arrow_function_name) { + return registry.RegisterFunction({id_uri, id_name}, arrow_function_name); +} + } // namespace substrait } // namespace engine diff --git a/cpp/src/arrow/engine/substrait/util.h b/cpp/src/arrow/engine/substrait/util.h index 860a459da2f..c5e150a6cb2 100644 --- a/cpp/src/arrow/engine/substrait/util.h +++ b/cpp/src/arrow/engine/substrait/util.h @@ -37,6 +37,16 @@ ARROW_ENGINE_EXPORT Result> ExecuteSerialized ARROW_ENGINE_EXPORT Result> SerializeJsonPlan( const std::string& substrait_json); +ARROW_ENGINE_EXPORT Result> DeserializePlans( + const Buffer& buf); + +ARROW_ENGINE_EXPORT std::shared_ptr MakeExtensionIdRegistry(); + +ARROW_ENGINE_EXPORT Status RegisterFunction(ExtensionIdRegistry& registry, + const std::string& id_uri, + const std::string& id_name, + const std::string& arrow_function_name); + } // namespace substrait } // namespace engine diff --git a/cpp/src/arrow/python/pyarrow.h b/cpp/src/arrow/python/pyarrow.h index 4c365081d70..c52ee7f2ebc 100644 --- a/cpp/src/arrow/python/pyarrow.h +++ b/cpp/src/arrow/python/pyarrow.h @@ -40,6 +40,12 @@ class Status; class Table; class Tensor; +namespace engine { + +class ExtensionIdRegistry; + +} // namespace engine + namespace py { // Returns 0 on success, -1 on error. @@ -71,6 +77,8 @@ DECLARE_WRAP_FUNCTIONS(tensor, Tensor) DECLARE_WRAP_FUNCTIONS(batch, RecordBatch) DECLARE_WRAP_FUNCTIONS(table, Table) +DECLARE_WRAP_FUNCTIONS(extension_id_registry, engine::ExtensionIdRegistry) + #undef DECLARE_WRAP_FUNCTIONS namespace internal { diff --git a/python/pyarrow/__init__.pxd b/python/pyarrow/__init__.pxd index 8cc54b4c6bf..2b3b2ed1922 100644 --- a/python/pyarrow/__init__.pxd +++ b/python/pyarrow/__init__.pxd @@ -20,7 +20,7 @@ from pyarrow.includes.libarrow cimport (CArray, CBuffer, CDataType, CField, CRecordBatch, CSchema, CTable, CTensor, CSparseCOOTensor, CSparseCSRMatrix, CSparseCSCMatrix, - CSparseCSFTensor) + CSparseCSFTensor, CExtensionIdRegistry) cdef extern from "arrow/python/pyarrow.h" namespace "arrow::py": cdef int import_pyarrow() except -1 @@ -40,3 +40,5 @@ cdef extern from "arrow/python/pyarrow.h" namespace "arrow::py": const shared_ptr[CSparseCSFTensor]& sp_sparse_tensor) cdef object wrap_table(const shared_ptr[CTable]& ctable) cdef object wrap_batch(const shared_ptr[CRecordBatch]& cbatch) + cdef object pyarrow_wrap_extension_id_registry( + shared_ptr[CExtensionIdRegistry]& cregistry) diff --git a/python/pyarrow/_exec_plan.pxd b/python/pyarrow/_exec_plan.pxd new file mode 100644 index 00000000000..eafe19491b9 --- /dev/null +++ b/python/pyarrow/_exec_plan.pxd @@ -0,0 +1,25 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +# cython: language_level = 3 + +from pyarrow.includes.common cimport * +from pyarrow.includes.libarrow cimport * + +cdef is_supported_execplan_output_type(output_type) + +cdef execplan(inputs, output_type, vector[CDeclaration] plan, c_bool use_threads=*) diff --git a/python/pyarrow/_exec_plan.pyx b/python/pyarrow/_exec_plan.pyx index 753abe27cfa..c47ddebb894 100644 --- a/python/pyarrow/_exec_plan.pyx +++ b/python/pyarrow/_exec_plan.pyx @@ -36,6 +36,9 @@ from pyarrow._dataset import InMemoryDataset Initialize() # Initialise support for Datasets in ExecPlan +cdef is_supported_execplan_output_type(output_type): + return output_type in [Table, InMemoryDataset] + cdef execplan(inputs, output_type, vector[CDeclaration] plan, c_bool use_threads=True): """ Internal Function to create an ExecPlan and run it. @@ -75,6 +78,9 @@ cdef execplan(inputs, output_type, vector[CDeclaration] plan, c_bool use_threads vector[CDeclaration.Input] no_c_inputs CStatus c_plan_status + if not is_supported_execplan_output_type(output_type): + raise TypeError(f"Unsupported output type {output_type}") + if use_threads: c_executor = GetCpuThreadPool() else: @@ -214,6 +220,9 @@ def _perform_join(join_type, left_operand not None, left_keys, vector[c_string] c_projected_col_names CJoinType c_join_type + if not is_supported_execplan_output_type(output_type): + raise TypeError(f"Unsupported output type {output_type}") + # Prepare left and right tables Keys to send them to the C++ function left_keys_order = {} if isinstance(left_keys, str): @@ -376,6 +385,9 @@ def _filter_table(table, expression, output_type=Table): vector[CDeclaration] c_decl_plan Expression expr = expression + if not is_supported_execplan_output_type(output_type): + raise TypeError(f"Unsupported output type {output_type}") + c_decl_plan.push_back( CDeclaration(tobytes("filter"), CFilterNodeOptions( expr.unwrap(), True @@ -392,4 +404,4 @@ def _filter_table(table, expression, output_type=Table): # "__fragment_index", "__batch_index", "__last_in_fragment", "__filename" return InMemoryDataset(r.select(table.schema.names)) else: - raise TypeError("Unsupported output type") + raise TypeError(f"Unsupported output type {output_type}") diff --git a/python/pyarrow/_substrait.pyx b/python/pyarrow/_substrait.pyx index 7f079fb717b..7374f66ac05 100644 --- a/python/pyarrow/_substrait.pyx +++ b/python/pyarrow/_substrait.pyx @@ -24,6 +24,59 @@ from pyarrow.includes.libarrow cimport * from pyarrow.includes.libarrow_substrait cimport * +from pyarrow._exec_plan cimport is_supported_execplan_output_type, execplan + + +def make_extension_id_registry(): + cdef: + shared_ptr[CExtensionIdRegistry] c_registry + ExtensionIdRegistry registry + + with nogil: + c_registry = MakeExtensionIdRegistry() + + registry = ExtensionIdRegistry.__new__(ExtensionIdRegistry) + registry.registry = &deref(c_registry) + return registry + +def register_function(registry, id_uri, id_name, arrow_function_name): + cdef: + c_string c_id_uri, c_id_name, c_arrow_function_name + shared_ptr[CExtensionIdRegistry] c_registry + CStatus c_status + + c_registry = pyarrow_unwrap_extension_id_registry(registry) + c_id_uri = id_uri + c_id_name = id_name + c_arrow_function_name = arrow_function_name + + with nogil: + c_status = RegisterFunction( + deref(c_registry), c_id_uri, c_id_name, c_arrow_function_name + ) + + return c_status.ok() + +def run_query_as(plan, output_type=RecordBatchReader): + if output_type == RecordBatchReader: + return run_query(plan) + return _run_query(plan, output_type) + +def _run_query(plan, output_type): + cdef: + CResult[vector[CDeclaration]] c_res_decls + vector[CDeclaration] c_decls + shared_ptr[CBuffer] c_buf_plan + + if not is_supported_execplan_output_type(output_type): + raise TypeError(f"Unsupported output type {output_type}") + + c_buf_plan = pyarrow_unwrap_buffer(plan) + with nogil: + c_res_decls = DeserializePlans(deref(c_buf_plan)) + c_decls = GetResultValue(c_res_decls) + return execplan([], output_type, c_decls) + def run_query(plan): """ Execute a Substrait plan and read the results as a RecordBatchReader. @@ -38,7 +91,6 @@ def run_query(plan): CResult[shared_ptr[CRecordBatchReader]] c_res_reader shared_ptr[CRecordBatchReader] c_reader RecordBatchReader reader - c_string c_str_plan shared_ptr[CBuffer] c_buf_plan c_buf_plan = pyarrow_unwrap_buffer(plan) diff --git a/python/pyarrow/includes/libarrow.pxd b/python/pyarrow/includes/libarrow.pxd index 2e51864b860..c28bc8c0416 100644 --- a/python/pyarrow/includes/libarrow.pxd +++ b/python/pyarrow/includes/libarrow.pxd @@ -2696,3 +2696,8 @@ cdef extern from "arrow/python/udf.h" namespace "arrow::py": CStatus RegisterScalarFunction(PyObject* function, function[CallbackUdf] wrapper, const CScalarUdfOptions& options) + +cdef extern from "arrow/engine/substrait/extension_set.h" namespace "arrow::engine" nogil: + + cdef cppclass CExtensionIdRegistry" arrow::engine::ExtensionIdRegistry": + pass diff --git a/python/pyarrow/includes/libarrow_substrait.pxd b/python/pyarrow/includes/libarrow_substrait.pxd index 2e1a17b06bd..6e91e31309d 100644 --- a/python/pyarrow/includes/libarrow_substrait.pxd +++ b/python/pyarrow/includes/libarrow_substrait.pxd @@ -21,6 +21,13 @@ from pyarrow.includes.common cimport * from pyarrow.includes.libarrow cimport * +cdef extern from "arrow/engine/substrait/extension_set.h" namespace "arrow::engine" nogil: + cdef cppclass CExtensionIdRegistry "arrow::engine::ExtensionIdRegistry" + cdef extern from "arrow/engine/substrait/util.h" namespace "arrow::engine::substrait" nogil: + shared_ptr[CExtensionIdRegistry] MakeExtensionIdRegistry() + CStatus RegisterFunction(CExtensionIdRegistry& registry, const c_string& id_uri, const c_string& id_name, const c_string& arrow_function_name) + CResult[shared_ptr[CRecordBatchReader]] ExecuteSerializedPlan(const CBuffer& substrait_buffer) CResult[shared_ptr[CBuffer]] SerializeJsonPlan(const c_string& substrait_json) + CResult[vector[CDeclaration]] DeserializePlans(const CBuffer& substrait_buffer) diff --git a/python/pyarrow/lib.pxd b/python/pyarrow/lib.pxd index 953b0e7b518..8b4624e9032 100644 --- a/python/pyarrow/lib.pxd +++ b/python/pyarrow/lib.pxd @@ -446,6 +446,14 @@ cdef class RecordBatch(_PandasConvertible): cdef void init(self, const shared_ptr[CRecordBatch]& table) +cdef class ExtensionIdRegistry(_Weakrefable): + cdef: + shared_ptr[CExtensionIdRegistry] sp_registry + CExtensionIdRegistry* registry + + cdef void init(self, shared_ptr[CExtensionIdRegistry]& registry) + + cdef class Buffer(_Weakrefable): cdef: shared_ptr[CBuffer] buffer @@ -585,6 +593,8 @@ cdef public object pyarrow_wrap_tensor(const shared_ptr[CTensor]& sp_tensor) cdef public object pyarrow_wrap_batch(const shared_ptr[CRecordBatch]& cbatch) cdef public object pyarrow_wrap_table(const shared_ptr[CTable]& ctable) +cdef public object pyarrow_wrap_extension_id_registry(shared_ptr[CExtensionIdRegistry]& cregistry) + # Unwrapping Python -> C++ cdef public shared_ptr[CBuffer] pyarrow_unwrap_buffer(object buffer) @@ -611,3 +621,5 @@ cdef public shared_ptr[CTensor] pyarrow_unwrap_tensor(object tensor) cdef public shared_ptr[CRecordBatch] pyarrow_unwrap_batch(object batch) cdef public shared_ptr[CTable] pyarrow_unwrap_table(object table) + +cdef public shared_ptr[CExtensionIdRegistry] pyarrow_unwrap_extension_id_registry(object registry) diff --git a/python/pyarrow/public-api.pxi b/python/pyarrow/public-api.pxi index c427fb9f5db..ee146c6dab8 100644 --- a/python/pyarrow/public-api.pxi +++ b/python/pyarrow/public-api.pxi @@ -416,3 +416,23 @@ cdef api object pyarrow_wrap_batch( cdef RecordBatch batch = RecordBatch.__new__(RecordBatch) batch.init(cbatch) return batch + + +cdef api bint pyarrow_is_extension_id_registry(object registry): + return isinstance(registry, ExtensionIdRegistry) + + +cdef api shared_ptr[CExtensionIdRegistry] pyarrow_unwrap_extension_id_registry(object registry): + cdef ExtensionIdRegistry reg + if pyarrow_is_extension_id_registry(registry): + reg = (registry) + return reg.sp_registry + + return shared_ptr[CExtensionIdRegistry]() + + +cdef api object pyarrow_wrap_extension_id_registry( + shared_ptr[CExtensionIdRegistry]& cregistry): + cdef ExtensionIdRegistry registry = ExtensionIdRegistry.__new__(ExtensionIdRegistry) + registry.init(cregistry) + return registry diff --git a/python/pyarrow/substrait.py b/python/pyarrow/substrait.py index e3ff28f4eba..df0df0bf289 100644 --- a/python/pyarrow/substrait.py +++ b/python/pyarrow/substrait.py @@ -16,5 +16,8 @@ # under the License. from pyarrow._substrait import ( # noqa + make_extension_id_registry, + register_function, + run_query_as, run_query, ) diff --git a/python/pyarrow/table.pxi b/python/pyarrow/table.pxi index 27e0144d758..a72b55db970 100644 --- a/python/pyarrow/table.pxi +++ b/python/pyarrow/table.pxi @@ -2568,6 +2568,20 @@ cdef class RecordBatch(_PandasConvertible): return pyarrow_wrap_batch(c_batch) +cdef class ExtensionIdRegistry(_Weakrefable): + + def __cinit__(self): + self.registry = NULL + + def __init__(self): + raise TypeError("Do not call ExtensionIdRegistry's constructor directly, use " + "the `MakeExtensionIdRegistry` function instead.") + + cdef void init(self, shared_ptr[CExtensionIdRegistry]& registry): + self.sp_registry = registry + self.registry = registry.get() + + def _reconstruct_record_batch(columns, schema): """ Internal: reconstruct RecordBatch from pickled components. From f49a85d95837397322eb3369323665a626be7840 Mon Sep 17 00:00:00 2001 From: Yaron Gvili Date: Tue, 24 May 2022 10:14:50 -0400 Subject: [PATCH 12/19] UDF PoC --- .../arrow/engine/substrait/extension_set.cc | 17 +- .../arrow/engine/substrait/extension_set.h | 9 +- .../arrow/engine/substrait/plan_internal.cc | 9 +- .../arrow/engine/substrait/plan_internal.h | 3 +- cpp/src/arrow/engine/substrait/serde.cc | 58 +++++- cpp/src/arrow/engine/substrait/serde.h | 41 +++- cpp/src/arrow/engine/substrait/util.cc | 25 ++- cpp/src/arrow/engine/substrait/util.h | 9 +- python/pyarrow/_substrait.pyx | 72 +++++-- .../pyarrow/includes/libarrow_substrait.pxd | 19 +- python/pyarrow/lib.pxd | 1 + python/pyarrow/substrait.py | 2 + python/pyarrow/tests/test_udf.py | 180 ++++++++++++++++++ 13 files changed, 396 insertions(+), 49 deletions(-) diff --git a/cpp/src/arrow/engine/substrait/extension_set.cc b/cpp/src/arrow/engine/substrait/extension_set.cc index 3d4aa93ec9e..dc8579f0861 100644 --- a/cpp/src/arrow/engine/substrait/extension_set.cc +++ b/cpp/src/arrow/engine/substrait/extension_set.cc @@ -213,6 +213,8 @@ const int* GetIndex(const KeyToIndex& key_to_index, const Key& key) { namespace { struct ExtensionIdRegistryImpl : ExtensionIdRegistry { + virtual ~ExtensionIdRegistryImpl() {} + std::vector Uris() const override { return {uris_.begin(), uris_.end()}; } @@ -231,8 +233,8 @@ struct ExtensionIdRegistryImpl : ExtensionIdRegistry { return {}; } - virtual Status CanRegisterType(Id id, std::shared_ptr type, - bool is_variation) const { + Status CanRegisterType(Id id, std::shared_ptr type, + bool is_variation) const override { auto& id_to_index = is_variation ? variation_id_to_index_ : id_to_index_; if (id_to_index.find(id) != id_to_index.end()) { return Status::Invalid("Type id was already registered"); @@ -285,11 +287,11 @@ struct ExtensionIdRegistryImpl : ExtensionIdRegistry { return {}; } - virtual Status CanRegisterFunction(Id id, std::string arrow_function_name) const { - if (function_id_to_index_.find(id) == function_id_to_index_.end()) { + Status CanRegisterFunction(Id id, std::string arrow_function_name) const override { + if (function_id_to_index_.find(id) != function_id_to_index_.end()) { return Status::Invalid("Function id was already registered"); } - if (function_name_to_index_.find(arrow_function_name) == + if (function_name_to_index_.find(arrow_function_name) != function_name_to_index_.end()) { return Status::Invalid("Function name was already registered"); } @@ -342,7 +344,10 @@ struct ExtensionIdRegistryImpl : ExtensionIdRegistry { }; struct NestedExtensionIdRegistryImpl : ExtensionIdRegistryImpl { - NestedExtensionIdRegistryImpl(const ExtensionIdRegistry* parent) : parent_(parent) {} + explicit NestedExtensionIdRegistryImpl(const ExtensionIdRegistry* parent) + : parent_(parent) {} + + virtual ~NestedExtensionIdRegistryImpl() {} std::vector Uris() const override { std::vector uris = parent_->Uris(); diff --git a/cpp/src/arrow/engine/substrait/extension_set.h b/cpp/src/arrow/engine/substrait/extension_set.h index a6019333fc0..f37e8177a72 100644 --- a/cpp/src/arrow/engine/substrait/extension_set.h +++ b/cpp/src/arrow/engine/substrait/extension_set.h @@ -20,6 +20,7 @@ #pragma once #include +#include #include "arrow/engine/substrait/visibility.h" #include "arrow/type_fwd.h" @@ -88,6 +89,12 @@ class ARROW_ENGINE_EXPORT ExtensionIdRegistry { util::string_view arrow_function_name) const = 0; virtual Status CanRegisterFunction(Id, std::string arrow_function_name) const = 0; virtual Status RegisterFunction(Id, std::string arrow_function_name) = 0; + + const std::string& AddExternalSymbol(const std::string& symbol) { + return *external_symbols.insert(symbol).first; + } +private: + std::set external_symbols; }; constexpr util::string_view kArrowExtTypesUri = @@ -219,7 +226,7 @@ class ARROW_ENGINE_EXPORT ExtensionSet { /// \brief Lookup the anchor for a given function /// - /// This operation is used when converting an Arrow execution plan to a Substrait plan. + /// This operation is used when converting an Arrow execution plan to a Substrait plan. /// If the function has been previously encoded then the same anchor value will be /// returned. /// diff --git a/cpp/src/arrow/engine/substrait/plan_internal.cc b/cpp/src/arrow/engine/substrait/plan_internal.cc index 8b3d53ae794..5c60e5655ab 100644 --- a/cpp/src/arrow/engine/substrait/plan_internal.cc +++ b/cpp/src/arrow/engine/substrait/plan_internal.cc @@ -111,7 +111,11 @@ void SetElement(size_t i, const Element& element, std::vector* vector) { } // namespace Result GetExtensionSetFromPlan(const substrait::Plan& plan, - const ExtensionIdRegistry* registry) { + const ExtensionIdRegistry* registry, + bool exclude_functions) { + if (registry == NULLPTR) { + registry = default_extension_id_registry(); + } std::vector uris; for (const auto& uri : plan.extension_uris()) { SetElement(uri.extension_uri_anchor(), uri.uri(), &uris); @@ -143,6 +147,9 @@ Result GetExtensionSetFromPlan(const substrait::Plan& plan, } case substrait::extensions::SimpleExtensionDeclaration::kExtensionFunction: { + if (exclude_functions) { + break; + } const auto& fn = ext.extension_function(); util::string_view uri = uris[fn.extension_uri_reference()]; SetElement(fn.function_anchor(), Id{uri, fn.name()}, &function_ids); diff --git a/cpp/src/arrow/engine/substrait/plan_internal.h b/cpp/src/arrow/engine/substrait/plan_internal.h index dce23cdceba..4f4f752f243 100644 --- a/cpp/src/arrow/engine/substrait/plan_internal.h +++ b/cpp/src/arrow/engine/substrait/plan_internal.h @@ -49,7 +49,8 @@ Status AddExtensionSetToPlan(const ExtensionSet& ext_set, substrait::Plan* plan) ARROW_ENGINE_EXPORT Result GetExtensionSetFromPlan( const substrait::Plan& plan, - const ExtensionIdRegistry* registry = default_extension_id_registry()); + const ExtensionIdRegistry* registry = default_extension_id_registry(), + bool exclude_functions = false); } // namespace engine } // namespace arrow diff --git a/cpp/src/arrow/engine/substrait/serde.cc b/cpp/src/arrow/engine/substrait/serde.cc index 57b4251c9fa..d2ccdd87711 100644 --- a/cpp/src/arrow/engine/substrait/serde.cc +++ b/cpp/src/arrow/engine/substrait/serde.cc @@ -61,10 +61,10 @@ Result DeserializeRelation(const Buffer& buf, static Result> DeserializePlans( const Buffer& buf, const std::string& factory_name, std::function()> options_factory, - ExtensionSet* ext_set_out) { + ExtensionSet* ext_set_out, const ExtensionIdRegistry* registry) { ARROW_ASSIGN_OR_RAISE(auto plan, ParseFromBuffer(buf)); - ARROW_ASSIGN_OR_RAISE(auto ext_set, GetExtensionSetFromPlan(plan)); + ARROW_ASSIGN_OR_RAISE(auto ext_set, GetExtensionSetFromPlan(plan, registry)); std::vector sink_decls; for (const substrait::PlanRel& plan_rel : plan.relations()) { @@ -93,7 +93,7 @@ static Result> DeserializePlans( Result> DeserializePlans( const Buffer& buf, const ConsumerFactory& consumer_factory, - ExtensionSet* ext_set_out) { + ExtensionSet* ext_set_out, const ExtensionIdRegistry* registry) { return DeserializePlans( buf, "consuming_sink", @@ -102,21 +102,23 @@ Result> DeserializePlans( compute::ConsumingSinkNodeOptions{consumer_factory()} ); }, - ext_set_out + ext_set_out, + registry ); } Result> DeserializePlans( const Buffer& buf, const WriteOptionsFactory& write_options_factory, - ExtensionSet* ext_set_out) { - return DeserializePlans(buf, "write", write_options_factory, ext_set_out); + ExtensionSet* ext_set_out, const ExtensionIdRegistry* registry) { + return DeserializePlans(buf, "write", write_options_factory, ext_set_out, registry); } Result DeserializePlan(const Buffer& buf, const ConsumerFactory& consumer_factory, - ExtensionSet* ext_set_out) { + ExtensionSet* ext_set_out, + const ExtensionIdRegistry* registry) { ARROW_ASSIGN_OR_RAISE(auto declarations, - DeserializePlans(buf, consumer_factory, ext_set_out)); + DeserializePlans(buf, consumer_factory, ext_set_out, registry)); if (declarations.size() > 1) { return Status::Invalid("DeserializePlan does not support multiple root relations"); } else { @@ -126,6 +128,46 @@ Result DeserializePlan(const Buffer& buf, } } +Result> DeserializePlanUdfs( + const Buffer& buf, const ExtensionIdRegistry* registry) { + ARROW_ASSIGN_OR_RAISE(auto plan, ParseFromBuffer(buf)); + + ARROW_ASSIGN_OR_RAISE(auto ext_set, GetExtensionSetFromPlan(plan, registry, true)); + + std::vector decls; + for (const auto& ext : plan.extensions()) { + switch (ext.mapping_type_case()) { + case substrait::extensions::SimpleExtensionDeclaration::kExtensionFunction: { + const auto& fn = ext.extension_function(); + if (fn.has_udf()) { + const auto& udf = fn.udf(); + const auto& in_types = udf.input_types(); + int size = in_types.size(); + std::vector, bool>> input_types(size); + for (int i=0; i> DeserializeSchema(const Buffer& buf, const ExtensionSet& ext_set) { ARROW_ASSIGN_OR_RAISE(auto named_struct, ParseFromBuffer(buf)); diff --git a/cpp/src/arrow/engine/substrait/serde.h b/cpp/src/arrow/engine/substrait/serde.h index 227c505c0e8..775fc4e0a4c 100644 --- a/cpp/src/arrow/engine/substrait/serde.h +++ b/cpp/src/arrow/engine/substrait/serde.h @@ -51,17 +51,52 @@ using ConsumerFactory = std::function /// Substrait Plan ARROW_ENGINE_EXPORT Result> DeserializePlans( const Buffer& buf, const ConsumerFactory& consumer_factory, - ExtensionSet* ext_set_out = NULLPTR); + ExtensionSet* ext_set_out = NULLPTR, const ExtensionIdRegistry* registry = NULLPTR); +/// \brief Deserializes a single-relation Substrait Plan message to an execution plan +/// +/// \param[in] buf a buffer containing the protobuf serialization of a Substrait Plan +/// message +/// \param[in] consumer_factory factory function for generating the node that consumes +/// the batches produced by each toplevel Substrait relation +/// \param[out] ext_set_out if non-null, the extension mapping used by the Substrait +/// Plan is returned here. +/// \return an ExecNode corresponding to the single toplevel relation in the Substrait +/// Plan Result DeserializePlan(const Buffer& buf, const ConsumerFactory& consumer_factory, - ExtensionSet* ext_set_out = NULLPTR); + ExtensionSet* ext_set_out = NULLPTR, + const ExtensionIdRegistry* registry = NULLPTR); +/// Factory function type for generating the write options of a node consuming the batches +/// produced by each toplevel Substrait relation when deserializing a Substrait Plan. using WriteOptionsFactory = std::function()>; +/// \brief Deserializes a Substrait Plan message to a list of ExecNode declarations +/// +/// \param[in] buf a buffer containing the protobuf serialization of a Substrait Plan +/// message +/// \param[in] write_options_factory factory function for generating the write options of +/// a node consuming the batches produced by each toplevel Substrait relation +/// \param[out] ext_set_out if non-null, the extension mapping used by the Substrait +/// Plan is returned here. +/// \return a vector of ExecNode declarations, one for each toplevel relation in the +/// Substrait Plan ARROW_ENGINE_EXPORT Result> DeserializePlans( const Buffer& buf, const WriteOptionsFactory& write_options_factory, - ExtensionSet* ext_set = NULLPTR); + ExtensionSet* ext_set = NULLPTR, const ExtensionIdRegistry* registry = NULLPTR); + +struct ARROW_ENGINE_EXPORT UdfDeclaration { + std::string name; + std::string code; + std::string summary; + std::string description; + std::vector, bool>> input_types; + std::pair, bool> output_type; +}; + +ARROW_ENGINE_EXPORT Result> DeserializePlanUdfs( + const Buffer& buf, const ExtensionIdRegistry* registry); /// \brief Deserializes a Substrait Type message to the corresponding Arrow type /// diff --git a/cpp/src/arrow/engine/substrait/util.cc b/cpp/src/arrow/engine/substrait/util.cc index 74ec8ff33b2..53356d86a57 100644 --- a/cpp/src/arrow/engine/substrait/util.cc +++ b/cpp/src/arrow/engine/substrait/util.cc @@ -84,7 +84,7 @@ class SubstraitExecutor { Status Close() { return plan_->finished().status(); } - Status Init(const Buffer& substrait_buffer) { + Status Init(const Buffer& substrait_buffer, const ExtensionIdRegistry* registry) { if (substrait_buffer.size() == 0) { return Status::Invalid("Empty substrait plan is passed."); } @@ -93,7 +93,10 @@ class SubstraitExecutor { return sink_consumer_; }; ARROW_ASSIGN_OR_RAISE(declarations_, - engine::DeserializePlans(substrait_buffer, consumer_factory)); + engine::DeserializePlans(substrait_buffer, + consumer_factory, + NULLPTR, + registry)); return Status::OK(); } @@ -108,13 +111,13 @@ class SubstraitExecutor { } // namespace Result> ExecuteSerializedPlan( - const Buffer& substrait_buffer) { + const Buffer& substrait_buffer, const ExtensionIdRegistry* registry) { ARROW_ASSIGN_OR_RAISE(auto plan, compute::ExecPlan::Make()); // TODO(ARROW-15732) compute::ExecContext exec_context(arrow::default_memory_pool(), ::arrow::internal::GetCpuThreadPool()); SubstraitExecutor executor(std::move(plan), exec_context); - RETURN_NOT_OK(executor.Init(substrait_buffer)); + RETURN_NOT_OK(executor.Init(substrait_buffer, registry)); ARROW_ASSIGN_OR_RAISE(auto sink_reader, executor.Execute()); return sink_reader; } @@ -123,9 +126,13 @@ Result> SerializeJsonPlan(const std::string& substrait_j return engine::internal::SubstraitFromJSON("Plan", substrait_json); } -Result> DeserializePlans(const Buffer& buffer) { +Result> DeserializePlans( + const Buffer& buffer, const ExtensionIdRegistry* registry) { return engine::DeserializePlans( - buffer, []() { return std::make_shared(); } + buffer, + []() { return std::make_shared(); }, + NULLPTR, + registry ); } @@ -136,7 +143,11 @@ std::shared_ptr MakeExtensionIdRegistry() { Status RegisterFunction(ExtensionIdRegistry& registry, const std::string& id_uri, const std::string& id_name, const std::string& arrow_function_name) { - return registry.RegisterFunction({id_uri, id_name}, arrow_function_name); + const std::string& id_uri_sym = registry.AddExternalSymbol(id_uri); + const std::string& id_name_sym = registry.AddExternalSymbol(id_name); + const std::string& arrow_function_name_sym = + registry.AddExternalSymbol(arrow_function_name); + return registry.RegisterFunction({id_uri_sym, id_name_sym}, arrow_function_name_sym); } } // namespace substrait diff --git a/cpp/src/arrow/engine/substrait/util.h b/cpp/src/arrow/engine/substrait/util.h index c5e150a6cb2..27fc3a8aaed 100644 --- a/cpp/src/arrow/engine/substrait/util.h +++ b/cpp/src/arrow/engine/substrait/util.h @@ -30,7 +30,7 @@ namespace substrait { /// \brief Retrieve a RecordBatchReader from a Substrait plan. ARROW_ENGINE_EXPORT Result> ExecuteSerializedPlan( - const Buffer& substrait_buffer); + const Buffer& substrait_buffer, const ExtensionIdRegistry* registry = NULLPTR); /// \brief Get a Serialized Plan from a Substrait JSON plan. /// This is a helper method for Python tests. @@ -38,7 +38,7 @@ ARROW_ENGINE_EXPORT Result> SerializeJsonPlan( const std::string& substrait_json); ARROW_ENGINE_EXPORT Result> DeserializePlans( - const Buffer& buf); + const Buffer& buf, const ExtensionIdRegistry* registry); ARROW_ENGINE_EXPORT std::shared_ptr MakeExtensionIdRegistry(); @@ -47,6 +47,11 @@ ARROW_ENGINE_EXPORT Status RegisterFunction(ExtensionIdRegistry& registry, const std::string& id_name, const std::string& arrow_function_name); +ARROW_ENGINE_EXPORT const std::string& default_extension_types_uri() { + static std::string uri = engine::kArrowExtTypesUri.to_string(); + return uri; +} + } // namespace substrait } // namespace engine diff --git a/python/pyarrow/_substrait.pyx b/python/pyarrow/_substrait.pyx index 7374f66ac05..4128b37df85 100644 --- a/python/pyarrow/_substrait.pyx +++ b/python/pyarrow/_substrait.pyx @@ -16,9 +16,10 @@ # under the License. # cython: language_level = 3 -from cython.operator cimport dereference as deref +from cython.operator cimport dereference as deref, preincrement as inc from pyarrow import Buffer +from pyarrow.lib import frombytes, tobytes from pyarrow.lib cimport * from pyarrow.includes.libarrow cimport * from pyarrow.includes.libarrow_substrait cimport * @@ -35,9 +36,42 @@ def make_extension_id_registry(): with nogil: c_registry = MakeExtensionIdRegistry() - registry = ExtensionIdRegistry.__new__(ExtensionIdRegistry) - registry.registry = &deref(c_registry) - return registry + return pyarrow_wrap_extension_id_registry(c_registry) + +def get_udf_declarations(plan, registry): + cdef: + shared_ptr[CBuffer] c_buf_plan + shared_ptr[CExtensionIdRegistry] c_registry + vector[CUdfDeclaration] c_decls + vector[CUdfDeclaration].iterator c_decls_iter + vector[pair[shared_ptr[CDataType], c_bool]].iterator c_in_types_iter + + c_buf_plan = pyarrow_unwrap_buffer(plan) + c_registry = pyarrow_unwrap_extension_id_registry(registry) + with nogil: + c_res_decls = DeserializePlanUdfs(deref(c_buf_plan), &deref(c_registry)) + c_decls = GetResultValue(c_res_decls) + + decls = [] + c_decls_iter = c_decls.begin() + while c_decls_iter != c_decls.end(): + input_types = [] + c_in_types_iter = deref(c_decls_iter).input_types.begin() + while c_in_types_iter != deref(c_decls_iter).input_types.end(): + input_types.append((pyarrow_wrap_data_type(deref(c_in_types_iter).first), + deref(c_in_types_iter).second)) + inc(c_in_types_iter) + decls.append({ + "name": frombytes(deref(c_decls_iter).name), + "code": frombytes(deref(c_decls_iter).code), + "summary": frombytes(deref(c_decls_iter).summary), + "description": frombytes(deref(c_decls_iter).description), + "input_types": input_types, + "output_type": (pyarrow_wrap_data_type(deref(c_decls_iter).output_type.first), + deref(c_decls_iter).output_type.second), + }) + inc(c_decls_iter) + return decls def register_function(registry, id_uri, id_name, arrow_function_name): cdef: @@ -46,38 +80,40 @@ def register_function(registry, id_uri, id_name, arrow_function_name): CStatus c_status c_registry = pyarrow_unwrap_extension_id_registry(registry) - c_id_uri = id_uri - c_id_name = id_name - c_arrow_function_name = arrow_function_name + c_id_uri = id_uri or default_extension_types_uri() + c_id_name = tobytes(id_name) + c_arrow_function_name = tobytes(arrow_function_name) with nogil: c_status = RegisterFunction( deref(c_registry), c_id_uri, c_id_name, c_arrow_function_name ) - return c_status.ok() + check_status(c_status) -def run_query_as(plan, output_type=RecordBatchReader): +def run_query_as(plan, registry, output_type=RecordBatchReader): if output_type == RecordBatchReader: - return run_query(plan) - return _run_query(plan, output_type) + return run_query(plan, registry) + return _run_query(plan, registry, output_type) -def _run_query(plan, output_type): +def _run_query(plan, registry, output_type): cdef: + shared_ptr[CBuffer] c_buf_plan + shared_ptr[CExtensionIdRegistry] c_registry CResult[vector[CDeclaration]] c_res_decls vector[CDeclaration] c_decls - shared_ptr[CBuffer] c_buf_plan if not is_supported_execplan_output_type(output_type): raise TypeError(f"Unsupported output type {output_type}") c_buf_plan = pyarrow_unwrap_buffer(plan) + c_registry = pyarrow_unwrap_extension_id_registry(registry) with nogil: - c_res_decls = DeserializePlans(deref(c_buf_plan)) + c_res_decls = DeserializePlans(deref(c_buf_plan), &deref(c_registry)) c_decls = GetResultValue(c_res_decls) return execplan([], output_type, c_decls) -def run_query(plan): +def run_query(plan, registry): """ Execute a Substrait plan and read the results as a RecordBatchReader. @@ -88,14 +124,16 @@ def run_query(plan): """ cdef: + shared_ptr[CBuffer] c_buf_plan + shared_ptr[CExtensionIdRegistry] c_registry CResult[shared_ptr[CRecordBatchReader]] c_res_reader shared_ptr[CRecordBatchReader] c_reader RecordBatchReader reader - shared_ptr[CBuffer] c_buf_plan c_buf_plan = pyarrow_unwrap_buffer(plan) + c_registry = pyarrow_unwrap_extension_id_registry(registry) with nogil: - c_res_reader = ExecuteSerializedPlan(deref(c_buf_plan)) + c_res_reader = ExecuteSerializedPlan(deref(c_buf_plan), &deref(c_registry)) c_reader = GetResultValue(c_res_reader) diff --git a/python/pyarrow/includes/libarrow_substrait.pxd b/python/pyarrow/includes/libarrow_substrait.pxd index 6e91e31309d..b21497b8340 100644 --- a/python/pyarrow/includes/libarrow_substrait.pxd +++ b/python/pyarrow/includes/libarrow_substrait.pxd @@ -22,12 +22,25 @@ from pyarrow.includes.libarrow cimport * cdef extern from "arrow/engine/substrait/extension_set.h" namespace "arrow::engine" nogil: - cdef cppclass CExtensionIdRegistry "arrow::engine::ExtensionIdRegistry" + cppclass CExtensionIdRegistry "arrow::engine::ExtensionIdRegistry" + +cdef extern from "arrow/engine/substrait/serde.h" namespace "arrow::engine" nogil: + cppclass CUdfDeclaration "arrow::engine::UdfDeclaration": + c_string name + c_string code + c_string summary + c_string description + vector[pair[shared_ptr[CDataType], c_bool]] input_types; + pair[shared_ptr[CDataType], c_bool] output_type; + + CResult[vector[CUdfDeclaration]] DeserializePlanUdfs(const CBuffer& substrait_buffer, const CExtensionIdRegistry* registry) cdef extern from "arrow/engine/substrait/util.h" namespace "arrow::engine::substrait" nogil: shared_ptr[CExtensionIdRegistry] MakeExtensionIdRegistry() CStatus RegisterFunction(CExtensionIdRegistry& registry, const c_string& id_uri, const c_string& id_name, const c_string& arrow_function_name) - CResult[shared_ptr[CRecordBatchReader]] ExecuteSerializedPlan(const CBuffer& substrait_buffer) + CResult[shared_ptr[CRecordBatchReader]] ExecuteSerializedPlan(const CBuffer& substrait_buffer, const CExtensionIdRegistry* registry) CResult[shared_ptr[CBuffer]] SerializeJsonPlan(const c_string& substrait_json) - CResult[vector[CDeclaration]] DeserializePlans(const CBuffer& substrait_buffer) + CResult[vector[CDeclaration]] DeserializePlans(const CBuffer& substrait_buffer, const CExtensionIdRegistry* registry) + + const c_string& default_extension_types_uri() diff --git a/python/pyarrow/lib.pxd b/python/pyarrow/lib.pxd index 8b4624e9032..35678304729 100644 --- a/python/pyarrow/lib.pxd +++ b/python/pyarrow/lib.pxd @@ -24,6 +24,7 @@ from libcpp.memory cimport dynamic_pointer_cast from pyarrow.includes.common cimport * from pyarrow.includes.libarrow cimport * from pyarrow.includes.libarrow_python cimport * +from pyarrow.includes.libarrow_substrait cimport * cdef extern from "Python.h": diff --git a/python/pyarrow/substrait.py b/python/pyarrow/substrait.py index df0df0bf289..eb7d9795491 100644 --- a/python/pyarrow/substrait.py +++ b/python/pyarrow/substrait.py @@ -17,7 +17,9 @@ from pyarrow._substrait import ( # noqa make_extension_id_registry, + get_udf_declarations, register_function, run_query_as, run_query, + _parse_json_plan, ) diff --git a/python/pyarrow/tests/test_udf.py b/python/pyarrow/tests/test_udf.py index 33315fc12d4..bdcda9b2ff5 100644 --- a/python/pyarrow/tests/test_udf.py +++ b/python/pyarrow/tests/test_udf.py @@ -16,10 +16,14 @@ # under the License. +import base64 +import cloudpickle +import os import pytest import pyarrow as pa from pyarrow import compute as pc +from pyarrow.lib import frombytes, tobytes, DoubleArray # UDFs are all tested with a dataset scan pytestmark = pytest.mark.dataset @@ -30,6 +34,11 @@ except ImportError: ds = None +try: + import pyarrow.substrait as substrait +except ImportError: + substrait = None + def mock_udf_context(batch_length=10): from pyarrow._compute import _get_scalar_udf_context @@ -501,3 +510,174 @@ def test_input_lifetime(unary_func_fixture): # Calling a UDF should not have kept `v` alive longer than required v = None assert proxy_pool.bytes_allocated() == 0 + + +def demean_and_zscore(scl_udf_ctx, v): + mean = v.mean() + std = v.std() + return v - mean, (v - mean) / std + +def twice_and_add_2(scl_udf_ctx, v): + return 2 * v, v + 2 + +def twice(scl_udf_ctx, v): + return DoubleArray.from_pandas((2 * v.to_pandas())) + + +def test_elementwise_scalar_udf_in_substrait_query(tmpdir): + substrait_query = """ + { + "extensionUris": [ + { + "extensionUriAnchor": 1, + "uri": "https://github.com/apache/arrow/blob/master/format/substrait/extension_types.yaml" + } + ], + "extensions": [ + { + "extensionFunction": { + "extensionUriReference": 1, + "functionAnchor": 1, + "name": "twice", + "udf": { + "code": "CODE_PLACEHOLDER", + "summary": "twice", + "description": "Compute twice the value of the input", + "inputTypes": [ + { + "fp64": { + "nullability": "NULLABILITY_NULLABLE" + } + } + ], + "outputType": { + "fp64": { + "nullability": "NULLABILITY_NULLABLE" + } + } + } + } + } + ], + "relations": [ + { + "root": { + "input": { + "project": { + "input": { + "read": { + "baseSchema": { + "names": [ + "key", + "value" + ], + "struct": { + "types": [ + { + "string": { + "nullability": "NULLABILITY_NULLABLE" + } + }, + { + "fp64": { + "nullability": "NULLABILITY_NULLABLE" + } + } + ], + "nullability": "NULLABILITY_REQUIRED" + } + }, + "local_files": { + "items": [ + { + "uri_file": "file://FILENAME_PLACEHOLDER" + } + ] + } + } + }, + "expressions": [ + { + "selection": { + "directReference": { + "structField": {} + }, + "rootReference": {} + } + }, + { + "selection": { + "directReference": { + "structField": { + "field": 1 + } + }, + "rootReference": {} + } + }, + { + "scalarFunction": { + "functionReference": 1, + "args": [ + { + "selection": { + "directReference": { + "structField": { + "field": 1 + } + }, + "rootReference": {} + } + } + ], + "outputType": { + "fp64": { + "nullability": "NULLABILITY_NULLABLE" + } + } + } + } + ] + } + }, + "names": [ + "key", + "value", + "twice" + ] + } + } + ] + } + """ + # TODO: replace with ipc when the support is finalized in C++ + code = frombytes(base64.b64encode(cloudpickle.dumps(twice))) + path = os.path.join(str(tmpdir), 'substrait_data.arrow') + table = pa.table([["a", "b", "a", "b", "a"], [1.0, 2.0, 3.0, 4.0, 5.0]], names=['key', 'value']) + with pa.ipc.RecordBatchFileWriter(path, schema=table.schema) as writer: + writer.write_table(table) + + query = tobytes(substrait_query.replace("CODE_PLACEHOLDER", code).replace("FILENAME_PLACEHOLDER", path)) + + plan = substrait._parse_json_plan(query) + + registry = substrait.make_extension_id_registry() + udf_decls = substrait.get_udf_declarations(plan, registry) + for udf_decl in udf_decls: + substrait.register_function(registry, None, udf_decl["name"], udf_decl["name"]) + pc.register_scalar_function( + cloudpickle.loads(base64.b64decode(tobytes(udf_decl["code"]))), + udf_decl["name"], + {"summary": udf_decl["summary"], "description": udf_decl["description"]}, + {f"arg$i": type_nullable_pair[0] + for i, type_nullable_pair in enumerate(udf_decl["input_types"]) + }, + udf_decl["output_type"][0], + ) + + reader = substrait.run_query(plan, registry) + res_tb = reader.read_all() + + assert len(res_tb) == len(table) + assert res_tb.schema == pa.schema([("key", pa.string()), ("value", pa.float64()), ("twice", pa.float64())]) + assert res_tb.drop(["twice"]) == table From 5795a86eb2addd23ca00604c473692124853db6c Mon Sep 17 00:00:00 2001 From: Yaron Gvili Date: Mon, 30 May 2022 06:46:03 -0400 Subject: [PATCH 13/19] UDF PoC with scoped registries --- cpp/src/arrow/CMakeLists.txt | 1 + cpp/src/arrow/compute/registry.cc | 192 ++++++++++++- cpp/src/arrow/compute/registry.h | 31 ++- cpp/src/arrow/compute/registry_test.cc | 166 ++++++++++- cpp/src/arrow/compute/registry_util.cc | 30 ++ cpp/src/arrow/compute/registry_util.h | 33 +++ cpp/src/arrow/engine/substrait/ext_test.cc | 263 ++++++++++++++++++ .../arrow/engine/substrait/extension_set.h | 11 + cpp/src/arrow/engine/substrait/util.cc | 8 +- cpp/src/arrow/engine/substrait/util.h | 4 +- cpp/src/arrow/python/udf.cc | 7 +- cpp/src/arrow/python/udf.h | 6 +- python/pyarrow/_compute.pxd | 3 + python/pyarrow/_compute.pyx | 32 ++- python/pyarrow/_exec_plan.pxd | 2 +- python/pyarrow/_exec_plan.pyx | 4 +- python/pyarrow/_substrait.pyx | 96 +++++-- python/pyarrow/compute.pxi | 33 +++ python/pyarrow/includes/libarrow.pxd | 10 +- .../pyarrow/includes/libarrow_substrait.pxd | 6 +- python/pyarrow/lib.pxd | 2 + python/pyarrow/lib.pyx | 3 + python/pyarrow/public-api.pxi | 19 ++ python/pyarrow/substrait.py | 3 + python/pyarrow/table.pxi | 14 - python/pyarrow/tests/test_substrait.py | 9 +- python/pyarrow/tests/test_udf.py | 30 +- 27 files changed, 907 insertions(+), 111 deletions(-) create mode 100644 cpp/src/arrow/compute/registry_util.cc create mode 100644 cpp/src/arrow/compute/registry_util.h create mode 100644 cpp/src/arrow/engine/substrait/ext_test.cc create mode 100644 python/pyarrow/compute.pxi diff --git a/cpp/src/arrow/CMakeLists.txt b/cpp/src/arrow/CMakeLists.txt index ec6cada1cda..ce14d8af402 100644 --- a/cpp/src/arrow/CMakeLists.txt +++ b/cpp/src/arrow/CMakeLists.txt @@ -408,6 +408,7 @@ if(ARROW_COMPUTE) compute/kernel.cc compute/light_array.cc compute/registry.cc + compute/registry_util.cc compute/kernels/aggregate_basic.cc compute/kernels/aggregate_mode.cc compute/kernels/aggregate_quantile.cc diff --git a/cpp/src/arrow/compute/registry.cc b/cpp/src/arrow/compute/registry.cc index 8ab83a72e5e..5f0a43468c1 100644 --- a/cpp/src/arrow/compute/registry.cc +++ b/cpp/src/arrow/compute/registry.cc @@ -33,8 +33,18 @@ namespace arrow { namespace compute { class FunctionRegistry::FunctionRegistryImpl { - public: - Status AddFunction(std::shared_ptr function, bool allow_overwrite) { + private: + using FuncAdd = std::function)>; + + const FuncAdd kFuncAddNoOp = [](const std::string& name, + std::shared_ptr func) {}; + const FuncAdd kFuncAddDo = [this](const std::string& name, + std::shared_ptr func) { + name_to_function_[name] = func; + }; + + Status DoAddFunction(std::shared_ptr function, bool allow_overwrite, + FuncAdd add) { #ifndef NDEBUG // This validates docstrings extensively, so don't waste time on it // in release builds. @@ -48,23 +58,56 @@ class FunctionRegistry::FunctionRegistryImpl { if (it != name_to_function_.end() && !allow_overwrite) { return Status::KeyError("Already have a function registered with name: ", name); } - name_to_function_[name] = std::move(function); + add(name, std::move(function)); return Status::OK(); } - Status AddAlias(const std::string& target_name, const std::string& source_name) { + public: + virtual Status CanAddFunction(std::shared_ptr function, + bool allow_overwrite) { + return DoAddFunction(function, allow_overwrite, kFuncAddNoOp); + } + + virtual Status AddFunction(std::shared_ptr function, bool allow_overwrite) { + return DoAddFunction(function, allow_overwrite, kFuncAddDo); + } + + private: + Status DoAddAlias(const std::string& target_name, const std::string& source_name, + FuncAdd add) { std::lock_guard mutation_guard(lock_); - auto it = name_to_function_.find(source_name); - if (it == name_to_function_.end()) { + auto func_res = GetFunction(source_name); // must not acquire the mutex + if (!func_res.ok()) { return Status::KeyError("No function registered with name: ", source_name); } - name_to_function_[target_name] = it->second; + add(target_name, func_res.ValueOrDie()); return Status::OK(); } - Status AddFunctionOptionsType(const FunctionOptionsType* options_type, - bool allow_overwrite = false) { + public: + virtual Status CanAddAlias(const std::string& target_name, + const std::string& source_name) { + return DoAddAlias(target_name, source_name, kFuncAddNoOp); + } + + virtual Status AddAlias(const std::string& target_name, + const std::string& source_name) { + return DoAddAlias(target_name, source_name, kFuncAddDo); + } + + private: + using FuncOptTypeAdd = std::function; + + const FuncOptTypeAdd kFuncOptTypeAddNoOp = [](const FunctionOptionsType* options_type) { + }; + const FuncOptTypeAdd kFuncOptTypeAddDo = + [this](const FunctionOptionsType* options_type) { + name_to_options_type_[options_type->type_name()] = options_type; + }; + + Status DoAddFunctionOptionsType(const FunctionOptionsType* options_type, + bool allow_overwrite, FuncOptTypeAdd add) { std::lock_guard mutation_guard(lock_); const std::string name = options_type->type_name(); @@ -73,11 +116,22 @@ class FunctionRegistry::FunctionRegistryImpl { return Status::KeyError( "Already have a function options type registered with name: ", name); } - name_to_options_type_[name] = options_type; + add(options_type); return Status::OK(); } - Result> GetFunction(const std::string& name) const { + public: + virtual Status CanAddFunctionOptionsType(const FunctionOptionsType* options_type, + bool allow_overwrite = false) { + return DoAddFunctionOptionsType(options_type, allow_overwrite, kFuncOptTypeAddNoOp); + } + + virtual Status AddFunctionOptionsType(const FunctionOptionsType* options_type, + bool allow_overwrite = false) { + return DoAddFunctionOptionsType(options_type, allow_overwrite, kFuncOptTypeAddDo); + } + + virtual Result> GetFunction(const std::string& name) const { auto it = name_to_function_.find(name); if (it == name_to_function_.end()) { return Status::KeyError("No function registered with name: ", name); @@ -85,7 +139,7 @@ class FunctionRegistry::FunctionRegistryImpl { return it->second; } - std::vector GetFunctionNames() const { + virtual std::vector GetFunctionNames() const { std::vector results; for (auto it : name_to_function_) { results.push_back(it.first); @@ -94,7 +148,7 @@ class FunctionRegistry::FunctionRegistryImpl { return results; } - Result GetFunctionOptionsType( + virtual Result GetFunctionOptionsType( const std::string& name) const { auto it = name_to_options_type_.find(name); if (it == name_to_options_type_.end()) { @@ -103,7 +157,7 @@ class FunctionRegistry::FunctionRegistryImpl { return it->second; } - int num_functions() const { return static_cast(name_to_function_.size()); } + virtual int num_functions() const { return static_cast(name_to_function_.size()); } private: std::mutex lock_; @@ -111,24 +165,132 @@ class FunctionRegistry::FunctionRegistryImpl { std::unordered_map name_to_options_type_; }; +class FunctionRegistry::NestedFunctionRegistryImpl + : public FunctionRegistry::FunctionRegistryImpl { + public: + explicit NestedFunctionRegistryImpl(FunctionRegistry::FunctionRegistryImpl* parent) + : parent_(parent) {} + + Status CanAddFunction(std::shared_ptr function, + bool allow_overwrite) override { + return parent_->CanAddFunction(function, allow_overwrite) & + FunctionRegistry::FunctionRegistryImpl::CanAddFunction(function, + allow_overwrite); + } + + Status AddFunction(std::shared_ptr function, bool allow_overwrite) override { + return parent_->CanAddFunction(function, allow_overwrite) & + FunctionRegistry::FunctionRegistryImpl::AddFunction(function, allow_overwrite); + } + + Status CanAddAlias(const std::string& target_name, + const std::string& source_name) override { + Status st = + FunctionRegistry::FunctionRegistryImpl::CanAddAlias(target_name, source_name); + return st.ok() ? st : parent_->CanAddAlias(target_name, source_name); + } + + Status AddAlias(const std::string& target_name, + const std::string& source_name) override { + Status st = + FunctionRegistry::FunctionRegistryImpl::AddAlias(target_name, source_name); + return st.ok() ? st : parent_->AddAlias(target_name, source_name); + } + + Status CanAddFunctionOptionsType(const FunctionOptionsType* options_type, + bool allow_overwrite = false) override { + return parent_->CanAddFunctionOptionsType(options_type, allow_overwrite) & + FunctionRegistry::FunctionRegistryImpl::CanAddFunctionOptionsType( + options_type, allow_overwrite); + } + + Status AddFunctionOptionsType(const FunctionOptionsType* options_type, + bool allow_overwrite = false) override { + return parent_->CanAddFunctionOptionsType(options_type, allow_overwrite) & + FunctionRegistry::FunctionRegistryImpl::AddFunctionOptionsType( + options_type, allow_overwrite); + } + + Result> GetFunction(const std::string& name) const override { + auto func_res = FunctionRegistry::FunctionRegistryImpl::GetFunction(name); + if (func_res.ok()) { + return func_res; + } + return parent_->GetFunction(name); + } + + std::vector GetFunctionNames() const override { + auto names = parent_->GetFunctionNames(); + auto more_names = FunctionRegistry::FunctionRegistryImpl::GetFunctionNames(); + names.insert(names.end(), std::make_move_iterator(more_names.begin()), + std::make_move_iterator(more_names.end())); + return names; + } + + Result GetFunctionOptionsType( + const std::string& name) const override { + auto options_type_res = + FunctionRegistry::FunctionRegistryImpl::GetFunctionOptionsType(name); + if (options_type_res.ok()) { + return options_type_res; + } + return parent_->GetFunctionOptionsType(name); + } + + int num_functions() const override { + return parent_->num_functions() + + FunctionRegistry::FunctionRegistryImpl::num_functions(); + } + + private: + FunctionRegistry::FunctionRegistryImpl* parent_; +}; + std::unique_ptr FunctionRegistry::Make() { return std::unique_ptr(new FunctionRegistry()); } -FunctionRegistry::FunctionRegistry() { impl_.reset(new FunctionRegistryImpl()); } +std::unique_ptr FunctionRegistry::Make(FunctionRegistry* parent) { + return std::unique_ptr(new FunctionRegistry( + new FunctionRegistry::NestedFunctionRegistryImpl(&*parent->impl_))); +} + +std::unique_ptr FunctionRegistry::Make( + std::unique_ptr parent) { + return FunctionRegistry::Make(&*parent); +} + +FunctionRegistry::FunctionRegistry() : FunctionRegistry(new FunctionRegistryImpl()) {} + +FunctionRegistry::FunctionRegistry(FunctionRegistryImpl* impl) { impl_.reset(impl); } FunctionRegistry::~FunctionRegistry() {} +Status FunctionRegistry::CanAddFunction(std::shared_ptr function, + bool allow_overwrite) { + return impl_->CanAddFunction(std::move(function), allow_overwrite); +} + Status FunctionRegistry::AddFunction(std::shared_ptr function, bool allow_overwrite) { return impl_->AddFunction(std::move(function), allow_overwrite); } +Status FunctionRegistry::CanAddAlias(const std::string& target_name, + const std::string& source_name) { + return impl_->CanAddAlias(target_name, source_name); +} + Status FunctionRegistry::AddAlias(const std::string& target_name, const std::string& source_name) { return impl_->AddAlias(target_name, source_name); } +Status FunctionRegistry::CanAddFunctionOptionsType( + const FunctionOptionsType* options_type, bool allow_overwrite) { + return impl_->CanAddFunctionOptionsType(options_type, allow_overwrite); +} + Status FunctionRegistry::AddFunctionOptionsType(const FunctionOptionsType* options_type, bool allow_overwrite) { return impl_->AddFunctionOptionsType(options_type, allow_overwrite); diff --git a/cpp/src/arrow/compute/registry.h b/cpp/src/arrow/compute/registry.h index e83036db6ac..de074e10d92 100644 --- a/cpp/src/arrow/compute/registry.h +++ b/cpp/src/arrow/compute/registry.h @@ -45,20 +45,42 @@ class FunctionOptionsType; /// lower-level function execution. class ARROW_EXPORT FunctionRegistry { public: - ~FunctionRegistry(); + virtual ~FunctionRegistry(); /// \brief Construct a new registry. Most users only need to use the global /// registry static std::unique_ptr Make(); + /// \brief Construct a new nested registry with the given parent. Most users only need + /// to use the global registry + static std::unique_ptr Make(FunctionRegistry* parent); + + /// \brief Construct a new nested registry with the given parent. Most users only need + /// to use the global registry + static std::unique_ptr Make(std::unique_ptr parent); + + /// \brief Checks whether a new function can be added to the registry. Returns + /// Status::KeyError if a function with the same name is already registered + Status CanAddFunction(std::shared_ptr function, bool allow_overwrite = false); + /// \brief Add a new function to the registry. Returns Status::KeyError if a /// function with the same name is already registered Status AddFunction(std::shared_ptr function, bool allow_overwrite = false); - /// \brief Add aliases for the given function name. Returns Status::KeyError if the + /// \brief Checks whether an alias can be added for the given function name. Returns + /// Status::KeyError if the function with the given name is not registered + Status CanAddAlias(const std::string& target_name, const std::string& source_name); + + /// \brief Add alias for the given function name. Returns Status::KeyError if the /// function with the given name is not registered Status AddAlias(const std::string& target_name, const std::string& source_name); + /// \brief Checks whether a new function options type can be added to the registry. + /// Returns Status::KeyError if a function options type with the same name is already + /// registered + Status CanAddFunctionOptionsType(const FunctionOptionsType* options_type, + bool allow_overwrite = false); + /// \brief Add a new function options type to the registry. Returns Status::KeyError if /// a function options type with the same name is already registered Status AddFunctionOptionsType(const FunctionOptionsType* options_type, @@ -84,6 +106,11 @@ class ARROW_EXPORT FunctionRegistry { // Use PIMPL pattern to not have std::unordered_map here class FunctionRegistryImpl; std::unique_ptr impl_; + + explicit FunctionRegistry(FunctionRegistryImpl* impl); + + class NestedFunctionRegistryImpl; + friend class NestedFunctionRegistryImpl; }; /// \brief Return the process-global function registry diff --git a/cpp/src/arrow/compute/registry_test.cc b/cpp/src/arrow/compute/registry_test.cc index faf47a46f68..319b6be7c08 100644 --- a/cpp/src/arrow/compute/registry_test.cc +++ b/cpp/src/arrow/compute/registry_test.cc @@ -27,37 +27,44 @@ #include "arrow/status.h" #include "arrow/testing/gtest_util.h" #include "arrow/util/macros.h" +#include "arrow/util/make_unique.h" namespace arrow { namespace compute { -class TestRegistry : public ::testing::Test { - public: - void SetUp() { registry_ = FunctionRegistry::Make(); } +using MakeFunctionRegistry = std::function()>; +using GetNumFunctions = std::function; +using GetFunctionNames = std::function()>; +using TestRegistryParams = + std::tuple; - protected: - std::unique_ptr registry_; -}; +struct TestRegistry : public ::testing::TestWithParam {}; -TEST_F(TestRegistry, CreateBuiltInRegistry) { +TEST(TestRegistry, CreateBuiltInRegistry) { // This does DCHECK_OK internally for now so this will fail in debug builds // if there is a problem initializing the global function registry FunctionRegistry* registry = GetFunctionRegistry(); ARROW_UNUSED(registry); } -TEST_F(TestRegistry, Basics) { - ASSERT_EQ(0, registry_->num_functions()); +TEST_P(TestRegistry, Basics) { + auto registry_factory = std::get<0>(GetParam()); + auto registry_ = registry_factory(); + auto get_num_funcs = std::get<1>(GetParam()); + int n_funcs = get_num_funcs(); + auto get_func_names = std::get<2>(GetParam()); + std::vector func_names = get_func_names(); + ASSERT_EQ(n_funcs + 0, registry_->num_functions()); std::shared_ptr func = std::make_shared( "f1", Arity::Unary(), /*doc=*/FunctionDoc::Empty()); ASSERT_OK(registry_->AddFunction(func)); - ASSERT_EQ(1, registry_->num_functions()); + ASSERT_EQ(n_funcs + 1, registry_->num_functions()); func = std::make_shared("f0", Arity::Binary(), /*doc=*/FunctionDoc::Empty()); ASSERT_OK(registry_->AddFunction(func)); - ASSERT_EQ(2, registry_->num_functions()); + ASSERT_EQ(n_funcs + 2, registry_->num_functions()); ASSERT_OK_AND_ASSIGN(std::shared_ptr f1, registry_->GetFunction("f1")); ASSERT_EQ("f1", f1->name()); @@ -75,7 +82,10 @@ TEST_F(TestRegistry, Basics) { ASSERT_OK_AND_ASSIGN(f1, registry_->GetFunction("f1")); ASSERT_EQ(Function::SCALAR_AGGREGATE, f1->kind()); - std::vector expected_names = {"f0", "f1"}; + std::vector expected_names(func_names); + for (auto name : {"f0", "f1"}) { + expected_names.push_back(name); + } ASSERT_EQ(expected_names, registry_->GetFunctionNames()); // Aliases @@ -85,5 +95,137 @@ TEST_F(TestRegistry, Basics) { ASSERT_EQ(func, f2); } +INSTANTIATE_TEST_SUITE_P( + TestRegistry, TestRegistry, + testing::Values( + std::make_tuple( + static_cast([]() { return FunctionRegistry::Make(); }), + []() { return 0; }, []() { return std::vector{}; }, "default"), + std::make_tuple( + static_cast([]() { + return FunctionRegistry::Make(GetFunctionRegistry()); + }), + []() { return GetFunctionRegistry()->num_functions(); }, + []() { return GetFunctionRegistry()->GetFunctionNames(); }, "nested"))); + +TEST(TestRegistry, RegisterTempFunctions) { + auto default_registry = GetFunctionRegistry(); + constexpr int rounds = 3; + for (int i = 0; i < rounds; i++) { + auto registry = FunctionRegistry::Make(default_registry); + for (std::string func_name : {"f1", "f2"}) { + std::shared_ptr func = std::make_shared( + func_name, Arity::Unary(), /*doc=*/FunctionDoc::Empty()); + ASSERT_OK(registry->CanAddFunction(func)); + ASSERT_OK(registry->AddFunction(func)); + ASSERT_RAISES(KeyError, registry->CanAddFunction(func)); + ASSERT_RAISES(KeyError, registry->AddFunction(func)); + ASSERT_OK(default_registry->CanAddFunction(func)); + } + } +} + +TEST(TestRegistry, RegisterTempAliases) { + auto default_registry = GetFunctionRegistry(); + std::vector func_names = default_registry->GetFunctionNames(); + constexpr int rounds = 3; + for (int i = 0; i < rounds; i++) { + auto registry = FunctionRegistry::Make(default_registry); + for (std::string func_name : func_names) { + std::string alias_name = "alias_of_" + func_name; + std::shared_ptr func = std::make_shared( + func_name, Arity::Unary(), /*doc=*/FunctionDoc::Empty()); + ASSERT_RAISES(KeyError, registry->GetFunction(alias_name)); + ASSERT_OK(registry->CanAddAlias(alias_name, func_name)); + ASSERT_OK(registry->AddAlias(alias_name, func_name)); + ASSERT_OK(registry->GetFunction(alias_name)); + ASSERT_OK(default_registry->GetFunction(func_name)); + ASSERT_RAISES(KeyError, default_registry->GetFunction(alias_name)); + } + } +} + +template +class ExampleOptions : public FunctionOptions { + public: + explicit ExampleOptions(std::shared_ptr value); + std::shared_ptr value; +}; + +template +class ExampleOptionsType : public FunctionOptionsType { + public: + static const FunctionOptionsType* GetInstance() { + static std::unique_ptr instance(new ExampleOptionsType()); + return instance.get(); + } + const char* type_name() const override { + static std::string name = std::string("example") + std::to_string(N); + return name.c_str(); + } + std::string Stringify(const FunctionOptions& options) const override { + return type_name(); + } + bool Compare(const FunctionOptions& options, + const FunctionOptions& other) const override { + return true; + } + std::unique_ptr Copy(const FunctionOptions& options) const override { + const auto& opts = static_cast&>(options); + return arrow::internal::make_unique>(opts.value); + } +}; +template +ExampleOptions::ExampleOptions(std::shared_ptr value) + : FunctionOptions(ExampleOptionsType::GetInstance()), value(std::move(value)) {} + +TEST(TestRegistry, RegisterTempFunctionOptionsType) { + auto default_registry = GetFunctionRegistry(); + std::vector options_types = { + ExampleOptionsType<1>::GetInstance(), + ExampleOptionsType<2>::GetInstance(), + }; + constexpr int rounds = 3; + for (int i = 0; i < rounds; i++) { + auto registry = FunctionRegistry::Make(default_registry); + for (auto options_type : options_types) { + ASSERT_OK(registry->CanAddFunctionOptionsType(options_type)); + ASSERT_OK(registry->AddFunctionOptionsType(options_type)); + ASSERT_RAISES(KeyError, registry->CanAddFunctionOptionsType(options_type)); + ASSERT_RAISES(KeyError, registry->AddFunctionOptionsType(options_type)); + ASSERT_OK(default_registry->CanAddFunctionOptionsType(options_type)); + } + } +} + +TEST(TestRegistry, RegisterNestedFunctions) { + auto default_registry = GetFunctionRegistry(); + std::shared_ptr func1 = std::make_shared( + "f1", Arity::Unary(), /*doc=*/FunctionDoc::Empty()); + std::shared_ptr func2 = std::make_shared( + "f2", Arity::Unary(), /*doc=*/FunctionDoc::Empty()); + constexpr int rounds = 3; + for (int i = 0; i < rounds; i++) { + auto registry1 = FunctionRegistry::Make(default_registry); + + ASSERT_OK(registry1->CanAddFunction(func1)); + ASSERT_OK(registry1->AddFunction(func1)); + + for (int j = 0; j < rounds; j++) { + auto registry2 = FunctionRegistry::Make(registry1.get()); + + ASSERT_OK(registry2->CanAddFunction(func2)); + ASSERT_OK(registry2->AddFunction(func2)); + ASSERT_RAISES(KeyError, registry2->CanAddFunction(func2)); + ASSERT_RAISES(KeyError, registry2->AddFunction(func2)); + ASSERT_OK(default_registry->CanAddFunction(func2)); + } + + ASSERT_RAISES(KeyError, registry1->CanAddFunction(func1)); + ASSERT_RAISES(KeyError, registry1->AddFunction(func1)); + ASSERT_OK(default_registry->CanAddFunction(func1)); + } +} + } // namespace compute } // namespace arrow diff --git a/cpp/src/arrow/compute/registry_util.cc b/cpp/src/arrow/compute/registry_util.cc new file mode 100644 index 00000000000..f116b68c8a6 --- /dev/null +++ b/cpp/src/arrow/compute/registry_util.cc @@ -0,0 +1,30 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#include "arrow/compute/registry.h" + +#include "arrow/compute/registry.h" + +namespace arrow { +namespace compute { + +std::unique_ptr MakeFunctionRegistry() { + return FunctionRegistry::Make(GetFunctionRegistry()); +} + +} // namespace compute +} // namespace arrow diff --git a/cpp/src/arrow/compute/registry_util.h b/cpp/src/arrow/compute/registry_util.h new file mode 100644 index 00000000000..14e9bc5381c --- /dev/null +++ b/cpp/src/arrow/compute/registry_util.h @@ -0,0 +1,33 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +// NOTE: API is EXPERIMENTAL and will change without going through a +// deprecation cycle + +#pragma once + +#include "arrow/compute/registry.h" +#include "arrow/util/visibility.h" + +namespace arrow { +namespace compute { + +/// \brief Make a nested function registry with the default one as parent +ARROW_EXPORT std::unique_ptr MakeFunctionRegistry(); + +} // namespace compute +} // namespace arrow diff --git a/cpp/src/arrow/engine/substrait/ext_test.cc b/cpp/src/arrow/engine/substrait/ext_test.cc new file mode 100644 index 00000000000..482212d75a6 --- /dev/null +++ b/cpp/src/arrow/engine/substrait/ext_test.cc @@ -0,0 +1,263 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#include "arrow/engine/substrait/extension_set.h" +#include "arrow/engine/substrait/util.h" + +#include +#include +#include +#include + +#include "arrow/testing/gtest_util.h" +#include "arrow/testing/matchers.h" + +using testing::ElementsAre; +using testing::Eq; +using testing::HasSubstr; +using testing::UnorderedElementsAre; + +namespace arrow { + +using internal::checked_cast; + +namespace engine { + +// an extension-id-registry provider to be used as a test parameter +// +// we cannot pass a pointer to a nested registry as a test parameter because the +// shared_ptr in which it is made would not be held and get destructed too early, +// nor can we pass a shared_ptr to the default nested registry as a test parameter +// because it is global and must never be cleaned up, so we pass a shared_ptr to a +// provider that either owns or does not own the registry it provides, depending +// on the case. +struct ExtensionIdRegistryProvider { + virtual ExtensionIdRegistry* get() const = 0; +}; + +struct DefaultExtensionIdRegistryProvider : public ExtensionIdRegistryProvider { + virtual ~DefaultExtensionIdRegistryProvider() {} + ExtensionIdRegistry* get() const override { return default_extension_id_registry(); } +}; + +struct NestedExtensionIdRegistryProvider : public ExtensionIdRegistryProvider { + virtual ~NestedExtensionIdRegistryProvider() {} + std::shared_ptr registry_ = substrait::MakeExtensionIdRegistry(); + ExtensionIdRegistry* get() const override { return &*registry_; } +}; + +using Id = ExtensionIdRegistry::Id; + +bool operator==(const Id& id1, const Id& id2) { + return id1.uri == id2.uri && id1.name == id2.name; +} + +bool operator!=(const Id& id1, const Id& id2) { return !(id1 == id2); } + +struct TypeName { + std::shared_ptr type; + util::string_view name; +}; + +static const std::vector kTypeNames = { + TypeName{uint8(), "u8"}, + TypeName{uint16(), "u16"}, + TypeName{uint32(), "u32"}, + TypeName{uint64(), "u64"}, + TypeName{float16(), "fp16"}, + TypeName{null(), "null"}, + TypeName{month_interval(), "interval_month"}, + TypeName{day_time_interval(), "interval_day_milli"}, + TypeName{month_day_nano_interval(), "interval_month_day_nano"}, +}; + +static const std::vector kFunctionNames = { + "add", +}; + +static const std::vector kTempFunctionNames = { + "temp_func_1", + "temp_func_2", +}; + +static const std::vector kTempTypeNames = { + TypeName{timestamp(TimeUnit::SECOND, "temp_tz_1"), "temp_type_1"}, + TypeName{timestamp(TimeUnit::SECOND, "temp_tz_2"), "temp_type_2"}, +}; + +using ExtensionIdRegistryParams = + std::tuple, std::string>; + +struct ExtensionIdRegistryTest + : public testing::TestWithParam {}; + +TEST_P(ExtensionIdRegistryTest, GetTypes) { + auto provider = std::get<0>(GetParam()); + auto registry = provider->get(); + + for (TypeName e : kTypeNames) { + auto id = Id{kArrowExtTypesUri, e.name}; + for (auto typerec_opt : {registry->GetType(id), registry->GetType(*e.type)}) { + ASSERT_TRUE(typerec_opt); + auto typerec = typerec_opt.value(); + ASSERT_EQ(id, typerec.id); + ASSERT_EQ(*e.type, *typerec.type); + } + } +} + +TEST_P(ExtensionIdRegistryTest, ReregisterTypes) { + auto provider = std::get<0>(GetParam()); + auto registry = provider->get(); + + for (TypeName e : kTypeNames) { + auto id = Id{kArrowExtTypesUri, e.name}; + ASSERT_RAISES(Invalid, registry->CanRegisterType(id, e.type)); + ASSERT_RAISES(Invalid, registry->RegisterType(id, e.type)); + } +} + +TEST_P(ExtensionIdRegistryTest, GetFunctions) { + auto provider = std::get<0>(GetParam()); + auto registry = provider->get(); + + for (util::string_view name : kFunctionNames) { + auto id = Id{kArrowExtTypesUri, name}; + for (auto funcrec_opt : {registry->GetFunction(id), registry->GetFunction(name)}) { + ASSERT_TRUE(funcrec_opt); + auto funcrec = funcrec_opt.value(); + ASSERT_EQ(id, funcrec.id); + ASSERT_EQ(name, funcrec.function_name); + } + } +} + +TEST_P(ExtensionIdRegistryTest, ReregisterFunctions) { + auto provider = std::get<0>(GetParam()); + auto registry = provider->get(); + + for (util::string_view name : kFunctionNames) { + auto id = Id{kArrowExtTypesUri, name}; + ASSERT_RAISES(Invalid, registry->CanRegisterFunction(id, name.to_string())); + ASSERT_RAISES(Invalid, registry->RegisterFunction(id, name.to_string())); + } +} + +INSTANTIATE_TEST_SUITE_P( + Substrait, ExtensionIdRegistryTest, + testing::Values( + std::make_tuple(std::make_shared(), + "default"), + std::make_tuple(std::make_shared(), + "nested"))); + +TEST(ExtensionIdRegistryTest, RegisterTempTypes) { + auto default_registry = default_extension_id_registry(); + constexpr int rounds = 3; + for (int i = 0; i < rounds; i++) { + auto registry = substrait::MakeExtensionIdRegistry(); + + for (TypeName e : kTempTypeNames) { + auto id = Id{kArrowExtTypesUri, e.name}; + ASSERT_OK(registry->CanRegisterType(id, e.type)); + ASSERT_OK(registry->RegisterType(id, e.type)); + ASSERT_RAISES(Invalid, registry->CanRegisterType(id, e.type)); + ASSERT_RAISES(Invalid, registry->RegisterType(id, e.type)); + ASSERT_OK(default_registry->CanRegisterType(id, e.type)); + } + } +} + +TEST(ExtensionIdRegistryTest, RegisterTempFunctions) { + auto default_registry = default_extension_id_registry(); + constexpr int rounds = 3; + for (int i = 0; i < rounds; i++) { + auto registry = substrait::MakeExtensionIdRegistry(); + + for (util::string_view name : kTempFunctionNames) { + auto id = Id{kArrowExtTypesUri, name}; + ASSERT_OK(registry->CanRegisterFunction(id, name.to_string())); + ASSERT_OK(registry->RegisterFunction(id, name.to_string())); + ASSERT_RAISES(Invalid, registry->CanRegisterFunction(id, name.to_string())); + ASSERT_RAISES(Invalid, registry->RegisterFunction(id, name.to_string())); + ASSERT_OK(default_registry->CanRegisterFunction(id, name.to_string())); + } + } +} + +TEST(ExtensionIdRegistryTest, RegisterNestedTypes) { + std::shared_ptr type1 = kTempTypeNames[0].type; + std::shared_ptr type2 = kTempTypeNames[1].type; + auto id1 = Id{kArrowExtTypesUri, kTempTypeNames[0].name}; + auto id2 = Id{kArrowExtTypesUri, kTempTypeNames[1].name}; + + auto default_registry = default_extension_id_registry(); + constexpr int rounds = 3; + for (int i = 0; i < rounds; i++) { + auto registry1 = nested_extension_id_registry(default_registry); + + ASSERT_OK(registry1->CanRegisterType(id1, type1)); + ASSERT_OK(registry1->RegisterType(id1, type1)); + + for (int j = 0; j < rounds; j++) { + auto registry2 = nested_extension_id_registry(&*registry1); + + ASSERT_OK(registry2->CanRegisterType(id2, type2)); + ASSERT_OK(registry2->RegisterType(id2, type2)); + ASSERT_RAISES(Invalid, registry2->CanRegisterType(id2, type2)); + ASSERT_RAISES(Invalid, registry2->RegisterType(id2, type2)); + ASSERT_OK(default_registry->CanRegisterType(id2, type2)); + } + + ASSERT_RAISES(Invalid, registry1->CanRegisterType(id1, type1)); + ASSERT_RAISES(Invalid, registry1->RegisterType(id1, type1)); + ASSERT_OK(default_registry->CanRegisterType(id1, type1)); + } +} + +TEST(ExtensionIdRegistryTest, RegisterNestedFunctions) { + util::string_view name1 = kTempFunctionNames[0]; + util::string_view name2 = kTempFunctionNames[1]; + auto id1 = Id{kArrowExtTypesUri, name1}; + auto id2 = Id{kArrowExtTypesUri, name2}; + + auto default_registry = default_extension_id_registry(); + constexpr int rounds = 3; + for (int i = 0; i < rounds; i++) { + auto registry1 = substrait::MakeExtensionIdRegistry(); + + ASSERT_OK(registry1->CanRegisterFunction(id1, name1.to_string())); + ASSERT_OK(registry1->RegisterFunction(id1, name1.to_string())); + + for (int j = 0; j < rounds; j++) { + auto registry2 = substrait::MakeExtensionIdRegistry(); + + ASSERT_OK(registry2->CanRegisterFunction(id2, name2.to_string())); + ASSERT_OK(registry2->RegisterFunction(id2, name2.to_string())); + ASSERT_RAISES(Invalid, registry2->CanRegisterFunction(id2, name2.to_string())); + ASSERT_RAISES(Invalid, registry2->RegisterFunction(id2, name2.to_string())); + ASSERT_OK(default_registry->CanRegisterFunction(id2, name2.to_string())); + } + + ASSERT_RAISES(Invalid, registry1->CanRegisterFunction(id1, name1.to_string())); + ASSERT_RAISES(Invalid, registry1->RegisterFunction(id1, name1.to_string())); + ASSERT_OK(default_registry->CanRegisterFunction(id1, name1.to_string())); + } +} + +} // namespace engine +} // namespace arrow diff --git a/cpp/src/arrow/engine/substrait/extension_set.h b/cpp/src/arrow/engine/substrait/extension_set.h index 638a354c6f2..b841a8db0e8 100644 --- a/cpp/src/arrow/engine/substrait/extension_set.h +++ b/cpp/src/arrow/engine/substrait/extension_set.h @@ -21,6 +21,7 @@ #include #include +#include #include "arrow/engine/substrait/visibility.h" #include "arrow/type_fwd.h" @@ -95,6 +96,16 @@ class ARROW_ENGINE_EXPORT ExtensionIdRegistry { virtual Status CanRegisterFunction(Id, const std::string& arrow_function_name) const = 0; virtual Status RegisterFunction(Id, std::string arrow_function_name) = 0; + + /// \brief Add a symbol external to the plan yet used in an Id. + /// + /// This ensures the symbol, which is only viewed but not held by the Id, lives while + /// the extension set does. Symbols appearing in the Substrait plan are already held. + const std::string& AddExternalSymbol(const std::string& symbol) { + return *external_symbols.insert(symbol).first; + } +private: + std::set external_symbols; }; constexpr util::string_view kArrowExtTypesUri = diff --git a/cpp/src/arrow/engine/substrait/util.cc b/cpp/src/arrow/engine/substrait/util.cc index 53356d86a57..26804af731c 100644 --- a/cpp/src/arrow/engine/substrait/util.cc +++ b/cpp/src/arrow/engine/substrait/util.cc @@ -111,13 +111,15 @@ class SubstraitExecutor { } // namespace Result> ExecuteSerializedPlan( - const Buffer& substrait_buffer, const ExtensionIdRegistry* registry) { + const Buffer& substrait_buffer, const ExtensionIdRegistry* extid_registry, + compute::FunctionRegistry* func_registry) { ARROW_ASSIGN_OR_RAISE(auto plan, compute::ExecPlan::Make()); // TODO(ARROW-15732) compute::ExecContext exec_context(arrow::default_memory_pool(), - ::arrow::internal::GetCpuThreadPool()); + ::arrow::internal::GetCpuThreadPool(), + func_registry); SubstraitExecutor executor(std::move(plan), exec_context); - RETURN_NOT_OK(executor.Init(substrait_buffer, registry)); + RETURN_NOT_OK(executor.Init(substrait_buffer, extid_registry)); ARROW_ASSIGN_OR_RAISE(auto sink_reader, executor.Execute()); return sink_reader; } diff --git a/cpp/src/arrow/engine/substrait/util.h b/cpp/src/arrow/engine/substrait/util.h index 27fc3a8aaed..98cf33cadb6 100644 --- a/cpp/src/arrow/engine/substrait/util.h +++ b/cpp/src/arrow/engine/substrait/util.h @@ -18,6 +18,7 @@ #pragma once #include +#include "arrow/compute/registry.h" #include "arrow/engine/substrait/api.h" #include "arrow/util/iterator.h" #include "arrow/util/optional.h" @@ -30,7 +31,8 @@ namespace substrait { /// \brief Retrieve a RecordBatchReader from a Substrait plan. ARROW_ENGINE_EXPORT Result> ExecuteSerializedPlan( - const Buffer& substrait_buffer, const ExtensionIdRegistry* registry = NULLPTR); + const Buffer& substrait_buffer, const ExtensionIdRegistry* registry = NULLPTR, + compute::FunctionRegistry* func_registry = NULLPTR); /// \brief Get a Serialized Plan from a Substrait JSON plan. /// This is a helper method for Python tests. diff --git a/cpp/src/arrow/python/udf.cc b/cpp/src/arrow/python/udf.cc index 41309d27bb7..a79c419973d 100644 --- a/cpp/src/arrow/python/udf.cc +++ b/cpp/src/arrow/python/udf.cc @@ -103,7 +103,8 @@ struct PythonUdf { } // namespace Status RegisterScalarFunction(PyObject* user_function, ScalarUdfWrapperCallback wrapper, - const ScalarUdfOptions& options) { + const ScalarUdfOptions& options, + compute::FunctionRegistry* registry) { if (!PyCallable_Check(user_function)) { return Status::TypeError("Expected a callable Python object."); } @@ -123,7 +124,9 @@ Status RegisterScalarFunction(PyObject* user_function, ScalarUdfWrapperCallback kernel.mem_allocation = compute::MemAllocation::NO_PREALLOCATE; kernel.null_handling = compute::NullHandling::COMPUTED_NO_PREALLOCATE; RETURN_NOT_OK(scalar_func->AddKernel(std::move(kernel))); - auto registry = compute::GetFunctionRegistry(); + if (registry == NULLPTR) { + registry = compute::GetFunctionRegistry(); + } RETURN_NOT_OK(registry->AddFunction(std::move(scalar_func))); return Status::OK(); } diff --git a/cpp/src/arrow/python/udf.h b/cpp/src/arrow/python/udf.h index 4ab3e7cc72b..138f9ee4908 100644 --- a/cpp/src/arrow/python/udf.h +++ b/cpp/src/arrow/python/udf.h @@ -50,9 +50,9 @@ using ScalarUdfWrapperCallback = std::function; /// \brief register a Scalar user-defined-function from Python -Status ARROW_PYTHON_EXPORT RegisterScalarFunction(PyObject* user_function, - ScalarUdfWrapperCallback wrapper, - const ScalarUdfOptions& options); +Status ARROW_PYTHON_EXPORT RegisterScalarFunction( + PyObject* user_function, ScalarUdfWrapperCallback wrapper, + const ScalarUdfOptions& options, compute::FunctionRegistry* registry = NULLPTR); } // namespace py diff --git a/python/pyarrow/_compute.pxd b/python/pyarrow/_compute.pxd index 8b09cbd445e..d65eb2e000e 100644 --- a/python/pyarrow/_compute.pxd +++ b/python/pyarrow/_compute.pxd @@ -27,6 +27,9 @@ cdef class ScalarUdfContext(_Weakrefable): cdef void init(self, const CScalarUdfContext& c_context) +cdef class BaseFunctionRegistry(_Weakrefable): + cdef CFunctionRegistry* registry + cdef class FunctionOptions(_Weakrefable): cdef: shared_ptr[CFunctionOptions] wrapped diff --git a/python/pyarrow/_compute.pyx b/python/pyarrow/_compute.pyx index 96da505f763..728a873a2a2 100644 --- a/python/pyarrow/_compute.pyx +++ b/python/pyarrow/_compute.pyx @@ -465,11 +465,15 @@ cdef _pack_compute_args(object values, vector[CDatum]* out): "for compute function") -cdef class FunctionRegistry(_Weakrefable): - cdef CFunctionRegistry* registry +cdef class FunctionRegistry(BaseFunctionRegistry): + cdef unique_ptr[CFunctionRegistry] up_registry - def __init__(self): - self.registry = GetFunctionRegistry() + def __init__(self, registry=None): + if registry is None: + self.registry = GetFunctionRegistry() + else: + self.registry = pyarrow_unwrap_function_registry(registry) + self.up_registry.reset(self.registry) def list_functions(self): """ @@ -502,6 +506,13 @@ def function_registry(): return _global_func_registry +def make_function_registry(): + up_registry = MakeFunctionRegistry() + c_registry = up_registry.get() + up_registry.release() + return FunctionRegistry(pyarrow_wrap_function_registry(c_registry)) + + def get_function(name): """ Get a function by name. @@ -2366,7 +2377,7 @@ def _get_scalar_udf_context(memory_pool, batch_length): def register_scalar_function(func, function_name, function_doc, in_types, - out_type): + out_type, func_registry=None): """ Register a user-defined scalar function. @@ -2407,6 +2418,8 @@ def register_scalar_function(func, function_name, function_doc, in_types, arity. out_type : DataType Output type of the function. + func_registry : FunctionRegistry + Optional function registry to use instead of the default global one. Examples -------- @@ -2444,6 +2457,7 @@ def register_scalar_function(func, function_name, function_doc, in_types, PyObject* c_function shared_ptr[CDataType] c_out_type CScalarUdfOptions c_options + CFunctionRegistry* c_func_registry if callable(func): c_function = func @@ -2485,5 +2499,11 @@ def register_scalar_function(func, function_name, function_doc, in_types, c_options.input_types = c_in_types c_options.output_type = c_out_type + if func_registry is None: + c_func_registry = NULL + else: + c_func_registry = pyarrow_unwrap_function_registry(func_registry) + check_status(RegisterScalarFunction(c_function, - &_scalar_udf_callback, c_options)) + &_scalar_udf_callback, + c_options, c_func_registry)) diff --git a/python/pyarrow/_exec_plan.pxd b/python/pyarrow/_exec_plan.pxd index eafe19491b9..4d7529eba64 100644 --- a/python/pyarrow/_exec_plan.pxd +++ b/python/pyarrow/_exec_plan.pxd @@ -22,4 +22,4 @@ from pyarrow.includes.libarrow cimport * cdef is_supported_execplan_output_type(output_type) -cdef execplan(inputs, output_type, vector[CDeclaration] plan, c_bool use_threads=*) +cdef execplan(inputs, output_type, vector[CDeclaration] plan, c_bool use_threads=*, CFunctionRegistry* c_func_registry=*) diff --git a/python/pyarrow/_exec_plan.pyx b/python/pyarrow/_exec_plan.pyx index c47ddebb894..7248ecc5c97 100644 --- a/python/pyarrow/_exec_plan.pyx +++ b/python/pyarrow/_exec_plan.pyx @@ -39,7 +39,7 @@ Initialize() # Initialise support for Datasets in ExecPlan cdef is_supported_execplan_output_type(output_type): return output_type in [Table, InMemoryDataset] -cdef execplan(inputs, output_type, vector[CDeclaration] plan, c_bool use_threads=True): +cdef execplan(inputs, output_type, vector[CDeclaration] plan, c_bool use_threads=True, CFunctionRegistry* c_func_registry=NULL): """ Internal Function to create an ExecPlan and run it. @@ -87,7 +87,7 @@ cdef execplan(inputs, output_type, vector[CDeclaration] plan, c_bool use_threads c_executor = NULL c_exec_context = make_shared[CExecContext]( - c_default_memory_pool(), c_executor) + c_default_memory_pool(), c_executor, c_func_registry) c_exec_plan = GetResultValue(CExecPlan.Make(c_exec_context.get())) plan_iter = plan.begin() diff --git a/python/pyarrow/_substrait.pyx b/python/pyarrow/_substrait.pyx index 4128b37df85..e458f70939e 100644 --- a/python/pyarrow/_substrait.pyx +++ b/python/pyarrow/_substrait.pyx @@ -15,9 +15,14 @@ # specific language governing permissions and limitations # under the License. +import base64 +import cloudpickle +import inspect + # cython: language_level = 3 from cython.operator cimport dereference as deref, preincrement as inc +from pyarrow import compute as pc from pyarrow import Buffer from pyarrow.lib import frombytes, tobytes from pyarrow.lib cimport * @@ -26,30 +31,37 @@ from pyarrow.includes.libarrow_substrait cimport * from pyarrow._exec_plan cimport is_supported_execplan_output_type, execplan +from pyarrow._compute import make_function_registry def make_extension_id_registry(): cdef: - shared_ptr[CExtensionIdRegistry] c_registry + shared_ptr[CExtensionIdRegistry] c_extid_registry ExtensionIdRegistry registry with nogil: - c_registry = MakeExtensionIdRegistry() + c_extid_registry = MakeExtensionIdRegistry() + + return pyarrow_wrap_extension_id_registry(c_extid_registry) + - return pyarrow_wrap_extension_id_registry(c_registry) +def _get_udf_code(func): + return frombytes(base64.b64encode(cloudpickle.dumps(func))) -def get_udf_declarations(plan, registry): + +def get_udf_declarations(plan, extid_registry): cdef: shared_ptr[CBuffer] c_buf_plan - shared_ptr[CExtensionIdRegistry] c_registry + shared_ptr[CExtensionIdRegistry] c_extid_registry vector[CUdfDeclaration] c_decls vector[CUdfDeclaration].iterator c_decls_iter vector[pair[shared_ptr[CDataType], c_bool]].iterator c_in_types_iter c_buf_plan = pyarrow_unwrap_buffer(plan) - c_registry = pyarrow_unwrap_extension_id_registry(registry) + c_extid_registry = pyarrow_unwrap_extension_id_registry(extid_registry) with nogil: - c_res_decls = DeserializePlanUdfs(deref(c_buf_plan), &deref(c_registry)) + c_res_decls = DeserializePlanUdfs( + deref(c_buf_plan), c_extid_registry.get()) c_decls = GetResultValue(c_res_decls) decls = [] @@ -73,33 +85,60 @@ def get_udf_declarations(plan, registry): inc(c_decls_iter) return decls -def register_function(registry, id_uri, id_name, arrow_function_name): + +def register_function(extid_registry, id_uri, id_name, arrow_function_name): cdef: c_string c_id_uri, c_id_name, c_arrow_function_name - shared_ptr[CExtensionIdRegistry] c_registry + shared_ptr[CExtensionIdRegistry] c_extid_registry CStatus c_status - c_registry = pyarrow_unwrap_extension_id_registry(registry) + c_extid_registry = pyarrow_unwrap_extension_id_registry(extid_registry) c_id_uri = id_uri or default_extension_types_uri() c_id_name = tobytes(id_name) c_arrow_function_name = tobytes(arrow_function_name) with nogil: c_status = RegisterFunction( - deref(c_registry), c_id_uri, c_id_name, c_arrow_function_name + deref(c_extid_registry), c_id_uri, c_id_name, c_arrow_function_name ) check_status(c_status) -def run_query_as(plan, registry, output_type=RecordBatchReader): + +def register_udf_declarations(plan, extid_registry, func_registry, udf_decls=None): + if udf_decls is None: + udf_decls = get_udf_declarations(plan, extid_registry) + for udf_decl in udf_decls: + udf_name = udf_decl["name"] + udf_func = cloudpickle.loads( + base64.b64decode(tobytes(udf_decl["code"]))) + udf_arg_names = list(inspect.signature(udf_func).parameters.keys()) + udf_arg_types = udf_decl["input_types"] + register_function(extid_registry, None, udf_name, udf_name) + pc.register_scalar_function( + udf_func, + udf_name, + {"summary": udf_decl["summary"], + "description": udf_decl["description"]}, + # range start from 1 to skip over udf scalar context argument + {udf_arg_names[i]: udf_arg_types[i][0] + for i in range(1 ,len(udf_arg_types))}, + udf_decl["output_type"][0], + func_registry, + ) + + +def run_query_as(plan, extid_registry, func_registry, output_type=RecordBatchReader): if output_type == RecordBatchReader: - return run_query(plan, registry) - return _run_query(plan, registry, output_type) + return run_query(plan, extid_registry, func_registry) + return _run_query(plan, extid_registry, func_registry, output_type) -def _run_query(plan, registry, output_type): + +def _run_query(plan, extid_registry, func_registry, output_type): cdef: shared_ptr[CBuffer] c_buf_plan - shared_ptr[CExtensionIdRegistry] c_registry + shared_ptr[CExtensionIdRegistry] c_extid_registry + CFunctionRegistry* c_func_registry CResult[vector[CDeclaration]] c_res_decls vector[CDeclaration] c_decls @@ -107,13 +146,16 @@ def _run_query(plan, registry, output_type): raise TypeError(f"Unsupported output type {output_type}") c_buf_plan = pyarrow_unwrap_buffer(plan) - c_registry = pyarrow_unwrap_extension_id_registry(registry) + c_extid_registry = pyarrow_unwrap_extension_id_registry(extid_registry) + c_func_registry = pyarrow_unwrap_function_registry(func_registry) with nogil: - c_res_decls = DeserializePlans(deref(c_buf_plan), &deref(c_registry)) + c_res_decls = DeserializePlans( + deref(c_buf_plan), c_extid_registry.get()) c_decls = GetResultValue(c_res_decls) - return execplan([], output_type, c_decls) + return execplan([], output_type, c_decls, True, c_func_registry) + -def run_query(plan, registry): +def run_query(plan, extid_registry, func_registry): """ Execute a Substrait plan and read the results as a RecordBatchReader. @@ -121,19 +163,27 @@ def run_query(plan, registry): ---------- plan : Buffer The serialized Substrait plan to execute. + extid_registry : ExtensionIdRegistry + The extension-id-registry to execute with. + func_registry : FunctionRegistry + The function registry to execute with. """ cdef: shared_ptr[CBuffer] c_buf_plan - shared_ptr[CExtensionIdRegistry] c_registry + shared_ptr[CExtensionIdRegistry] c_extid_registry + CFunctionRegistry* c_func_registry CResult[shared_ptr[CRecordBatchReader]] c_res_reader shared_ptr[CRecordBatchReader] c_reader RecordBatchReader reader c_buf_plan = pyarrow_unwrap_buffer(plan) - c_registry = pyarrow_unwrap_extension_id_registry(registry) + c_extid_registry = pyarrow_unwrap_extension_id_registry(extid_registry) + c_func_registry = pyarrow_unwrap_function_registry(func_registry) with nogil: - c_res_reader = ExecuteSerializedPlan(deref(c_buf_plan), &deref(c_registry)) + c_res_reader = ExecuteSerializedPlan( + deref(c_buf_plan), c_extid_registry.get(), c_func_registry + ) c_reader = GetResultValue(c_res_reader) diff --git a/python/pyarrow/compute.pxi b/python/pyarrow/compute.pxi new file mode 100644 index 00000000000..f2684ba4211 --- /dev/null +++ b/python/pyarrow/compute.pxi @@ -0,0 +1,33 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +# separating out this base class is easier than unifying it into +# FunctionRegistry, which lives outside libarrow +cdef class BaseFunctionRegistry(_Weakrefable): + cdef CFunctionRegistry* registry + +cdef class ExtensionIdRegistry(_Weakrefable): + def __cinit__(self): + self.registry = NULL + + def __init__(self): + raise TypeError("Do not call ExtensionIdRegistry's constructor directly, use " + "the `MakeExtensionIdRegistry` function instead.") + + cdef void init(self, shared_ptr[CExtensionIdRegistry]& registry): + self.sp_registry = registry + self.registry = registry.get() diff --git a/python/pyarrow/includes/libarrow.pxd b/python/pyarrow/includes/libarrow.pxd index c28bc8c0416..7dfda8ebf18 100644 --- a/python/pyarrow/includes/libarrow.pxd +++ b/python/pyarrow/includes/libarrow.pxd @@ -2695,9 +2695,13 @@ cdef extern from "arrow/python/udf.h" namespace "arrow::py": shared_ptr[CDataType] output_type CStatus RegisterScalarFunction(PyObject* function, - function[CallbackUdf] wrapper, const CScalarUdfOptions& options) + function[CallbackUdf] wrapper, const CScalarUdfOptions& options, + CFunctionRegistry* registry) cdef extern from "arrow/engine/substrait/extension_set.h" namespace "arrow::engine" nogil: - cdef cppclass CExtensionIdRegistry" arrow::engine::ExtensionIdRegistry": - pass + cdef cppclass CExtensionIdRegistry" arrow::engine::ExtensionIdRegistry" + +cdef extern from "arrow/compute/registry_util.h" namespace "arrow::compute" nogil: + + unique_ptr[CFunctionRegistry] MakeFunctionRegistry() diff --git a/python/pyarrow/includes/libarrow_substrait.pxd b/python/pyarrow/includes/libarrow_substrait.pxd index b21497b8340..30d772b8b1a 100644 --- a/python/pyarrow/includes/libarrow_substrait.pxd +++ b/python/pyarrow/includes/libarrow_substrait.pxd @@ -30,8 +30,8 @@ cdef extern from "arrow/engine/substrait/serde.h" namespace "arrow::engine" nogi c_string code c_string summary c_string description - vector[pair[shared_ptr[CDataType], c_bool]] input_types; - pair[shared_ptr[CDataType], c_bool] output_type; + vector[pair[shared_ptr[CDataType], c_bool]] input_types + pair[shared_ptr[CDataType], c_bool] output_type CResult[vector[CUdfDeclaration]] DeserializePlanUdfs(const CBuffer& substrait_buffer, const CExtensionIdRegistry* registry) @@ -39,7 +39,7 @@ cdef extern from "arrow/engine/substrait/util.h" namespace "arrow::engine::subst shared_ptr[CExtensionIdRegistry] MakeExtensionIdRegistry() CStatus RegisterFunction(CExtensionIdRegistry& registry, const c_string& id_uri, const c_string& id_name, const c_string& arrow_function_name) - CResult[shared_ptr[CRecordBatchReader]] ExecuteSerializedPlan(const CBuffer& substrait_buffer, const CExtensionIdRegistry* registry) + CResult[shared_ptr[CRecordBatchReader]] ExecuteSerializedPlan(const CBuffer& substrait_buffer, const CExtensionIdRegistry* extid_registry, CFunctionRegistry* func_registry) CResult[shared_ptr[CBuffer]] SerializeJsonPlan(const c_string& substrait_json) CResult[vector[CDeclaration]] DeserializePlans(const CBuffer& substrait_buffer, const CExtensionIdRegistry* registry) diff --git a/python/pyarrow/lib.pxd b/python/pyarrow/lib.pxd index 35678304729..5b6a5416958 100644 --- a/python/pyarrow/lib.pxd +++ b/python/pyarrow/lib.pxd @@ -594,6 +594,7 @@ cdef public object pyarrow_wrap_tensor(const shared_ptr[CTensor]& sp_tensor) cdef public object pyarrow_wrap_batch(const shared_ptr[CRecordBatch]& cbatch) cdef public object pyarrow_wrap_table(const shared_ptr[CTable]& ctable) +cdef public object pyarrow_wrap_function_registry(CFunctionRegistry* cregistry) cdef public object pyarrow_wrap_extension_id_registry(shared_ptr[CExtensionIdRegistry]& cregistry) # Unwrapping Python -> C++ @@ -623,4 +624,5 @@ cdef public shared_ptr[CTensor] pyarrow_unwrap_tensor(object tensor) cdef public shared_ptr[CRecordBatch] pyarrow_unwrap_batch(object batch) cdef public shared_ptr[CTable] pyarrow_unwrap_table(object table) +cdef public CFunctionRegistry* pyarrow_unwrap_function_registry(object registry) cdef public shared_ptr[CExtensionIdRegistry] pyarrow_unwrap_extension_id_registry(object registry) diff --git a/python/pyarrow/lib.pyx b/python/pyarrow/lib.pyx index a665ea59c6e..653ce1064cd 100644 --- a/python/pyarrow/lib.pyx +++ b/python/pyarrow/lib.pyx @@ -169,6 +169,9 @@ include "builder.pxi" # Column, Table, Record Batch include "table.pxi" +# Compute registries +include "compute.pxi" + # Tensors include "tensor.pxi" diff --git a/python/pyarrow/public-api.pxi b/python/pyarrow/public-api.pxi index ee146c6dab8..607cc475553 100644 --- a/python/pyarrow/public-api.pxi +++ b/python/pyarrow/public-api.pxi @@ -418,10 +418,23 @@ cdef api object pyarrow_wrap_batch( return batch +cdef api bint pyarrow_is_function_registry(object registry): + return isinstance(registry, BaseFunctionRegistry) + + cdef api bint pyarrow_is_extension_id_registry(object registry): return isinstance(registry, ExtensionIdRegistry) +cdef api CFunctionRegistry* pyarrow_unwrap_function_registry(object registry): + cdef BaseFunctionRegistry reg + if pyarrow_is_function_registry(registry): + reg = (registry) + return reg.registry + + return NULL + + cdef api shared_ptr[CExtensionIdRegistry] pyarrow_unwrap_extension_id_registry(object registry): cdef ExtensionIdRegistry reg if pyarrow_is_extension_id_registry(registry): @@ -431,6 +444,12 @@ cdef api shared_ptr[CExtensionIdRegistry] pyarrow_unwrap_extension_id_registry(o return shared_ptr[CExtensionIdRegistry]() +cdef api object pyarrow_wrap_function_registry(CFunctionRegistry* cregistry): + cdef BaseFunctionRegistry registry = BaseFunctionRegistry.__new__(BaseFunctionRegistry) + registry.registry = cregistry + return registry + + cdef api object pyarrow_wrap_extension_id_registry( shared_ptr[CExtensionIdRegistry]& cregistry): cdef ExtensionIdRegistry registry = ExtensionIdRegistry.__new__(ExtensionIdRegistry) diff --git a/python/pyarrow/substrait.py b/python/pyarrow/substrait.py index eb7d9795491..3584bed3cb8 100644 --- a/python/pyarrow/substrait.py +++ b/python/pyarrow/substrait.py @@ -16,9 +16,12 @@ # under the License. from pyarrow._substrait import ( # noqa + make_function_registry, make_extension_id_registry, + _get_udf_code, get_udf_declarations, register_function, + register_udf_declarations, run_query_as, run_query, _parse_json_plan, diff --git a/python/pyarrow/table.pxi b/python/pyarrow/table.pxi index 5c0afdab8b6..d4025f734ca 100644 --- a/python/pyarrow/table.pxi +++ b/python/pyarrow/table.pxi @@ -2568,20 +2568,6 @@ cdef class RecordBatch(_PandasConvertible): return pyarrow_wrap_batch(c_batch) -cdef class ExtensionIdRegistry(_Weakrefable): - - def __cinit__(self): - self.registry = NULL - - def __init__(self): - raise TypeError("Do not call ExtensionIdRegistry's constructor directly, use " - "the `MakeExtensionIdRegistry` function instead.") - - cdef void init(self, shared_ptr[CExtensionIdRegistry]& registry): - self.sp_registry = registry - self.registry = registry.get() - - def _reconstruct_record_batch(columns, schema): """ Internal: reconstruct RecordBatch from pickled components. diff --git a/python/pyarrow/tests/test_substrait.py b/python/pyarrow/tests/test_substrait.py index 8df35bbba44..b5c9d966f2a 100644 --- a/python/pyarrow/tests/test_substrait.py +++ b/python/pyarrow/tests/test_substrait.py @@ -22,6 +22,7 @@ import pyarrow as pa from pyarrow.lib import tobytes from pyarrow.lib import ArrowInvalid +from pyarrow.substrait import make_extension_id_registry try: import pyarrow.substrait as substrait @@ -74,7 +75,9 @@ def test_run_serialized_query(tmpdir): buf = pa._substrait._parse_json_plan(query) - reader = substrait.run_query(buf) + extid_registry = substrait.make_extension_id_registry() + func_registry = substrait.make_function_registry() + reader = substrait.run_query(buf, extid_registry, func_registry) res_tb = reader.read_all() assert table.select(["foo"]) == res_tb.select(["foo"]) @@ -88,6 +91,8 @@ def test_invalid_plan(): } """ buf = pa._substrait._parse_json_plan(tobytes(query)) + extid_registry = substrait.make_extension_id_registry() + func_registry = substrait.make_function_registry() exec_message = "Empty substrait plan is passed." with pytest.raises(ArrowInvalid, match=exec_message): - substrait.run_query(buf) + substrait.run_query(buf, extid_registry, func_registry) diff --git a/python/pyarrow/tests/test_udf.py b/python/pyarrow/tests/test_udf.py index bdcda9b2ff5..f358427324f 100644 --- a/python/pyarrow/tests/test_udf.py +++ b/python/pyarrow/tests/test_udf.py @@ -16,8 +16,6 @@ # under the License. -import base64 -import cloudpickle import os import pytest @@ -517,9 +515,11 @@ def demean_and_zscore(scl_udf_ctx, v): std = v.std() return v - mean, (v - mean) / std + def twice_and_add_2(scl_udf_ctx, v): return 2 * v, v + 2 + def twice(scl_udf_ctx, v): return DoubleArray.from_pandas((2 * v.to_pandas())) @@ -651,9 +651,10 @@ def test_elementwise_scalar_udf_in_substrait_query(tmpdir): } """ # TODO: replace with ipc when the support is finalized in C++ - code = frombytes(base64.b64encode(cloudpickle.dumps(twice))) + code = substrait._get_udf_code(twice) path = os.path.join(str(tmpdir), 'substrait_data.arrow') - table = pa.table([["a", "b", "a", "b", "a"], [1.0, 2.0, 3.0, 4.0, 5.0]], names=['key', 'value']) + table = pa.table([["a", "b", "a", "b", "a"], [ + 1.0, 2.0, 3.0, 4.0, 5.0]], names=['key', 'value']) with pa.ipc.RecordBatchFileWriter(path, schema=table.schema) as writer: writer.write_table(table) @@ -661,23 +662,14 @@ def test_elementwise_scalar_udf_in_substrait_query(tmpdir): plan = substrait._parse_json_plan(query) - registry = substrait.make_extension_id_registry() - udf_decls = substrait.get_udf_declarations(plan, registry) - for udf_decl in udf_decls: - substrait.register_function(registry, None, udf_decl["name"], udf_decl["name"]) - pc.register_scalar_function( - cloudpickle.loads(base64.b64decode(tobytes(udf_decl["code"]))), - udf_decl["name"], - {"summary": udf_decl["summary"], "description": udf_decl["description"]}, - {f"arg$i": type_nullable_pair[0] - for i, type_nullable_pair in enumerate(udf_decl["input_types"]) - }, - udf_decl["output_type"][0], - ) + extid_registry = substrait.make_extension_id_registry() + func_registry = substrait.make_function_registry() + substrait.register_udf_declarations(plan, extid_registry, func_registry) - reader = substrait.run_query(plan, registry) + reader = substrait.run_query(plan, extid_registry, func_registry) res_tb = reader.read_all() assert len(res_tb) == len(table) - assert res_tb.schema == pa.schema([("key", pa.string()), ("value", pa.float64()), ("twice", pa.float64())]) + assert res_tb.schema == pa.schema( + [("key", pa.string()), ("value", pa.float64()), ("twice", pa.float64())]) assert res_tb.drop(["twice"]) == table From 90f20d08a5179f8278f4d71fa1e811b5e42d5400 Mon Sep 17 00:00:00 2001 From: Yaron Gvili Date: Wed, 1 Jun 2022 10:06:28 -0400 Subject: [PATCH 14/19] Fix parameter order and doc of DeserializePlan functions --- cpp/src/arrow/engine/substrait/serde.cc | 20 ++++++++++---------- cpp/src/arrow/engine/substrait/serde.h | 11 +++++++---- cpp/src/arrow/engine/substrait/serde_test.cc | 4 ++-- cpp/src/arrow/engine/substrait/util.cc | 2 -- 4 files changed, 19 insertions(+), 18 deletions(-) diff --git a/cpp/src/arrow/engine/substrait/serde.cc b/cpp/src/arrow/engine/substrait/serde.cc index d2ccdd87711..bc30bd2b635 100644 --- a/cpp/src/arrow/engine/substrait/serde.cc +++ b/cpp/src/arrow/engine/substrait/serde.cc @@ -61,7 +61,7 @@ Result DeserializeRelation(const Buffer& buf, static Result> DeserializePlans( const Buffer& buf, const std::string& factory_name, std::function()> options_factory, - ExtensionSet* ext_set_out, const ExtensionIdRegistry* registry) { + const ExtensionIdRegistry* registry, ExtensionSet* ext_set_out) { ARROW_ASSIGN_OR_RAISE(auto plan, ParseFromBuffer(buf)); ARROW_ASSIGN_OR_RAISE(auto ext_set, GetExtensionSetFromPlan(plan, registry)); @@ -93,7 +93,7 @@ static Result> DeserializePlans( Result> DeserializePlans( const Buffer& buf, const ConsumerFactory& consumer_factory, - ExtensionSet* ext_set_out, const ExtensionIdRegistry* registry) { + const ExtensionIdRegistry* registry, ExtensionSet* ext_set_out) { return DeserializePlans( buf, "consuming_sink", @@ -102,23 +102,23 @@ Result> DeserializePlans( compute::ConsumingSinkNodeOptions{consumer_factory()} ); }, - ext_set_out, - registry + registry, + ext_set_out ); } Result> DeserializePlans( const Buffer& buf, const WriteOptionsFactory& write_options_factory, - ExtensionSet* ext_set_out, const ExtensionIdRegistry* registry) { - return DeserializePlans(buf, "write", write_options_factory, ext_set_out, registry); + const ExtensionIdRegistry* registry, ExtensionSet* ext_set_out) { + return DeserializePlans(buf, "write", write_options_factory, registry, ext_set_out); } Result DeserializePlan(const Buffer& buf, const ConsumerFactory& consumer_factory, - ExtensionSet* ext_set_out, - const ExtensionIdRegistry* registry) { + const ExtensionIdRegistry* registry, + ExtensionSet* ext_set_out) { ARROW_ASSIGN_OR_RAISE(auto declarations, - DeserializePlans(buf, consumer_factory, ext_set_out, registry)); + DeserializePlans(buf, consumer_factory, registry, ext_set_out)); if (declarations.size() > 1) { return Status::Invalid("DeserializePlan does not support multiple root relations"); } else { @@ -151,7 +151,7 @@ Result> DeserializePlanUdfs( ARROW_ASSIGN_OR_RAISE(auto output_type, FromProto(udf.output_type(), ext_set)); decls.push_back(std::move(UdfDeclaration{ fn.name(), - udf.code(), + udf.code(), udf.summary(), udf.description(), std::move(input_types), diff --git a/cpp/src/arrow/engine/substrait/serde.h b/cpp/src/arrow/engine/substrait/serde.h index 775fc4e0a4c..1a025b07d65 100644 --- a/cpp/src/arrow/engine/substrait/serde.h +++ b/cpp/src/arrow/engine/substrait/serde.h @@ -45,13 +45,14 @@ using ConsumerFactory = std::function /// message /// \param[in] consumer_factory factory function for generating the node that consumes /// the batches produced by each toplevel Substrait relation +/// \param[in] registry an extension-id-registry to use, or null for the default one. /// \param[out] ext_set_out if non-null, the extension mapping used by the Substrait /// Plan is returned here. /// \return a vector of ExecNode declarations, one for each toplevel relation in the /// Substrait Plan ARROW_ENGINE_EXPORT Result> DeserializePlans( const Buffer& buf, const ConsumerFactory& consumer_factory, - ExtensionSet* ext_set_out = NULLPTR, const ExtensionIdRegistry* registry = NULLPTR); + const ExtensionIdRegistry* registry = NULLPTR, ExtensionSet* ext_set_out = NULLPTR); /// \brief Deserializes a single-relation Substrait Plan message to an execution plan /// @@ -59,14 +60,15 @@ ARROW_ENGINE_EXPORT Result> DeserializePlans( /// message /// \param[in] consumer_factory factory function for generating the node that consumes /// the batches produced by each toplevel Substrait relation +/// \param[in] registry an extension-id-registry to use, or null for the default one. /// \param[out] ext_set_out if non-null, the extension mapping used by the Substrait /// Plan is returned here. /// \return an ExecNode corresponding to the single toplevel relation in the Substrait /// Plan Result DeserializePlan(const Buffer& buf, const ConsumerFactory& consumer_factory, - ExtensionSet* ext_set_out = NULLPTR, - const ExtensionIdRegistry* registry = NULLPTR); + const ExtensionIdRegistry* registry = NULLPTR, + ExtensionSet* ext_set_out = NULLPTR); /// Factory function type for generating the write options of a node consuming the batches /// produced by each toplevel Substrait relation when deserializing a Substrait Plan. @@ -78,13 +80,14 @@ using WriteOptionsFactory = std::function> DeserializePlans( const Buffer& buf, const WriteOptionsFactory& write_options_factory, - ExtensionSet* ext_set = NULLPTR, const ExtensionIdRegistry* registry = NULLPTR); + const ExtensionIdRegistry* registry = NULLPTR, ExtensionSet* ext_set = NULLPTR); struct ARROW_ENGINE_EXPORT UdfDeclaration { std::string name; diff --git a/cpp/src/arrow/engine/substrait/serde_test.cc b/cpp/src/arrow/engine/substrait/serde_test.cc index fae23f200de..90708bc73f3 100644 --- a/cpp/src/arrow/engine/substrait/serde_test.cc +++ b/cpp/src/arrow/engine/substrait/serde_test.cc @@ -701,7 +701,7 @@ TEST(Substrait, ExtensionSetFromPlan) { auto sink_decls, DeserializePlans( *buf, [] { return std::shared_ptr{nullptr}; }, - &ext_set)); + NULLPTR, &ext_set)); EXPECT_OK_AND_ASSIGN(auto decoded_null_type, ext_set.DecodeType(42)); EXPECT_EQ(decoded_null_type.id.uri, kArrowExtTypesUri); @@ -737,7 +737,7 @@ TEST(Substrait, ExtensionSetFromPlanMissingFunc) { Invalid, DeserializePlans( *buf, [] { return std::shared_ptr{nullptr}; }, - &ext_set)); + NULLPTR, &ext_set)); } Result GetSubstraitJSON() { diff --git a/cpp/src/arrow/engine/substrait/util.cc b/cpp/src/arrow/engine/substrait/util.cc index 26804af731c..9f3d39f5b71 100644 --- a/cpp/src/arrow/engine/substrait/util.cc +++ b/cpp/src/arrow/engine/substrait/util.cc @@ -95,7 +95,6 @@ class SubstraitExecutor { ARROW_ASSIGN_OR_RAISE(declarations_, engine::DeserializePlans(substrait_buffer, consumer_factory, - NULLPTR, registry)); return Status::OK(); } @@ -133,7 +132,6 @@ Result> DeserializePlans( return engine::DeserializePlans( buffer, []() { return std::make_shared(); }, - NULLPTR, registry ); } From 879999ef5527f14f4641846d49e389189ee7213c Mon Sep 17 00:00:00 2001 From: Yaron Gvili Date: Thu, 9 Jun 2022 14:09:33 -0400 Subject: [PATCH 15/19] fix registry scoping --- cpp/src/arrow/compute/exec/filter_node.cc | 3 +- cpp/src/arrow/compute/exec/hash_join.h | 2 +- cpp/src/arrow/compute/exec/hash_join_node.cc | 11 +- cpp/src/arrow/compute/exec/project_node.cc | 3 +- .../kernels/base_arithmetic_internal.h | 96 ++++++ cpp/src/arrow/compute/registry.cc | 275 ++++++++---------- cpp/src/arrow/compute/registry.h | 66 +++-- cpp/src/arrow/compute/registry_test.cc | 23 +- cpp/src/arrow/engine/substrait/serde_test.cc | 8 +- cpp/src/arrow/engine/substrait/util.cc | 2 +- python/pyarrow/_compute.pxd | 2 +- python/pyarrow/_compute.pyx | 19 +- python/pyarrow/_substrait.pyx | 26 +- python/pyarrow/public-api.pxi | 4 + python/pyarrow/tests/test_udf.py | 28 ++ 15 files changed, 344 insertions(+), 224 deletions(-) diff --git a/cpp/src/arrow/compute/exec/filter_node.cc b/cpp/src/arrow/compute/exec/filter_node.cc index 0c849cb0435..b424da35f85 100644 --- a/cpp/src/arrow/compute/exec/filter_node.cc +++ b/cpp/src/arrow/compute/exec/filter_node.cc @@ -50,7 +50,8 @@ class FilterNode : public MapNode { auto filter_expression = filter_options.filter_expression; if (!filter_expression.IsBound()) { - ARROW_ASSIGN_OR_RAISE(filter_expression, filter_expression.Bind(*schema)); + ARROW_ASSIGN_OR_RAISE(filter_expression, + filter_expression.Bind(*schema, plan->exec_context())); } if (filter_expression.type()->id() != Type::BOOL) { diff --git a/cpp/src/arrow/compute/exec/hash_join.h b/cpp/src/arrow/compute/exec/hash_join.h index 9739cbc6436..84685989a1e 100644 --- a/cpp/src/arrow/compute/exec/hash_join.h +++ b/cpp/src/arrow/compute/exec/hash_join.h @@ -59,7 +59,7 @@ class ARROW_EXPORT HashJoinSchema { const std::string& right_field_name_prefix); Result BindFilter(Expression filter, const Schema& left_schema, - const Schema& right_schema); + const Schema& right_schema, ExecContext* exec_context); std::shared_ptr MakeOutputSchema(const std::string& left_field_name_suffix, const std::string& right_field_name_suffix); diff --git a/cpp/src/arrow/compute/exec/hash_join_node.cc b/cpp/src/arrow/compute/exec/hash_join_node.cc index e47d6095542..7f33df15b85 100644 --- a/cpp/src/arrow/compute/exec/hash_join_node.cc +++ b/cpp/src/arrow/compute/exec/hash_join_node.cc @@ -336,7 +336,8 @@ std::shared_ptr HashJoinSchema::MakeOutputSchema( Result HashJoinSchema::BindFilter(Expression filter, const Schema& left_schema, - const Schema& right_schema) { + const Schema& right_schema, + ExecContext* exec_context) { if (filter.IsBound() || filter == literal(true)) { return std::move(filter); } @@ -367,7 +368,7 @@ Result HashJoinSchema::BindFilter(Expression filter, filter); // Step 3: Bind - ARROW_ASSIGN_OR_RAISE(filter, filter.Bind(filter_schema)); + ARROW_ASSIGN_OR_RAISE(filter, filter.Bind(filter_schema, exec_context)); if (filter.type()->id() != Type::BOOL) { return Status::TypeError("Filter expression must evaluate to bool, but ", filter.ToString(), " evaluates to ", @@ -499,9 +500,9 @@ class HashJoinNode : public ExecNode { join_options.output_suffix_for_left, join_options.output_suffix_for_right)); } - ARROW_ASSIGN_OR_RAISE( - Expression filter, - schema_mgr->BindFilter(join_options.filter, left_schema, right_schema)); + ARROW_ASSIGN_OR_RAISE(Expression filter, + schema_mgr->BindFilter(join_options.filter, left_schema, + right_schema, plan->exec_context())); // Generate output schema std::shared_ptr output_schema = schema_mgr->MakeOutputSchema( diff --git a/cpp/src/arrow/compute/exec/project_node.cc b/cpp/src/arrow/compute/exec/project_node.cc index b8fb64c5d54..cad8d7c45ae 100644 --- a/cpp/src/arrow/compute/exec/project_node.cc +++ b/cpp/src/arrow/compute/exec/project_node.cc @@ -64,7 +64,8 @@ class ProjectNode : public MapNode { int i = 0; for (auto& expr : exprs) { if (!expr.IsBound()) { - ARROW_ASSIGN_OR_RAISE(expr, expr.Bind(*inputs[0]->output_schema())); + ARROW_ASSIGN_OR_RAISE( + expr, expr.Bind(*inputs[0]->output_schema(), plan->exec_context())); } fields[i] = field(std::move(names[i]), expr.type()); ++i; diff --git a/cpp/src/arrow/compute/kernels/base_arithmetic_internal.h b/cpp/src/arrow/compute/kernels/base_arithmetic_internal.h index 1707ed7c137..bc51d53f316 100644 --- a/cpp/src/arrow/compute/kernels/base_arithmetic_internal.h +++ b/cpp/src/arrow/compute/kernels/base_arithmetic_internal.h @@ -425,6 +425,102 @@ struct DivideChecked { } }; +// if at least one argument is NaN, returns the first one that is NaN +struct Minimum { + template + static constexpr enable_if_floating_value Call(KernelContext*, Arg0 left, Arg1 right, + Status*) { + return std::isnan(left) ? left : left < right ? left : right; + } + + template + static constexpr enable_if_unsigned_integer_value Call(KernelContext*, Arg0 left, + Arg1 right, Status*) { + return std::isnan(left) ? left : left < right ? left : right; + } + + template + static constexpr enable_if_signed_integer_value Call(KernelContext*, Arg0 left, + Arg1 right, Status*) { + return std::isnan(left) ? left : left < right ? left : right; + } + + template + static enable_if_decimal_value Call(KernelContext*, Arg0 left, Arg1 right, Status*) { + return left < right ? left : right; + } +}; + +// if both arguments are NaN, returns the first one +struct MinimumChecked { + template + static enable_if_integer_value Call(KernelContext*, Arg0 left, Arg1 right, + Status* st) { + static_assert(std::is_same::value && std::is_same::value, ""); + return std::isnan(left) && std::isnan(right) ? left : left < right ? left : right; + } + + template + static enable_if_floating_value Call(KernelContext*, Arg0 left, Arg1 right, + Status*) { + static_assert(std::is_same::value && std::is_same::value, ""); + return std::isnan(left) && std::isnan(right) ? left : left < right ? left : right; + } + + template + static enable_if_decimal_value Call(KernelContext*, Arg0 left, Arg1 right, Status*) { + return left < right ? left : right; + } +}; + +// if at least one argument is NaN, returns the first one that is NaN +struct Maximum { + template + static constexpr enable_if_floating_value Call(KernelContext*, Arg0 left, Arg1 right, + Status*) { + return std::isnan(left) ? left : left > right ? left : right; + } + + template + static constexpr enable_if_unsigned_integer_value Call(KernelContext*, Arg0 left, + Arg1 right, Status*) { + return std::isnan(left) ? left : left > right ? left : right; + } + + template + static constexpr enable_if_signed_integer_value Call(KernelContext*, Arg0 left, + Arg1 right, Status*) { + return std::isnan(left) ? left : left > right ? left : right; + } + + template + static enable_if_decimal_value Call(KernelContext*, Arg0 left, Arg1 right, Status*) { + return left > right ? left : right; + } +}; + +// if both arguments are NaN, returns the first one +struct MaximumChecked { + template + static enable_if_integer_value Call(KernelContext*, Arg0 left, Arg1 right, + Status* st) { + static_assert(std::is_same::value && std::is_same::value, ""); + return std::isnan(left) && std::isnan(right) ? left : left > right ? left : right; + } + + template + static enable_if_floating_value Call(KernelContext*, Arg0 left, Arg1 right, + Status*) { + static_assert(std::is_same::value && std::is_same::value, ""); + return std::isnan(left) && std::isnan(right) ? left : left > right ? left : right; + } + + template + static enable_if_decimal_value Call(KernelContext*, Arg0 left, Arg1 right, Status*) { + return left > right ? left : right; + } +}; + struct Negate { template static constexpr enable_if_floating_value Call(KernelContext*, Arg arg, Status*) { diff --git a/cpp/src/arrow/compute/registry.cc b/cpp/src/arrow/compute/registry.cc index 82b677ae72f..fe7c6fa8ad1 100644 --- a/cpp/src/arrow/compute/registry.cc +++ b/cpp/src/arrow/compute/registry.cc @@ -33,114 +33,73 @@ namespace arrow { namespace compute { class FunctionRegistry::FunctionRegistryImpl { - private: - using FuncAdd = std::function)>; - - const FuncAdd kFuncAddNoOp = [](const std::string& name, - std::shared_ptr func) {}; - const FuncAdd kFuncAddDo = [this](const std::string& name, - std::shared_ptr func) { - name_to_function_[name] = func; - }; - - Status DoAddFunction(std::shared_ptr function, bool allow_overwrite, - FuncAdd add) { -#ifndef NDEBUG - // This validates docstrings extensively, so don't waste time on it - // in release builds. - RETURN_NOT_OK(function->Validate()); -#endif - - std::lock_guard mutation_guard(lock_); - - const std::string& name = function->name(); - auto it = name_to_function_.find(name); - if (it != name_to_function_.end() && !allow_overwrite) { - return Status::KeyError("Already have a function registered with name: ", name); - } - add(name, std::move(function)); - return Status::OK(); - } - public: - virtual Status CanAddFunction(std::shared_ptr function, - bool allow_overwrite) { - return DoAddFunction(function, allow_overwrite, kFuncAddNoOp); - } - - virtual Status AddFunction(std::shared_ptr function, bool allow_overwrite) { - return DoAddFunction(function, allow_overwrite, kFuncAddDo); - } - - private: - Status DoAddAlias(const std::string& target_name, const std::string& source_name, - FuncAdd add) { - std::lock_guard mutation_guard(lock_); + explicit FunctionRegistryImpl(FunctionRegistryImpl* parent = NULLPTR) + : parent_(parent) {} + ~FunctionRegistryImpl() {} - auto func_res = GetFunction(source_name); // must not acquire the mutex - if (!func_res.ok()) { - return Status::KeyError("No function registered with name: ", source_name); + Status CanAddFunction(std::shared_ptr function, bool allow_overwrite) { + if (parent_ != NULLPTR) { + RETURN_NOT_OK(parent_->CanAddFunction(function, allow_overwrite)); } - add(target_name, func_res.ValueOrDie()); - return Status::OK(); + return DoAddFunction(function, allow_overwrite, /*add=*/false); } - public: - virtual Status CanAddAlias(const std::string& target_name, - const std::string& source_name) { - return DoAddAlias(target_name, source_name, kFuncAddNoOp); + Status AddFunction(std::shared_ptr function, bool allow_overwrite) { + if (parent_ != NULLPTR) { + RETURN_NOT_OK(parent_->CanAddFunction(function, allow_overwrite)); + } + return DoAddFunction(function, allow_overwrite, /*add=*/true); } - virtual Status AddAlias(const std::string& target_name, - const std::string& source_name) { - return DoAddAlias(target_name, source_name, kFuncAddDo); + Status CanAddAlias(const std::string& target_name, const std::string& source_name) { + if (parent_ != NULLPTR) { + RETURN_NOT_OK(parent_->CanAddFunctionName(target_name, + /*allow_overwrite=*/false)); + } + return DoAddAlias(target_name, source_name, /*add=*/false); } - private: - using FuncOptTypeAdd = std::function; - - const FuncOptTypeAdd kFuncOptTypeAddNoOp = [](const FunctionOptionsType* options_type) { - }; - const FuncOptTypeAdd kFuncOptTypeAddDo = - [this](const FunctionOptionsType* options_type) { - name_to_options_type_[options_type->type_name()] = options_type; - }; - - Status DoAddFunctionOptionsType(const FunctionOptionsType* options_type, - bool allow_overwrite, FuncOptTypeAdd add) { - std::lock_guard mutation_guard(lock_); - - const std::string name = options_type->type_name(); - auto it = name_to_options_type_.find(name); - if (it != name_to_options_type_.end() && !allow_overwrite) { - return Status::KeyError( - "Already have a function options type registered with name: ", name); + Status AddAlias(const std::string& target_name, const std::string& source_name) { + if (parent_ != NULLPTR) { + RETURN_NOT_OK(parent_->CanAddFunctionName(target_name, + /*allow_overwrite=*/false)); } - add(options_type); - return Status::OK(); + return DoAddAlias(target_name, source_name, /*add=*/true); } - public: - virtual Status CanAddFunctionOptionsType(const FunctionOptionsType* options_type, - bool allow_overwrite = false) { - return DoAddFunctionOptionsType(options_type, allow_overwrite, kFuncOptTypeAddNoOp); + Status CanAddFunctionOptionsType(const FunctionOptionsType* options_type, + bool allow_overwrite = false) { + if (parent_ != NULLPTR) { + RETURN_NOT_OK(parent_->CanAddFunctionOptionsType(options_type, allow_overwrite)); + } + return DoAddFunctionOptionsType(options_type, allow_overwrite, /*add=*/false); } - virtual Status AddFunctionOptionsType(const FunctionOptionsType* options_type, - bool allow_overwrite = false) { - return DoAddFunctionOptionsType(options_type, allow_overwrite, kFuncOptTypeAddDo); + Status AddFunctionOptionsType(const FunctionOptionsType* options_type, + bool allow_overwrite = false) { + if (parent_ != NULLPTR) { + RETURN_NOT_OK(parent_->CanAddFunctionOptionsType(options_type, allow_overwrite)); + } + return DoAddFunctionOptionsType(options_type, allow_overwrite, /*add=*/true); } - virtual Result> GetFunction(const std::string& name) const { + Result> GetFunction(const std::string& name) const { auto it = name_to_function_.find(name); if (it == name_to_function_.end()) { + if (parent_ != NULLPTR) { + return parent_->GetFunction(name); + } return Status::KeyError("No function registered with name: ", name); } return it->second; } - virtual std::vector GetFunctionNames() const { + std::vector GetFunctionNames() const { std::vector results; + if (parent_ != NULLPTR) { + results = parent_->GetFunctionNames(); + } for (auto it : name_to_function_) { results.push_back(it.first); } @@ -148,102 +107,103 @@ class FunctionRegistry::FunctionRegistryImpl { return results; } - virtual Result GetFunctionOptionsType( + Result GetFunctionOptionsType( const std::string& name) const { auto it = name_to_options_type_.find(name); if (it == name_to_options_type_.end()) { + if (parent_ != NULLPTR) { + return parent_->GetFunctionOptionsType(name); + } return Status::KeyError("No function options type registered with name: ", name); } return it->second; } - virtual int num_functions() const { return static_cast(name_to_function_.size()); } + int num_functions() const { + return (parent_ == NULLPTR ? 0 : parent_->num_functions()) + + static_cast(name_to_function_.size()); + } private: - std::mutex lock_; - std::unordered_map> name_to_function_; - std::unordered_map name_to_options_type_; -}; - -class FunctionRegistry::NestedFunctionRegistryImpl - : public FunctionRegistry::FunctionRegistryImpl { - public: - explicit NestedFunctionRegistryImpl(FunctionRegistry::FunctionRegistryImpl* parent) - : parent_(parent) {} - - Status CanAddFunction(std::shared_ptr function, - bool allow_overwrite) override { - return parent_->CanAddFunction(function, allow_overwrite) & - FunctionRegistry::FunctionRegistryImpl::CanAddFunction(function, - allow_overwrite); + // must not acquire mutex + Status CanAddFunctionName(const std::string& name, bool allow_overwrite) { + if (parent_ != NULLPTR) { + RETURN_NOT_OK(parent_->CanAddFunctionName(name, allow_overwrite)); + } + if (!allow_overwrite) { + auto it = name_to_function_.find(name); + if (it != name_to_function_.end()) { + return Status::KeyError("Already have a function registered with name: ", name); + } + } + return Status::OK(); } - Status AddFunction(std::shared_ptr function, bool allow_overwrite) override { - return parent_->CanAddFunction(function, allow_overwrite) & - FunctionRegistry::FunctionRegistryImpl::AddFunction(function, allow_overwrite); + // must not acquire mutex + Status CanAddOptionsTypeName(const std::string& name, bool allow_overwrite) { + if (parent_ != NULLPTR) { + RETURN_NOT_OK(parent_->CanAddOptionsTypeName(name, allow_overwrite)); + } + if (!allow_overwrite) { + auto it = name_to_options_type_.find(name); + if (it != name_to_options_type_.end()) { + return Status::KeyError( + "Already have a function options type registered with name: ", name); + } + } + return Status::OK(); } - Status CanAddAlias(const std::string& target_name, - const std::string& source_name) override { - Status st = - FunctionRegistry::FunctionRegistryImpl::CanAddAlias(target_name, source_name); - return st.ok() ? st : parent_->CanAddAlias(target_name, source_name); - } + Status DoAddFunction(std::shared_ptr function, bool allow_overwrite, + bool add) { +#ifndef NDEBUG + // This validates docstrings extensively, so don't waste time on it + // in release builds. + RETURN_NOT_OK(function->Validate()); +#endif - Status AddAlias(const std::string& target_name, - const std::string& source_name) override { - Status st = - FunctionRegistry::FunctionRegistryImpl::AddAlias(target_name, source_name); - return st.ok() ? st : parent_->AddAlias(target_name, source_name); - } + std::lock_guard mutation_guard(lock_); - Status CanAddFunctionOptionsType(const FunctionOptionsType* options_type, - bool allow_overwrite = false) override { - return parent_->CanAddFunctionOptionsType(options_type, allow_overwrite) & - FunctionRegistry::FunctionRegistryImpl::CanAddFunctionOptionsType( - options_type, allow_overwrite); + const std::string& name = function->name(); + RETURN_NOT_OK(CanAddFunctionName(name, allow_overwrite)); + if (add) { + name_to_function_[name] = std::move(function); + } + return Status::OK(); } - Status AddFunctionOptionsType(const FunctionOptionsType* options_type, - bool allow_overwrite = false) override { - return parent_->CanAddFunctionOptionsType(options_type, allow_overwrite) & - FunctionRegistry::FunctionRegistryImpl::AddFunctionOptionsType( - options_type, allow_overwrite); - } + Status DoAddAlias(const std::string& target_name, const std::string& source_name, + bool add) { + // source name must exist in this registry or the parent + // check outside mutex, in case GetFunction leads to mutex acquisition + ARROW_ASSIGN_OR_RAISE(auto func, GetFunction(source_name)); - Result> GetFunction(const std::string& name) const override { - auto func_res = FunctionRegistry::FunctionRegistryImpl::GetFunction(name); - if (func_res.ok()) { - return func_res; + std::lock_guard mutation_guard(lock_); + + // target name must be available in this registry and the parent + RETURN_NOT_OK(CanAddFunctionName(target_name, /*allow_overwrite=*/false)); + if (add) { + name_to_function_[target_name] = func; } - return parent_->GetFunction(name); + return Status::OK(); } - std::vector GetFunctionNames() const override { - auto names = parent_->GetFunctionNames(); - auto more_names = FunctionRegistry::FunctionRegistryImpl::GetFunctionNames(); - names.insert(names.end(), std::make_move_iterator(more_names.begin()), - std::make_move_iterator(more_names.end())); - return names; - } + Status DoAddFunctionOptionsType(const FunctionOptionsType* options_type, + bool allow_overwrite, bool add) { + std::lock_guard mutation_guard(lock_); - Result GetFunctionOptionsType( - const std::string& name) const override { - auto options_type_res = - FunctionRegistry::FunctionRegistryImpl::GetFunctionOptionsType(name); - if (options_type_res.ok()) { - return options_type_res; + const std::string name = options_type->type_name(); + RETURN_NOT_OK(CanAddOptionsTypeName(name, /*allow_overwrite=*/false)); + if (add) { + name_to_options_type_[options_type->type_name()] = options_type; } - return parent_->GetFunctionOptionsType(name); - } - - int num_functions() const override { - return parent_->num_functions() + - FunctionRegistry::FunctionRegistryImpl::num_functions(); + return Status::OK(); } - private: - FunctionRegistry::FunctionRegistryImpl* parent_; + FunctionRegistryImpl* parent_; + std::mutex lock_; + std::unordered_map> name_to_function_; + std::unordered_map name_to_options_type_; }; std::unique_ptr FunctionRegistry::Make() { @@ -252,12 +212,7 @@ std::unique_ptr FunctionRegistry::Make() { std::unique_ptr FunctionRegistry::Make(FunctionRegistry* parent) { return std::unique_ptr(new FunctionRegistry( - new FunctionRegistry::NestedFunctionRegistryImpl(&*parent->impl_))); -} - -std::unique_ptr FunctionRegistry::Make( - std::unique_ptr parent) { - return FunctionRegistry::Make(&*parent); + new FunctionRegistry::FunctionRegistryImpl(parent->impl_.get()))); } FunctionRegistry::FunctionRegistry() : FunctionRegistry(new FunctionRegistryImpl()) {} diff --git a/cpp/src/arrow/compute/registry.h b/cpp/src/arrow/compute/registry.h index de074e10d92..97abe7e9fde 100644 --- a/cpp/src/arrow/compute/registry.h +++ b/cpp/src/arrow/compute/registry.h @@ -45,59 +45,66 @@ class FunctionOptionsType; /// lower-level function execution. class ARROW_EXPORT FunctionRegistry { public: - virtual ~FunctionRegistry(); + ~FunctionRegistry(); - /// \brief Construct a new registry. Most users only need to use the global - /// registry + /// \brief Construct a new registry. + /// + /// Most users only need to use the global registry. static std::unique_ptr Make(); - /// \brief Construct a new nested registry with the given parent. Most users only need - /// to use the global registry + /// \brief Construct a new nested registry with the given parent. + /// + /// Most users only need to use the global registry. The returned registry never changes + /// its parent, even when an operation allows overwritting. static std::unique_ptr Make(FunctionRegistry* parent); - /// \brief Construct a new nested registry with the given parent. Most users only need - /// to use the global registry - static std::unique_ptr Make(std::unique_ptr parent); - - /// \brief Checks whether a new function can be added to the registry. Returns - /// Status::KeyError if a function with the same name is already registered + /// \brief Check whether a new function can be added to the registry. + /// + /// \returns Status::KeyError if a function with the same name is already registered. Status CanAddFunction(std::shared_ptr function, bool allow_overwrite = false); - /// \brief Add a new function to the registry. Returns Status::KeyError if a - /// function with the same name is already registered + /// \brief Add a new function to the registry. + /// + /// \returns Status::KeyError if a function with the same name is already registered. Status AddFunction(std::shared_ptr function, bool allow_overwrite = false); - /// \brief Checks whether an alias can be added for the given function name. Returns - /// Status::KeyError if the function with the given name is not registered + /// \brief Check whether an alias can be added for the given function name. + /// + /// \returns Status::KeyError if the function with the given name is not registered. Status CanAddAlias(const std::string& target_name, const std::string& source_name); - /// \brief Add alias for the given function name. Returns Status::KeyError if the - /// function with the given name is not registered + /// \brief Add alias for the given function name. + /// + /// \returns Status::KeyError if the function with the given name is not registered. Status AddAlias(const std::string& target_name, const std::string& source_name); - /// \brief Checks whether a new function options type can be added to the registry. - /// Returns Status::KeyError if a function options type with the same name is already - /// registered + /// \brief Check whether a new function options type can be added to the registry. + /// + /// \returns Status::KeyError if a function options type with the same name is already + /// registered. Status CanAddFunctionOptionsType(const FunctionOptionsType* options_type, bool allow_overwrite = false); - /// \brief Add a new function options type to the registry. Returns Status::KeyError if - /// a function options type with the same name is already registered + /// \brief Add a new function options type to the registry. + /// + /// \returns Status::KeyError if a function options type with the same name is already + /// registered. Status AddFunctionOptionsType(const FunctionOptionsType* options_type, bool allow_overwrite = false); - /// \brief Retrieve a function by name from the registry + /// \brief Retrieve a function by name from the registry. Result> GetFunction(const std::string& name) const; - /// \brief Return vector of all entry names in the registry. Helpful for - /// displaying a manifest of available functions + /// \brief Return vector of all entry names in the registry. + /// + /// Helpful for displaying a manifest of available functions. std::vector GetFunctionNames() const; - /// \brief Retrieve a function options type by name from the registry + /// \brief Retrieve a function options type by name from the registry. Result GetFunctionOptionsType( const std::string& name) const; - /// \brief The number of currently registered functions + /// \brief The number of currently registered functions. int num_functions() const; private: @@ -108,12 +115,9 @@ class ARROW_EXPORT FunctionRegistry { std::unique_ptr impl_; explicit FunctionRegistry(FunctionRegistryImpl* impl); - - class NestedFunctionRegistryImpl; - friend class NestedFunctionRegistryImpl; }; -/// \brief Return the process-global function registry +/// \brief Return the process-global function registry. ARROW_EXPORT FunctionRegistry* GetFunctionRegistry(); } // namespace compute diff --git a/cpp/src/arrow/compute/registry_test.cc b/cpp/src/arrow/compute/registry_test.cc index 319b6be7c08..937515af4ac 100644 --- a/cpp/src/arrow/compute/registry_test.cc +++ b/cpp/src/arrow/compute/registry_test.cc @@ -54,7 +54,7 @@ TEST_P(TestRegistry, Basics) { int n_funcs = get_num_funcs(); auto get_func_names = std::get<2>(GetParam()); std::vector func_names = get_func_names(); - ASSERT_EQ(n_funcs + 0, registry_->num_functions()); + ASSERT_EQ(n_funcs, registry_->num_functions()); std::shared_ptr func = std::make_shared( "f1", Arity::Unary(), /*doc=*/FunctionDoc::Empty()); @@ -86,6 +86,7 @@ TEST_P(TestRegistry, Basics) { for (auto name : {"f0", "f1"}) { expected_names.push_back(name); } + std::sort(expected_names.begin(), expected_names.end()); ASSERT_EQ(expected_names, registry_->GetFunctionNames()); // Aliases @@ -145,22 +146,23 @@ TEST(TestRegistry, RegisterTempAliases) { } } -template +template class ExampleOptions : public FunctionOptions { public: explicit ExampleOptions(std::shared_ptr value); std::shared_ptr value; }; -template +template class ExampleOptionsType : public FunctionOptionsType { public: static const FunctionOptionsType* GetInstance() { - static std::unique_ptr instance(new ExampleOptionsType()); + static std::unique_ptr instance( + new ExampleOptionsType()); return instance.get(); } const char* type_name() const override { - static std::string name = std::string("example") + std::to_string(N); + static std::string name = std::string("example") + std::to_string(kExampleSeqNum); return name.c_str(); } std::string Stringify(const FunctionOptions& options) const override { @@ -171,13 +173,14 @@ class ExampleOptionsType : public FunctionOptionsType { return true; } std::unique_ptr Copy(const FunctionOptions& options) const override { - const auto& opts = static_cast&>(options); - return arrow::internal::make_unique>(opts.value); + const auto& opts = static_cast&>(options); + return arrow::internal::make_unique>(opts.value); } }; -template -ExampleOptions::ExampleOptions(std::shared_ptr value) - : FunctionOptions(ExampleOptionsType::GetInstance()), value(std::move(value)) {} +template +ExampleOptions::ExampleOptions(std::shared_ptr value) + : FunctionOptions(ExampleOptionsType::GetInstance()), + value(std::move(value)) {} TEST(TestRegistry, RegisterTempFunctionOptionsType) { auto default_registry = GetFunctionRegistry(); diff --git a/cpp/src/arrow/engine/substrait/serde_test.cc b/cpp/src/arrow/engine/substrait/serde_test.cc index d2966601408..4f8379d2911 100644 --- a/cpp/src/arrow/engine/substrait/serde_test.cc +++ b/cpp/src/arrow/engine/substrait/serde_test.cc @@ -903,7 +903,7 @@ TEST(Substrait, JoinPlanBasic) { auto sink_decls, DeserializePlans( *buf, [] { return std::shared_ptr{nullptr}; }, - &ext_set)); + NULLPTR, &ext_set)); auto join_decl = sink_decls[0].inputs[0]; @@ -1035,7 +1035,7 @@ TEST(Substrait, JoinPlanInvalidKeyCmp) { Invalid, DeserializePlans( *buf, [] { return std::shared_ptr{nullptr}; }, - &ext_set)); + NULLPTR, &ext_set)); } TEST(Substrait, JoinPlanInvalidExpression) { @@ -1102,7 +1102,7 @@ TEST(Substrait, JoinPlanInvalidExpression) { Invalid, DeserializePlans( *buf, [] { return std::shared_ptr{nullptr}; }, - &ext_set)); + NULLPTR, &ext_set)); } TEST(Substrait, JoinPlanInvalidKeys) { @@ -1170,7 +1170,7 @@ TEST(Substrait, JoinPlanInvalidKeys) { Invalid, DeserializePlans( *buf, [] { return std::shared_ptr{nullptr}; }, - &ext_set)); + NULLPTR, &ext_set)); } } // namespace engine diff --git a/cpp/src/arrow/engine/substrait/util.cc b/cpp/src/arrow/engine/substrait/util.cc index 29b9cf97fd8..c6d45d50c5a 100644 --- a/cpp/src/arrow/engine/substrait/util.cc +++ b/cpp/src/arrow/engine/substrait/util.cc @@ -112,11 +112,11 @@ class SubstraitExecutor { Result> ExecuteSerializedPlan( const Buffer& substrait_buffer, const ExtensionIdRegistry* extid_registry, compute::FunctionRegistry* func_registry) { - ARROW_ASSIGN_OR_RAISE(auto plan, compute::ExecPlan::Make()); // TODO(ARROW-15732) compute::ExecContext exec_context(arrow::default_memory_pool(), ::arrow::internal::GetCpuThreadPool(), func_registry); + ARROW_ASSIGN_OR_RAISE(auto plan, compute::ExecPlan::Make(&exec_context)); SubstraitExecutor executor(std::move(plan), exec_context); RETURN_NOT_OK(executor.Init(substrait_buffer, extid_registry)); ARROW_ASSIGN_OR_RAISE(auto sink_reader, executor.Execute()); diff --git a/python/pyarrow/_compute.pxd b/python/pyarrow/_compute.pxd index d65eb2e000e..c3266c0cf9f 100644 --- a/python/pyarrow/_compute.pxd +++ b/python/pyarrow/_compute.pxd @@ -27,7 +27,7 @@ cdef class ScalarUdfContext(_Weakrefable): cdef void init(self, const CScalarUdfContext& c_context) -cdef class BaseFunctionRegistry(_Weakrefable): +cdef class FunctionRegistry(_Weakrefable): cdef CFunctionRegistry* registry cdef class FunctionOptions(_Weakrefable): diff --git a/python/pyarrow/_compute.pyx b/python/pyarrow/_compute.pyx index 162e0a9be1c..60431fa71c8 100644 --- a/python/pyarrow/_compute.pyx +++ b/python/pyarrow/_compute.pyx @@ -465,15 +465,13 @@ cdef _pack_compute_args(object values, vector[CDatum]* out): "for compute function") -cdef class FunctionRegistry(BaseFunctionRegistry): - cdef unique_ptr[CFunctionRegistry] up_registry - +cdef class FunctionRegistry(_Weakrefable): def __init__(self, registry=None): if registry is None: self.registry = GetFunctionRegistry() else: self.registry = pyarrow_unwrap_function_registry(registry) - self.up_registry.reset(self.registry) + print(f"self: {self} , self.registry: {self.registry != NULL}") def list_functions(self): """ @@ -510,6 +508,7 @@ def make_function_registry(): up_registry = MakeFunctionRegistry() c_registry = up_registry.get() up_registry.release() + print(f"up_registry: {c_registry != NULL}") return FunctionRegistry(pyarrow_wrap_function_registry(c_registry)) @@ -2527,7 +2526,11 @@ def register_scalar_function(func, function_name, function_doc, in_types, c_func_name = tobytes(function_name) - func_spec = inspect.getfullargspec(func) + try: + func_spec = inspect.getfullargspec(func) + is_varargs = func_spec.varargs is not None + except: + is_varargs = True num_args = -1 if isinstance(in_types, dict): for in_type in in_types.values(): @@ -2539,7 +2542,7 @@ def register_scalar_function(func, function_name, function_doc, in_types, raise TypeError( "in_types must be a dictionary of DataType") - c_arity = CArity(num_args, func_spec.varargs) + c_arity = CArity(num_args, is_varargs) if "summary" not in function_doc: raise ValueError("Function doc must contain a summary") @@ -2563,7 +2566,9 @@ def register_scalar_function(func, function_name, function_doc, in_types, if func_registry is None: c_func_registry = NULL else: - c_func_registry = pyarrow_unwrap_function_registry(func_registry) + print(f"func_registry: {func_registry}") + c_func_registry = (func_registry).registry + print(f"c_func_registry: {c_func_registry != NULL}") check_status(RegisterScalarFunction(c_function, &_scalar_udf_callback, diff --git a/python/pyarrow/_substrait.pyx b/python/pyarrow/_substrait.pyx index e458f70939e..9c812bfb3b1 100644 --- a/python/pyarrow/_substrait.pyx +++ b/python/pyarrow/_substrait.pyx @@ -28,6 +28,7 @@ from pyarrow.lib import frombytes, tobytes from pyarrow.lib cimport * from pyarrow.includes.libarrow cimport * from pyarrow.includes.libarrow_substrait cimport * +from pyarrow._compute cimport FunctionRegistry from pyarrow._exec_plan cimport is_supported_execplan_output_type, execplan @@ -115,14 +116,25 @@ def register_udf_declarations(plan, extid_registry, func_registry, udf_decls=Non udf_arg_names = list(inspect.signature(udf_func).parameters.keys()) udf_arg_types = udf_decl["input_types"] register_function(extid_registry, None, udf_name, udf_name) + def udf(ctx, *args): + try: + r = udf_func(*args) + with open("bblah", "w") as f: + f.write(str((ctx,args,r))) + return r + except: + import sys + with open("bblah", "w") as f: + f.write(str((ctx,args,sys.exc_info()))) + raise pc.register_scalar_function( - udf_func, + udf, udf_name, {"summary": udf_decl["summary"], "description": udf_decl["description"]}, # range start from 1 to skip over udf scalar context argument {udf_arg_names[i]: udf_arg_types[i][0] - for i in range(1 ,len(udf_arg_types))}, + for i in range(0 ,len(udf_arg_types))}, udf_decl["output_type"][0], func_registry, ) @@ -147,7 +159,12 @@ def _run_query(plan, extid_registry, func_registry, output_type): c_buf_plan = pyarrow_unwrap_buffer(plan) c_extid_registry = pyarrow_unwrap_extension_id_registry(extid_registry) + print("substrait _1") c_func_registry = pyarrow_unwrap_function_registry(func_registry) + if c_func_registry == NULL: + c_func_registry = (func_registry).registry + print(f"c_func_registry: {c_func_registry != NULL}") + print("substrait _2") with nogil: c_res_decls = DeserializePlans( deref(c_buf_plan), c_extid_registry.get()) @@ -179,7 +196,12 @@ def run_query(plan, extid_registry, func_registry): c_buf_plan = pyarrow_unwrap_buffer(plan) c_extid_registry = pyarrow_unwrap_extension_id_registry(extid_registry) + print("substrait 1") c_func_registry = pyarrow_unwrap_function_registry(func_registry) + if c_func_registry == NULL: + c_func_registry = (func_registry).registry + print(f"c_func_registry: {c_func_registry != NULL}") + print("substrait 2") with nogil: c_res_reader = ExecuteSerializedPlan( deref(c_buf_plan), c_extid_registry.get(), c_func_registry diff --git a/python/pyarrow/public-api.pxi b/python/pyarrow/public-api.pxi index 607cc475553..9de8d5d8b03 100644 --- a/python/pyarrow/public-api.pxi +++ b/python/pyarrow/public-api.pxi @@ -419,6 +419,7 @@ cdef api object pyarrow_wrap_batch( cdef api bint pyarrow_is_function_registry(object registry): + print(f"is_reg: {registry},{isinstance(registry, BaseFunctionRegistry)},{BaseFunctionRegistry}") return isinstance(registry, BaseFunctionRegistry) @@ -430,8 +431,10 @@ cdef api CFunctionRegistry* pyarrow_unwrap_function_registry(object registry): cdef BaseFunctionRegistry reg if pyarrow_is_function_registry(registry): reg = (registry) + print(f"reg.registry: {reg.registry != NULL}, registry: {registry}") return reg.registry + print(f"reg: False , registry: {registry}") return NULL @@ -447,6 +450,7 @@ cdef api shared_ptr[CExtensionIdRegistry] pyarrow_unwrap_extension_id_registry(o cdef api object pyarrow_wrap_function_registry(CFunctionRegistry* cregistry): cdef BaseFunctionRegistry registry = BaseFunctionRegistry.__new__(BaseFunctionRegistry) registry.registry = cregistry + print(f"registry.registry: {cregistry != NULL}") return registry diff --git a/python/pyarrow/tests/test_udf.py b/python/pyarrow/tests/test_udf.py index f358427324f..d1427ce6ffb 100644 --- a/python/pyarrow/tests/test_udf.py +++ b/python/pyarrow/tests/test_udf.py @@ -673,3 +673,31 @@ def test_elementwise_scalar_udf_in_substrait_query(tmpdir): assert res_tb.schema == pa.schema( [("key", pa.string()), ("value", pa.float64()), ("twice", pa.float64())]) assert res_tb.drop(["twice"]) == table + + +def _test_query(path): + with open(path) as f: + querystr = f.read() + query = tobytes(querystr) + + plan = substrait._parse_json_plan(query) + + extid_registry = substrait.make_extension_id_registry() + func_registry = substrait.make_function_registry() + substrait.register_udf_declarations(plan, extid_registry, func_registry) + + reader = substrait.run_query(plan, extid_registry, func_registry) + res_tb = reader.read_all() + + +def test_modelstate_udf_add(): + _test_query("/mnt/user1/tscontract/github/rtpsw/ibis-substrait/blah_udf_add") + _test_query("/mnt/user1/tscontract/github/rtpsw/ibis-substrait/blah_udf_add") + #import timeit + #timeit.timeit(lambda: _test_query("/mnt/user1/tscontract/github/rtpsw/ibis-substrait/blah_udf_add")) + + +def test_modelstate_reg_add(): + _test_query("/mnt/user1/tscontract/github/rtpsw/ibis-substrait/blah_reg_add") + #import timeit + #timeit.timeit(lambda: _test_query("/mnt/user1/tscontract/github/rtpsw/ibis-substrait/blah_reg_add")) From a15e0ca77e6ee2232b6d84c56ae361f657cc0535 Mon Sep 17 00:00:00 2001 From: Yaron Gvili Date: Fri, 10 Jun 2022 05:07:46 -0400 Subject: [PATCH 16/19] simple UDF benchmark --- cpp/src/arrow/compute/exec/hash_join_node.cc | 15 + python/pyarrow/tests/test_udf.py | 533 ++++++++++++++++++- 2 files changed, 536 insertions(+), 12 deletions(-) diff --git a/cpp/src/arrow/compute/exec/hash_join_node.cc b/cpp/src/arrow/compute/exec/hash_join_node.cc index 7f33df15b85..8ea6883b558 100644 --- a/cpp/src/arrow/compute/exec/hash_join_node.cc +++ b/cpp/src/arrow/compute/exec/hash_join_node.cc @@ -455,6 +455,20 @@ Status HashJoinSchema::CollectFilterColumns(std::vector& left_filter, return Status::OK(); } +Status ValidateHashJoinNodeOptions(const HashJoinNodeOptions& join_options) { + if (join_options.key_cmp.empty() || join_options.left_keys.empty() || + join_options.right_keys.empty()) { + return Status::Invalid("key_cmp and keys cannot be empty"); + } + + if ((join_options.key_cmp.size() != join_options.left_keys.size()) || + (join_options.key_cmp.size() != join_options.right_keys.size())) { + return Status::Invalid("key_cmp and keys must have the same size"); + } + + return Status::OK(); +} + class HashJoinNode : public ExecNode { public: HashJoinNode(ExecPlan* plan, NodeVector inputs, const HashJoinNodeOptions& join_options, @@ -482,6 +496,7 @@ class HashJoinNode : public ExecNode { ::arrow::internal::make_unique(); const auto& join_options = checked_cast(options); + RETURN_NOT_OK(ValidateHashJoinNodeOptions(join_options)); const auto& left_schema = *(inputs[0]->output_schema()); const auto& right_schema = *(inputs[1]->output_schema()); diff --git a/python/pyarrow/tests/test_udf.py b/python/pyarrow/tests/test_udf.py index d1427ce6ffb..482386d9029 100644 --- a/python/pyarrow/tests/test_udf.py +++ b/python/pyarrow/tests/test_udf.py @@ -17,6 +17,7 @@ import os +import json import pytest import pyarrow as pa @@ -675,11 +676,7 @@ def test_elementwise_scalar_udf_in_substrait_query(tmpdir): assert res_tb.drop(["twice"]) == table -def _test_query(path): - with open(path) as f: - querystr = f.read() - query = tobytes(querystr) - +def _test_query(query): plan = substrait._parse_json_plan(query) extid_registry = substrait.make_extension_id_registry() @@ -690,14 +687,526 @@ def _test_query(path): res_tb = reader.read_all() +def _test_query_string(querystr): + query = tobytes(querystr) + return _test_query(query) + + +def _test_query_path(path): + with open(path) as f: + querystr = f.read() + return _test_query_string(querystr) + + def test_modelstate_udf_add(): - _test_query("/mnt/user1/tscontract/github/rtpsw/ibis-substrait/blah_udf_add") - _test_query("/mnt/user1/tscontract/github/rtpsw/ibis-substrait/blah_udf_add") - #import timeit - #timeit.timeit(lambda: _test_query("/mnt/user1/tscontract/github/rtpsw/ibis-substrait/blah_udf_add")) + code = ( + "gAWVGgMAAAAAAACMF2Nsb3VkcGlja2xlLmNsb3VkcGlja2xllIwNX2J1aWx0aW5fdHlwZZSTlIwKTGFt" + + "YmRhVHlwZZSFlFKUKGgCjAhDb2RlVHlwZZSFlFKUKEsBSwBLAEsCSwRLQ0MYZAFkAmwAbQF9AQEAfAGg" + + "AnwAZAOhAlMAlCiMJENvbXB1dGUgdHdpY2UgdGhlIHZhbHVlIG9mIHRoZSBpbnB1dJRLAE5LAnSUjA9w" + + "eWFycm93LmNvbXB1dGWUjAdjb21wdXRllIwIbXVsdGlwbHmUh5SMAXaUjAJwY5SGlIxgL21udC91c2Vy" + + "MS90c2NvbnRyYWN0L2dpdGh1Yi9ydHBzdy9pYmlzLXN1YnN0cmFpdC9pYmlzX3N1YnN0cmFpdC90ZXN0" + + "cy9jb21waWxlci90ZXN0X2NvbXBpbGVyLnB5lIwFdHdpY2WUTUYBQwQAAwwBlCkpdJRSlH2UKIwLX19w" + + "YWNrYWdlX1+UjB1pYmlzX3N1YnN0cmFpdC50ZXN0cy5jb21waWxlcpSMCF9fbmFtZV9flIwraWJpc19z" + + "dWJzdHJhaXQudGVzdHMuY29tcGlsZXIudGVzdF9jb21waWxlcpSMCF9fZmlsZV9flIxgL21udC91c2Vy" + + "MS90c2NvbnRyYWN0L2dpdGh1Yi9ydHBzdy9pYmlzLXN1YnN0cmFpdC9pYmlzX3N1YnN0cmFpdC90ZXN0" + + "cy9jb21waWxlci90ZXN0X2NvbXBpbGVyLnB5lHVOTk50lFKUjBxjbG91ZHBpY2tsZS5jbG91ZHBpY2ts" + + "ZV9mYXN0lIwSX2Z1bmN0aW9uX3NldHN0YXRllJOUaCB9lH2UKGgbaBSMDF9fcXVhbG5hbWVfX5RoFIwP" + + "X19hbm5vdGF0aW9uc19flH2UjA5fX2t3ZGVmYXVsdHNfX5ROjAxfX2RlZmF1bHRzX1+UTowKX19tb2R1" + + "bGVfX5RoHIwHX19kb2NfX5RoCowLX19jbG9zdXJlX1+UTowXX2Nsb3VkcGlja2xlX3N1Ym1vZHVsZXOU" + + "XZSMC19fZ2xvYmFsc19flH2UdYaUhlIwLg==" + ) + querystr = json.dumps( + { + "extensionUris": [ + { + "extensionUriAnchor": 1, + "uri": "https://github.com/apache/arrow/blob/master/format/substrait/extension_types.yaml" + } + ], + "extensions": [ + { + "extensionFunction": { + "extensionUriReference": 1, + "functionAnchor": 1, + "name": "twice", + "udf": { + "code": code, + "summary": "twice", + "description": "Compute twice the value of the input", + "inputTypes": [ + { + "fp64": { + "nullability": "NULLABILITY_NULLABLE" + } + } + ], + "outputType": { + "fp64": { + "nullability": "NULLABILITY_NULLABLE" + } + } + } + } + } + ], + "relations": [ + { + "root": { + "input": { + "project": { + "input": { + "read": { + "baseSchema": { + "names": [ + "time", + "id", + "return_prev_1d_lag", + "return_next_1d_lead", + "variance", + "volume", + "market_cap", + "factor_id", + "price" + ], + "struct": { + "types": [ + { + "timestamp": { + "nullability": "NULLABILITY_NULLABLE" + } + }, + { + "i32": { + "nullability": "NULLABILITY_NULLABLE" + } + }, + { + "fp64": { + "nullability": "NULLABILITY_NULLABLE" + } + }, + { + "fp64": { + "nullability": "NULLABILITY_NULLABLE" + } + }, + { + "fp64": { + "nullability": "NULLABILITY_NULLABLE" + } + }, + { + "fp64": { + "nullability": "NULLABILITY_NULLABLE" + } + }, + { + "fp64": { + "nullability": "NULLABILITY_NULLABLE" + } + }, + { + "i32": { + "nullability": "NULLABILITY_NULLABLE" + } + }, + { + "fp64": { + "nullability": "NULLABILITY_NULLABLE" + } + } + ], + "nullability": "NULLABILITY_REQUIRED" + } + }, + "localFiles": { + "items": [ + { + "uriFile": "file:///mnt/user1/tscontract/github/rtpsw/bamboo-streaming/data/modelstate2.feather" + } + ] + } + } + }, + "expressions": [ + { + "selection": { + "directReference": { + "structField": {} + }, + "rootReference": {} + } + }, + { + "selection": { + "directReference": { + "structField": { + "field": 1 + } + }, + "rootReference": {} + } + }, + { + "selection": { + "directReference": { + "structField": { + "field": 2 + } + }, + "rootReference": {} + } + }, + { + "selection": { + "directReference": { + "structField": { + "field": 3 + } + }, + "rootReference": {} + } + }, + { + "selection": { + "directReference": { + "structField": { + "field": 4 + } + }, + "rootReference": {} + } + }, + { + "selection": { + "directReference": { + "structField": { + "field": 5 + } + }, + "rootReference": {} + } + }, + { + "selection": { + "directReference": { + "structField": { + "field": 6 + } + }, + "rootReference": {} + } + }, + { + "selection": { + "directReference": { + "structField": { + "field": 7 + } + }, + "rootReference": {} + } + }, + { + "selection": { + "directReference": { + "structField": { + "field": 8 + } + }, + "rootReference": {} + } + }, + { + "scalarFunction": { + "functionReference": 1, + "args": [ + { + "selection": { + "directReference": { + "structField": { + "field": 5 + } + }, + "rootReference": {} + } + } + ], + "outputType": { + "fp64": { + "nullability": "NULLABILITY_NULLABLE" + } + } + } + } + ] + } + }, + "names": [ + "time", + "id", + "return_prev_1d_lag", + "return_next_1d_lead", + "variance", + "volume", + "market_cap", + "factor_id", + "price", + "twice_volume" + ] + } + } + ] + } + ) + import timeit + n = 5 + secs = timeit.timeit( + lambda: _test_query_string(querystr), + number=n + ) + with open("test_modelstate_udf_add.timeit", "w") as f: + f.write(f"seconds: {secs/n}\n") def test_modelstate_reg_add(): - _test_query("/mnt/user1/tscontract/github/rtpsw/ibis-substrait/blah_reg_add") - #import timeit - #timeit.timeit(lambda: _test_query("/mnt/user1/tscontract/github/rtpsw/ibis-substrait/blah_reg_add")) + querystr = json.dumps( + { + "extensionUris": [ + { + "extensionUriAnchor": 1, + "uri": "https://github.com/apache/arrow/blob/master/format/substrait/extension_types.yaml" + } + ], + "extensions": [ + { + "extensionFunction": { + "extensionUriReference": 1, + "functionAnchor": 1, + "name": "*" + } + } + ], + "relations": [ + { + "root": { + "input": { + "project": { + "input": { + "read": { + "baseSchema": { + "names": [ + "time", + "id", + "return_prev_1d_lag", + "return_next_1d_lead", + "variance", + "volume", + "market_cap", + "factor_id", + "price" + ], + "struct": { + "types": [ + { + "timestamp": { + "nullability": "NULLABILITY_NULLABLE" + } + }, + { + "i32": { + "nullability": "NULLABILITY_NULLABLE" + } + }, + { + "fp64": { + "nullability": "NULLABILITY_NULLABLE" + } + }, + { + "fp64": { + "nullability": "NULLABILITY_NULLABLE" + } + }, + { + "fp64": { + "nullability": "NULLABILITY_NULLABLE" + } + }, + { + "fp64": { + "nullability": "NULLABILITY_NULLABLE" + } + }, + { + "fp64": { + "nullability": "NULLABILITY_NULLABLE" + } + }, + { + "i32": { + "nullability": "NULLABILITY_NULLABLE" + } + }, + { + "fp64": { + "nullability": "NULLABILITY_NULLABLE" + } + } + ], + "nullability": "NULLABILITY_REQUIRED" + } + }, + "localFiles": { + "items": [ + { + "uriFile": "file:///mnt/user1/tscontract/github/rtpsw/bamboo-streaming/data/modelstate2.feather" + } + ] + } + } + }, + "expressions": [ + { + "selection": { + "directReference": { + "structField": {} + }, + "rootReference": {} + } + }, + { + "selection": { + "directReference": { + "structField": { + "field": 1 + } + }, + "rootReference": {} + } + }, + { + "selection": { + "directReference": { + "structField": { + "field": 2 + } + }, + "rootReference": {} + } + }, + { + "selection": { + "directReference": { + "structField": { + "field": 3 + } + }, + "rootReference": {} + } + }, + { + "selection": { + "directReference": { + "structField": { + "field": 4 + } + }, + "rootReference": {} + } + }, + { + "selection": { + "directReference": { + "structField": { + "field": 5 + } + }, + "rootReference": {} + } + }, + { + "selection": { + "directReference": { + "structField": { + "field": 6 + } + }, + "rootReference": {} + } + }, + { + "selection": { + "directReference": { + "structField": { + "field": 7 + } + }, + "rootReference": {} + } + }, + { + "selection": { + "directReference": { + "structField": { + "field": 8 + } + }, + "rootReference": {} + } + }, + { + "scalarFunction": { + "functionReference": 1, + "args": [ + { + "selection": { + "directReference": { + "structField": { + "field": 5 + } + }, + "rootReference": {} + } + }, + { + "literal": { + "i8": 2 + } + } + ], + "outputType": { + "fp64": { + "nullability": "NULLABILITY_NULLABLE" + } + } + } + } + ] + } + }, + "names": [ + "time", + "id", + "return_prev_1d_lag", + "return_next_1d_lead", + "variance", + "volume", + "market_cap", + "factor_id", + "price", + "twice_volume" + ] + } + } + ] + } + ) + import timeit + n = 5 + secs = timeit.timeit( + lambda: _test_query_string(querystr), + number=n + ) + with open("test_modelstate_reg_add.timeit", "w") as f: + f.write(f"seconds: {secs/n}\n") From 85bbaf4fd688ec81e406131bc0e037d0ca7a7904 Mon Sep 17 00:00:00 2001 From: Yaron Gvili Date: Sun, 12 Jun 2022 15:24:07 -0400 Subject: [PATCH 17/19] improved UDF PoC benchmark --- cpp/src/arrow/engine/substrait/util.cc | 3 +- python/pyarrow/_compute.pyx | 4 - python/pyarrow/_substrait.pyx | 6 - python/pyarrow/public-api.pxi | 4 - python/pyarrow/tests/test_udf.py | 365 ++++--------------------- 5 files changed, 51 insertions(+), 331 deletions(-) diff --git a/cpp/src/arrow/engine/substrait/util.cc b/cpp/src/arrow/engine/substrait/util.cc index c6d45d50c5a..05d4e3478df 100644 --- a/cpp/src/arrow/engine/substrait/util.cc +++ b/cpp/src/arrow/engine/substrait/util.cc @@ -68,7 +68,7 @@ class SubstraitExecutor { compute::ExecContext exec_context) : plan_(std::move(plan)), exec_context_(exec_context) {} - ~SubstraitExecutor() { ARROW_CHECK_OK(this->Close()); } + ~SubstraitExecutor() { this->Close(); } Result> Execute() { for (const compute::Declaration& decl : declarations_) { @@ -120,6 +120,7 @@ Result> ExecuteSerializedPlan( SubstraitExecutor executor(std::move(plan), exec_context); RETURN_NOT_OK(executor.Init(substrait_buffer, extid_registry)); ARROW_ASSIGN_OR_RAISE(auto sink_reader, executor.Execute()); + RETURN_NOT_OK(executor.Close()); return sink_reader; } diff --git a/python/pyarrow/_compute.pyx b/python/pyarrow/_compute.pyx index 60431fa71c8..3d5bdae315e 100644 --- a/python/pyarrow/_compute.pyx +++ b/python/pyarrow/_compute.pyx @@ -471,7 +471,6 @@ cdef class FunctionRegistry(_Weakrefable): self.registry = GetFunctionRegistry() else: self.registry = pyarrow_unwrap_function_registry(registry) - print(f"self: {self} , self.registry: {self.registry != NULL}") def list_functions(self): """ @@ -508,7 +507,6 @@ def make_function_registry(): up_registry = MakeFunctionRegistry() c_registry = up_registry.get() up_registry.release() - print(f"up_registry: {c_registry != NULL}") return FunctionRegistry(pyarrow_wrap_function_registry(c_registry)) @@ -2566,9 +2564,7 @@ def register_scalar_function(func, function_name, function_doc, in_types, if func_registry is None: c_func_registry = NULL else: - print(f"func_registry: {func_registry}") c_func_registry = (func_registry).registry - print(f"c_func_registry: {c_func_registry != NULL}") check_status(RegisterScalarFunction(c_function, &_scalar_udf_callback, diff --git a/python/pyarrow/_substrait.pyx b/python/pyarrow/_substrait.pyx index 9c812bfb3b1..5e4f065ad47 100644 --- a/python/pyarrow/_substrait.pyx +++ b/python/pyarrow/_substrait.pyx @@ -159,12 +159,9 @@ def _run_query(plan, extid_registry, func_registry, output_type): c_buf_plan = pyarrow_unwrap_buffer(plan) c_extid_registry = pyarrow_unwrap_extension_id_registry(extid_registry) - print("substrait _1") c_func_registry = pyarrow_unwrap_function_registry(func_registry) if c_func_registry == NULL: c_func_registry = (func_registry).registry - print(f"c_func_registry: {c_func_registry != NULL}") - print("substrait _2") with nogil: c_res_decls = DeserializePlans( deref(c_buf_plan), c_extid_registry.get()) @@ -196,12 +193,9 @@ def run_query(plan, extid_registry, func_registry): c_buf_plan = pyarrow_unwrap_buffer(plan) c_extid_registry = pyarrow_unwrap_extension_id_registry(extid_registry) - print("substrait 1") c_func_registry = pyarrow_unwrap_function_registry(func_registry) if c_func_registry == NULL: c_func_registry = (func_registry).registry - print(f"c_func_registry: {c_func_registry != NULL}") - print("substrait 2") with nogil: c_res_reader = ExecuteSerializedPlan( deref(c_buf_plan), c_extid_registry.get(), c_func_registry diff --git a/python/pyarrow/public-api.pxi b/python/pyarrow/public-api.pxi index 9de8d5d8b03..607cc475553 100644 --- a/python/pyarrow/public-api.pxi +++ b/python/pyarrow/public-api.pxi @@ -419,7 +419,6 @@ cdef api object pyarrow_wrap_batch( cdef api bint pyarrow_is_function_registry(object registry): - print(f"is_reg: {registry},{isinstance(registry, BaseFunctionRegistry)},{BaseFunctionRegistry}") return isinstance(registry, BaseFunctionRegistry) @@ -431,10 +430,8 @@ cdef api CFunctionRegistry* pyarrow_unwrap_function_registry(object registry): cdef BaseFunctionRegistry reg if pyarrow_is_function_registry(registry): reg = (registry) - print(f"reg.registry: {reg.registry != NULL}, registry: {registry}") return reg.registry - print(f"reg: False , registry: {registry}") return NULL @@ -450,7 +447,6 @@ cdef api shared_ptr[CExtensionIdRegistry] pyarrow_unwrap_extension_id_registry(o cdef api object pyarrow_wrap_function_registry(CFunctionRegistry* cregistry): cdef BaseFunctionRegistry registry = BaseFunctionRegistry.__new__(BaseFunctionRegistry) registry.registry = cregistry - print(f"registry.registry: {cregistry != NULL}") return registry diff --git a/python/pyarrow/tests/test_udf.py b/python/pyarrow/tests/test_udf.py index 482386d9029..4176f8a17cd 100644 --- a/python/pyarrow/tests/test_udf.py +++ b/python/pyarrow/tests/test_udf.py @@ -20,6 +20,7 @@ import json import pytest +import numpy as np import pyarrow as pa from pyarrow import compute as pc from pyarrow.lib import frombytes, tobytes, DoubleArray @@ -698,7 +699,7 @@ def _test_query_path(path): return _test_query_string(querystr) -def test_modelstate_udf_add(): +def _simple_udf_add_query_string(input_path): code = ( "gAWVGgMAAAAAAACMF2Nsb3VkcGlja2xlLmNsb3VkcGlja2xllIwNX2J1aWx0aW5fdHlwZZSTlIwKTGFt" + "YmRhVHlwZZSFlFKUKGgCjAhDb2RlVHlwZZSFlFKUKEsBSwBLAEsCSwRLQ0MYZAFkAmwAbQF9AQEAfAGg" + @@ -758,63 +759,15 @@ def test_modelstate_udf_add(): "read": { "baseSchema": { "names": [ - "time", - "id", - "return_prev_1d_lag", - "return_next_1d_lead", - "variance", - "volume", - "market_cap", - "factor_id", - "price" + "v", ], "struct": { "types": [ - { - "timestamp": { - "nullability": "NULLABILITY_NULLABLE" - } - }, - { - "i32": { - "nullability": "NULLABILITY_NULLABLE" - } - }, - { - "fp64": { - "nullability": "NULLABILITY_NULLABLE" - } - }, - { - "fp64": { - "nullability": "NULLABILITY_NULLABLE" - } - }, - { - "fp64": { - "nullability": "NULLABILITY_NULLABLE" - } - }, { "fp64": { "nullability": "NULLABILITY_NULLABLE" } }, - { - "fp64": { - "nullability": "NULLABILITY_NULLABLE" - } - }, - { - "i32": { - "nullability": "NULLABILITY_NULLABLE" - } - }, - { - "fp64": { - "nullability": "NULLABILITY_NULLABLE" - } - } ], "nullability": "NULLABILITY_REQUIRED" } @@ -822,101 +775,13 @@ def test_modelstate_udf_add(): "localFiles": { "items": [ { - "uriFile": "file:///mnt/user1/tscontract/github/rtpsw/bamboo-streaming/data/modelstate2.feather" + "uriFile": "file://" + input_path } ] } } }, "expressions": [ - { - "selection": { - "directReference": { - "structField": {} - }, - "rootReference": {} - } - }, - { - "selection": { - "directReference": { - "structField": { - "field": 1 - } - }, - "rootReference": {} - } - }, - { - "selection": { - "directReference": { - "structField": { - "field": 2 - } - }, - "rootReference": {} - } - }, - { - "selection": { - "directReference": { - "structField": { - "field": 3 - } - }, - "rootReference": {} - } - }, - { - "selection": { - "directReference": { - "structField": { - "field": 4 - } - }, - "rootReference": {} - } - }, - { - "selection": { - "directReference": { - "structField": { - "field": 5 - } - }, - "rootReference": {} - } - }, - { - "selection": { - "directReference": { - "structField": { - "field": 6 - } - }, - "rootReference": {} - } - }, - { - "selection": { - "directReference": { - "structField": { - "field": 7 - } - }, - "rootReference": {} - } - }, - { - "selection": { - "directReference": { - "structField": { - "field": 8 - } - }, - "rootReference": {} - } - }, { "scalarFunction": { "functionReference": 1, @@ -924,9 +789,7 @@ def test_modelstate_udf_add(): { "selection": { "directReference": { - "structField": { - "field": 5 - } + "structField": {} }, "rootReference": {} } @@ -943,33 +806,17 @@ def test_modelstate_udf_add(): } }, "names": [ - "time", - "id", - "return_prev_1d_lag", - "return_next_1d_lead", - "variance", - "volume", - "market_cap", - "factor_id", - "price", - "twice_volume" + "twice_v" ] } } ] } ) - import timeit - n = 5 - secs = timeit.timeit( - lambda: _test_query_string(querystr), - number=n - ) - with open("test_modelstate_udf_add.timeit", "w") as f: - f.write(f"seconds: {secs/n}\n") + return querystr -def test_modelstate_reg_add(): +def _simple_reg_add_query_string(input_path): querystr = json.dumps( { "extensionUris": [ @@ -996,63 +843,15 @@ def test_modelstate_reg_add(): "read": { "baseSchema": { "names": [ - "time", - "id", - "return_prev_1d_lag", - "return_next_1d_lead", - "variance", - "volume", - "market_cap", - "factor_id", - "price" + "v", ], "struct": { "types": [ - { - "timestamp": { - "nullability": "NULLABILITY_NULLABLE" - } - }, - { - "i32": { - "nullability": "NULLABILITY_NULLABLE" - } - }, - { - "fp64": { - "nullability": "NULLABILITY_NULLABLE" - } - }, { "fp64": { "nullability": "NULLABILITY_NULLABLE" } }, - { - "fp64": { - "nullability": "NULLABILITY_NULLABLE" - } - }, - { - "fp64": { - "nullability": "NULLABILITY_NULLABLE" - } - }, - { - "fp64": { - "nullability": "NULLABILITY_NULLABLE" - } - }, - { - "i32": { - "nullability": "NULLABILITY_NULLABLE" - } - }, - { - "fp64": { - "nullability": "NULLABILITY_NULLABLE" - } - } ], "nullability": "NULLABILITY_REQUIRED" } @@ -1060,101 +859,13 @@ def test_modelstate_reg_add(): "localFiles": { "items": [ { - "uriFile": "file:///mnt/user1/tscontract/github/rtpsw/bamboo-streaming/data/modelstate2.feather" + "uriFile": "file://" + input_path } ] } } }, "expressions": [ - { - "selection": { - "directReference": { - "structField": {} - }, - "rootReference": {} - } - }, - { - "selection": { - "directReference": { - "structField": { - "field": 1 - } - }, - "rootReference": {} - } - }, - { - "selection": { - "directReference": { - "structField": { - "field": 2 - } - }, - "rootReference": {} - } - }, - { - "selection": { - "directReference": { - "structField": { - "field": 3 - } - }, - "rootReference": {} - } - }, - { - "selection": { - "directReference": { - "structField": { - "field": 4 - } - }, - "rootReference": {} - } - }, - { - "selection": { - "directReference": { - "structField": { - "field": 5 - } - }, - "rootReference": {} - } - }, - { - "selection": { - "directReference": { - "structField": { - "field": 6 - } - }, - "rootReference": {} - } - }, - { - "selection": { - "directReference": { - "structField": { - "field": 7 - } - }, - "rootReference": {} - } - }, - { - "selection": { - "directReference": { - "structField": { - "field": 8 - } - }, - "rootReference": {} - } - }, { "scalarFunction": { "functionReference": 1, @@ -1162,9 +873,7 @@ def test_modelstate_reg_add(): { "selection": { "directReference": { - "structField": { - "field": 5 - } + "structField": {} }, "rootReference": {} } @@ -1186,27 +895,51 @@ def test_modelstate_reg_add(): } }, "names": [ - "time", - "id", - "return_prev_1d_lag", - "return_next_1d_lead", - "variance", - "volume", - "market_cap", - "factor_id", - "price", - "twice_volume" + "twice_v" ] } } ] } ) + return querystr + + +def _make_ipc_data(data_path, batch_size, num_batches): + sink = open(data_path, "wb") + schema = pa.schema([("v", pa.float64())]) + writer = pa.ipc.new_file(sink, schema) + for i in range(num_batches): + batch = pa.record_batch([np.random.randn(batch_size)], schema=schema) + writer.write_batch(batch) + writer.close() + sink.close() + + +def _timeit_query(querystr, n): import timeit - n = 5 - secs = timeit.timeit( + total_secs = timeit.timeit( lambda: _test_query_string(querystr), number=n ) - with open("test_modelstate_reg_add.timeit", "w") as f: - f.write(f"seconds: {secs/n}\n") + return total_secs / n + + +def _simple_query(tmpdir, name): + data_path = os.path.realpath(os.path.join(tmpdir, f"{name}.feather")) + output_path = f"{name}.timeit" + querystr = _simple_udf_add_query_string(data_path) + batch_size = 1024 + n = 5 + with open(output_path, "w") as f: + for num_batches in [1, 10, 100, 1000, 10000]: + _make_ipc_data(data_path, batch_size, num_batches) + secs = _timeit_query(querystr, n) + f.write(f"batches={num_batches} size={batch_size} average-of-{n}:seconds={secs/n}\n") + + +def test_simple_udf_add(tmpdir): + _simple_query(tmpdir, "test_simple_udf_add") + +def test_simple_reg_add(tmpdir): + _simple_query(tmpdir, "test_simple_reg_add") From 394676aea22a3ea3cef6e605f0faaa6a9a422ed6 Mon Sep 17 00:00:00 2001 From: Yaron Gvili Date: Tue, 14 Jun 2022 05:55:18 -0400 Subject: [PATCH 18/19] add substrait tests --- .../arrow/engine/substrait/plan_internal.cc | 2 +- .../engine/substrait/relation_internal.cc | 10 +- cpp/src/arrow/engine/substrait/serde_test.cc | 274 ++++++++++++------ 3 files changed, 185 insertions(+), 101 deletions(-) diff --git a/cpp/src/arrow/engine/substrait/plan_internal.cc b/cpp/src/arrow/engine/substrait/plan_internal.cc index dc5ca7eb198..1498a80f3f9 100644 --- a/cpp/src/arrow/engine/substrait/plan_internal.cc +++ b/cpp/src/arrow/engine/substrait/plan_internal.cc @@ -123,7 +123,7 @@ Result GetExtensionSetFromPlan(const substrait::Plan& plan, case substrait::extensions::SimpleExtensionDeclaration::kExtensionFunction: { if (exclude_functions) { - break; + break; } const auto& fn = ext.extension_function(); util::string_view uri = uris[fn.extension_uri_reference()]; diff --git a/cpp/src/arrow/engine/substrait/relation_internal.cc b/cpp/src/arrow/engine/substrait/relation_internal.cc index 69c99ae8721..3bd6bfa135d 100644 --- a/cpp/src/arrow/engine/substrait/relation_internal.cc +++ b/cpp/src/arrow/engine/substrait/relation_internal.cc @@ -96,7 +96,7 @@ Result FromProto(const substrait::Expression& expr, const std::string& } default: { return Status::NotImplemented(std::string( - "substrait::AsOfMergeRel with non-selection for ") + what); + "substrait::Expression with non-selection for ") + what); } } return FieldRef(FieldPath({index})); @@ -379,9 +379,7 @@ Result FromProtoInternal( return compute::Declaration::Sequence({ std::move(input), - {"project", - compute::ProjectNodeOptions{std::move(expressions)} - }, + {"project", compute::ProjectNodeOptions{std::move(expressions)}}, }); } @@ -526,9 +524,7 @@ Result FromProto(const substrait::Rel& rel, } return compute::Declaration::Sequence({ std::move(input), - {"project", - compute::ProjectNodeOptions{std::move(expressions), std::move(names)} - }, + {"project", compute::ProjectNodeOptions{std::move(expressions), std::move(names)}}, }); } diff --git a/cpp/src/arrow/engine/substrait/serde_test.cc b/cpp/src/arrow/engine/substrait/serde_test.cc index 4f8379d2911..bb3af56ec8e 100644 --- a/cpp/src/arrow/engine/substrait/serde_test.cc +++ b/cpp/src/arrow/engine/substrait/serde_test.cc @@ -663,7 +663,7 @@ TEST(Substrait, ReadRel) { } TEST(Substrait, ExtensionSetFromPlan) { - ASSERT_OK_AND_ASSIGN(auto buf, internal::SubstraitFromJSON("Plan", R"({ + std::string substrait_json = R"({ "relations": [ {"rel": { "read": { @@ -680,7 +680,7 @@ TEST(Substrait, ExtensionSetFromPlan) { "extension_uris": [ { "extension_uri_anchor": 7, - "uri": "https://github.com/apache/arrow/blob/master/format/substrait/extension_types.yaml" + "uri": ")" + substrait::default_extension_types_uri() + R"(" } ], "extensions": [ @@ -695,32 +695,37 @@ TEST(Substrait, ExtensionSetFromPlan) { "name": "add" }} ] - })")); - ExtensionSet ext_set; - ASSERT_OK_AND_ASSIGN( - auto sink_decls, - DeserializePlans( - *buf, [] { return std::shared_ptr{nullptr}; }, - NULLPTR, &ext_set)); - - EXPECT_OK_AND_ASSIGN(auto decoded_null_type, ext_set.DecodeType(42)); - EXPECT_EQ(decoded_null_type.id.uri, kArrowExtTypesUri); - EXPECT_EQ(decoded_null_type.id.name, "null"); - EXPECT_EQ(*decoded_null_type.type, NullType()); - - EXPECT_OK_AND_ASSIGN(auto decoded_add_func, ext_set.DecodeFunction(42)); - EXPECT_EQ(decoded_add_func.id.uri, kArrowExtTypesUri); - EXPECT_EQ(decoded_add_func.id.name, "add"); - EXPECT_EQ(decoded_add_func.name, "add"); + })"; + ASSERT_OK_AND_ASSIGN(auto buf, internal::SubstraitFromJSON("Plan", substrait_json)); + for (auto sp_ext_id_reg : { + std::shared_ptr(), substrait::MakeExtensionIdRegistry()}) { + ExtensionIdRegistry* ext_id_reg = sp_ext_id_reg.get(); + ExtensionSet ext_set(ext_id_reg); + ASSERT_OK_AND_ASSIGN( + auto sink_decls, + DeserializePlans( + *buf, [] { return std::shared_ptr{nullptr}; }, + ext_id_reg, &ext_set)); + + EXPECT_OK_AND_ASSIGN(auto decoded_null_type, ext_set.DecodeType(42)); + EXPECT_EQ(decoded_null_type.id.uri, kArrowExtTypesUri); + EXPECT_EQ(decoded_null_type.id.name, "null"); + EXPECT_EQ(*decoded_null_type.type, NullType()); + + EXPECT_OK_AND_ASSIGN(auto decoded_add_func, ext_set.DecodeFunction(42)); + EXPECT_EQ(decoded_add_func.id.uri, kArrowExtTypesUri); + EXPECT_EQ(decoded_add_func.id.name, "add"); + EXPECT_EQ(decoded_add_func.name, "add"); + } } TEST(Substrait, ExtensionSetFromPlanMissingFunc) { - ASSERT_OK_AND_ASSIGN(auto buf, internal::SubstraitFromJSON("Plan", R"({ + std::string substrait_json = R"({ "relations": [], "extension_uris": [ { "extension_uri_anchor": 7, - "uri": "https://github.com/apache/arrow/blob/master/format/substrait/extension_types.yaml" + "uri": ")" + substrait::default_extension_types_uri() + R"(" } ], "extensions": [ @@ -730,14 +735,62 @@ TEST(Substrait, ExtensionSetFromPlanMissingFunc) { "name": "does_not_exist" }} ] - })")); + })"; + ASSERT_OK_AND_ASSIGN(auto buf, internal::SubstraitFromJSON("Plan", substrait_json)); + + for (auto sp_ext_id_reg : { + std::shared_ptr(), substrait::MakeExtensionIdRegistry()}) { + ExtensionIdRegistry* ext_id_reg = sp_ext_id_reg.get(); + ExtensionSet ext_set(ext_id_reg); + ASSERT_RAISES( + Invalid, + DeserializePlans( + *buf, [] { return std::shared_ptr{nullptr}; }, + ext_id_reg, &ext_set)); + } +} - ExtensionSet ext_set; +TEST(Substrait, ExtensionSetFromPlanRegisterFunc) { + std::string substrait_json = R"({ + "relations": [], + "extension_uris": [ + { + "extension_uri_anchor": 7, + "uri": ")" + substrait::default_extension_types_uri() + R"(" + } + ], + "extensions": [ + {"extension_function": { + "extension_uri_reference": 7, + "function_anchor": 42, + "name": "new_func" + }} + ] + })"; + ASSERT_OK_AND_ASSIGN(auto buf, internal::SubstraitFromJSON("Plan", substrait_json)); + + auto sp_ext_id_reg = substrait::MakeExtensionIdRegistry(); + ExtensionIdRegistry* ext_id_reg = sp_ext_id_reg.get(); + // invalid before registration + ExtensionSet ext_set_invalid(ext_id_reg); ASSERT_RAISES( Invalid, DeserializePlans( *buf, [] { return std::shared_ptr{nullptr}; }, - NULLPTR, &ext_set)); + ext_id_reg, &ext_set_invalid)); + substrait::RegisterFunction(*ext_id_reg, substrait::default_extension_types_uri(), + "new_func", "multiply"); + // valid after registration + ExtensionSet ext_set_valid(ext_id_reg); + ASSERT_OK_AND_ASSIGN( + auto sink_decls, + DeserializePlans( + *buf, [] { return std::shared_ptr{nullptr}; }, + ext_id_reg, &ext_set_valid)); + EXPECT_OK_AND_ASSIGN(auto decoded_add_func, ext_set_valid.DecodeFunction(42)); + EXPECT_EQ(decoded_add_func.id.uri, kArrowExtTypesUri); + EXPECT_EQ(decoded_add_func.id.name, "new_func"); + EXPECT_EQ(decoded_add_func.name, "multiply"); } Result GetSubstraitJSON() { @@ -752,7 +805,7 @@ Result GetSubstraitJSON() { "read": { "base_schema": { "struct": { - "types": [ + "types": [ {"binary": {}} ] }, @@ -778,17 +831,31 @@ Result GetSubstraitJSON() { return substrait_json; } +static void test_with_registries(std::function test) { + auto default_func_reg = compute::GetFunctionRegistry(); + auto nested_ext_id_reg = substrait::MakeExtensionIdRegistry(); + auto nested_func_reg = compute::FunctionRegistry::Make(default_func_reg); + test(NULLPTR, default_func_reg); + test(NULLPTR, nested_func_reg.get()); + test(nested_ext_id_reg.get(), default_func_reg); + test(nested_ext_id_reg.get(), nested_func_reg.get()); +} + TEST(Substrait, GetRecordBatchReader) { #ifdef _WIN32 GTEST_SKIP() << "ARROW-16392: Substrait File URI not supported for Windows"; #else ASSERT_OK_AND_ASSIGN(std::string substrait_json, GetSubstraitJSON()); - ASSERT_OK_AND_ASSIGN(auto buf, substrait::SerializeJsonPlan(substrait_json)); - ASSERT_OK_AND_ASSIGN(auto reader, substrait::ExecuteSerializedPlan(*buf)); - ASSERT_OK_AND_ASSIGN(auto table, Table::FromRecordBatchReader(reader.get())); - // Note: assuming the binary.parquet file contains fixed amount of records - // in case of a test failure, re-evalaute the content in the file - EXPECT_EQ(table->num_rows(), 12); + test_with_registries([&substrait_json](ExtensionIdRegistry* ext_id_reg, + compute::FunctionRegistry* func_registry) { + ASSERT_OK_AND_ASSIGN(auto buf, substrait::SerializeJsonPlan(substrait_json)); + ASSERT_OK_AND_ASSIGN(auto reader, substrait::ExecuteSerializedPlan(*buf)); + ASSERT_OK_AND_ASSIGN(auto table, Table::FromRecordBatchReader(reader.get())); + // Note: assuming the binary.parquet file contains fixed amount of records + // in case of a test failure, re-evalaute the content in the file + EXPECT_EQ(table->num_rows(), 12); + }); #endif } @@ -797,12 +864,15 @@ TEST(Substrait, InvalidPlan) { "relations": [ ] })"; - ASSERT_OK_AND_ASSIGN(auto buf, substrait::SerializeJsonPlan(substrait_json)); - ASSERT_RAISES(Invalid, substrait::ExecuteSerializedPlan(*buf)); + test_with_registries([&substrait_json](ExtensionIdRegistry* ext_id_reg, + compute::FunctionRegistry* func_registry) { + ASSERT_OK_AND_ASSIGN(auto buf, substrait::SerializeJsonPlan(substrait_json)); + ASSERT_RAISES(Invalid, substrait::ExecuteSerializedPlan(*buf)); + }); } TEST(Substrait, JoinPlanBasic) { - ASSERT_OK_AND_ASSIGN(auto buf, internal::SubstraitFromJSON("Plan", R"({ + std::string substrait_json = R"({ "relations": [{ "rel": { "join": { @@ -820,13 +890,13 @@ TEST(Substrait, JoinPlanBasic) { }] } }, - "local_files": { + "local_files": { "items": [ { "uri_file": "file:///tmp/dat1.parquet", "format": "FILE_FORMAT_PARQUET" } - ] + ] } } }, @@ -844,7 +914,7 @@ TEST(Substrait, JoinPlanBasic) { }] } }, - "local_files": { + "local_files": { "items": [ { "uri_file": "file:///tmp/dat2.parquet", @@ -887,7 +957,7 @@ TEST(Substrait, JoinPlanBasic) { "extension_uris": [ { "extension_uri_anchor": 0, - "uri": "https://github.com/apache/arrow/blob/master/format/substrait/extension_types.yaml" + "uri": ")" + substrait::default_extension_types_uri() + R"(" } ], "extensions": [ @@ -897,44 +967,49 @@ TEST(Substrait, JoinPlanBasic) { "name": "equal" }} ] - })")); - ExtensionSet ext_set; - ASSERT_OK_AND_ASSIGN( - auto sink_decls, - DeserializePlans( - *buf, [] { return std::shared_ptr{nullptr}; }, - NULLPTR, &ext_set)); + })"; + ASSERT_OK_AND_ASSIGN(auto buf, internal::SubstraitFromJSON("Plan", substrait_json)); + for (auto sp_ext_id_reg : { + std::shared_ptr(), substrait::MakeExtensionIdRegistry()}) { + ExtensionIdRegistry* ext_id_reg = sp_ext_id_reg.get(); + ExtensionSet ext_set(ext_id_reg); + ASSERT_OK_AND_ASSIGN( + auto sink_decls, + DeserializePlans( + *buf, [] { return std::shared_ptr{nullptr}; }, + ext_id_reg, &ext_set)); - auto join_decl = sink_decls[0].inputs[0]; + auto join_decl = sink_decls[0].inputs[0]; - const auto& join_rel = join_decl.get(); + const auto& join_rel = join_decl.get(); - const auto& join_options = - checked_cast(*join_rel->options); + const auto& join_options = + checked_cast(*join_rel->options); - EXPECT_EQ(join_rel->factory_name, "hashjoin"); - EXPECT_EQ(join_options.join_type, compute::JoinType::INNER); + EXPECT_EQ(join_rel->factory_name, "hashjoin"); + EXPECT_EQ(join_options.join_type, compute::JoinType::INNER); - const auto& left_rel = join_rel->inputs[0].get(); - const auto& right_rel = join_rel->inputs[1].get(); + const auto& left_rel = join_rel->inputs[0].get(); + const auto& right_rel = join_rel->inputs[1].get(); - const auto& l_options = - checked_cast(*left_rel->options); - const auto& r_options = - checked_cast(*right_rel->options); + const auto& l_options = + checked_cast(*left_rel->options); + const auto& r_options = + checked_cast(*right_rel->options); - AssertSchemaEqual( - l_options.dataset->schema(), - schema({field("A", int32()), field("B", int32()), field("C", int32())})); - AssertSchemaEqual( - r_options.dataset->schema(), - schema({field("X", int32()), field("Y", int32()), field("A", int32())})); + AssertSchemaEqual( + l_options.dataset->schema(), + schema({field("A", int32()), field("B", int32()), field("C", int32())})); + AssertSchemaEqual( + r_options.dataset->schema(), + schema({field("X", int32()), field("Y", int32()), field("A", int32())})); - EXPECT_EQ(join_options.key_cmp[0], compute::JoinKeyCmp::EQ); + EXPECT_EQ(join_options.key_cmp[0], compute::JoinKeyCmp::EQ); + } } TEST(Substrait, JoinPlanInvalidKeyCmp) { - ASSERT_OK_AND_ASSIGN(auto buf, internal::SubstraitFromJSON("Plan", R"({ + std::string substrait_json = R"({ "relations": [{ "rel": { "join": { @@ -952,13 +1027,13 @@ TEST(Substrait, JoinPlanInvalidKeyCmp) { }] } }, - "local_files": { + "local_files": { "items": [ { "uri_file": "file:///tmp/dat1.parquet", "format": "FILE_FORMAT_PARQUET" } - ] + ] } } }, @@ -976,7 +1051,7 @@ TEST(Substrait, JoinPlanInvalidKeyCmp) { }] } }, - "local_files": { + "local_files": { "items": [ { "uri_file": "file:///tmp/dat2.parquet", @@ -1019,7 +1094,7 @@ TEST(Substrait, JoinPlanInvalidKeyCmp) { "extension_uris": [ { "extension_uri_anchor": 0, - "uri": "https://github.com/apache/arrow/blob/master/format/substrait/extension_types.yaml" + "uri": ")" + substrait::default_extension_types_uri() + R"(" } ], "extensions": [ @@ -1029,13 +1104,18 @@ TEST(Substrait, JoinPlanInvalidKeyCmp) { "name": "add" }} ] - })")); - ExtensionSet ext_set; - ASSERT_RAISES( - Invalid, - DeserializePlans( - *buf, [] { return std::shared_ptr{nullptr}; }, - NULLPTR, &ext_set)); + })"; + ASSERT_OK_AND_ASSIGN(auto buf, internal::SubstraitFromJSON("Plan", substrait_json)); + for (auto sp_ext_id_reg : { + std::shared_ptr(), substrait::MakeExtensionIdRegistry()}) { + ExtensionIdRegistry* ext_id_reg = sp_ext_id_reg.get(); + ExtensionSet ext_set(ext_id_reg); + ASSERT_RAISES( + Invalid, + DeserializePlans( + *buf, [] { return std::shared_ptr{nullptr}; }, + ext_id_reg, &ext_set)); + } } TEST(Substrait, JoinPlanInvalidExpression) { @@ -1057,13 +1137,13 @@ TEST(Substrait, JoinPlanInvalidExpression) { }] } }, - "local_files": { + "local_files": { "items": [ { "uri_file": "file:///tmp/dat1.parquet", "format": "FILE_FORMAT_PARQUET" } - ] + ] } } }, @@ -1081,7 +1161,7 @@ TEST(Substrait, JoinPlanInvalidExpression) { }] } }, - "local_files": { + "local_files": { "items": [ { "uri_file": "file:///tmp/dat2.parquet", @@ -1097,12 +1177,16 @@ TEST(Substrait, JoinPlanInvalidExpression) { } }] })")); - ExtensionSet ext_set; - ASSERT_RAISES( - Invalid, - DeserializePlans( - *buf, [] { return std::shared_ptr{nullptr}; }, - NULLPTR, &ext_set)); + for (auto sp_ext_id_reg : { + std::shared_ptr(), substrait::MakeExtensionIdRegistry()}) { + ExtensionIdRegistry* ext_id_reg = sp_ext_id_reg.get(); + ExtensionSet ext_set(ext_id_reg); + ASSERT_RAISES( + Invalid, + DeserializePlans( + *buf, [] { return std::shared_ptr{nullptr}; }, + ext_id_reg, &ext_set)); + } } TEST(Substrait, JoinPlanInvalidKeys) { @@ -1124,13 +1208,13 @@ TEST(Substrait, JoinPlanInvalidKeys) { }] } }, - "local_files": { + "local_files": { "items": [ { "uri_file": "file:///tmp/dat1.parquet", "format": "FILE_FORMAT_PARQUET" } - ] + ] } } }, @@ -1165,12 +1249,16 @@ TEST(Substrait, JoinPlanInvalidKeys) { } }] })")); - ExtensionSet ext_set; - ASSERT_RAISES( - Invalid, - DeserializePlans( - *buf, [] { return std::shared_ptr{nullptr}; }, - NULLPTR, &ext_set)); + for (auto sp_ext_id_reg : { + std::shared_ptr(), substrait::MakeExtensionIdRegistry()}) { + ExtensionIdRegistry* ext_id_reg = sp_ext_id_reg.get(); + ExtensionSet ext_set(ext_id_reg); + ASSERT_RAISES( + Invalid, + DeserializePlans( + *buf, [] { return std::shared_ptr{nullptr}; }, + ext_id_reg, &ext_set)); + } } } // namespace engine From 1b1fdde06a8220b5050c2600d0d73e4999265a20 Mon Sep 17 00:00:00 2001 From: Yaron Gvili Date: Sun, 3 Jul 2022 05:07:00 -0400 Subject: [PATCH 19/19] lint --- python/pyarrow/_substrait.pyx | 12 ++---------- 1 file changed, 2 insertions(+), 10 deletions(-) diff --git a/python/pyarrow/_substrait.pyx b/python/pyarrow/_substrait.pyx index 5e4f065ad47..ce0d5704f4a 100644 --- a/python/pyarrow/_substrait.pyx +++ b/python/pyarrow/_substrait.pyx @@ -117,16 +117,8 @@ def register_udf_declarations(plan, extid_registry, func_registry, udf_decls=Non udf_arg_types = udf_decl["input_types"] register_function(extid_registry, None, udf_name, udf_name) def udf(ctx, *args): - try: - r = udf_func(*args) - with open("bblah", "w") as f: - f.write(str((ctx,args,r))) - return r - except: - import sys - with open("bblah", "w") as f: - f.write(str((ctx,args,sys.exc_info()))) - raise + return udf_func(*args) + pc.register_scalar_function( udf, udf_name,