Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
26 commits
Select commit Hold shift + click to select a range
716a5b9
Substrait integrations
rtpsw Mar 11, 2022
8327b67
Substrait integrations
rtpsw Mar 11, 2022
d054f2a
Merge branch 'rtpsw-x1' of https://github.com/rtpsw/arrow into rtpsw-x1
rtpsw Mar 25, 2022
330ae66
Added end-to-end Substrait-to-Arrow enhancements
rtpsw Mar 25, 2022
abee905
Added logical comparison operators to Substrait registry
rtpsw Mar 27, 2022
3f3f3ef
Added as-of-merge execution
rtpsw Mar 31, 2022
98d2663
Added Substrait deserialization of flat field references for AsOfMerge
rtpsw Apr 8, 2022
5aa7ede
Support write-consumer of Arrow Substrait plan
rtpsw Apr 17, 2022
c0c0d08
Added explanation comment on MakeWriteNode
rtpsw Apr 28, 2022
f202dc5
Set use_threads on scan options of Arrow Substrait
rtpsw Apr 28, 2022
a912ea5
try
rtpsw May 13, 2022
b8e56bc
merge rtpsw-x1 and fix
rtpsw May 13, 2022
5b9025b
Merge branch 'master' into rtpsw-x2
rtpsw May 22, 2022
dbacb0a
integrated and tested
rtpsw May 22, 2022
f49a85d
UDF PoC
rtpsw May 24, 2022
4eba11f
merge master to rtpsw-x2
rtpsw May 27, 2022
5795a86
UDF PoC with scoped registries
rtpsw May 30, 2022
90f20d0
Fix parameter order and doc of DeserializePlan functions
rtpsw Jun 1, 2022
908862a
Merge branch 'master' into rtpsw-x2
rtpsw Jun 7, 2022
879999e
fix registry scoping
rtpsw Jun 9, 2022
a15e0ca
simple UDF benchmark
rtpsw Jun 10, 2022
85bbaf4
improved UDF PoC benchmark
rtpsw Jun 12, 2022
394676a
add substrait tests
rtpsw Jun 14, 2022
11d59a2
Merge branch 'master' into rtpsw-x2
rtpsw Jun 28, 2022
2a7386e
ARROW-16968: [C++] Expand Python-UDF support to Arrow Substrait
rtpsw Jul 3, 2022
1b1fdde
lint
rtpsw Jul 3, 2022
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions cpp/src/arrow/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -413,6 +413,7 @@ if(ARROW_COMPUTE)
compute/kernel.cc
compute/light_array.cc
compute/registry.cc
compute/registry_util.cc
compute/kernels/aggregate_basic.cc
compute/kernels/aggregate_mode.cc
compute/kernels/aggregate_quantile.cc
Expand Down
14 changes: 12 additions & 2 deletions cpp/src/arrow/compute/exec/options.h
Original file line number Diff line number Diff line change
Expand Up @@ -221,6 +221,18 @@ class ARROW_EXPORT SinkNodeConsumer {
virtual Future<> Finish() = 0;
};

class ARROW_EXPORT NullSinkNodeConsumer : public SinkNodeConsumer {
public:
Status Init(const std::shared_ptr<Schema>&, BackpressureControl*) override {
return Status::OK();
}
Status Consume(ExecBatch exec_batch) override { return Status::OK(); }
Future<> Finish() override { return Status::OK(); }
static std::shared_ptr<NullSinkNodeConsumer> Make() {
return std::make_shared<NullSinkNodeConsumer>();
}
};

/// \brief Add a sink node which consumes data within the exec plan run
class ARROW_EXPORT ConsumingSinkNodeOptions : public ExecNodeOptions {
public:
Expand Down Expand Up @@ -438,7 +450,5 @@ class ARROW_EXPORT TableSinkNodeOptions : public ExecNodeOptions {
std::shared_ptr<Table>* output_table;
};

/// @}
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You shouldn't remove this, this matches the opening brace in \addtogroup execnode-options above.


} // namespace compute
} // namespace arrow
28 changes: 28 additions & 0 deletions cpp/src/arrow/compute/registry_util.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.

