Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 1 addition & 3 deletions cpp/src/arrow/compute/registry.h
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,7 @@ class ARROW_EXPORT FunctionRegistry {

/// \brief Check whether a new function options type can be added to the registry.
///
/// \returns Status::KeyError if a function options type with the same name is already
/// \return Status::KeyError if a function options type with the same name is already
/// registered.
Status CanAddFunctionOptionsType(const FunctionOptionsType* options_type,
bool allow_overwrite = false);
Expand Down Expand Up @@ -115,8 +115,6 @@ class ARROW_EXPORT FunctionRegistry {
std::unique_ptr<FunctionRegistryImpl> impl_;

explicit FunctionRegistry(FunctionRegistryImpl* impl);

class NestedFunctionRegistryImpl;
};

/// \brief Return the process-global function registry.
Expand Down
5 changes: 5 additions & 0 deletions cpp/src/arrow/engine/substrait/extension_set.cc
Original file line number Diff line number Diff line change
Expand Up @@ -315,6 +315,11 @@ struct ExtensionIdRegistryImpl : ExtensionIdRegistry {
return Status::OK();
}

Status RegisterFunction(std::string uri, std::string name,
std::string arrow_function_name) override {
return RegisterFunction({uri, name}, arrow_function_name);
}

// owning storage of uris, names, (arrow::)function_names, types
// note that storing strings like this is safe since references into an
// unordered_set are not invalidated on insertion
Expand Down
4 changes: 4 additions & 0 deletions cpp/src/arrow/engine/substrait/extension_set.h
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,11 @@ class ARROW_ENGINE_EXPORT ExtensionIdRegistry {
util::string_view arrow_function_name) const = 0;
virtual Status CanRegisterFunction(Id,
const std::string& arrow_function_name) const = 0;
// registers a function without taking ownership of uri and name within Id
virtual Status RegisterFunction(Id, std::string arrow_function_name) = 0;
// registers a function while taking ownership of uri and name
virtual Status RegisterFunction(std::string uri, std::string name,
std::string arrow_function_name) = 0;
};

constexpr util::string_view kArrowExtTypesUri =
Expand Down
3 changes: 3 additions & 0 deletions cpp/src/arrow/engine/substrait/plan_internal.cc
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,9 @@ Status AddExtensionSetToPlan(const ExtensionSet& ext_set, substrait::Plan* plan)

