Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 11 additions & 0 deletions cpp/proto/substrait/extension_rels.proto
Original file line number Diff line number Diff line change
Expand Up @@ -58,3 +58,14 @@ message NamedTapRel {
// If empty, field names will be automatically generated.
repeated string columns = 3;
}

message SegmentedAggregateRel {
// Grouping keys of the aggregation
repeated substrait.Expression.FieldReference grouping_keys = 1;

// Segment keys of the aggregation
repeated substrait.Expression.FieldReference segment_keys = 2;

// A list of one or more aggregate expressions along with an optional filter.
repeated substrait.AggregateRel.Measure measures = 3;
}
9 changes: 9 additions & 0 deletions cpp/src/arrow/compute/exec/source_node.cc
Original file line number Diff line number Diff line change
Expand Up @@ -28,8 +28,10 @@
#include "arrow/compute/expression.h"
#include "arrow/datum.h"
#include "arrow/io/util_internal.h"
#include "arrow/ipc/util.h"
#include "arrow/result.h"
#include "arrow/table.h"
#include "arrow/util/align_util.h"
#include "arrow/util/async_generator.h"
#include "arrow/util/async_util.h"
#include "arrow/util/checked_cast.h"
Expand Down Expand Up @@ -102,6 +104,13 @@ struct SourceNode : ExecNode, public TracedNode {
batch_size = morsel_length;
}
ExecBatch batch = morsel.Slice(offset, batch_size);
for (auto& value : batch.values) {
if (value.is_array()) {
ARROW_ASSIGN_OR_RAISE(
value, util::EnsureAlignment(value.make_array(), ipc::kArrowAlignment,
default_memory_pool()));
}
}
if (has_ordering) {
batch.index = batch_index;
}
Expand Down
161 changes: 100 additions & 61 deletions cpp/src/arrow/engine/substrait/expression_internal.cc
Original file line number Diff line number Diff line change
Expand Up @@ -138,6 +138,105 @@ std::string EnumToString(int value, const google::protobuf::EnumDescriptor* desc
return value_desc->name();
}

Result<compute::Expression> FromProto(const substrait::Expression::ReferenceSegment* ref,
const ExtensionSet& ext_set,
const ConversionOptions& conversion_options,
std::optional<compute::Expression> in_expr) {
auto in_ref = ref;
auto& current = in_expr;
while (ref != nullptr) {
switch (ref->reference_type_case()) {
case substrait::Expression::ReferenceSegment::kStructField: {
auto index = ref->struct_field().field();
if (!current) {
// Root StructField (column selection)
current = compute::field_ref(FieldRef(index));
} else if (auto current_ref = current->field_ref()) {
// Nested StructFields on the root (selection of struct-typed column
// combined with selecting struct fields)
current = compute::field_ref(FieldRef(*current_ref, index));
} else if (current->call() && current->call()->function_name == "struct_field") {
// Nested StructFields on top of an arbitrary expression
auto* field_options =
checked_cast<compute::StructFieldOptions*>(current->call()->options.get());
field_options->field_ref = FieldRef(std::move(field_options->field_ref), index);
} else {
// First StructField on top of an arbitrary expression
current = compute::call("struct_field", {std::move(*current)},
arrow::compute::StructFieldOptions({index}));
}

// Segment handled, continue with child segment (if any)
if (ref->struct_field().has_child()) {
ref = &ref->struct_field().child();
} else {
ref = nullptr;
}
break;
}
case substrait::Expression::ReferenceSegment::kListElement: {
if (!current) {
// Root ListField (illegal)
return Status::Invalid(
"substrait::ListElement cannot take a Relation as an argument");
}

// ListField on top of an arbitrary expression
current = compute::call(
"list_element",
{std::move(*current), compute::literal(ref->list_element().offset())});

// Segment handled, continue with child segment (if any)
if (ref->list_element().has_child()) {
ref = &ref->list_element().child();
} else {
ref = nullptr;
}
break;
}
default:
// Unimplemented construct, break current of loop
current.reset();
ref = nullptr;
}
}
if (current) {
return *std::move(current);
}

return Status::NotImplemented(
"conversion to arrow::compute::Expression from Substrait reference segment: ",
in_ref ? in_ref->DebugString() : "null");
}

Result<compute::Expression> FromProto(const substrait::Expression::FieldReference* fref,
const ExtensionSet& ext_set,
const ConversionOptions& conversion_options,
std::optional<compute::Expression> in_expr) {
if (fref->reference_type_case() !=
substrait::Expression::FieldReference::kDirectReference ||
fref->root_type_case() != substrait::Expression::FieldReference::kRootReference) {
return Status::NotImplemented("substrait::FieldReference not direct root reference");
}
auto& dref = fref->direct_reference();
return FromProto(&dref, ext_set, conversion_options, std::move(in_expr));
}