#include "arrow/compute/registry_util.h"

namespace arrow {
namespace compute {

std::unique_ptr<FunctionRegistry> MakeFunctionRegistry() {
return FunctionRegistry::Make(GetFunctionRegistry());
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What is the point of this mostly trivial function? Why not let the user call FunctionRegistry::Make directly?

}

} // namespace compute
} // namespace arrow
33 changes: 33 additions & 0 deletions cpp/src/arrow/compute/registry_util.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.

// NOTE: API is EXPERIMENTAL and will change without going through a
// deprecation cycle

#pragma once

#include "arrow/compute/registry.h"
#include "arrow/util/visibility.h"

namespace arrow {
namespace compute {

/// \brief Make a nested function registry with the default one as parent
ARROW_EXPORT std::unique_ptr<FunctionRegistry> MakeFunctionRegistry();

} // namespace compute
} // namespace arrow
28 changes: 25 additions & 3 deletions cpp/src/arrow/dataset/file_base.cc
Original file line number Diff line number Diff line change
Expand Up @@ -331,6 +331,22 @@ class DatasetWritingSinkNodeConsumer : public compute::SinkNodeConsumer {
return Status::OK();
}

Status Init(compute::ExecNode* node) {
if (node == nullptr) {
return Status::Invalid("internal error - null node");
}
auto schema = node->inputs()[0]->output_schema();
if (schema.get() == nullptr) {
return Status::Invalid("internal error - null schema");
}
if (schema_.get() == nullptr) {
schema_ = schema;
} else if (schema_.get() != schema.get()) {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this really comparing the pointers by value? Don't you want to compare the underlying schemas instead?

return Status::Invalid("internal error - inconsistent schemata");
}
return Status::OK();
}

Status Consume(compute::ExecBatch batch) override {
ARROW_ASSIGN_OR_RAISE(std::shared_ptr<RecordBatch> record_batch,
batch.ToRecordBatch(schema_));
Expand Down Expand Up @@ -432,9 +448,15 @@ Result<compute::ExecNode*> MakeWriteNode(compute::ExecPlan* plan,
custom_metadata, std::move(dataset_writer), write_options);

ARROW_ASSIGN_OR_RAISE(
auto node,
compute::MakeExecNode("consuming_sink", plan, std::move(inputs),
compute::ConsumingSinkNodeOptions{std::move(consumer)}));
auto node, compute::MakeExecNode("consuming_sink", plan, std::move(inputs),
compute::ConsumingSinkNodeOptions{consumer}));

// this is a workaround specific for Arrow Substrait code paths
// Arrow Substrait creates ExecNodeOptions instances within a Declaration
// at this stage, schemata have not yet been created since nodes haven't
// thus, the ConsumingSinkNodeOptions passed to consumer has a null schema
// the following call to Init fills in the schema using the node just created
ARROW_RETURN_NOT_OK(consumer->Init(node));

return node;
}
Expand Down
6 changes: 5 additions & 1 deletion cpp/src/arrow/engine/substrait/plan_internal.cc
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,8 @@ Status AddExtensionSetToPlan(const ExtensionSet& ext_set, substrait::Plan* plan)
}

Result<ExtensionSet> GetExtensionSetFromPlan(const substrait::Plan& plan,
const ExtensionIdRegistry* registry) {
const ExtensionIdRegistry* registry,
bool exclude_functions) {
if (registry == NULLPTR) {
registry = default_extension_id_registry();
}
Expand Down Expand Up @@ -121,6 +122,9 @@ Result<ExtensionSet> GetExtensionSetFromPlan(const substrait::Plan& plan,
}

case substrait::extensions::SimpleExtensionDeclaration::kExtensionFunction: {
if (exclude_functions) {
break;
}
const auto& fn = ext.extension_function();
util::string_view uri = uris[fn.extension_uri_reference()];
function_ids[fn.function_anchor()] = Id{uri, fn.name()};
Expand Down
3 changes: 2 additions & 1 deletion cpp/src/arrow/engine/substrait/plan_internal.h
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,8 @@ Status AddExtensionSetToPlan(const ExtensionSet& ext_set, substrait::Plan* plan)
ARROW_ENGINE_EXPORT
Result<ExtensionSet> GetExtensionSetFromPlan(
const substrait::Plan& plan,
const ExtensionIdRegistry* registry = default_extension_id_registry());
const ExtensionIdRegistry* registry = default_extension_id_registry(),
bool exclude_functions = false);
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you add documentation for this parameter in the docstring above?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also, as a nit, double negatives are not terrific, so I would instead suggest bool include_functions = true.


} // namespace engine
} // namespace arrow
66 changes: 65 additions & 1 deletion cpp/src/arrow/engine/substrait/relation_internal.cc
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@
// under the License.

