diff --git a/cpp/src/arrow/CMakeLists.txt b/cpp/src/arrow/CMakeLists.txt index 4fc312a6e30..be109d39f13 100644 --- a/cpp/src/arrow/CMakeLists.txt +++ b/cpp/src/arrow/CMakeLists.txt @@ -413,6 +413,7 @@ if(ARROW_COMPUTE) compute/kernel.cc compute/light_array.cc compute/registry.cc + compute/registry_util.cc compute/kernels/aggregate_basic.cc compute/kernels/aggregate_mode.cc compute/kernels/aggregate_quantile.cc diff --git a/cpp/src/arrow/compute/exec/options.h b/cpp/src/arrow/compute/exec/options.h index a86b6c63d36..f4314b9e409 100644 --- a/cpp/src/arrow/compute/exec/options.h +++ b/cpp/src/arrow/compute/exec/options.h @@ -221,6 +221,18 @@ class ARROW_EXPORT SinkNodeConsumer { virtual Future<> Finish() = 0; }; +class ARROW_EXPORT NullSinkNodeConsumer : public SinkNodeConsumer { + public: + Status Init(const std::shared_ptr&, BackpressureControl*) override { + return Status::OK(); + } + Status Consume(ExecBatch exec_batch) override { return Status::OK(); } + Future<> Finish() override { return Status::OK(); } + static std::shared_ptr Make() { + return std::make_shared(); + } +}; + /// \brief Add a sink node which consumes data within the exec plan run class ARROW_EXPORT ConsumingSinkNodeOptions : public ExecNodeOptions { public: @@ -438,7 +450,5 @@ class ARROW_EXPORT TableSinkNodeOptions : public ExecNodeOptions { std::shared_ptr* output_table; }; -/// @} - } // namespace compute } // namespace arrow diff --git a/cpp/src/arrow/compute/registry_util.cc b/cpp/src/arrow/compute/registry_util.cc new file mode 100644 index 00000000000..ca3c729fd77 --- /dev/null +++ b/cpp/src/arrow/compute/registry_util.cc @@ -0,0 +1,28 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#include "arrow/compute/registry_util.h" + +namespace arrow { +namespace compute { + +std::unique_ptr MakeFunctionRegistry() { + return FunctionRegistry::Make(GetFunctionRegistry()); +} + +} // namespace compute +} // namespace arrow diff --git a/cpp/src/arrow/compute/registry_util.h b/cpp/src/arrow/compute/registry_util.h new file mode 100644 index 00000000000..14e9bc5381c --- /dev/null +++ b/cpp/src/arrow/compute/registry_util.h @@ -0,0 +1,33 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +// NOTE: API is EXPERIMENTAL and will change without going through a +// deprecation cycle + +#pragma once + +#include "arrow/compute/registry.h" +#include "arrow/util/visibility.h" + +namespace arrow { +namespace compute { + +/// \brief Make a nested function registry with the default one as parent +ARROW_EXPORT std::unique_ptr MakeFunctionRegistry(); + +} // namespace compute +} // namespace arrow diff --git a/cpp/src/arrow/dataset/file_base.cc b/cpp/src/arrow/dataset/file_base.cc index 568ab414513..bd06f9444ee 100644 --- a/cpp/src/arrow/dataset/file_base.cc +++ b/cpp/src/arrow/dataset/file_base.cc @@ -331,6 +331,22 @@ class DatasetWritingSinkNodeConsumer : public compute::SinkNodeConsumer { return Status::OK(); } + Status Init(compute::ExecNode* node) { + if (node == nullptr) { + return Status::Invalid("internal error - null node"); + } + auto schema = node->inputs()[0]->output_schema(); + if (schema.get() == nullptr) { + return Status::Invalid("internal error - null schema"); + } + if (schema_.get() == nullptr) { + schema_ = schema; + } else if (schema_.get() != schema.get()) { + return Status::Invalid("internal error - inconsistent schemata"); + } + return Status::OK(); + } + Status Consume(compute::ExecBatch batch) override { ARROW_ASSIGN_OR_RAISE(std::shared_ptr record_batch, batch.ToRecordBatch(schema_)); @@ -432,9 +448,15 @@ Result MakeWriteNode(compute::ExecPlan* plan, custom_metadata, std::move(dataset_writer), write_options); ARROW_ASSIGN_OR_RAISE( - auto node, - compute::MakeExecNode("consuming_sink", plan, std::move(inputs), - compute::ConsumingSinkNodeOptions{std::move(consumer)})); + auto node, compute::MakeExecNode("consuming_sink", plan, std::move(inputs), + compute::ConsumingSinkNodeOptions{consumer})); + + // this is a workaround specific for Arrow Substrait code paths + // Arrow Substrait creates ExecNodeOptions instances within a Declaration + // at this stage, schemata have not yet been created since nodes haven't + // thus, the ConsumingSinkNodeOptions passed to consumer has a null schema + // the following call to Init fills in the schema using the node just created + ARROW_RETURN_NOT_OK(consumer->Init(node)); return node; } diff --git a/cpp/src/arrow/engine/substrait/plan_internal.cc b/cpp/src/arrow/engine/substrait/plan_internal.cc index 2da037000cf..1498a80f3f9 100644 --- a/cpp/src/arrow/engine/substrait/plan_internal.cc +++ b/cpp/src/arrow/engine/substrait/plan_internal.cc @@ -91,7 +91,8 @@ Status AddExtensionSetToPlan(const ExtensionSet& ext_set, substrait::Plan* plan) } Result GetExtensionSetFromPlan(const substrait::Plan& plan, - const ExtensionIdRegistry* registry) { + const ExtensionIdRegistry* registry, + bool exclude_functions) { if (registry == NULLPTR) { registry = default_extension_id_registry(); } @@ -121,6 +122,9 @@ Result GetExtensionSetFromPlan(const substrait::Plan& plan, } case substrait::extensions::SimpleExtensionDeclaration::kExtensionFunction: { + if (exclude_functions) { + break; + } const auto& fn = ext.extension_function(); util::string_view uri = uris[fn.extension_uri_reference()]; function_ids[fn.function_anchor()] = Id{uri, fn.name()}; diff --git a/cpp/src/arrow/engine/substrait/plan_internal.h b/cpp/src/arrow/engine/substrait/plan_internal.h index dce23cdceba..4f4f752f243 100644 --- a/cpp/src/arrow/engine/substrait/plan_internal.h +++ b/cpp/src/arrow/engine/substrait/plan_internal.h @@ -49,7 +49,8 @@ Status AddExtensionSetToPlan(const ExtensionSet& ext_set, substrait::Plan* plan) ARROW_ENGINE_EXPORT Result GetExtensionSetFromPlan( const substrait::Plan& plan, - const ExtensionIdRegistry* registry = default_extension_id_registry()); + const ExtensionIdRegistry* registry = default_extension_id_registry(), + bool exclude_functions = false); } // namespace engine } // namespace arrow diff --git a/cpp/src/arrow/engine/substrait/relation_internal.cc b/cpp/src/arrow/engine/substrait/relation_internal.cc index 61e29865516..c098505981a 100644 --- a/cpp/src/arrow/engine/substrait/relation_internal.cc +++ b/cpp/src/arrow/engine/substrait/relation_internal.cc @@ -16,7 +16,6 @@ // under the License. #include "arrow/engine/substrait/relation_internal.h" - #include "arrow/compute/api_scalar.h" #include "arrow/compute/exec/options.h" #include "arrow/dataset/file_base.h" @@ -52,6 +51,69 @@ Status CheckRelCommon(const RelMessage& rel) { return Status::OK(); } +Result FromProto(const substrait::Expression& expr, const std::string& what) { + int32_t index; + switch (expr.rex_type_case()) { + case substrait::Expression::RexTypeCase::kSelection: { + const auto& selection = expr.selection(); + switch (selection.root_type_case()) { + case substrait::Expression_FieldReference::RootTypeCase::kRootReference: { + break; + } + default: { + return Status::NotImplemented( + std::string("substrait::Expression with non-root-reference for ") + what); + } + } + switch (selection.reference_type_case()) { + case substrait::Expression_FieldReference::ReferenceTypeCase::kDirectReference: { + const auto& direct_reference = selection.direct_reference(); + switch (direct_reference.reference_type_case()) { + case substrait::Expression_ReferenceSegment::ReferenceTypeCase:: + kStructField: { + break; + } + default: { + return Status::NotImplemented( + std::string("substrait::Expression with non-struct-field for ") + what); + } + } + const auto& struct_field = direct_reference.struct_field(); + if (struct_field.has_child()) { + return Status::NotImplemented( + std::string("substrait::Expression with non-flat struct-field for ") + + what); + } + index = struct_field.field(); + break; + } + default: { + return Status::NotImplemented( + std::string("substrait::Expression with non-direct reference for ") + what); + } + } + break; + } + default: { + return Status::NotImplemented( + std::string("substrait::Expression with non-selection for ") + what); + } + } + return FieldRef(FieldPath({index})); +} + +Result> FromProto( + const google::protobuf::RepeatedPtrField& exprs, + const std::string& what) { + std::vector fields; + int size = exprs.size(); + for (int i = 0; i < size; i++) { + ARROW_ASSIGN_OR_RAISE(FieldRef field, FromProto(exprs[i], what)); + fields.push_back(field); + } + return fields; +} + Result FromProto(const substrait::Rel& rel, const ExtensionSet& ext_set) { static bool dataset_init = false; @@ -109,6 +171,8 @@ Result FromProto(const substrait::Rel& rel, path = item.uri_path_glob(); } + util::string_view uri_file{item.uri_file()}; + if (item.format() == substrait::ReadRel::LocalFiles::FileOrFiles::FILE_FORMAT_PARQUET) { format = std::make_shared(); diff --git a/cpp/src/arrow/engine/substrait/serde.cc b/cpp/src/arrow/engine/substrait/serde.cc index dda41c282a1..bb0108e5b05 100644 --- a/cpp/src/arrow/engine/substrait/serde.cc +++ b/cpp/src/arrow/engine/substrait/serde.cc @@ -101,13 +101,12 @@ DeclarationFactory MakeWriteDeclarationFactory( return [&write_options_factory]( compute::Declaration input, std::vector names) -> Result { - std::shared_ptr options = write_options_factory(); + std::shared_ptr options = write_options_factory(); if (options == NULLPTR) { return Status::Invalid("write options factory is exhausted"); } compute::Declaration projected = ProjectByNamesDeclaration(input, names); - return compute::Declaration::Sequence( - {std::move(projected), {"write", std::move(*options)}}); + return compute::Declaration::Sequence({std::move(projected), {"write", options}}); }; } @@ -204,6 +203,48 @@ Result DeserializePlan( return MakeSingleDeclarationPlan(declarations); } +Result> DeserializePlanUdfs( + const Buffer& buf, const ExtensionIdRegistry* registry) { + ARROW_ASSIGN_OR_RAISE(auto plan, ParseFromBuffer(buf)); + + ARROW_ASSIGN_OR_RAISE(auto ext_set, GetExtensionSetFromPlan(plan, registry, true)); + + std::vector decls; + /* + for (const auto& ext : plan.extensions()) { + switch (ext.mapping_type_case()) { + case substrait::extensions::SimpleExtensionDeclaration::kExtensionFunction: { + const auto& fn = ext.extension_function(); + if (fn.has_udf()) { + const auto& udf = fn.udf(); + const auto& in_types = udf.input_types(); + int size = in_types.size(); + std::vector, bool>> input_types; + for (int i=0; i> DeserializeSchema(const Buffer& buf, const ExtensionSet& ext_set) { ARROW_ASSIGN_OR_RAISE(auto named_struct, ParseFromBuffer(buf)); diff --git a/cpp/src/arrow/engine/substrait/serde.h b/cpp/src/arrow/engine/substrait/serde.h index 545e7449fb2..3125276a15a 100644 --- a/cpp/src/arrow/engine/substrait/serde.h +++ b/cpp/src/arrow/engine/substrait/serde.h @@ -115,6 +115,22 @@ ARROW_ENGINE_EXPORT Result DeserializePlan( const Buffer& buf, const std::shared_ptr& write_options, const ExtensionIdRegistry* registry = NULLPTR, ExtensionSet* ext_set_out = NULLPTR); +/// Factory function type for generating the write options of a node consuming the batches +/// produced by each toplevel Substrait relation when deserializing a Substrait Plan. +using WriteOptionsFactory = std::function()>; + +struct ARROW_ENGINE_EXPORT UdfDeclaration { + std::string name; + std::string code; + std::string summary; + std::string description; + std::vector, bool>> input_types; + std::pair, bool> output_type; +}; + +ARROW_ENGINE_EXPORT Result> DeserializePlanUdfs( + const Buffer& buf, const ExtensionIdRegistry* registry); + /// \brief Deserializes a Substrait Type message to the corresponding Arrow type /// /// \param[in] buf a buffer containing the protobuf serialization of a Substrait Type diff --git a/cpp/src/arrow/engine/substrait/util.cc b/cpp/src/arrow/engine/substrait/util.cc index 27b61f0b343..38ddb181b32 100644 --- a/cpp/src/arrow/engine/substrait/util.cc +++ b/cpp/src/arrow/engine/substrait/util.cc @@ -127,6 +127,13 @@ Result> SerializeJsonPlan(const std::string& substrait_j return engine::internal::SubstraitFromJSON("Plan", substrait_json); } +Result> DeserializePlans( + const Buffer& buffer, const ExtensionIdRegistry* registry) { + return engine::DeserializePlans( + buffer, []() { return std::make_shared(); }, + registry); +} + std::shared_ptr MakeExtensionIdRegistry() { return nested_extension_id_registry(default_extension_id_registry()); } diff --git a/cpp/src/arrow/engine/substrait/util.h b/cpp/src/arrow/engine/substrait/util.h index 134d633bb33..ee060915e2d 100644 --- a/cpp/src/arrow/engine/substrait/util.h +++ b/cpp/src/arrow/engine/substrait/util.h @@ -39,6 +39,9 @@ ARROW_ENGINE_EXPORT Result> ExecuteSerialized ARROW_ENGINE_EXPORT Result> SerializeJsonPlan( const std::string& substrait_json); +ARROW_ENGINE_EXPORT Result> DeserializePlans( + const Buffer& buf, const ExtensionIdRegistry* registry); + /// \brief Make a nested registry with the default registry as parent. /// See arrow::engine::nested_extension_id_registry for details. ARROW_ENGINE_EXPORT std::shared_ptr MakeExtensionIdRegistry(); diff --git a/cpp/src/arrow/python/pyarrow.h b/cpp/src/arrow/python/pyarrow.h index 4c365081d70..c52ee7f2ebc 100644 --- a/cpp/src/arrow/python/pyarrow.h +++ b/cpp/src/arrow/python/pyarrow.h @@ -40,6 +40,12 @@ class Status; class Table; class Tensor; +namespace engine { + +class ExtensionIdRegistry; + +} // namespace engine + namespace py { // Returns 0 on success, -1 on error. @@ -71,6 +77,8 @@ DECLARE_WRAP_FUNCTIONS(tensor, Tensor) DECLARE_WRAP_FUNCTIONS(batch, RecordBatch) DECLARE_WRAP_FUNCTIONS(table, Table) +DECLARE_WRAP_FUNCTIONS(extension_id_registry, engine::ExtensionIdRegistry) + #undef DECLARE_WRAP_FUNCTIONS namespace internal { diff --git a/cpp/src/arrow/python/udf.cc b/cpp/src/arrow/python/udf.cc index 227629eb24e..6904af89656 100644 --- a/cpp/src/arrow/python/udf.cc +++ b/cpp/src/arrow/python/udf.cc @@ -104,7 +104,8 @@ struct PythonUdf { } // namespace Status RegisterScalarFunction(PyObject* user_function, ScalarUdfWrapperCallback wrapper, - const ScalarUdfOptions& options) { + const ScalarUdfOptions& options, + compute::FunctionRegistry* registry) { if (!PyCallable_Check(user_function)) { return Status::TypeError("Expected a callable Python object."); } @@ -124,7 +125,9 @@ Status RegisterScalarFunction(PyObject* user_function, ScalarUdfWrapperCallback kernel.mem_allocation = compute::MemAllocation::NO_PREALLOCATE; kernel.null_handling = compute::NullHandling::COMPUTED_NO_PREALLOCATE; RETURN_NOT_OK(scalar_func->AddKernel(std::move(kernel))); - auto registry = compute::GetFunctionRegistry(); + if (registry == NULLPTR) { + registry = compute::GetFunctionRegistry(); + } RETURN_NOT_OK(registry->AddFunction(std::move(scalar_func))); return Status::OK(); } diff --git a/cpp/src/arrow/python/udf.h b/cpp/src/arrow/python/udf.h index 4ab3e7cc72b..138f9ee4908 100644 --- a/cpp/src/arrow/python/udf.h +++ b/cpp/src/arrow/python/udf.h @@ -50,9 +50,9 @@ using ScalarUdfWrapperCallback = std::function; /// \brief register a Scalar user-defined-function from Python -Status ARROW_PYTHON_EXPORT RegisterScalarFunction(PyObject* user_function, - ScalarUdfWrapperCallback wrapper, - const ScalarUdfOptions& options); +Status ARROW_PYTHON_EXPORT RegisterScalarFunction( + PyObject* user_function, ScalarUdfWrapperCallback wrapper, + const ScalarUdfOptions& options, compute::FunctionRegistry* registry = NULLPTR); } // namespace py diff --git a/python/pyarrow/__init__.pxd b/python/pyarrow/__init__.pxd index 8cc54b4c6bf..2b3b2ed1922 100644 --- a/python/pyarrow/__init__.pxd +++ b/python/pyarrow/__init__.pxd @@ -20,7 +20,7 @@ from pyarrow.includes.libarrow cimport (CArray, CBuffer, CDataType, CField, CRecordBatch, CSchema, CTable, CTensor, CSparseCOOTensor, CSparseCSRMatrix, CSparseCSCMatrix, - CSparseCSFTensor) + CSparseCSFTensor, CExtensionIdRegistry) cdef extern from "arrow/python/pyarrow.h" namespace "arrow::py": cdef int import_pyarrow() except -1 @@ -40,3 +40,5 @@ cdef extern from "arrow/python/pyarrow.h" namespace "arrow::py": const shared_ptr[CSparseCSFTensor]& sp_sparse_tensor) cdef object wrap_table(const shared_ptr[CTable]& ctable) cdef object wrap_batch(const shared_ptr[CRecordBatch]& cbatch) + cdef object pyarrow_wrap_extension_id_registry( + shared_ptr[CExtensionIdRegistry]& cregistry) diff --git a/python/pyarrow/_compute.pxd b/python/pyarrow/_compute.pxd index 8b09cbd445e..c3266c0cf9f 100644 --- a/python/pyarrow/_compute.pxd +++ b/python/pyarrow/_compute.pxd @@ -27,6 +27,9 @@ cdef class ScalarUdfContext(_Weakrefable): cdef void init(self, const CScalarUdfContext& c_context) +cdef class FunctionRegistry(_Weakrefable): + cdef CFunctionRegistry* registry + cdef class FunctionOptions(_Weakrefable): cdef: shared_ptr[CFunctionOptions] wrapped diff --git a/python/pyarrow/_compute.pyx b/python/pyarrow/_compute.pyx index b9594d90e85..697131d3f27 100644 --- a/python/pyarrow/_compute.pyx +++ b/python/pyarrow/_compute.pyx @@ -477,10 +477,11 @@ cdef _pack_compute_args(object values, vector[CDatum]* out): cdef class FunctionRegistry(_Weakrefable): - cdef CFunctionRegistry* registry - - def __init__(self): - self.registry = GetFunctionRegistry() + def __init__(self, registry=None): + if registry is None: + self.registry = GetFunctionRegistry() + else: + self.registry = pyarrow_unwrap_function_registry(registry) def list_functions(self): """ @@ -513,6 +514,13 @@ def function_registry(): return _global_func_registry +def make_function_registry(): + up_registry = MakeFunctionRegistry() + c_registry = up_registry.get() + up_registry.release() + return FunctionRegistry(pyarrow_wrap_function_registry(c_registry)) + + def get_function(name): """ Get a function by name. @@ -2515,7 +2523,7 @@ def _get_scalar_udf_context(memory_pool, batch_length): def register_scalar_function(func, function_name, function_doc, in_types, - out_type): + out_type, func_registry=None): """ Register a user-defined scalar function. @@ -2556,6 +2564,8 @@ def register_scalar_function(func, function_name, function_doc, in_types, arity. out_type : DataType Output type of the function. + func_registry : FunctionRegistry + Optional function registry to use instead of the default global one. Examples -------- @@ -2593,6 +2603,7 @@ def register_scalar_function(func, function_name, function_doc, in_types, PyObject* c_function shared_ptr[CDataType] c_out_type CScalarUdfOptions c_options + CFunctionRegistry* c_func_registry if callable(func): c_function = func @@ -2601,7 +2612,11 @@ def register_scalar_function(func, function_name, function_doc, in_types, c_func_name = tobytes(function_name) - func_spec = inspect.getfullargspec(func) + try: + func_spec = inspect.getfullargspec(func) + is_varargs = func_spec.varargs is not None + except: + is_varargs = True num_args = -1 if isinstance(in_types, dict): for in_type in in_types.values(): @@ -2613,7 +2628,7 @@ def register_scalar_function(func, function_name, function_doc, in_types, raise TypeError( "in_types must be a dictionary of DataType") - c_arity = CArity(num_args, func_spec.varargs) + c_arity = CArity(num_args, is_varargs) if "summary" not in function_doc: raise ValueError("Function doc must contain a summary") @@ -2634,5 +2649,11 @@ def register_scalar_function(func, function_name, function_doc, in_types, c_options.input_types = c_in_types c_options.output_type = c_out_type + if func_registry is None: + c_func_registry = NULL + else: + c_func_registry = (func_registry).registry + check_status(RegisterScalarFunction(c_function, - &_scalar_udf_callback, c_options)) + &_scalar_udf_callback, + c_options, c_func_registry)) diff --git a/python/pyarrow/_exec_plan.pxd b/python/pyarrow/_exec_plan.pxd new file mode 100644 index 00000000000..4d7529eba64 --- /dev/null +++ b/python/pyarrow/_exec_plan.pxd @@ -0,0 +1,25 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +# cython: language_level = 3 + +from pyarrow.includes.common cimport * +from pyarrow.includes.libarrow cimport * + +cdef is_supported_execplan_output_type(output_type) + +cdef execplan(inputs, output_type, vector[CDeclaration] plan, c_bool use_threads=*, CFunctionRegistry* c_func_registry=*) diff --git a/python/pyarrow/_exec_plan.pyx b/python/pyarrow/_exec_plan.pyx index 89e474f4390..11ddb2f4a14 100644 --- a/python/pyarrow/_exec_plan.pyx +++ b/python/pyarrow/_exec_plan.pyx @@ -36,7 +36,10 @@ from pyarrow._dataset import InMemoryDataset Initialize() # Initialise support for Datasets in ExecPlan -cdef execplan(inputs, output_type, vector[CDeclaration] plan, c_bool use_threads=True): +cdef is_supported_execplan_output_type(output_type): + return output_type in [Table, InMemoryDataset] + +cdef execplan(inputs, output_type, vector[CDeclaration] plan, c_bool use_threads=True, CFunctionRegistry* c_func_registry=NULL): """ Internal Function to create an ExecPlan and run it. @@ -75,13 +78,16 @@ cdef execplan(inputs, output_type, vector[CDeclaration] plan, c_bool use_threads vector[CDeclaration.Input] no_c_inputs CStatus c_plan_status + if not is_supported_execplan_output_type(output_type): + raise TypeError(f"Unsupported output type {output_type}") + if use_threads: c_executor = GetCpuThreadPool() else: c_executor = NULL c_exec_context = make_shared[CExecContext]( - c_default_memory_pool(), c_executor) + c_default_memory_pool(), c_executor, c_func_registry) c_exec_plan = GetResultValue(CExecPlan.Make(c_exec_context.get())) plan_iter = plan.begin() @@ -214,6 +220,9 @@ def _perform_join(join_type, left_operand not None, left_keys, vector[c_string] c_projected_col_names CJoinType c_join_type + if not is_supported_execplan_output_type(output_type): + raise TypeError(f"Unsupported output type {output_type}") + # Prepare left and right tables Keys to send them to the C++ function left_keys_order = {} if isinstance(left_keys, str): @@ -382,6 +391,9 @@ def _filter_table(table, expression, output_type=Table): vector[CDeclaration] c_decl_plan Expression expr = expression + if not is_supported_execplan_output_type(output_type): + raise TypeError(f"Unsupported output type {output_type}") + c_decl_plan.push_back( CDeclaration(tobytes("filter"), CFilterNodeOptions( expr.unwrap(), True @@ -398,4 +410,4 @@ def _filter_table(table, expression, output_type=Table): # "__fragment_index", "__batch_index", "__last_in_fragment", "__filename" return InMemoryDataset(r.select(table.schema.names)) else: - raise TypeError("Unsupported output type") + raise TypeError(f"Unsupported output type {output_type}") diff --git a/python/pyarrow/_substrait.pyx b/python/pyarrow/_substrait.pyx index 7f079fb717b..ce0d5704f4a 100644 --- a/python/pyarrow/_substrait.pyx +++ b/python/pyarrow/_substrait.pyx @@ -15,16 +15,153 @@ # specific language governing permissions and limitations # under the License. +import base64 +import cloudpickle +import inspect + # cython: language_level = 3 -from cython.operator cimport dereference as deref +from cython.operator cimport dereference as deref, preincrement as inc +from pyarrow import compute as pc from pyarrow import Buffer +from pyarrow.lib import frombytes, tobytes from pyarrow.lib cimport * from pyarrow.includes.libarrow cimport * from pyarrow.includes.libarrow_substrait cimport * +from pyarrow._compute cimport FunctionRegistry + + +from pyarrow._exec_plan cimport is_supported_execplan_output_type, execplan +from pyarrow._compute import make_function_registry + + +def make_extension_id_registry(): + cdef: + shared_ptr[CExtensionIdRegistry] c_extid_registry + ExtensionIdRegistry registry + + with nogil: + c_extid_registry = MakeExtensionIdRegistry() + + return pyarrow_wrap_extension_id_registry(c_extid_registry) + + +def _get_udf_code(func): + return frombytes(base64.b64encode(cloudpickle.dumps(func))) + + +def get_udf_declarations(plan, extid_registry): + cdef: + shared_ptr[CBuffer] c_buf_plan + shared_ptr[CExtensionIdRegistry] c_extid_registry + vector[CUdfDeclaration] c_decls + vector[CUdfDeclaration].iterator c_decls_iter + vector[pair[shared_ptr[CDataType], c_bool]].iterator c_in_types_iter + + c_buf_plan = pyarrow_unwrap_buffer(plan) + c_extid_registry = pyarrow_unwrap_extension_id_registry(extid_registry) + with nogil: + c_res_decls = DeserializePlanUdfs( + deref(c_buf_plan), c_extid_registry.get()) + c_decls = GetResultValue(c_res_decls) + + decls = [] + c_decls_iter = c_decls.begin() + while c_decls_iter != c_decls.end(): + input_types = [] + c_in_types_iter = deref(c_decls_iter).input_types.begin() + while c_in_types_iter != deref(c_decls_iter).input_types.end(): + input_types.append((pyarrow_wrap_data_type(deref(c_in_types_iter).first), + deref(c_in_types_iter).second)) + inc(c_in_types_iter) + decls.append({ + "name": frombytes(deref(c_decls_iter).name), + "code": frombytes(deref(c_decls_iter).code), + "summary": frombytes(deref(c_decls_iter).summary), + "description": frombytes(deref(c_decls_iter).description), + "input_types": input_types, + "output_type": (pyarrow_wrap_data_type(deref(c_decls_iter).output_type.first), + deref(c_decls_iter).output_type.second), + }) + inc(c_decls_iter) + return decls + + +def register_function(extid_registry, id_uri, id_name, arrow_function_name): + cdef: + c_string c_id_uri, c_id_name, c_arrow_function_name + shared_ptr[CExtensionIdRegistry] c_extid_registry + CStatus c_status + + c_extid_registry = pyarrow_unwrap_extension_id_registry(extid_registry) + c_id_uri = id_uri or default_extension_types_uri() + c_id_name = tobytes(id_name) + c_arrow_function_name = tobytes(arrow_function_name) + + with nogil: + c_status = RegisterFunction( + deref(c_extid_registry), c_id_uri, c_id_name, c_arrow_function_name + ) + + check_status(c_status) + +def register_udf_declarations(plan, extid_registry, func_registry, udf_decls=None): + if udf_decls is None: + udf_decls = get_udf_declarations(plan, extid_registry) + for udf_decl in udf_decls: + udf_name = udf_decl["name"] + udf_func = cloudpickle.loads( + base64.b64decode(tobytes(udf_decl["code"]))) + udf_arg_names = list(inspect.signature(udf_func).parameters.keys()) + udf_arg_types = udf_decl["input_types"] + register_function(extid_registry, None, udf_name, udf_name) + def udf(ctx, *args): + return udf_func(*args) -def run_query(plan): + pc.register_scalar_function( + udf, + udf_name, + {"summary": udf_decl["summary"], + "description": udf_decl["description"]}, + # range start from 1 to skip over udf scalar context argument + {udf_arg_names[i]: udf_arg_types[i][0] + for i in range(0 ,len(udf_arg_types))}, + udf_decl["output_type"][0], + func_registry, + ) + + +def run_query_as(plan, extid_registry, func_registry, output_type=RecordBatchReader): + if output_type == RecordBatchReader: + return run_query(plan, extid_registry, func_registry) + return _run_query(plan, extid_registry, func_registry, output_type) + + +def _run_query(plan, extid_registry, func_registry, output_type): + cdef: + shared_ptr[CBuffer] c_buf_plan + shared_ptr[CExtensionIdRegistry] c_extid_registry + CFunctionRegistry* c_func_registry + CResult[vector[CDeclaration]] c_res_decls + vector[CDeclaration] c_decls + + if not is_supported_execplan_output_type(output_type): + raise TypeError(f"Unsupported output type {output_type}") + + c_buf_plan = pyarrow_unwrap_buffer(plan) + c_extid_registry = pyarrow_unwrap_extension_id_registry(extid_registry) + c_func_registry = pyarrow_unwrap_function_registry(func_registry) + if c_func_registry == NULL: + c_func_registry = (func_registry).registry + with nogil: + c_res_decls = DeserializePlans( + deref(c_buf_plan), c_extid_registry.get()) + c_decls = GetResultValue(c_res_decls) + return execplan([], output_type, c_decls, True, c_func_registry) + + +def run_query(plan, extid_registry, func_registry): """ Execute a Substrait plan and read the results as a RecordBatchReader. @@ -32,18 +169,29 @@ def run_query(plan): ---------- plan : Buffer The serialized Substrait plan to execute. + extid_registry : ExtensionIdRegistry + The extension-id-registry to execute with. + func_registry : FunctionRegistry + The function registry to execute with. """ cdef: + shared_ptr[CBuffer] c_buf_plan + shared_ptr[CExtensionIdRegistry] c_extid_registry + CFunctionRegistry* c_func_registry CResult[shared_ptr[CRecordBatchReader]] c_res_reader shared_ptr[CRecordBatchReader] c_reader RecordBatchReader reader - c_string c_str_plan - shared_ptr[CBuffer] c_buf_plan c_buf_plan = pyarrow_unwrap_buffer(plan) + c_extid_registry = pyarrow_unwrap_extension_id_registry(extid_registry) + c_func_registry = pyarrow_unwrap_function_registry(func_registry) + if c_func_registry == NULL: + c_func_registry = (func_registry).registry with nogil: - c_res_reader = ExecuteSerializedPlan(deref(c_buf_plan)) + c_res_reader = ExecuteSerializedPlan( + deref(c_buf_plan), c_extid_registry.get(), c_func_registry + ) c_reader = GetResultValue(c_res_reader) diff --git a/python/pyarrow/compute.pxi b/python/pyarrow/compute.pxi new file mode 100644 index 00000000000..f2684ba4211 --- /dev/null +++ b/python/pyarrow/compute.pxi @@ -0,0 +1,33 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +# separating out this base class is easier than unifying it into +# FunctionRegistry, which lives outside libarrow +cdef class BaseFunctionRegistry(_Weakrefable): + cdef CFunctionRegistry* registry + +cdef class ExtensionIdRegistry(_Weakrefable): + def __cinit__(self): + self.registry = NULL + + def __init__(self): + raise TypeError("Do not call ExtensionIdRegistry's constructor directly, use " + "the `MakeExtensionIdRegistry` function instead.") + + cdef void init(self, shared_ptr[CExtensionIdRegistry]& registry): + self.sp_registry = registry + self.registry = registry.get() diff --git a/python/pyarrow/includes/libarrow.pxd b/python/pyarrow/includes/libarrow.pxd index 9e43eb4eb9c..8ec211b88e7 100644 --- a/python/pyarrow/includes/libarrow.pxd +++ b/python/pyarrow/includes/libarrow.pxd @@ -2731,4 +2731,13 @@ cdef extern from "arrow/python/udf.h" namespace "arrow::py": shared_ptr[CDataType] output_type CStatus RegisterScalarFunction(PyObject* function, - function[CallbackUdf] wrapper, const CScalarUdfOptions& options) + function[CallbackUdf] wrapper, const CScalarUdfOptions& options, + CFunctionRegistry* registry) + +cdef extern from "arrow/engine/substrait/extension_set.h" namespace "arrow::engine" nogil: + + cdef cppclass CExtensionIdRegistry" arrow::engine::ExtensionIdRegistry" + +cdef extern from "arrow/compute/registry_util.h" namespace "arrow::compute" nogil: + + unique_ptr[CFunctionRegistry] MakeFunctionRegistry() diff --git a/python/pyarrow/includes/libarrow_substrait.pxd b/python/pyarrow/includes/libarrow_substrait.pxd index 2e1a17b06bd..30d772b8b1a 100644 --- a/python/pyarrow/includes/libarrow_substrait.pxd +++ b/python/pyarrow/includes/libarrow_substrait.pxd @@ -21,6 +21,26 @@ from pyarrow.includes.common cimport * from pyarrow.includes.libarrow cimport * +cdef extern from "arrow/engine/substrait/extension_set.h" namespace "arrow::engine" nogil: + cppclass CExtensionIdRegistry "arrow::engine::ExtensionIdRegistry" + +cdef extern from "arrow/engine/substrait/serde.h" namespace "arrow::engine" nogil: + cppclass CUdfDeclaration "arrow::engine::UdfDeclaration": + c_string name + c_string code + c_string summary + c_string description + vector[pair[shared_ptr[CDataType], c_bool]] input_types + pair[shared_ptr[CDataType], c_bool] output_type + + CResult[vector[CUdfDeclaration]] DeserializePlanUdfs(const CBuffer& substrait_buffer, const CExtensionIdRegistry* registry) + cdef extern from "arrow/engine/substrait/util.h" namespace "arrow::engine::substrait" nogil: - CResult[shared_ptr[CRecordBatchReader]] ExecuteSerializedPlan(const CBuffer& substrait_buffer) + shared_ptr[CExtensionIdRegistry] MakeExtensionIdRegistry() + CStatus RegisterFunction(CExtensionIdRegistry& registry, const c_string& id_uri, const c_string& id_name, const c_string& arrow_function_name) + + CResult[shared_ptr[CRecordBatchReader]] ExecuteSerializedPlan(const CBuffer& substrait_buffer, const CExtensionIdRegistry* extid_registry, CFunctionRegistry* func_registry) CResult[shared_ptr[CBuffer]] SerializeJsonPlan(const c_string& substrait_json) + CResult[vector[CDeclaration]] DeserializePlans(const CBuffer& substrait_buffer, const CExtensionIdRegistry* registry) + + const c_string& default_extension_types_uri() diff --git a/python/pyarrow/lib.pxd b/python/pyarrow/lib.pxd index 953b0e7b518..5b6a5416958 100644 --- a/python/pyarrow/lib.pxd +++ b/python/pyarrow/lib.pxd @@ -24,6 +24,7 @@ from libcpp.memory cimport dynamic_pointer_cast from pyarrow.includes.common cimport * from pyarrow.includes.libarrow cimport * from pyarrow.includes.libarrow_python cimport * +from pyarrow.includes.libarrow_substrait cimport * cdef extern from "Python.h": @@ -446,6 +447,14 @@ cdef class RecordBatch(_PandasConvertible): cdef void init(self, const shared_ptr[CRecordBatch]& table) +cdef class ExtensionIdRegistry(_Weakrefable): + cdef: + shared_ptr[CExtensionIdRegistry] sp_registry + CExtensionIdRegistry* registry + + cdef void init(self, shared_ptr[CExtensionIdRegistry]& registry) + + cdef class Buffer(_Weakrefable): cdef: shared_ptr[CBuffer] buffer @@ -585,6 +594,9 @@ cdef public object pyarrow_wrap_tensor(const shared_ptr[CTensor]& sp_tensor) cdef public object pyarrow_wrap_batch(const shared_ptr[CRecordBatch]& cbatch) cdef public object pyarrow_wrap_table(const shared_ptr[CTable]& ctable) +cdef public object pyarrow_wrap_function_registry(CFunctionRegistry* cregistry) +cdef public object pyarrow_wrap_extension_id_registry(shared_ptr[CExtensionIdRegistry]& cregistry) + # Unwrapping Python -> C++ cdef public shared_ptr[CBuffer] pyarrow_unwrap_buffer(object buffer) @@ -611,3 +623,6 @@ cdef public shared_ptr[CTensor] pyarrow_unwrap_tensor(object tensor) cdef public shared_ptr[CRecordBatch] pyarrow_unwrap_batch(object batch) cdef public shared_ptr[CTable] pyarrow_unwrap_table(object table) + +cdef public CFunctionRegistry* pyarrow_unwrap_function_registry(object registry) +cdef public shared_ptr[CExtensionIdRegistry] pyarrow_unwrap_extension_id_registry(object registry) diff --git a/python/pyarrow/lib.pyx b/python/pyarrow/lib.pyx index a665ea59c6e..653ce1064cd 100644 --- a/python/pyarrow/lib.pyx +++ b/python/pyarrow/lib.pyx @@ -169,6 +169,9 @@ include "builder.pxi" # Column, Table, Record Batch include "table.pxi" +# Compute registries +include "compute.pxi" + # Tensors include "tensor.pxi" diff --git a/python/pyarrow/public-api.pxi b/python/pyarrow/public-api.pxi index c427fb9f5db..607cc475553 100644 --- a/python/pyarrow/public-api.pxi +++ b/python/pyarrow/public-api.pxi @@ -416,3 +416,42 @@ cdef api object pyarrow_wrap_batch( cdef RecordBatch batch = RecordBatch.__new__(RecordBatch) batch.init(cbatch) return batch + + +cdef api bint pyarrow_is_function_registry(object registry): + return isinstance(registry, BaseFunctionRegistry) + + +cdef api bint pyarrow_is_extension_id_registry(object registry): + return isinstance(registry, ExtensionIdRegistry) + + +cdef api CFunctionRegistry* pyarrow_unwrap_function_registry(object registry): + cdef BaseFunctionRegistry reg + if pyarrow_is_function_registry(registry): + reg = (registry) + return reg.registry + + return NULL + + +cdef api shared_ptr[CExtensionIdRegistry] pyarrow_unwrap_extension_id_registry(object registry): + cdef ExtensionIdRegistry reg + if pyarrow_is_extension_id_registry(registry): + reg = (registry) + return reg.sp_registry + + return shared_ptr[CExtensionIdRegistry]() + + +cdef api object pyarrow_wrap_function_registry(CFunctionRegistry* cregistry): + cdef BaseFunctionRegistry registry = BaseFunctionRegistry.__new__(BaseFunctionRegistry) + registry.registry = cregistry + return registry + + +cdef api object pyarrow_wrap_extension_id_registry( + shared_ptr[CExtensionIdRegistry]& cregistry): + cdef ExtensionIdRegistry registry = ExtensionIdRegistry.__new__(ExtensionIdRegistry) + registry.init(cregistry) + return registry diff --git a/python/pyarrow/substrait.py b/python/pyarrow/substrait.py index e3ff28f4eba..3584bed3cb8 100644 --- a/python/pyarrow/substrait.py +++ b/python/pyarrow/substrait.py @@ -16,5 +16,13 @@ # under the License. from pyarrow._substrait import ( # noqa + make_function_registry, + make_extension_id_registry, + _get_udf_code, + get_udf_declarations, + register_function, + register_udf_declarations, + run_query_as, run_query, + _parse_json_plan, ) diff --git a/python/pyarrow/tests/test_substrait.py b/python/pyarrow/tests/test_substrait.py index 8df35bbba44..b5c9d966f2a 100644 --- a/python/pyarrow/tests/test_substrait.py +++ b/python/pyarrow/tests/test_substrait.py @@ -22,6 +22,7 @@ import pyarrow as pa from pyarrow.lib import tobytes from pyarrow.lib import ArrowInvalid +from pyarrow.substrait import make_extension_id_registry try: import pyarrow.substrait as substrait @@ -74,7 +75,9 @@ def test_run_serialized_query(tmpdir): buf = pa._substrait._parse_json_plan(query) - reader = substrait.run_query(buf) + extid_registry = substrait.make_extension_id_registry() + func_registry = substrait.make_function_registry() + reader = substrait.run_query(buf, extid_registry, func_registry) res_tb = reader.read_all() assert table.select(["foo"]) == res_tb.select(["foo"]) @@ -88,6 +91,8 @@ def test_invalid_plan(): } """ buf = pa._substrait._parse_json_plan(tobytes(query)) + extid_registry = substrait.make_extension_id_registry() + func_registry = substrait.make_function_registry() exec_message = "Empty substrait plan is passed." with pytest.raises(ArrowInvalid, match=exec_message): - substrait.run_query(buf) + substrait.run_query(buf, extid_registry, func_registry) diff --git a/python/pyarrow/tests/test_udf.py b/python/pyarrow/tests/test_udf.py index e711619582d..02b0c13685f 100644 --- a/python/pyarrow/tests/test_udf.py +++ b/python/pyarrow/tests/test_udf.py @@ -18,6 +18,7 @@ import pytest +import numpy as np import pyarrow as pa from pyarrow import compute as pc