Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
22 commits
Select commit Hold shift + click to select a range
6f4a3f6
feat(project): adding project test case for substrait with minor chan…
vibhatha Aug 16, 2022
9367049
feat(emit): initial version of emit with project added
vibhatha Aug 16, 2022
a83deec
fix(test): fixing the test feature
vibhatha Aug 17, 2022
e55b2c8
feat(data-gen): adding data generator script wip
vibhatha Aug 17, 2022
7fc83cb
fix(format): refactor to simplify tests
vibhatha Aug 17, 2022
43fea24
feat(filter): adding filter emit
vibhatha Aug 18, 2022
136edf5
feat(join): adding join example
vibhatha Aug 18, 2022
a2c08a2
fix(rebase): merge with substrait changes
vibhatha Aug 19, 2022
9578a91
fix(project): replaced the add op with equal for test case
vibhatha Aug 19, 2022
fb77dc1
feat(aggreagte): basic end-to-end test added
vibhatha Aug 19, 2022
ea2a05c
feat(agg): adding aggregate feature for emits
vibhatha Aug 21, 2022
8da8b54
fix(num_columns): fix the number of columns for emit feature
vibhatha Aug 21, 2022
1f4da76
fix(cleanup): cleaning up code
vibhatha Aug 21, 2022
d99ddf4
fix(reviews): remove column count from DeclarationInfo
vibhatha Aug 31, 2022
81ad00b
fix(reviews): removed a redundant loop
vibhatha Aug 31, 2022
bba665f
fix(reviews): updated the emit processing logic and added switch cases
vibhatha Aug 31, 2022
7eb4623
fix(path_issue): added a check for replacing clause
vibhatha Sep 1, 2022
54b18df
fix(path): remove temp path fix
vibhatha Sep 1, 2022
5051070
fix(reviews): imd commit
vibhatha Sep 7, 2022
5bb1051
fix(read): namedTable emit config added
vibhatha Sep 7, 2022
19e49ed
fix(rebase)
vibhatha Sep 9, 2022
2416e95
fix(reviews): address reviews
vibhatha Sep 9, 2022
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
160 changes: 133 additions & 27 deletions cpp/src/arrow/engine/substrait/relation_internal.cc
Original file line number Diff line number Diff line change
Expand Up @@ -42,11 +42,45 @@ using internal::make_unique;
namespace engine {

template <typename RelMessage>
Status CheckRelCommon(const RelMessage& rel) {
Result<std::vector<compute::Expression>> GetEmitInfo(
const RelMessage& rel, const std::shared_ptr<Schema>& schema) {
const auto& emit = rel.common().emit();
int emit_size = emit.output_mapping_size();
std::vector<compute::Expression> proj_field_refs(emit_size);
for (int i = 0; i < emit_size; i++) {
int32_t map_id = emit.output_mapping(i);
proj_field_refs[i] = compute::field_ref(FieldRef(map_id));
}
return std::move(proj_field_refs);
}

template <typename RelMessage>
Result<DeclarationInfo> ProcessEmit(const RelMessage& rel,
const DeclarationInfo& no_emit_declr,
const std::shared_ptr<Schema>& schema) {
if (rel.has_common()) {
if (rel.common().has_emit()) {
return Status::NotImplemented("substrait::RelCommon::Emit");
switch (rel.common().emit_kind_case()) {
case substrait::RelCommon::EmitKindCase::kDirect:
return no_emit_declr;
case substrait::RelCommon::EmitKindCase::kEmit: {
ARROW_ASSIGN_OR_RAISE(auto emit_expressions, GetEmitInfo(rel, schema));
return DeclarationInfo{
compute::Declaration::Sequence(
{no_emit_declr.declaration,
{"project", compute::ProjectNodeOptions{std::move(emit_expressions)}}}),
std::move(schema)};
}
default:
return Status::Invalid("Invalid emit case");
}
} else {
return no_emit_declr;
}
}

template <typename RelMessage>
Status CheckRelCommon(const RelMessage& rel) {
if (rel.has_common()) {
if (rel.common().has_hint()) {
return Status::NotImplemented("substrait::RelCommon::Hint");
}
Expand Down Expand Up @@ -75,7 +109,6 @@ Result<DeclarationInfo> FromProto(const substrait::Rel& rel, const ExtensionSet&

ARROW_ASSIGN_OR_RAISE(auto base_schema,
FromProto(read.base_schema(), ext_set, conversion_options));
auto num_columns = static_cast<int>(base_schema->fields().size());

auto scan_options = std::make_shared<dataset::ScanOptions>();
scan_options->use_threads = true;
Expand Down Expand Up @@ -104,7 +137,9 @@ Result<DeclarationInfo> FromProto(const substrait::Rel& rel, const ExtensionSet&
named_table.names().end());
ARROW_ASSIGN_OR_RAISE(compute::Declaration source_decl,
named_table_provider(table_names));
return DeclarationInfo{std::move(source_decl), num_columns};
return ProcessEmit(std::move(read),
DeclarationInfo{std::move(source_decl), base_schema},
std::move(base_schema));
}

if (!read.has_local_files()) {
Expand Down Expand Up @@ -216,12 +251,14 @@ Result<DeclarationInfo> FromProto(const substrait::Rel& rel, const ExtensionSet&
std::move(filesystem), std::move(files),
std::move(format), {}));

ARROW_ASSIGN_OR_RAISE(auto ds, ds_factory->Finish(std::move(base_schema)));
ARROW_ASSIGN_OR_RAISE(auto ds, ds_factory->Finish(base_schema));

DeclarationInfo scan_declaration = {
compute::Declaration{"scan", dataset::ScanNodeOptions{ds, scan_options}},
base_schema};

return DeclarationInfo{
compute::Declaration{
"scan", dataset::ScanNodeOptions{std::move(ds), std::move(scan_options)}},
num_columns};
return ProcessEmit(std::move(read), std::move(scan_declaration),
std::move(base_schema));
}

case substrait::Rel::RelTypeCase::kFilter: {
Expand All @@ -239,19 +276,20 @@ Result<DeclarationInfo> FromProto(const substrait::Rel& rel, const ExtensionSet&
}
ARROW_ASSIGN_OR_RAISE(auto condition,
FromProto(filter.condition(), ext_set, conversion_options));

return DeclarationInfo{
DeclarationInfo filter_declaration{
compute::Declaration::Sequence({
std::move(input.declaration),
{"filter", compute::FilterNodeOptions{std::move(condition)}},
}),
input.num_columns};
input.output_schema};

return ProcessEmit(std::move(filter), std::move(filter_declaration),
input.output_schema);
}

case substrait::Rel::RelTypeCase::kProject: {
const auto& project = rel.project();
RETURN_NOT_OK(CheckRelCommon(project));

if (!project.has_input()) {
return Status::Invalid("substrait::ProjectRel with no input relation");
}
Expand All @@ -261,23 +299,48 @@ Result<DeclarationInfo> FromProto(const substrait::Rel& rel, const ExtensionSet&
// NOTE: Substrait ProjectRels *append* columns, while Acero's project node replaces
// them. Therefore, we need to prefix all the current columns for compatibility.
std::vector<compute::Expression> expressions;
expressions.reserve(input.num_columns + project.expressions().size());
for (int i = 0; i < input.num_columns; i++) {
int num_columns = input.output_schema->num_fields();
expressions.reserve(num_columns + project.expressions().size());
for (int i = 0; i < num_columns; i++) {
expressions.emplace_back(compute::field_ref(FieldRef(i)));
}

int i = 0;
auto project_schema = input.output_schema;
for (const auto& expr : project.expressions()) {
expressions.emplace_back();
ARROW_ASSIGN_OR_RAISE(expressions.back(),
std::shared_ptr<Field> project_field;
ARROW_ASSIGN_OR_RAISE(compute::Expression des_expr,
FromProto(expr, ext_set, conversion_options));
auto bound_expr = des_expr.Bind(*input.output_schema);
if (auto* expr_call = bound_expr->call()) {
project_field = field(expr_call->function_name,
expr_call->kernel->signature->out_type().type());
} else if (auto* field_ref = des_expr.field_ref()) {
ARROW_ASSIGN_OR_RAISE(FieldPath field_path,
field_ref->FindOne(*input.output_schema));
ARROW_ASSIGN_OR_RAISE(project_field, field_path.Get(*input.output_schema));
} else if (auto* literal = des_expr.literal()) {
project_field =
field("field_" + std::to_string(num_columns + i), literal->type());
}
ARROW_ASSIGN_OR_RAISE(
project_schema,
project_schema->AddField(
num_columns + static_cast<int>(project.expressions().size()) - 1,
std::move(project_field)));
i++;
expressions.emplace_back(des_expr);
}

auto num_columns = static_cast<int>(expressions.size());
return DeclarationInfo{
DeclarationInfo project_declaration{
compute::Declaration::Sequence({
std::move(input.declaration),
{"project", compute::ProjectNodeOptions{std::move(expressions)}},
}),
num_columns};
project_schema};

return ProcessEmit(std::move(project), std::move(project_declaration),
std::move(project_schema));
}

case substrait::Rel::RelTypeCase::kJoin: {
Expand Down Expand Up @@ -355,15 +418,26 @@ Result<DeclarationInfo> FromProto(const substrait::Rel& rel, const ExtensionSet&
if (!left_keys || !right_keys) {
return Status::Invalid("Left keys for join cannot be null");
}

// Create output schema from left, right relations and join keys
FieldVector combined_fields = left.output_schema->fields();
const FieldVector& right_fields = right.output_schema->fields();
combined_fields.insert(combined_fields.end(), right_fields.begin(),
right_fields.end());
std::shared_ptr<Schema> join_schema = schema(std::move(combined_fields));

compute::HashJoinNodeOptions join_options{{std::move(*left_keys)},
{std::move(*right_keys)}};
join_options.join_type = join_type;
join_options.key_cmp = {join_key_cmp};
compute::Declaration join_dec{"hashjoin", std::move(join_options)};
auto num_columns = left.num_columns + right.num_columns;
join_dec.inputs.emplace_back(std::move(left.declaration));
join_dec.inputs.emplace_back(std::move(right.declaration));
return DeclarationInfo{std::move(join_dec), num_columns};

DeclarationInfo join_declaration{std::move(join_dec), join_schema};

return ProcessEmit(std::move(join), std::move(join_declaration),
std::move(join_schema));
}
case substrait::Rel::RelTypeCase::kAggregate: {
const auto& aggregate = rel.aggregate();
Expand All @@ -381,16 +455,25 @@ Result<DeclarationInfo> FromProto(const substrait::Rel& rel, const ExtensionSet&
"Grouping sets not supported. AggregateRel::groupings may not have more "
"than one item");
}

// prepare output schema from aggregates
auto input_schema = input.output_schema;
// store key fields to be used when output schema is created
std::vector<int> key_field_ids;
std::vector<FieldRef> keys;
if (aggregate.groupings_size() > 0) {
const substrait::AggregateRel::Grouping& group = aggregate.groupings(0);
keys.reserve(group.grouping_expressions_size());
for (int exp_id = 0; exp_id < group.grouping_expressions_size(); exp_id++) {
int grouping_expr_size = group.grouping_expressions_size();
keys.reserve(grouping_expr_size);
key_field_ids.reserve(grouping_expr_size);
for (int exp_id = 0; exp_id < grouping_expr_size; exp_id++) {
ARROW_ASSIGN_OR_RAISE(
compute::Expression expr,
FromProto(group.grouping_expressions(exp_id), ext_set, conversion_options));
const FieldRef* field_ref = expr.field_ref();
if (field_ref) {
ARROW_ASSIGN_OR_RAISE(auto match, field_ref->FindOne(*input_schema));
key_field_ids.emplace_back(std::move(match[0]));
keys.emplace_back(std::move(*field_ref));
} else {
return Status::Invalid(
Expand All @@ -402,6 +485,8 @@ Result<DeclarationInfo> FromProto(const substrait::Rel& rel, const ExtensionSet&
int measure_size = aggregate.measures_size();
std::vector<compute::Aggregate> aggregates;
aggregates.reserve(measure_size);
// store aggregate fields to be used when output schema is created
std::vector<int> agg_src_field_ids(measure_size);
for (int measure_id = 0; measure_id < measure_size; measure_id++) {
const auto& agg_measure = aggregate.measures(measure_id);
if (agg_measure.has_measure()) {
Expand All @@ -416,17 +501,38 @@ Result<DeclarationInfo> FromProto(const substrait::Rel& rel, const ExtensionSet&
ExtensionIdRegistry::SubstraitAggregateToArrow converter,
ext_set.registry()->GetSubstraitAggregateToArrow(aggregate_call.id()));
ARROW_ASSIGN_OR_RAISE(compute::Aggregate arrow_agg, converter(aggregate_call));

// find aggregate field ids from schema
const auto field_ref = arrow_agg.target;
ARROW_ASSIGN_OR_RAISE(auto match, field_ref.FindOne(*input_schema));
agg_src_field_ids[measure_id] = match[0];

aggregates.push_back(std::move(arrow_agg));
} else {
return Status::Invalid("substrait::AggregateFunction not provided");
}
}
FieldVector output_fields;
output_fields.reserve(key_field_ids.size() + agg_src_field_ids.size());
// extract aggregate fields to output schema
for (int id = 0; id < static_cast<int>(agg_src_field_ids.size()); id++) {
output_fields.emplace_back(input_schema->field(agg_src_field_ids[id]));
}
// extract key fields to output schema
for (int id = 0; id < static_cast<int>(key_field_ids.size()); id++) {
output_fields.emplace_back(input_schema->field(key_field_ids[id]));
}
Comment on lines +518 to +524
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Wrong order; keys come first.

The list of distinct columns from each grouping set (ordered by their first appearance) followed by the list of measures in declaration order, [...]

https://substrait.io/relations/logical_relations/#aggregate-operation

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@jvanstraten I also noticed this, but I forget to leave a comment about it. This is probably a separate JIRA because of the order used in the aggregate_node.cc[1]. Please refer to the comment in this line and the two loops after that. The aggregate fields appened first and then the key fields. One thing we can do is swap the response here.

cc @westonpace

[1].

// Aggregate fields come before key fields to match the behavior of GroupBy function

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Based on the comment this looks like intentional behavior of Arrow, so I don't think aggregate node is going to be adjusted to match Substrait. So that just means there should be a project node inserted behind the aggregate node that moves the columns around accordingly, right? I guess you could fix that in a separate JIRA/PR though. Maybe add a FIXME comment in that case?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, the requirement in Acero may be static here.
We can use the project to swap things around and document it properly. Probably we can do it in this PR as well.

cc @westonpace

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

On second thoughts, it would be better to solve this one in another PR. Because I am not quite sure if this would break R test cases.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Another PR is fine but I wouldn't consider the output order from Acero to be too static. Fixing it up to output things in the order Substrait expects would be nice so we can at least avoid the project node in some cases (when a direct emit). It'll be a breaking change and probably cause some slight heartburn to our existing tests but we should probably fix it while we still have the opportunity.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.


std::shared_ptr<Schema> aggregate_schema = schema(std::move(output_fields));

return DeclarationInfo{
DeclarationInfo aggregate_declaration{
compute::Declaration::Sequence(
{std::move(input.declaration),
{"aggregate", compute::AggregateNodeOptions{aggregates, keys}}}),
static_cast<int>(aggregates.size())};
aggregate_schema};

return ProcessEmit(std::move(aggregate), std::move(aggregate_declaration),
std::move(aggregate_schema));
}

default:
Expand Down
3 changes: 1 addition & 2 deletions cpp/src/arrow/engine/substrait/relation_internal.h
Original file line number Diff line number Diff line change
Expand Up @@ -36,8 +36,7 @@ struct DeclarationInfo {
/// The compute declaration produced thus far.
compute::Declaration declaration;

/// The number of columns returned by the declaration.
int num_columns;
std::shared_ptr<Schema> output_schema;
};

/// \brief Convert a Substrait Rel object to an Acero declaration
Expand Down
Loading