#include "arrow/engine/substrait/relation_internal.h"

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

FTR, the convention would be to leave the blank link as it's separating the .h corresponding to this .cc from other included headers.

#include "arrow/compute/api_scalar.h"
#include "arrow/compute/exec/options.h"
#include "arrow/dataset/file_base.h"
Expand Down Expand Up @@ -52,6 +51,69 @@ Status CheckRelCommon(const RelMessage& rel) {
return Status::OK();
}

Result<FieldRef> FromProto(const substrait::Expression& expr, const std::string& what) {
int32_t index;
switch (expr.rex_type_case()) {
case substrait::Expression::RexTypeCase::kSelection: {
const auto& selection = expr.selection();
switch (selection.root_type_case()) {
case substrait::Expression_FieldReference::RootTypeCase::kRootReference: {
break;
}
default: {
return Status::NotImplemented(
std::string("substrait::Expression with non-root-reference for ") + what);
}
}
switch (selection.reference_type_case()) {
case substrait::Expression_FieldReference::ReferenceTypeCase::kDirectReference: {
const auto& direct_reference = selection.direct_reference();
switch (direct_reference.reference_type_case()) {
case substrait::Expression_ReferenceSegment::ReferenceTypeCase::
kStructField: {
break;
}
default: {
return Status::NotImplemented(
std::string("substrait::Expression with non-struct-field for ") + what);
}
}
const auto& struct_field = direct_reference.struct_field();
if (struct_field.has_child()) {
return Status::NotImplemented(
std::string("substrait::Expression with non-flat struct-field for ") +
what);
}
index = struct_field.field();
break;
}
default: {
return Status::NotImplemented(
std::string("substrait::Expression with non-direct reference for ") + what);
}
}
break;
}
default: {
return Status::NotImplemented(
std::string("substrait::Expression with non-selection for ") + what);
}
}
return FieldRef(FieldPath({index}));
}

