-
Notifications
You must be signed in to change notification settings - Fork 4k
ARROW 16968: [C++] Expand Python-UDF support to Arrow Substrait #13500
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
716a5b9
8327b67
d054f2a
330ae66
abee905
3f3f3ef
98d2663
5aa7ede
c0c0d08
f202dc5
a912ea5
b8e56bc
5b9025b
dbacb0a
f49a85d
4eba11f
5795a86
90f20d0
908862a
879999e
a15e0ca
85bbaf4
394676a
11d59a2
2a7386e
1b1fdde
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,28 @@ | ||
| // Licensed to the Apache Software Foundation (ASF) under one | ||
| // or more contributor license agreements. See the NOTICE file | ||
| // distributed with this work for additional information | ||
| // regarding copyright ownership. The ASF licenses this file | ||
| // to you under the Apache License, Version 2.0 (the | ||
| // "License"); you may not use this file except in compliance | ||
| // with the License. You may obtain a copy of the License at | ||
| // | ||
| // http://www.apache.org/licenses/LICENSE-2.0 | ||
| // | ||
| // Unless required by applicable law or agreed to in writing, | ||
| // software distributed under the License is distributed on an | ||
| // "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY | ||
| // KIND, either express or implied. See the License for the | ||
| // specific language governing permissions and limitations | ||
| // under the License. | ||
|
|
||
| #include "arrow/compute/registry_util.h" | ||
|
|
||
| namespace arrow { | ||
| namespace compute { | ||
|
|
||
| std::unique_ptr<FunctionRegistry> MakeFunctionRegistry() { | ||
| return FunctionRegistry::Make(GetFunctionRegistry()); | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. What is the point of this mostly trivial function? Why not let the user call |
||
| } | ||
|
|
||
| } // namespace compute | ||
| } // namespace arrow | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,33 @@ | ||
| // Licensed to the Apache Software Foundation (ASF) under one | ||
| // or more contributor license agreements. See the NOTICE file | ||
| // distributed with this work for additional information | ||
| // regarding copyright ownership. The ASF licenses this file | ||
| // to you under the Apache License, Version 2.0 (the | ||
| // "License"); you may not use this file except in compliance | ||
| // with the License. You may obtain a copy of the License at | ||
| // | ||
| // http://www.apache.org/licenses/LICENSE-2.0 | ||
| // | ||
| // Unless required by applicable law or agreed to in writing, | ||
| // software distributed under the License is distributed on an | ||
| // "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY | ||
| // KIND, either express or implied. See the License for the | ||
| // specific language governing permissions and limitations | ||
| // under the License. | ||
|
|
||
| // NOTE: API is EXPERIMENTAL and will change without going through a | ||
| // deprecation cycle | ||
|
|
||
| #pragma once | ||
|
|
||
| #include "arrow/compute/registry.h" | ||
| #include "arrow/util/visibility.h" | ||
|
|
||
| namespace arrow { | ||
| namespace compute { | ||
|
|
||
| /// \brief Make a nested function registry with the default one as parent | ||
| ARROW_EXPORT std::unique_ptr<FunctionRegistry> MakeFunctionRegistry(); | ||
|
|
||
| } // namespace compute | ||
| } // namespace arrow |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -331,6 +331,22 @@ class DatasetWritingSinkNodeConsumer : public compute::SinkNodeConsumer { | |
| return Status::OK(); | ||
| } | ||
|
|
||
| Status Init(compute::ExecNode* node) { | ||
| if (node == nullptr) { | ||
| return Status::Invalid("internal error - null node"); | ||
| } | ||
| auto schema = node->inputs()[0]->output_schema(); | ||
| if (schema.get() == nullptr) { | ||
| return Status::Invalid("internal error - null schema"); | ||
| } | ||
| if (schema_.get() == nullptr) { | ||
| schema_ = schema; | ||
| } else if (schema_.get() != schema.get()) { | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Is this really comparing the pointers by value? Don't you want to compare the underlying schemas instead? |
||
| return Status::Invalid("internal error - inconsistent schemata"); | ||
| } | ||
| return Status::OK(); | ||
| } | ||
|
|
||
| Status Consume(compute::ExecBatch batch) override { | ||
| ARROW_ASSIGN_OR_RAISE(std::shared_ptr<RecordBatch> record_batch, | ||
| batch.ToRecordBatch(schema_)); | ||
|
|
@@ -432,9 +448,15 @@ Result<compute::ExecNode*> MakeWriteNode(compute::ExecPlan* plan, | |
| custom_metadata, std::move(dataset_writer), write_options); | ||
|
|
||
| ARROW_ASSIGN_OR_RAISE( | ||
| auto node, | ||
| compute::MakeExecNode("consuming_sink", plan, std::move(inputs), | ||
| compute::ConsumingSinkNodeOptions{std::move(consumer)})); | ||
| auto node, compute::MakeExecNode("consuming_sink", plan, std::move(inputs), | ||
| compute::ConsumingSinkNodeOptions{consumer})); | ||
|
|
||
| // this is a workaround specific for Arrow Substrait code paths | ||
| // Arrow Substrait creates ExecNodeOptions instances within a Declaration | ||
| // at this stage, schemata have not yet been created since nodes haven't | ||
| // thus, the ConsumingSinkNodeOptions passed to consumer has a null schema | ||
| // the following call to Init fills in the schema using the node just created | ||
| ARROW_RETURN_NOT_OK(consumer->Init(node)); | ||
|
|
||
| return node; | ||
| } | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -49,7 +49,8 @@ Status AddExtensionSetToPlan(const ExtensionSet& ext_set, substrait::Plan* plan) | |
| ARROW_ENGINE_EXPORT | ||
| Result<ExtensionSet> GetExtensionSetFromPlan( | ||
| const substrait::Plan& plan, | ||
| const ExtensionIdRegistry* registry = default_extension_id_registry()); | ||
| const ExtensionIdRegistry* registry = default_extension_id_registry(), | ||
| bool exclude_functions = false); | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Can you add documentation for this parameter in the docstring above?
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Also, as a nit, double negatives are not terrific, so I would instead suggest |
||
|
|
||
| } // namespace engine | ||
| } // namespace arrow | ||
| Original file line number | Diff line number | Diff line change | ||||||||||
|---|---|---|---|---|---|---|---|---|---|---|---|---|
|
|
@@ -16,7 +16,6 @@ | |||||||||||
| // under the License. | ||||||||||||
|
|
||||||||||||
| #include "arrow/engine/substrait/relation_internal.h" | ||||||||||||
|
|
||||||||||||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. FTR, the convention would be to leave the blank link as it's separating the |
||||||||||||
| #include "arrow/compute/api_scalar.h" | ||||||||||||
| #include "arrow/compute/exec/options.h" | ||||||||||||
| #include "arrow/dataset/file_base.h" | ||||||||||||
|
|
@@ -52,6 +51,69 @@ Status CheckRelCommon(const RelMessage& rel) { | |||||||||||
| return Status::OK(); | ||||||||||||
| } | ||||||||||||
|
|
||||||||||||
| Result<FieldRef> FromProto(const substrait::Expression& expr, const std::string& what) { | ||||||||||||
| int32_t index; | ||||||||||||
| switch (expr.rex_type_case()) { | ||||||||||||
| case substrait::Expression::RexTypeCase::kSelection: { | ||||||||||||
| const auto& selection = expr.selection(); | ||||||||||||
| switch (selection.root_type_case()) { | ||||||||||||
| case substrait::Expression_FieldReference::RootTypeCase::kRootReference: { | ||||||||||||
| break; | ||||||||||||
| } | ||||||||||||
| default: { | ||||||||||||
| return Status::NotImplemented( | ||||||||||||
| std::string("substrait::Expression with non-root-reference for ") + what); | ||||||||||||
| } | ||||||||||||
| } | ||||||||||||
| switch (selection.reference_type_case()) { | ||||||||||||
| case substrait::Expression_FieldReference::ReferenceTypeCase::kDirectReference: { | ||||||||||||
| const auto& direct_reference = selection.direct_reference(); | ||||||||||||
| switch (direct_reference.reference_type_case()) { | ||||||||||||
| case substrait::Expression_ReferenceSegment::ReferenceTypeCase:: | ||||||||||||
| kStructField: { | ||||||||||||
| break; | ||||||||||||
| } | ||||||||||||
| default: { | ||||||||||||
| return Status::NotImplemented( | ||||||||||||
| std::string("substrait::Expression with non-struct-field for ") + what); | ||||||||||||
| } | ||||||||||||
| } | ||||||||||||
| const auto& struct_field = direct_reference.struct_field(); | ||||||||||||
| if (struct_field.has_child()) { | ||||||||||||
| return Status::NotImplemented( | ||||||||||||
| std::string("substrait::Expression with non-flat struct-field for ") + | ||||||||||||
| what); | ||||||||||||
| } | ||||||||||||
| index = struct_field.field(); | ||||||||||||
| break; | ||||||||||||
| } | ||||||||||||
| default: { | ||||||||||||
| return Status::NotImplemented( | ||||||||||||
| std::string("substrait::Expression with non-direct reference for ") + what); | ||||||||||||
| } | ||||||||||||
| } | ||||||||||||
| break; | ||||||||||||
| } | ||||||||||||
| default: { | ||||||||||||
| return Status::NotImplemented( | ||||||||||||
| std::string("substrait::Expression with non-selection for ") + what); | ||||||||||||
| } | ||||||||||||
| } | ||||||||||||
| return FieldRef(FieldPath({index})); | ||||||||||||
| } | ||||||||||||
|
|
||||||||||||
| Result<std::vector<FieldRef>> FromProto( | ||||||||||||
| const google::protobuf::RepeatedPtrField<substrait::Expression>& exprs, | ||||||||||||
| const std::string& what) { | ||||||||||||
| std::vector<FieldRef> fields; | ||||||||||||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. May want to presize this? |
||||||||||||
| int size = exprs.size(); | ||||||||||||
| for (int i = 0; i < size; i++) { | ||||||||||||
| ARROW_ASSIGN_OR_RAISE(FieldRef field, FromProto(exprs[i], what)); | ||||||||||||
|
Comment on lines
+109
to
+111
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. You can probably use a for-range construct:
Suggested change
|
||||||||||||
| fields.push_back(field); | ||||||||||||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
|
||||||||||||
| } | ||||||||||||
| return fields; | ||||||||||||
| } | ||||||||||||
|
|
||||||||||||
| Result<compute::Declaration> FromProto(const substrait::Rel& rel, | ||||||||||||
| const ExtensionSet& ext_set) { | ||||||||||||
| static bool dataset_init = false; | ||||||||||||
|
|
@@ -109,6 +171,8 @@ Result<compute::Declaration> FromProto(const substrait::Rel& rel, | |||||||||||
| path = item.uri_path_glob(); | ||||||||||||
| } | ||||||||||||
|
|
||||||||||||
| util::string_view uri_file{item.uri_file()}; | ||||||||||||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This seems unused. |
||||||||||||
|
|
||||||||||||
| if (item.format() == | ||||||||||||
| substrait::ReadRel::LocalFiles::FileOrFiles::FILE_FORMAT_PARQUET) { | ||||||||||||
| format = std::make_shared<dataset::ParquetFileFormat>(); | ||||||||||||
|
|
||||||||||||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -101,13 +101,12 @@ DeclarationFactory MakeWriteDeclarationFactory( | |
| return [&write_options_factory]( | ||
| compute::Declaration input, | ||
| std::vector<std::string> names) -> Result<compute::Declaration> { | ||
| std::shared_ptr<dataset::WriteNodeOptions> options = write_options_factory(); | ||
| std::shared_ptr<compute::ExecNodeOptions> options = write_options_factory(); | ||
| if (options == NULLPTR) { | ||
| return Status::Invalid("write options factory is exhausted"); | ||
| } | ||
| compute::Declaration projected = ProjectByNamesDeclaration(input, names); | ||
| return compute::Declaration::Sequence( | ||
| {std::move(projected), {"write", std::move(*options)}}); | ||
| return compute::Declaration::Sequence({std::move(projected), {"write", options}}); | ||
| }; | ||
| } | ||
|
|
||
|
|
@@ -204,6 +203,48 @@ Result<compute::ExecPlan> DeserializePlan( | |
| return MakeSingleDeclarationPlan(declarations); | ||
| } | ||
|
|
||
| Result<std::vector<UdfDeclaration>> DeserializePlanUdfs( | ||
| const Buffer& buf, const ExtensionIdRegistry* registry) { | ||
| ARROW_ASSIGN_OR_RAISE(auto plan, ParseFromBuffer<substrait::Plan>(buf)); | ||
|
|
||
| ARROW_ASSIGN_OR_RAISE(auto ext_set, GetExtensionSetFromPlan(plan, registry, true)); | ||
|
|
||
| std::vector<UdfDeclaration> decls; | ||
| /* | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. So, is this code that needs to be debugged and then enabled?
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This code compiles (and passes locally implemented tests I have) with code i proposed to Substrait that is pending agreement as noted in this explanation post. For now, it goes to show the logic that's going to be implemented here. |
||
| for (const auto& ext : plan.extensions()) { | ||
| switch (ext.mapping_type_case()) { | ||
| case substrait::extensions::SimpleExtensionDeclaration::kExtensionFunction: { | ||
| const auto& fn = ext.extension_function(); | ||
| if (fn.has_udf()) { | ||
| const auto& udf = fn.udf(); | ||
| const auto& in_types = udf.input_types(); | ||
| int size = in_types.size(); | ||
| std::vector<std::pair<std::shared_ptr<DataType>, bool>> input_types; | ||
| for (int i=0; i<size; i++) { | ||
| ARROW_ASSIGN_OR_RAISE(auto input_type, FromProto(in_types.Get(i), ext_set)); | ||
| input_types.push_back(std::move(input_type)); | ||
| } | ||
| ARROW_ASSIGN_OR_RAISE(auto output_type, FromProto(udf.output_type(), ext_set)); | ||
| decls.push_back(std::move(UdfDeclaration{ | ||
| fn.name(), | ||
| udf.code(), | ||
| udf.summary(), | ||
| udf.description(), | ||
| std::move(input_types), | ||
| std::move(output_type), | ||
| })); | ||
| } | ||
| break; | ||
| } | ||
| default: { | ||
| break; | ||
| } | ||
| } | ||
| } | ||
| */ | ||
|
Comment on lines
+213
to
+244
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Is this not used?
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This is the commented-out code-block explained here. |
||
| return decls; | ||
| } | ||
|
|
||
| Result<std::shared_ptr<Schema>> DeserializeSchema(const Buffer& buf, | ||
| const ExtensionSet& ext_set) { | ||
| ARROW_ASSIGN_OR_RAISE(auto named_struct, ParseFromBuffer<substrait::NamedStruct>(buf)); | ||
|
|
||
| Original file line number | Diff line number | Diff line change | ||||||||||||||||
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
|
|
@@ -115,6 +115,22 @@ ARROW_ENGINE_EXPORT Result<compute::ExecPlan> DeserializePlan( | |||||||||||||||||
| const Buffer& buf, const std::shared_ptr<dataset::WriteNodeOptions>& write_options, | ||||||||||||||||||
| const ExtensionIdRegistry* registry = NULLPTR, ExtensionSet* ext_set_out = NULLPTR); | ||||||||||||||||||
|
|
||||||||||||||||||
| /// Factory function type for generating the write options of a node consuming the batches | ||||||||||||||||||
| /// produced by each toplevel Substrait relation when deserializing a Substrait Plan. | ||||||||||||||||||
| using WriteOptionsFactory = std::function<std::shared_ptr<dataset::WriteNodeOptions>()>; | ||||||||||||||||||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. It seems this is a duplicate declaration. |
||||||||||||||||||
|
|
||||||||||||||||||
| struct ARROW_ENGINE_EXPORT UdfDeclaration { | ||||||||||||||||||
| std::string name; | ||||||||||||||||||
| std::string code; | ||||||||||||||||||
| std::string summary; | ||||||||||||||||||
| std::string description; | ||||||||||||||||||
| std::vector<std::pair<std::shared_ptr<DataType>, bool>> input_types; | ||||||||||||||||||
| std::pair<std::shared_ptr<DataType>, bool> output_type; | ||||||||||||||||||
|
Comment on lines
+127
to
+128
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. It's not obvious what the
Suggested change
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. IIRC, this pair-type originates from here, so the |
||||||||||||||||||
| }; | ||||||||||||||||||
|
|
||||||||||||||||||
| ARROW_ENGINE_EXPORT Result<std::vector<UdfDeclaration>> DeserializePlanUdfs( | ||||||||||||||||||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Add a docstring? |
||||||||||||||||||
| const Buffer& buf, const ExtensionIdRegistry* registry); | ||||||||||||||||||
|
|
||||||||||||||||||
| /// \brief Deserializes a Substrait Type message to the corresponding Arrow type | ||||||||||||||||||
| /// | ||||||||||||||||||
| /// \param[in] buf a buffer containing the protobuf serialization of a Substrait Type | ||||||||||||||||||
|
|
||||||||||||||||||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -39,6 +39,9 @@ ARROW_ENGINE_EXPORT Result<std::shared_ptr<RecordBatchReader>> ExecuteSerialized | |
| ARROW_ENGINE_EXPORT Result<std::shared_ptr<Buffer>> SerializeJsonPlan( | ||
| const std::string& substrait_json); | ||
|
|
||
| ARROW_ENGINE_EXPORT Result<std::vector<compute::Declaration>> DeserializePlans( | ||
| const Buffer& buf, const ExtensionIdRegistry* registry); | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. There are already functions named Also, can you add a docstring?
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The main purpose of |
||
|
|
||
| /// \brief Make a nested registry with the default registry as parent. | ||
| /// See arrow::engine::nested_extension_id_registry for details. | ||
| ARROW_ENGINE_EXPORT std::shared_ptr<ExtensionIdRegistry> MakeExtensionIdRegistry(); | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -40,6 +40,12 @@ class Status; | |
| class Table; | ||
| class Tensor; | ||
|
|
||
| namespace engine { | ||
|
|
||
| class ExtensionIdRegistry; | ||
|
|
||
| } // namespace engine | ||
|
|
||
| namespace py { | ||
|
|
||
| // Returns 0 on success, -1 on error. | ||
|
|
@@ -71,6 +77,8 @@ DECLARE_WRAP_FUNCTIONS(tensor, Tensor) | |
| DECLARE_WRAP_FUNCTIONS(batch, RecordBatch) | ||
| DECLARE_WRAP_FUNCTIONS(table, Table) | ||
|
|
||
| DECLARE_WRAP_FUNCTIONS(extension_id_registry, engine::ExtensionIdRegistry) | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Hmm, this will expose the wrapper functions to C++ code, which doesn't seem to be used anywhere. Instead, you should wrap/unwrap purely on the Cython side, like for most other C++ classes.
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This exposure is for PyArrow. It will be used in an upcoming PR, which should not be merged into this one. The purpose of exposing
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This doesn't really address the following comment:
|
||
|
|
||
| #undef DECLARE_WRAP_FUNCTIONS | ||
|
|
||
| namespace internal { | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You shouldn't remove this, this matches the opening brace in
\addtogroup execnode-optionsabove.