Result<FieldRef> DirectReferenceFromProto(
const substrait::Expression::FieldReference* fref, const ExtensionSet& ext_set,
const ConversionOptions& conversion_options) {
ARROW_ASSIGN_OR_RAISE(compute::Expression expr,
FromProto(fref, ext_set, conversion_options, {}));
const FieldRef* field_ref = expr.field_ref();
if (field_ref) {
return *field_ref;
} else {
return Status::Invalid(
"A direct reference was expected but a more complex expression was given "
"instead");
}
}

Result<SubstraitCall> FromProto(const substrait::AggregateFunction& func, bool is_hash,
const ExtensionSet& ext_set,
const ConversionOptions& conversion_options) {
Expand Down Expand Up @@ -193,67 +292,7 @@ Result<compute::Expression> FromProto(const substrait::Expression& expr,
}

const auto* ref = &expr.selection().direct_reference();
while (ref != nullptr) {
switch (ref->reference_type_case()) {
case substrait::Expression::ReferenceSegment::kStructField: {
auto index = ref->struct_field().field();
if (!out) {
// Root StructField (column selection)
out = compute::field_ref(FieldRef(index));
} else if (auto out_ref = out->field_ref()) {
// Nested StructFields on the root (selection of struct-typed column
// combined with selecting struct fields)
out = compute::field_ref(FieldRef(*out_ref, index));
} else if (out->call() && out->call()->function_name == "struct_field") {
// Nested StructFields on top of an arbitrary expression
auto* field_options =
checked_cast<compute::StructFieldOptions*>(out->call()->options.get());
field_options->field_ref =
FieldRef(std::move(field_options->field_ref), index);
} else {
// First StructField on top of an arbitrary expression
out = compute::call("struct_field", {std::move(*out)},
arrow::compute::StructFieldOptions({index}));
}

// Segment handled, continue with child segment (if any)
if (ref->struct_field().has_child()) {
ref = &ref->struct_field().child();
} else {
ref = nullptr;
}
break;
}
case substrait::Expression::ReferenceSegment::kListElement: {
if (!out) {
// Root ListField (illegal)
return Status::Invalid(
"substrait::ListElement cannot take a Relation as an argument");
}

// ListField on top of an arbitrary expression
out = compute::call(
"list_element",
{std::move(*out), compute::literal(ref->list_element().offset())});

// Segment handled, continue with child segment (if any)
if (ref->list_element().has_child()) {
ref = &ref->list_element().child();
} else {
ref = nullptr;
}
break;
}
default:
// Unimplemented construct, break out of loop
out.reset();
ref = nullptr;
}
}
if (out) {
return *std::move(out);
}
break;
return FromProto(ref, ext_set, conversion_options, std::move(out));
}

case substrait::Expression::kIfThen: {
Expand Down
10 changes: 10 additions & 0 deletions cpp/src/arrow/engine/substrait/expression_internal.h
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
#pragma once

#include <memory>
#include <optional>

#include "arrow/compute/type_fwd.h"
#include "arrow/datum.h"
Expand All @@ -34,6 +35,15 @@ namespace engine {

class SubstraitCall;

ARROW_ENGINE_EXPORT
Result<FieldRef> DirectReferenceFromProto(const substrait::Expression::FieldReference*,
const ExtensionSet&, const ConversionOptions&);

ARROW_ENGINE_EXPORT
Result<compute::Expression> FromProto(const substrait::Expression::FieldReference*,
const ExtensionSet&, const ConversionOptions&,
std::optional<compute::Expression>);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What is the expression here?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's the "current expression". This method is currently called int he middle of a deserializing an expression tree. So, for example:

flowchart TD
    A[Call] -->|args| FieldRef
    A -->|args| C
    C[Call*] -->|args| D
    C -->|Call| E
    D[Literal]
    E[FieldRef*]
Loading

So, when de-referencing FieldRef this will be Call and when dereferencing FieldRef* this will be Call*.

However, can we remove this prototype from the header file and put it in an anonymous namespace? I think it should be an internal method and not exposed.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

DirectReferenceFromProto is used by MakeSegmentedAggregateRel in options.cc, so I think it has to be in a header. I wouldn't say it's a public API, because it's in expression_internal.h.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks Weston for the explanation.

+1 on not exposing this in the header file. I think if we move the MakeSegmentedAggregateRel to expression_internal.cc then we don't need to expose this via the header file. I have a comment about moving these "make extension rel" methods out of options.cc anyway.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think expression_internal.cc, which deals with Substrait expression parsing, is the right place for any of the Make*Rel methods, which deal with Substrait extension handling.

IMO, refactoring of options.cc could be reasonable but is outside the scope of this PR.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

However, can we remove this prototype from the header file and put it in an anonymous namespace? I think it should be an internal method and not exposed.

@westonpace Would sth like this look better to you? e958b6d

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@westonpace Would sth like this look better to you? e958b6d

I'm neutral on the change and don't want to spend too much time bike shedding here. If you're comfortable let's move forward.

DirectReferenceFromProto is used by MakeSegmentedAggregateRel in options.cc, so I think it has to be in a header.

Yes, I am fine with DirectReferenceFromProto. I was merely asking about the FromProto variant that takes a reference segment.


ARROW_ENGINE_EXPORT
Result<compute::Expression> FromProto(const substrait::Expression&, const ExtensionSet&,
const ConversionOptions&);
Expand Down
73 changes: 73 additions & 0 deletions cpp/src/arrow/engine/substrait/options.cc
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,11 @@ class DefaultExtensionProvider : public BaseExtensionProvider {
rel.UnpackTo(&named_tap_rel);
return MakeNamedTapRel(conv_opts, inputs, named_tap_rel, ext_set);
}
if (rel.Is<substrait_ext::SegmentedAggregateRel>()) {
substrait_ext::SegmentedAggregateRel seg_agg_rel;
rel.UnpackTo(&seg_agg_rel);
return MakeSegmentedAggregateRel(conv_opts, inputs, seg_agg_rel, ext_set);
}
return Status::NotImplemented("Unrecognized extension in Susbstrait plan: ",
rel.DebugString());
}
Expand Down Expand Up @@ -165,6 +170,74 @@ class DefaultExtensionProvider : public BaseExtensionProvider {
named_tap_rel.name(), renamed_schema));
return RelationInfo{{std::move(decl), std::move(renamed_schema)}, std::nullopt};
}