Result<std::vector<FieldRef>> FromProto(
const google::protobuf::RepeatedPtrField<substrait::Expression>& exprs,
const std::string& what) {
std::vector<FieldRef> fields;
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

May want to presize this?

int size = exprs.size();
for (int i = 0; i < size; i++) {
ARROW_ASSIGN_OR_RAISE(FieldRef field, FromProto(exprs[i], what));
Comment on lines +109 to +111
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You can probably use a for-range construct:

Suggested change
int size = exprs.size();
for (int i = 0; i < size; i++) {
ARROW_ASSIGN_OR_RAISE(FieldRef field, FromProto(exprs[i], what));
for (const auto& expr : exprs) {
ARROW_ASSIGN_OR_RAISE(FieldRef field, FromProto(expr, what));

fields.push_back(field);
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
fields.push_back(field);
fields.push_back(std::move(field));

}
return fields;
}

Result<compute::Declaration> FromProto(const substrait::Rel& rel,
const ExtensionSet& ext_set) {
static bool dataset_init = false;
Expand Down Expand Up @@ -109,6 +171,8 @@ Result<compute::Declaration> FromProto(const substrait::Rel& rel,
path = item.uri_path_glob();
}

util::string_view uri_file{item.uri_file()};
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This seems unused.


if (item.format() ==
substrait::ReadRel::LocalFiles::FileOrFiles::FILE_FORMAT_PARQUET) {
format = std::make_shared<dataset::ParquetFileFormat>();
Expand Down
47 changes: 44 additions & 3 deletions cpp/src/arrow/engine/substrait/serde.cc
Original file line number Diff line number Diff line change
Expand Up @@ -101,13 +101,12 @@ DeclarationFactory MakeWriteDeclarationFactory(
return [&write_options_factory](
compute::Declaration input,
std::vector<std::string> names) -> Result<compute::Declaration> {
std::shared_ptr<dataset::WriteNodeOptions> options = write_options_factory();
std::shared_ptr<compute::ExecNodeOptions> options = write_options_factory();
if (options == NULLPTR) {
return Status::Invalid("write options factory is exhausted");
}
compute::Declaration projected = ProjectByNamesDeclaration(input, names);
return compute::Declaration::Sequence(
{std::move(projected), {"write", std::move(*options)}});
return compute::Declaration::Sequence({std::move(projected), {"write", options}});
};
}

Expand Down Expand Up @@ -204,6 +203,48 @@ Result<compute::ExecPlan> DeserializePlan(
return MakeSingleDeclarationPlan(declarations);
}

Result<std::vector<UdfDeclaration>> DeserializePlanUdfs(
const Buffer& buf, const ExtensionIdRegistry* registry) {
ARROW_ASSIGN_OR_RAISE(auto plan, ParseFromBuffer<substrait::Plan>(buf));

ARROW_ASSIGN_OR_RAISE(auto ext_set, GetExtensionSetFromPlan(plan, registry, true));

std::vector<UdfDeclaration> decls;
/*
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So, is this code that needs to be debugged and then enabled?
If this PR is not finished, could you mark it as draft?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This code compiles (and passes locally implemented tests I have) with code i proposed to Substrait that is pending agreement as noted in this explanation post. For now, it goes to show the logic that's going to be implemented here.

for (const auto& ext : plan.extensions()) {
switch (ext.mapping_type_case()) {
case substrait::extensions::SimpleExtensionDeclaration::kExtensionFunction: {
const auto& fn = ext.extension_function();
if (fn.has_udf()) {
const auto& udf = fn.udf();
const auto& in_types = udf.input_types();
int size = in_types.size();
std::vector<std::pair<std::shared_ptr<DataType>, bool>> input_types;
for (int i=0; i<size; i++) {
ARROW_ASSIGN_OR_RAISE(auto input_type, FromProto(in_types.Get(i), ext_set));
input_types.push_back(std::move(input_type));
}
ARROW_ASSIGN_OR_RAISE(auto output_type, FromProto(udf.output_type(), ext_set));
decls.push_back(std::move(UdfDeclaration{
fn.name(),
udf.code(),
udf.summary(),
udf.description(),
std::move(input_types),
std::move(output_type),
}));
}
break;
}
default: {
break;
}
}
}
*/
Comment on lines +213 to +244
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this not used?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is the commented-out code-block explained here.

return decls;
}

Result<std::shared_ptr<Schema>> DeserializeSchema(const Buffer& buf,
const ExtensionSet& ext_set) {
ARROW_ASSIGN_OR_RAISE(auto named_struct, ParseFromBuffer<substrait::NamedStruct>(buf));
Expand Down
16 changes: 16 additions & 0 deletions cpp/src/arrow/engine/substrait/serde.h
Original file line number Diff line number Diff line change
Expand Up @@ -115,6 +115,22 @@ ARROW_ENGINE_EXPORT Result<compute::ExecPlan> DeserializePlan(
const Buffer& buf, const std::shared_ptr<dataset::WriteNodeOptions>& write_options,
const ExtensionIdRegistry* registry = NULLPTR, ExtensionSet* ext_set_out = NULLPTR);

/// Factory function type for generating the write options of a node consuming the batches
/// produced by each toplevel Substrait relation when deserializing a Substrait Plan.
using WriteOptionsFactory = std::function<std::shared_ptr<dataset::WriteNodeOptions>()>;
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It seems this is a duplicate declaration.


struct ARROW_ENGINE_EXPORT UdfDeclaration {
std::string name;
std::string code;
std::string summary;
std::string description;
std::vector<std::pair<std::shared_ptr<DataType>, bool>> input_types;
std::pair<std::shared_ptr<DataType>, bool> output_type;
Comment on lines +127 to +128
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's not obvious what the bool is for. Can you perhaps use a helper struct, e.g.:

Suggested change
std::vector<std::pair<std::shared_ptr<DataType>, bool>> input_types;
std::pair<std::shared_ptr<DataType>, bool> output_type;
struct TypeDeclaration {
std::shared_ptr<DataType> type;
bool xxx_some_suitable_name;
};
std::vector<TypeDeclaration> input_types;
TypeDeclaration output_type;

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

IIRC, this pair-type originates from here, so the bool means is_nullable. I'll see if all relevant places in the code can be cleaned up.

};

ARROW_ENGINE_EXPORT Result<std::vector<UdfDeclaration>> DeserializePlanUdfs(
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Add a docstring?

const Buffer& buf, const ExtensionIdRegistry* registry);

/// \brief Deserializes a Substrait Type message to the corresponding Arrow type
///
/// \param[in] buf a buffer containing the protobuf serialization of a Substrait Type
Expand Down
7 changes: 7 additions & 0 deletions cpp/src/arrow/engine/substrait/util.cc
Original file line number Diff line number Diff line change
Expand Up @@ -127,6 +127,13 @@ Result<std::shared_ptr<Buffer>> SerializeJsonPlan(const std::string& substrait_j
return engine::internal::SubstraitFromJSON("Plan", substrait_json);
}

Result<std::vector<compute::Declaration>> DeserializePlans(
const Buffer& buffer, const ExtensionIdRegistry* registry) {
return engine::DeserializePlans(
buffer, []() { return std::make_shared<compute::NullSinkNodeConsumer>(); },
registry);
}

std::shared_ptr<ExtensionIdRegistry> MakeExtensionIdRegistry() {
return nested_extension_id_registry(default_extension_id_registry());
}
Expand Down
3 changes: 3 additions & 0 deletions cpp/src/arrow/engine/substrait/util.h
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,9 @@ ARROW_ENGINE_EXPORT Result<std::shared_ptr<RecordBatchReader>> ExecuteSerialized
ARROW_ENGINE_EXPORT Result<std::shared_ptr<Buffer>> SerializeJsonPlan(
const std::string& substrait_json);

ARROW_ENGINE_EXPORT Result<std::vector<compute::Declaration>> DeserializePlans(
const Buffer& buf, const ExtensionIdRegistry* registry);
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There are already functions named DeserializePlans in serde.h. Isn't it a bit confusing to have another one similarly named here?

Also, can you add a docstring?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The main purpose of util.h is to expose functions to PyArrow. Nevertheless, I'll go over these function to see what can be simplified,


/// \brief Make a nested registry with the default registry as parent.
/// See arrow::engine::nested_extension_id_registry for details.
ARROW_ENGINE_EXPORT std::shared_ptr<ExtensionIdRegistry> MakeExtensionIdRegistry();
Expand Down
8 changes: 8 additions & 0 deletions cpp/src/arrow/python/pyarrow.h
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,12 @@ class Status;
class Table;
class Tensor;

namespace engine {

class ExtensionIdRegistry;

} // namespace engine

namespace py {

// Returns 0 on success, -1 on error.
Expand Down Expand Up @@ -71,6 +77,8 @@ DECLARE_WRAP_FUNCTIONS(tensor, Tensor)
DECLARE_WRAP_FUNCTIONS(batch, RecordBatch)
DECLARE_WRAP_FUNCTIONS(table, Table)

DECLARE_WRAP_FUNCTIONS(extension_id_registry, engine::ExtensionIdRegistry)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hmm, this will expose the wrapper functions to C++ code, which doesn't seem to be used anywhere. Instead, you should wrap/unwrap purely on the Cython side, like for most other C++ classes.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This exposure is for PyArrow. It will be used in an upcoming PR, which should not be merged into this one. The purpose of exposing ExtensionIdRegistry is to allow use of a nested registry from PyArrow.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This doesn't really address the following comment:

Instead, you should wrap/unwrap purely on the Cython side, like for most other C++ classes.


#undef DECLARE_WRAP_FUNCTIONS

namespace internal {
Expand Down
Loading