Result<ExtensionSet> GetExtensionSetFromPlan(const substrait::Plan& plan,
const ExtensionIdRegistry* registry) {
if (registry == NULLPTR) {
registry = default_extension_id_registry();
}
std::unordered_map<uint32_t, util::string_view> uris;
uris.reserve(plan.extension_uris_size());
for (const auto& uri : plan.extension_uris()) {
Expand Down
3 changes: 3 additions & 0 deletions cpp/src/arrow/engine/substrait/relation_internal.cc
Original file line number Diff line number Diff line change
Expand Up @@ -19,13 +19,15 @@

#include "arrow/compute/api_scalar.h"
#include "arrow/compute/exec/options.h"
#include "arrow/dataset/file_base.h"
#include "arrow/dataset/file_ipc.h"
#include "arrow/dataset/file_parquet.h"
#include "arrow/dataset/plan.h"
#include "arrow/dataset/scanner.h"
#include "arrow/engine/substrait/expression_internal.h"
#include "arrow/engine/substrait/type_internal.h"
#include "arrow/filesystem/localfs.h"
#include "arrow/filesystem/path_util.h"
#include "arrow/filesystem/util_internal.h"

namespace arrow {
Expand Down Expand Up @@ -66,6 +68,7 @@ Result<compute::Declaration> FromProto(const substrait::Rel& rel,
ARROW_ASSIGN_OR_RAISE(auto base_schema, FromProto(read.base_schema(), ext_set));

auto scan_options = std::make_shared<dataset::ScanOptions>();
scan_options->use_threads = true;

if (read.has_filter()) {
ARROW_ASSIGN_OR_RAISE(scan_options->filter, FromProto(read.filter(), ext_set));
Expand Down
127 changes: 113 additions & 14 deletions cpp/src/arrow/engine/substrait/serde.cc
Original file line number Diff line number Diff line change
Expand Up @@ -58,12 +58,65 @@ Result<compute::Declaration> DeserializeRelation(const Buffer& buf,
return FromProto(rel, ext_set);
}

using DeclarationFactory = std::function<Result<compute::Declaration>(
compute::Declaration, std::vector<std::string> names)>;

namespace {

DeclarationFactory MakeConsumingSinkDeclarationFactory(
const ConsumerFactory& consumer_factory) {
return [&consumer_factory](
compute::Declaration input,
std::vector<std::string> names) -> Result<compute::Declaration> {
std::shared_ptr<compute::SinkNodeConsumer> consumer = consumer_factory();
if (consumer == NULLPTR) {
return Status::Invalid("consumer factory is exhausted");
}
std::shared_ptr<compute::ExecNodeOptions> options =
std::make_shared<compute::ConsumingSinkNodeOptions>(
compute::ConsumingSinkNodeOptions{consumer_factory(), std::move(names)});
return compute::Declaration::Sequence(
{std::move(input), {"consuming_sink", options}});
};
}

compute::Declaration ProjectByNamesDeclaration(compute::Declaration input,
std::vector<std::string> names) {
int names_size = static_cast<int>(names.size());
if (names_size == 0) {
return input;
}
std::vector<compute::Expression> expressions;
for (int i = 0; i < names_size; i++) {
expressions.push_back(compute::field_ref(FieldRef(i)));
}
return compute::Declaration::Sequence(
{std::move(input),
{"project",
compute::ProjectNodeOptions{std::move(expressions), std::move(names)}}});
}

DeclarationFactory MakeWriteDeclarationFactory(
const WriteOptionsFactory& write_options_factory) {
return [&write_options_factory](
compute::Declaration input,
std::vector<std::string> names) -> Result<compute::Declaration> {
std::shared_ptr<dataset::WriteNodeOptions> options = write_options_factory();
if (options == NULLPTR) {
return Status::Invalid("write options factory is exhausted");
}
compute::Declaration projected = ProjectByNamesDeclaration(input, names);
return compute::Declaration::Sequence(
{std::move(projected), {"write", std::move(*options)}});
};
}

Result<std::vector<compute::Declaration>> DeserializePlans(
const Buffer& buf, const ConsumerFactory& consumer_factory,
ExtensionSet* ext_set_out) {
const Buffer& buf, DeclarationFactory declaration_factory,
const ExtensionIdRegistry* registry, ExtensionSet* ext_set_out) {
ARROW_ASSIGN_OR_RAISE(auto plan, ParseFromBuffer<substrait::Plan>(buf));

ARROW_ASSIGN_OR_RAISE(auto ext_set, GetExtensionSetFromPlan(plan));
ARROW_ASSIGN_OR_RAISE(auto ext_set, GetExtensionSetFromPlan(plan, registry));

std::vector<compute::Declaration> sink_decls;
for (const substrait::PlanRel& plan_rel : plan.relations()) {
Expand All @@ -76,12 +129,9 @@ Result<std::vector<compute::Declaration>> DeserializePlans(
names.assign(plan_rel.root().names().begin(), plan_rel.root().names().end());
}

// pipe each relation into a consuming_sink node
auto sink_decl = compute::Declaration::Sequence({
std::move(decl),
{"consuming_sink",
compute::ConsumingSinkNodeOptions{consumer_factory(), std::move(names)}},
});
// pipe each relation
ARROW_ASSIGN_OR_RAISE(auto sink_decl,
declaration_factory(std::move(decl), std::move(names)));
sink_decls.push_back(std::move(sink_decl));
}

Expand All @@ -91,11 +141,26 @@ Result<std::vector<compute::Declaration>> DeserializePlans(
return sink_decls;
}

Result<compute::ExecPlan> DeserializePlan(const Buffer& buf,
const ConsumerFactory& consumer_factory,
ExtensionSet* ext_set_out) {
ARROW_ASSIGN_OR_RAISE(auto declarations,
DeserializePlans(buf, consumer_factory, ext_set_out));
} // namespace

Result<std::vector<compute::Declaration>> DeserializePlans(
const Buffer& buf, const ConsumerFactory& consumer_factory,
const ExtensionIdRegistry* registry, ExtensionSet* ext_set_out) {
return DeserializePlans(buf, MakeConsumingSinkDeclarationFactory(consumer_factory),
registry, ext_set_out);
}

Result<std::vector<compute::Declaration>> DeserializePlans(
const Buffer& buf, const WriteOptionsFactory& write_options_factory,
const ExtensionIdRegistry* registry, ExtensionSet* ext_set_out) {
return DeserializePlans(buf, MakeWriteDeclarationFactory(write_options_factory),
registry, ext_set_out);
}

namespace {

Result<compute::ExecPlan> MakeSingleDeclarationPlan(
std::vector<compute::Declaration> declarations) {
if (declarations.size() > 1) {
return Status::Invalid("DeserializePlan does not support multiple root relations");
} else {
Expand All @@ -105,6 +170,40 @@ Result<compute::ExecPlan> DeserializePlan(const Buffer& buf,
}
}

} // namespace

Result<compute::ExecPlan> DeserializePlan(
const Buffer& buf, const std::shared_ptr<compute::SinkNodeConsumer>& consumer,
const ExtensionIdRegistry* registry, ExtensionSet* ext_set_out) {
bool factory_done = false;
auto single_consumer = [&factory_done, &consumer] {
if (factory_done) {
return std::shared_ptr<compute::SinkNodeConsumer>{};
}
factory_done = true;
return consumer;
};
ARROW_ASSIGN_OR_RAISE(auto declarations,
DeserializePlans(buf, single_consumer, registry, ext_set_out));
return MakeSingleDeclarationPlan(declarations);
}

Result<compute::ExecPlan> DeserializePlan(
const Buffer& buf, const std::shared_ptr<dataset::WriteNodeOptions>& write_options,
const ExtensionIdRegistry* registry, ExtensionSet* ext_set_out) {
bool factory_done = false;
auto single_write_options = [&factory_done, &write_options] {
if (factory_done) {
return std::shared_ptr<dataset::WriteNodeOptions>{};
}
factory_done = true;
return write_options;
};
ARROW_ASSIGN_OR_RAISE(auto declarations, DeserializePlans(buf, single_write_options,
registry, ext_set_out));
return MakeSingleDeclarationPlan(declarations);
}

Result<std::shared_ptr<Schema>> DeserializeSchema(const Buffer& buf,
const ExtensionSet& ext_set) {
ARROW_ASSIGN_OR_RAISE(auto named_struct, ParseFromBuffer<substrait::NamedStruct>(buf));
Expand Down
67 changes: 63 additions & 4 deletions cpp/src/arrow/engine/substrait/serde.h
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
#include "arrow/buffer.h"
#include "arrow/compute/exec/exec_plan.h"
#include "arrow/compute/exec/options.h"
#include "arrow/dataset/file_base.h"
#include "arrow/engine/substrait/extension_set.h"
#include "arrow/engine/substrait/visibility.h"
#include "arrow/result.h"
Expand All @@ -40,21 +41,79 @@ using ConsumerFactory = std::function<std::shared_ptr<compute::SinkNodeConsumer>

/// \brief Deserializes a Substrait Plan message to a list of ExecNode declarations
///
/// The output of each top-level Substrait relation will be sent to a caller supplied
/// consumer function provided by consumer_factory
///
/// \param[in] buf a buffer containing the protobuf serialization of a Substrait Plan
/// message
/// \param[in] consumer_factory factory function for generating the node that consumes
/// the batches produced by each toplevel Substrait relation
/// \param[in] registry an extension-id-registry to use, or null for the default one.
/// \param[out] ext_set_out if non-null, the extension mapping used by the Substrait
/// Plan is returned here.
/// \return a vector of ExecNode declarations, one for each toplevel relation in the
/// Substrait Plan
ARROW_ENGINE_EXPORT Result<std::vector<compute::Declaration>> DeserializePlans(
const Buffer& buf, const ConsumerFactory& consumer_factory,
ExtensionSet* ext_set_out = NULLPTR);
const ExtensionIdRegistry* registry = NULLPTR, ExtensionSet* ext_set_out = NULLPTR);

/// \brief Deserializes a single-relation Substrait Plan message to an execution plan
///
/// The output of each top-level Substrait relation will be sent to a caller supplied
/// consumer function provided by consumer_factory
///
/// \param[in] buf a buffer containing the protobuf serialization of a Substrait Plan
/// message
/// \param[in] consumer node that consumes the batches produced by each toplevel Substrait
/// relation
/// \param[in] registry an extension-id-registry to use, or null for the default one.
/// \param[out] ext_set_out if non-null, the extension mapping used by the Substrait
/// Plan is returned here.
/// \return an ExecNode corresponding to the single toplevel relation in the Substrait
/// Plan
Result<compute::ExecPlan> DeserializePlan(
const Buffer& buf, const std::shared_ptr<compute::SinkNodeConsumer>& consumer,
const ExtensionIdRegistry* registry = NULLPTR, ExtensionSet* ext_set_out = NULLPTR);

/// Factory function type for generating the write options of a node consuming the batches
/// produced by each toplevel Substrait relation when deserializing a Substrait Plan.
using WriteOptionsFactory = std::function<std::shared_ptr<dataset::WriteNodeOptions>()>;

/// \brief Deserializes a Substrait Plan message to a list of ExecNode declarations
///
/// The output of each top-level Substrait relation will be written to a filesystem.
/// `write_options_factory` can be used to control write behavior.
///
/// \param[in] buf a buffer containing the protobuf serialization of a Substrait Plan
/// message
/// \param[in] write_options_factory factory function for generating the write options of
/// a node consuming the batches produced by each toplevel Substrait relation
/// \param[in] registry an extension-id-registry to use, or null for the default one.
/// \param[out] ext_set_out if non-null, the extension mapping used by the Substrait
/// Plan is returned here.
/// \return a vector of ExecNode declarations, one for each toplevel relation in the
/// Substrait Plan
ARROW_ENGINE_EXPORT Result<std::vector<compute::Declaration>> DeserializePlans(
const Buffer& buf, const WriteOptionsFactory& write_options_factory,
const ExtensionIdRegistry* registry = NULLPTR, ExtensionSet* ext_set_out = NULLPTR);

Result<compute::ExecPlan> DeserializePlan(const Buffer& buf,
const ConsumerFactory& consumer_factory,
ExtensionSet* ext_set_out = NULLPTR);
/// \brief Deserializes a single-relation Substrait Plan message to an execution plan
///
/// The output of the single Substrait relation will be written to a filesystem.
/// `write_options_factory` can be used to control write behavior.
///
/// \param[in] buf a buffer containing the protobuf serialization of a Substrait Plan
/// message
/// \param[in] write_options write options of a node consuming the batches produced by
/// each toplevel Substrait relation
/// \param[in] registry an extension-id-registry to use, or null for the default one.
/// \param[out] ext_set_out if non-null, the extension mapping used by the Substrait
/// Plan is returned here.
/// \return a vector of ExecNode declarations, one for each toplevel relation in the
/// Substrait Plan
ARROW_ENGINE_EXPORT Result<compute::ExecPlan> DeserializePlan(
const Buffer& buf, const std::shared_ptr<dataset::WriteNodeOptions>& write_options,
const ExtensionIdRegistry* registry = NULLPTR, ExtensionSet* ext_set_out = NULLPTR);

/// \brief Deserializes a Substrait Type message to the corresponding Arrow type
///
Expand Down
Loading