Result<RelationInfo> MakeSegmentedAggregateRel(
const ConversionOptions& conv_opts, const std::vector<DeclarationInfo>& inputs,
const substrait_ext::SegmentedAggregateRel& seg_agg_rel,
const ExtensionSet& ext_set) {
if (inputs.size() != 1) {
return Status::Invalid(
"substrait_ext::SegmentedAggregateRel requires a single input but got: ",
inputs.size());
}
if (seg_agg_rel.segment_keys_size() == 0) {
return Status::Invalid(
"substrait_ext::SegmentedAggregateRel requires at least one segment key");
}

auto input_schema = inputs[0].output_schema;

// store key fields to be used when output schema is created
std::vector<int> key_field_ids;
std::vector<FieldRef> keys;
for (auto& ref : seg_agg_rel.grouping_keys()) {
ARROW_ASSIGN_OR_RAISE(auto field_ref,
DirectReferenceFromProto(&ref, ext_set, conv_opts));
ARROW_ASSIGN_OR_RAISE(auto match, field_ref.FindOne(*input_schema));
key_field_ids.emplace_back(std::move(match[0]));
keys.emplace_back(std::move(field_ref));
}

// store segment key fields to be used when output schema is created
std::vector<int> segment_key_field_ids;
std::vector<FieldRef> segment_keys;
for (auto& ref : seg_agg_rel.segment_keys()) {
ARROW_ASSIGN_OR_RAISE(auto field_ref,
DirectReferenceFromProto(&ref, ext_set, conv_opts));
ARROW_ASSIGN_OR_RAISE(auto match, field_ref.FindOne(*input_schema));
segment_key_field_ids.emplace_back(std::move(match[0]));
segment_keys.emplace_back(std::move(field_ref));
}

std::vector<compute::Aggregate> aggregates;
aggregates.reserve(seg_agg_rel.measures_size());
std::vector<std::vector<int>> agg_src_fieldsets;
agg_src_fieldsets.reserve(seg_agg_rel.measures_size());
for (auto agg_measure : seg_agg_rel.measures()) {
ARROW_ASSIGN_OR_RAISE(
auto parsed_measure,
internal::ParseAggregateMeasure(agg_measure, ext_set, conv_opts,
/*is_hash=*/!keys.empty(), input_schema));
aggregates.push_back(std::move(parsed_measure.aggregate));
agg_src_fieldsets.push_back(std::move(parsed_measure.fieldset));
}

ARROW_ASSIGN_OR_RAISE(auto decl_info,
internal::MakeAggregateDeclaration(
std::move(inputs[0].declaration), std::move(input_schema),
seg_agg_rel.measures_size(), std::move(aggregates),
std::move(agg_src_fieldsets), std::move(keys),
std::move(key_field_ids), std::move(segment_keys),
std::move(segment_key_field_ids), ext_set, conv_opts));

const auto& output_schema = decl_info.output_schema;
size_t out_size = output_schema->num_fields();
std::vector<int> field_output_indices(out_size);
for (int i = 0; i < static_cast<int>(out_size); i++) {
field_output_indices[i] = i;
}
return RelationInfo{decl_info, std::move(field_output_indices)};
}
};

namespace {
Expand Down
Loading