-
Notifications
You must be signed in to change notification settings - Fork 4k
ARROW-15584: [C++] Add support for Substrait's RelCommon::Emit #13914
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
6f4a3f6
9367049
a83deec
e55b2c8
7fc83cb
43fea24
136edf5
a2c08a2
9578a91
fb77dc1
ea2a05c
8da8b54
1f4da76
d99ddf4
81ad00b
bba665f
7eb4623
54b18df
5051070
5bb1051
19e49ed
2416e95
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change | ||
|---|---|---|---|---|
|
|
@@ -42,11 +42,45 @@ using internal::make_unique; | |||
| namespace engine { | ||||
|
|
||||
| template <typename RelMessage> | ||||
| Status CheckRelCommon(const RelMessage& rel) { | ||||
| Result<std::vector<compute::Expression>> GetEmitInfo( | ||||
| const RelMessage& rel, const std::shared_ptr<Schema>& schema) { | ||||
| const auto& emit = rel.common().emit(); | ||||
| int emit_size = emit.output_mapping_size(); | ||||
| std::vector<compute::Expression> proj_field_refs(emit_size); | ||||
| for (int i = 0; i < emit_size; i++) { | ||||
| int32_t map_id = emit.output_mapping(i); | ||||
| proj_field_refs[i] = compute::field_ref(FieldRef(map_id)); | ||||
| } | ||||
| return std::move(proj_field_refs); | ||||
| } | ||||
|
|
||||
| template <typename RelMessage> | ||||
| Result<DeclarationInfo> ProcessEmit(const RelMessage& rel, | ||||
| const DeclarationInfo& no_emit_declr, | ||||
| const std::shared_ptr<Schema>& schema) { | ||||
| if (rel.has_common()) { | ||||
| if (rel.common().has_emit()) { | ||||
| return Status::NotImplemented("substrait::RelCommon::Emit"); | ||||
| switch (rel.common().emit_kind_case()) { | ||||
| case substrait::RelCommon::EmitKindCase::kDirect: | ||||
| return no_emit_declr; | ||||
| case substrait::RelCommon::EmitKindCase::kEmit: { | ||||
| ARROW_ASSIGN_OR_RAISE(auto emit_expressions, GetEmitInfo(rel, schema)); | ||||
| return DeclarationInfo{ | ||||
| compute::Declaration::Sequence( | ||||
| {no_emit_declr.declaration, | ||||
| {"project", compute::ProjectNodeOptions{std::move(emit_expressions)}}}), | ||||
| std::move(schema)}; | ||||
| } | ||||
| default: | ||||
| return Status::Invalid("Invalid emit case"); | ||||
| } | ||||
| } else { | ||||
| return no_emit_declr; | ||||
| } | ||||
| } | ||||
|
|
||||
| template <typename RelMessage> | ||||
| Status CheckRelCommon(const RelMessage& rel) { | ||||
| if (rel.has_common()) { | ||||
| if (rel.common().has_hint()) { | ||||
| return Status::NotImplemented("substrait::RelCommon::Hint"); | ||||
| } | ||||
|
|
@@ -75,7 +109,6 @@ Result<DeclarationInfo> FromProto(const substrait::Rel& rel, const ExtensionSet& | |||
|
|
||||
| ARROW_ASSIGN_OR_RAISE(auto base_schema, | ||||
| FromProto(read.base_schema(), ext_set, conversion_options)); | ||||
| auto num_columns = static_cast<int>(base_schema->fields().size()); | ||||
|
|
||||
| auto scan_options = std::make_shared<dataset::ScanOptions>(); | ||||
| scan_options->use_threads = true; | ||||
|
|
@@ -104,7 +137,9 @@ Result<DeclarationInfo> FromProto(const substrait::Rel& rel, const ExtensionSet& | |||
| named_table.names().end()); | ||||
| ARROW_ASSIGN_OR_RAISE(compute::Declaration source_decl, | ||||
| named_table_provider(table_names)); | ||||
| return DeclarationInfo{std::move(source_decl), num_columns}; | ||||
| return ProcessEmit(std::move(read), | ||||
| DeclarationInfo{std::move(source_decl), base_schema}, | ||||
| std::move(base_schema)); | ||||
| } | ||||
|
|
||||
| if (!read.has_local_files()) { | ||||
|
|
@@ -216,12 +251,14 @@ Result<DeclarationInfo> FromProto(const substrait::Rel& rel, const ExtensionSet& | |||
| std::move(filesystem), std::move(files), | ||||
| std::move(format), {})); | ||||
|
|
||||
| ARROW_ASSIGN_OR_RAISE(auto ds, ds_factory->Finish(std::move(base_schema))); | ||||
| ARROW_ASSIGN_OR_RAISE(auto ds, ds_factory->Finish(base_schema)); | ||||
|
|
||||
| DeclarationInfo scan_declaration = { | ||||
| compute::Declaration{"scan", dataset::ScanNodeOptions{ds, scan_options}}, | ||||
| base_schema}; | ||||
|
|
||||
| return DeclarationInfo{ | ||||
| compute::Declaration{ | ||||
| "scan", dataset::ScanNodeOptions{std::move(ds), std::move(scan_options)}}, | ||||
| num_columns}; | ||||
| return ProcessEmit(std::move(read), std::move(scan_declaration), | ||||
| std::move(base_schema)); | ||||
| } | ||||
|
|
||||
| case substrait::Rel::RelTypeCase::kFilter: { | ||||
|
|
@@ -239,19 +276,20 @@ Result<DeclarationInfo> FromProto(const substrait::Rel& rel, const ExtensionSet& | |||
| } | ||||
| ARROW_ASSIGN_OR_RAISE(auto condition, | ||||
| FromProto(filter.condition(), ext_set, conversion_options)); | ||||
|
|
||||
| return DeclarationInfo{ | ||||
| DeclarationInfo filter_declaration{ | ||||
| compute::Declaration::Sequence({ | ||||
| std::move(input.declaration), | ||||
| {"filter", compute::FilterNodeOptions{std::move(condition)}}, | ||||
| }), | ||||
| input.num_columns}; | ||||
| input.output_schema}; | ||||
|
|
||||
| return ProcessEmit(std::move(filter), std::move(filter_declaration), | ||||
| input.output_schema); | ||||
| } | ||||
|
|
||||
| case substrait::Rel::RelTypeCase::kProject: { | ||||
| const auto& project = rel.project(); | ||||
| RETURN_NOT_OK(CheckRelCommon(project)); | ||||
|
|
||||
| if (!project.has_input()) { | ||||
| return Status::Invalid("substrait::ProjectRel with no input relation"); | ||||
| } | ||||
|
|
@@ -261,23 +299,48 @@ Result<DeclarationInfo> FromProto(const substrait::Rel& rel, const ExtensionSet& | |||
| // NOTE: Substrait ProjectRels *append* columns, while Acero's project node replaces | ||||
| // them. Therefore, we need to prefix all the current columns for compatibility. | ||||
| std::vector<compute::Expression> expressions; | ||||
| expressions.reserve(input.num_columns + project.expressions().size()); | ||||
| for (int i = 0; i < input.num_columns; i++) { | ||||
| int num_columns = input.output_schema->num_fields(); | ||||
| expressions.reserve(num_columns + project.expressions().size()); | ||||
| for (int i = 0; i < num_columns; i++) { | ||||
| expressions.emplace_back(compute::field_ref(FieldRef(i))); | ||||
| } | ||||
|
|
||||
| int i = 0; | ||||
| auto project_schema = input.output_schema; | ||||
| for (const auto& expr : project.expressions()) { | ||||
| expressions.emplace_back(); | ||||
| ARROW_ASSIGN_OR_RAISE(expressions.back(), | ||||
| std::shared_ptr<Field> project_field; | ||||
| ARROW_ASSIGN_OR_RAISE(compute::Expression des_expr, | ||||
| FromProto(expr, ext_set, conversion_options)); | ||||
| auto bound_expr = des_expr.Bind(*input.output_schema); | ||||
| if (auto* expr_call = bound_expr->call()) { | ||||
| project_field = field(expr_call->function_name, | ||||
| expr_call->kernel->signature->out_type().type()); | ||||
| } else if (auto* field_ref = des_expr.field_ref()) { | ||||
| ARROW_ASSIGN_OR_RAISE(FieldPath field_path, | ||||
| field_ref->FindOne(*input.output_schema)); | ||||
| ARROW_ASSIGN_OR_RAISE(project_field, field_path.Get(*input.output_schema)); | ||||
| } else if (auto* literal = des_expr.literal()) { | ||||
| project_field = | ||||
| field("field_" + std::to_string(num_columns + i), literal->type()); | ||||
| } | ||||
| ARROW_ASSIGN_OR_RAISE( | ||||
| project_schema, | ||||
| project_schema->AddField( | ||||
| num_columns + static_cast<int>(project.expressions().size()) - 1, | ||||
| std::move(project_field))); | ||||
| i++; | ||||
| expressions.emplace_back(des_expr); | ||||
| } | ||||
|
|
||||
| auto num_columns = static_cast<int>(expressions.size()); | ||||
| return DeclarationInfo{ | ||||
| DeclarationInfo project_declaration{ | ||||
| compute::Declaration::Sequence({ | ||||
| std::move(input.declaration), | ||||
| {"project", compute::ProjectNodeOptions{std::move(expressions)}}, | ||||
| }), | ||||
| num_columns}; | ||||
| project_schema}; | ||||
|
|
||||
| return ProcessEmit(std::move(project), std::move(project_declaration), | ||||
| std::move(project_schema)); | ||||
| } | ||||
|
|
||||
| case substrait::Rel::RelTypeCase::kJoin: { | ||||
|
|
@@ -355,15 +418,26 @@ Result<DeclarationInfo> FromProto(const substrait::Rel& rel, const ExtensionSet& | |||
| if (!left_keys || !right_keys) { | ||||
| return Status::Invalid("Left keys for join cannot be null"); | ||||
| } | ||||
|
|
||||
| // Create output schema from left, right relations and join keys | ||||
| FieldVector combined_fields = left.output_schema->fields(); | ||||
| const FieldVector& right_fields = right.output_schema->fields(); | ||||
| combined_fields.insert(combined_fields.end(), right_fields.begin(), | ||||
| right_fields.end()); | ||||
| std::shared_ptr<Schema> join_schema = schema(std::move(combined_fields)); | ||||
|
|
||||
| compute::HashJoinNodeOptions join_options{{std::move(*left_keys)}, | ||||
| {std::move(*right_keys)}}; | ||||
| join_options.join_type = join_type; | ||||
| join_options.key_cmp = {join_key_cmp}; | ||||
| compute::Declaration join_dec{"hashjoin", std::move(join_options)}; | ||||
| auto num_columns = left.num_columns + right.num_columns; | ||||
| join_dec.inputs.emplace_back(std::move(left.declaration)); | ||||
| join_dec.inputs.emplace_back(std::move(right.declaration)); | ||||
| return DeclarationInfo{std::move(join_dec), num_columns}; | ||||
|
|
||||
| DeclarationInfo join_declaration{std::move(join_dec), join_schema}; | ||||
|
|
||||
| return ProcessEmit(std::move(join), std::move(join_declaration), | ||||
| std::move(join_schema)); | ||||
| } | ||||
| case substrait::Rel::RelTypeCase::kAggregate: { | ||||
| const auto& aggregate = rel.aggregate(); | ||||
|
|
@@ -381,16 +455,25 @@ Result<DeclarationInfo> FromProto(const substrait::Rel& rel, const ExtensionSet& | |||
| "Grouping sets not supported. AggregateRel::groupings may not have more " | ||||
| "than one item"); | ||||
| } | ||||
|
|
||||
| // prepare output schema from aggregates | ||||
| auto input_schema = input.output_schema; | ||||
| // store key fields to be used when output schema is created | ||||
| std::vector<int> key_field_ids; | ||||
| std::vector<FieldRef> keys; | ||||
| if (aggregate.groupings_size() > 0) { | ||||
| const substrait::AggregateRel::Grouping& group = aggregate.groupings(0); | ||||
| keys.reserve(group.grouping_expressions_size()); | ||||
| for (int exp_id = 0; exp_id < group.grouping_expressions_size(); exp_id++) { | ||||
| int grouping_expr_size = group.grouping_expressions_size(); | ||||
| keys.reserve(grouping_expr_size); | ||||
| key_field_ids.reserve(grouping_expr_size); | ||||
| for (int exp_id = 0; exp_id < grouping_expr_size; exp_id++) { | ||||
| ARROW_ASSIGN_OR_RAISE( | ||||
| compute::Expression expr, | ||||
| FromProto(group.grouping_expressions(exp_id), ext_set, conversion_options)); | ||||
| const FieldRef* field_ref = expr.field_ref(); | ||||
| if (field_ref) { | ||||
| ARROW_ASSIGN_OR_RAISE(auto match, field_ref->FindOne(*input_schema)); | ||||
| key_field_ids.emplace_back(std::move(match[0])); | ||||
| keys.emplace_back(std::move(*field_ref)); | ||||
| } else { | ||||
| return Status::Invalid( | ||||
|
|
@@ -402,6 +485,8 @@ Result<DeclarationInfo> FromProto(const substrait::Rel& rel, const ExtensionSet& | |||
| int measure_size = aggregate.measures_size(); | ||||
| std::vector<compute::Aggregate> aggregates; | ||||
| aggregates.reserve(measure_size); | ||||
| // store aggregate fields to be used when output schema is created | ||||
| std::vector<int> agg_src_field_ids(measure_size); | ||||
| for (int measure_id = 0; measure_id < measure_size; measure_id++) { | ||||
| const auto& agg_measure = aggregate.measures(measure_id); | ||||
| if (agg_measure.has_measure()) { | ||||
|
|
@@ -416,17 +501,38 @@ Result<DeclarationInfo> FromProto(const substrait::Rel& rel, const ExtensionSet& | |||
| ExtensionIdRegistry::SubstraitAggregateToArrow converter, | ||||
| ext_set.registry()->GetSubstraitAggregateToArrow(aggregate_call.id())); | ||||
| ARROW_ASSIGN_OR_RAISE(compute::Aggregate arrow_agg, converter(aggregate_call)); | ||||
|
|
||||
| // find aggregate field ids from schema | ||||
| const auto field_ref = arrow_agg.target; | ||||
| ARROW_ASSIGN_OR_RAISE(auto match, field_ref.FindOne(*input_schema)); | ||||
| agg_src_field_ids[measure_id] = match[0]; | ||||
|
|
||||
| aggregates.push_back(std::move(arrow_agg)); | ||||
| } else { | ||||
| return Status::Invalid("substrait::AggregateFunction not provided"); | ||||
| } | ||||
| } | ||||
| FieldVector output_fields; | ||||
| output_fields.reserve(key_field_ids.size() + agg_src_field_ids.size()); | ||||
| // extract aggregate fields to output schema | ||||
| for (int id = 0; id < static_cast<int>(agg_src_field_ids.size()); id++) { | ||||
| output_fields.emplace_back(input_schema->field(agg_src_field_ids[id])); | ||||
| } | ||||
| // extract key fields to output schema | ||||
| for (int id = 0; id < static_cast<int>(key_field_ids.size()); id++) { | ||||
| output_fields.emplace_back(input_schema->field(key_field_ids[id])); | ||||
| } | ||||
|
||||
| // Aggregate fields come before key fields to match the behavior of GroupBy function |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Based on the comment this looks like intentional behavior of Arrow, so I don't think aggregate node is going to be adjusted to match Substrait. So that just means there should be a project node inserted behind the aggregate node that moves the columns around accordingly, right? I guess you could fix that in a separate JIRA/PR though. Maybe add a FIXME comment in that case?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, the requirement in Acero may be static here.
We can use the project to swap things around and document it properly. Probably we can do it in this PR as well.
cc @westonpace
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
On second thoughts, it would be better to solve this one in another PR. Because I am not quite sure if this would break R test cases.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Another PR is fine but I wouldn't consider the output order from Acero to be too static. Fixing it up to output things in the order Substrait expects would be nice so we can at least avoid the project node in some cases (when a direct emit). It'll be a breaking change and probably cause some slight heartburn to our existing tests but we should probably fix it while we still have the opportunity.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Jira created here: https://issues.apache.org/jira/browse/ARROW-17656
Uh oh!
There was an error while loading. Please reload this page.