From 52725746470d80633c469ad097e2c933ff6aaf2f Mon Sep 17 00:00:00 2001 From: Yaron Gvili Date: Mon, 24 Oct 2022 04:37:48 -0400 Subject: [PATCH 1/9] ARROW-17980: [C++] As-of-Join Substrait extension --- cpp/cmake_modules/ThirdpartyToolchain.cmake | 34 +++ cpp/proto/substrait/extension_rels.proto | 44 ++++ .../arrow/compute/exec/asof_join_benchmark.cc | 12 +- cpp/src/arrow/compute/exec/asof_join_node.cc | 107 ++++++++-- cpp/src/arrow/compute/exec/asof_join_node.h | 37 ++++ .../arrow/compute/exec/asof_join_node_test.cc | 27 ++- cpp/src/arrow/compute/exec/options.h | 41 ++-- cpp/src/arrow/engine/CMakeLists.txt | 1 + .../engine/substrait/expression_internal.cc | 2 + cpp/src/arrow/engine/substrait/options.cc | 113 ++++++++++ cpp/src/arrow/engine/substrait/options.h | 19 +- .../arrow/engine/substrait/plan_internal.cc | 2 + .../arrow/engine/substrait/plan_internal.h | 2 + .../engine/substrait/relation_internal.cc | 31 +++ .../engine/substrait/relation_internal.h | 8 + cpp/src/arrow/engine/substrait/serde.cc | 2 + cpp/src/arrow/engine/substrait/serde_test.cc | 200 ++++++++++++++++++ .../engine/substrait/test_plan_builder.cc | 2 + cpp/src/arrow/engine/substrait/type_fwd.h | 1 + .../arrow/engine/substrait/type_internal.cc | 83 ++++---- .../arrow/engine/substrait/type_internal.h | 2 + 21 files changed, 680 insertions(+), 90 deletions(-) create mode 100644 cpp/proto/substrait/extension_rels.proto create mode 100644 cpp/src/arrow/compute/exec/asof_join_node.h create mode 100644 cpp/src/arrow/engine/substrait/options.cc diff --git a/cpp/cmake_modules/ThirdpartyToolchain.cmake b/cpp/cmake_modules/ThirdpartyToolchain.cmake index d0c8c600d55..00e8a72166e 100644 --- a/cpp/cmake_modules/ThirdpartyToolchain.cmake +++ b/cpp/cmake_modules/ThirdpartyToolchain.cmake @@ -657,6 +657,13 @@ else() "${THIRDPARTY_MIRROR_URL}/snappy-${ARROW_SNAPPY_BUILD_VERSION}.tar.gz") endif() +# Remove these two lines once https://github.com/substrait-io/substrait/pull/342 merges +set(ENV{ARROW_SUBSTRAIT_URL} + "https://github.com/substrait-io/substrait/archive/e59008b6b202f8af06c2266991161b1e45cb056a.tar.gz" +) +set(ARROW_SUBSTRAIT_BUILD_SHA256_CHECKSUM + "f64629cb377fcc62c9d3e8fe69fa6a4cf326f34d756e03db84843c5cce8d04cd") + if(DEFINED ENV{ARROW_SUBSTRAIT_URL}) set(SUBSTRAIT_SOURCE_URL "$ENV{ARROW_SUBSTRAIT_URL}") else() @@ -1738,6 +1745,10 @@ macro(build_substrait) # Note: not all protos in Substrait actually matter to plan # consumption. No need to build the ones we don't need. set(SUBSTRAIT_PROTOS algebra extensions/extensions plan type) + set(ARROW_SUBSTRAIT_PROTOS extension_rels) + set(ARROW_SUBSTRAIT_PROTOS_DIR "${CMAKE_SOURCE_DIR}/proto") + message("SOURCE DIR IS ${SOURCE_DIR} AND ${CMAKE_SOURCE_DIR} AND ${ARROW_SUBSTRAIT_PROTOS_DIR}" + ) externalproject_add(substrait_ep ${EP_COMMON_OPTIONS} @@ -1789,6 +1800,29 @@ macro(build_substrait) list(APPEND SUBSTRAIT_SOURCES "${SUBSTRAIT_PROTO_GEN}.cc") endforeach() + message("SOURCE DIR2 IS ${SOURCE_DIR} AND ${CMAKE_SOURCE_DIR} AND ${ARROW_SUBSTRAIT_PROTOS_DIR}" + ) + foreach(ARROW_SUBSTRAIT_PROTO ${ARROW_SUBSTRAIT_PROTOS}) + set(ARROW_SUBSTRAIT_PROTO_GEN + "${SUBSTRAIT_CPP_DIR}/substrait/${ARROW_SUBSTRAIT_PROTO}.pb") + foreach(EXT h cc) + set_source_files_properties("${ARROW_SUBSTRAIT_PROTO_GEN}.${EXT}" + PROPERTIES COMPILE_OPTIONS + "${SUBSTRAIT_SUPPRESSED_FLAGS}" + GENERATED TRUE + SKIP_UNITY_BUILD_INCLUSION TRUE) + list(APPEND SUBSTRAIT_PROTO_GEN_ALL "${ARROW_SUBSTRAIT_PROTO_GEN}.${EXT}") + endforeach() + add_custom_command(OUTPUT "${ARROW_SUBSTRAIT_PROTO_GEN}.cc" + "${ARROW_SUBSTRAIT_PROTO_GEN}.h" + COMMAND ${ARROW_PROTOBUF_PROTOC} "-I${SUBSTRAIT_LOCAL_DIR}/proto" + "-I${ARROW_SUBSTRAIT_PROTOS_DIR}" + "--cpp_out=${SUBSTRAIT_CPP_DIR}" + "${ARROW_SUBSTRAIT_PROTOS_DIR}/substrait/${ARROW_SUBSTRAIT_PROTO}.proto" + DEPENDS ${PROTO_DEPENDS} substrait_ep) + + list(APPEND SUBSTRAIT_SOURCES "${ARROW_SUBSTRAIT_PROTO_GEN}.cc") + endforeach() add_custom_target(substrait_gen ALL DEPENDS ${SUBSTRAIT_PROTO_GEN_ALL}) diff --git a/cpp/proto/substrait/extension_rels.proto b/cpp/proto/substrait/extension_rels.proto new file mode 100644 index 00000000000..6f806d00e5b --- /dev/null +++ b/cpp/proto/substrait/extension_rels.proto @@ -0,0 +1,44 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. +syntax = "proto3"; + +package arrow.substrait; + +import "substrait/algebra.proto"; + +option csharp_namespace = "Arrow.Substrait"; +option go_package = "github.com/apache/arrow/substrait"; +option java_multiple_files = true; +option java_package = "io.arrow.substrait"; + +// As-Of-Join relation +message AsOfJoinRel { + // One key per input relation, each key describing how to join the corresponding input + repeated AsOfJoinKey keys = 1; + + // As-Of tolerance, in units of the on-key + int64 tolerance = 2; + + // As-Of-Join key + message AsOfJoinKey { + // A field reference defining the on-key + .substrait.Expression on = 1; + + // A set of field references defining the by-key + repeated .substrait.Expression by = 2; + } +} diff --git a/cpp/src/arrow/compute/exec/asof_join_benchmark.cc b/cpp/src/arrow/compute/exec/asof_join_benchmark.cc index 5890e10c206..366508e34b8 100644 --- a/cpp/src/arrow/compute/exec/asof_join_benchmark.cc +++ b/cpp/src/arrow/compute/exec/asof_join_benchmark.cc @@ -101,9 +101,19 @@ static void TableJoinOverhead(benchmark::State& state, benchmark::Counter(static_cast(default_memory_pool()->max_memory())); } +AsofJoinNodeOptions GetRepeatedOptions(size_t repeat, FieldRef on_key, + std::vector by_key, int64_t tolerance) { + std::vector input_keys(repeat); + for (size_t i = 0; i < repeat; i++) { + input_keys[i] = {on_key, by_key}; + } + return AsofJoinNodeOptions(input_keys, tolerance); +} + static void AsOfJoinOverhead(benchmark::State& state) { int64_t tolerance = 0; - AsofJoinNodeOptions options = AsofJoinNodeOptions(kTimeCol, {kKeyCol}, tolerance); + AsofJoinNodeOptions options = + GetRepeatedOptions(int(state.range(4)), kTimeCol, {kKeyCol}, tolerance); TableJoinOverhead( state, TableGenerationProperties{int(state.range(0)), int(state.range(1)), diff --git a/cpp/src/arrow/compute/exec/asof_join_node.cc b/cpp/src/arrow/compute/exec/asof_join_node.cc index 83bbf5df4ca..d071c0ce7f4 100644 --- a/cpp/src/arrow/compute/exec/asof_join_node.cc +++ b/cpp/src/arrow/compute/exec/asof_join_node.cc @@ -15,6 +15,8 @@ // specific language governing permissions and limitations // under the License. +#include "arrow/compute/exec/asof_join_node.h" + #include #include #include @@ -997,17 +999,16 @@ class AsofJoinNode : public ExecNode { } static arrow::Result> MakeOutputSchema( - const std::vector& inputs, + const std::vector> input_schema, const std::vector& indices_of_on_key, const std::vector>& indices_of_by_key) { std::vector> fields; - size_t n_by = indices_of_by_key[0].size(); + size_t n_by = indices_of_by_key.size() == 0 ? 0 : indices_of_by_key[0].size(); const DataType* on_key_type = NULLPTR; std::vector by_key_type(n_by, NULLPTR); // Take all non-key, non-time RHS fields - for (size_t j = 0; j < inputs.size(); ++j) { - const auto& input_schema = inputs[j]->output_schema(); + for (size_t j = 0; j < input_schema.size(); ++j) { const auto& on_field_ix = indices_of_on_key[j]; const auto& by_field_ix = indices_of_by_key[j]; @@ -1015,10 +1016,10 @@ class AsofJoinNode : public ExecNode { return Status::Invalid("Missing join key on table ", j); } - const auto& on_field = input_schema->fields()[on_field_ix]; + const auto& on_field = input_schema[j]->fields()[on_field_ix]; std::vector by_field(n_by); for (size_t k = 0; k < n_by; k++) { - by_field[k] = input_schema->fields()[by_field_ix[k]].get(); + by_field[k] = input_schema[j]->fields()[by_field_ix[k]].get(); } if (on_key_type == NULLPTR) { @@ -1038,8 +1039,8 @@ class AsofJoinNode : public ExecNode { } } - for (int i = 0; i < input_schema->num_fields(); ++i) { - const auto field = input_schema->field(i); + for (int i = 0; i < input_schema[j]->num_fields(); ++i) { + const auto field = input_schema[j]->field(i); if (i == on_field_ix) { ARROW_RETURN_NOT_OK(is_valid_on_field(field)); // Only add on field from the left table @@ -1076,6 +1077,56 @@ class AsofJoinNode : public ExecNode { return match.indices()[0]; } + static Result GetByKeySize( + const std::vector& input_keys) { + size_t n_by = 0; + for (size_t i = 0; i < input_keys.size(); ++i) { + const auto& by_key = input_keys[i].by_key; + if (i == 0) { + n_by = by_key.size(); + } else if (n_by != by_key.size()) { + return Status::Invalid("inconsistent size of by-key across inputs"); + } + } + return n_by; + } + + static Result> GetIndicesOfOnKey( + const std::vector>& input_schema, + const std::vector& input_keys) { + if (input_schema.size() != input_keys.size()) { + return Status::Invalid("mismatching number of input schema and keys"); + } + size_t n_input = input_schema.size(); + std::vector indices_of_on_key(n_input); + for (size_t i = 0; i < n_input; ++i) { + const auto& on_key = input_keys[i].on_key; + ARROW_ASSIGN_OR_RAISE(indices_of_on_key[i], + FindColIndex(*input_schema[i], on_key, "on")); + } + return indices_of_on_key; + } + + static Result>> GetIndicesOfByKey( + const std::vector>& input_schema, + const std::vector& input_keys) { + if (input_schema.size() != input_keys.size()) { + return Status::Invalid("mismatching number of input schema and keys"); + } + ARROW_ASSIGN_OR_RAISE(size_t n_by, GetByKeySize(input_keys)); + size_t n_input = input_schema.size(); + std::vector> indices_of_by_key( + n_input, std::vector(n_by)); + for (size_t i = 0; i < n_input; ++i) { + for (size_t k = 0; k < n_by; k++) { + const auto& by_key = input_keys[i].by_key; + ARROW_ASSIGN_OR_RAISE(indices_of_by_key[i][k], + FindColIndex(*input_schema[i], by_key[k], "by")); + } + } + return indices_of_by_key; + } + static arrow::Result Make(ExecPlan* plan, std::vector inputs, const ExecNodeOptions& options) { DCHECK_GE(inputs.size(), 2) << "Must have at least two inputs"; @@ -1086,24 +1137,21 @@ class AsofJoinNode : public ExecNode { join_options.tolerance); } - size_t n_input = inputs.size(), n_by = join_options.by_key.size(); + ARROW_ASSIGN_OR_RAISE(size_t n_by, GetByKeySize(join_options.input_keys)); + size_t n_input = inputs.size(); std::vector input_labels(n_input); - std::vector indices_of_on_key(n_input); - std::vector> indices_of_by_key( - n_input, std::vector(n_by)); + std::vector> input_schema(n_input); for (size_t i = 0; i < n_input; ++i) { input_labels[i] = i == 0 ? "left" : "right_" + ToChars(i); - const Schema& input_schema = *inputs[i]->output_schema(); - ARROW_ASSIGN_OR_RAISE(indices_of_on_key[i], - FindColIndex(input_schema, join_options.on_key, "on")); - for (size_t k = 0; k < n_by; k++) { - ARROW_ASSIGN_OR_RAISE(indices_of_by_key[i][k], - FindColIndex(input_schema, join_options.by_key[k], "by")); - } + input_schema[i] = inputs[i]->output_schema(); } - - ARROW_ASSIGN_OR_RAISE(std::shared_ptr output_schema, - MakeOutputSchema(inputs, indices_of_on_key, indices_of_by_key)); + ARROW_ASSIGN_OR_RAISE(std::vector indices_of_on_key, + GetIndicesOfOnKey(input_schema, join_options.input_keys)); + ARROW_ASSIGN_OR_RAISE(std::vector> indices_of_by_key, + GetIndicesOfByKey(input_schema, join_options.input_keys)); + ARROW_ASSIGN_OR_RAISE( + std::shared_ptr output_schema, + MakeOutputSchema(input_schema, indices_of_on_key, indices_of_by_key)); std::vector> key_hashers; for (size_t i = 0; i < n_input; i++) { @@ -1213,5 +1261,20 @@ void RegisterAsofJoinNode(ExecFactoryRegistry* registry) { } } // namespace internal +namespace asofjoin { + +Result> MakeOutputSchema( + const std::vector>& input_schema, + const std::vector& input_keys) { + ARROW_ASSIGN_OR_RAISE(std::vector indices_of_on_key, + AsofJoinNode::GetIndicesOfOnKey(input_schema, input_keys)); + ARROW_ASSIGN_OR_RAISE(std::vector> indices_of_by_key, + AsofJoinNode::GetIndicesOfByKey(input_schema, input_keys)); + return AsofJoinNode::MakeOutputSchema(input_schema, indices_of_on_key, + indices_of_by_key); +} + +} // namespace asofjoin + } // namespace compute } // namespace arrow diff --git a/cpp/src/arrow/compute/exec/asof_join_node.h b/cpp/src/arrow/compute/exec/asof_join_node.h new file mode 100644 index 00000000000..27777090d3d --- /dev/null +++ b/cpp/src/arrow/compute/exec/asof_join_node.h @@ -0,0 +1,37 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#include + +#include "arrow/compute/exec.h" +#include "arrow/compute/exec/options.h" +#include "arrow/type.h" +#include "arrow/util/visibility.h" + +namespace arrow { +namespace compute { +namespace asofjoin { + +using AsofJoinKeys = AsofJoinNodeOptions::Keys; + +ARROW_EXPORT Result> MakeOutputSchema( + const std::vector>& input_schema, + const std::vector& input_keys); + +} // namespace asofjoin +} // namespace compute +} // namespace arrow diff --git a/cpp/src/arrow/compute/exec/asof_join_node_test.cc b/cpp/src/arrow/compute/exec/asof_join_node_test.cc index c865f9f38f8..e30e8420956 100644 --- a/cpp/src/arrow/compute/exec/asof_join_node_test.cc +++ b/cpp/src/arrow/compute/exec/asof_join_node_test.cc @@ -142,6 +142,15 @@ void BuildZeroBaseBinaryArray(std::shared_ptr& empty, int64_t length) { ASSERT_OK(builder.Finish(&empty)); } +AsofJoinNodeOptions GetRepeatedOptions(size_t repeat, FieldRef on_key, + std::vector by_key, int64_t tolerance) { + std::vector input_keys(repeat); + for (size_t i = 0; i < repeat; i++) { + input_keys[i] = {on_key, by_key}; + } + return AsofJoinNodeOptions(input_keys, tolerance); +} + // mutates by copying from_key into to_key and changing from_key to zero Result MutateByKey(BatchesWithSchema& batches, std::string from_key, std::string to_key, bool replace_key = false, @@ -248,7 +257,7 @@ void CheckRunOutput(const BatchesWithSchema& l_batches, const BatchesWithSchema& r1_batches, const BatchesWithSchema& exp_batches, \ const FieldRef time, by_key_type key, const int64_t tolerance) { \ CheckRunOutput(l_batches, r0_batches, r1_batches, exp_batches, \ - AsofJoinNodeOptions(time, {key}, tolerance)); \ + GetRepeatedOptions(3, time, {key}, tolerance)); \ } EXPAND_BY_KEY_TYPE(CHECK_RUN_OUTPUT) @@ -300,7 +309,7 @@ void DoRunInvalidPlanTest(const std::shared_ptr& l_schema, const std::shared_ptr& r_schema, int64_t tolerance, const std::string& expected_error_str) { DoRunInvalidPlanTest(l_schema, r_schema, - AsofJoinNodeOptions("time", {"key"}, tolerance), + GetRepeatedOptions(2, "time", {"key"}, tolerance), expected_error_str); } @@ -323,27 +332,27 @@ void DoRunMissingKeysTest(const std::shared_ptr& l_schema, void DoRunMissingOnKeyTest(const std::shared_ptr& l_schema, const std::shared_ptr& r_schema) { DoRunInvalidPlanTest(l_schema, r_schema, - AsofJoinNodeOptions("invalid_time", {"key"}, 0), + GetRepeatedOptions(2, "invalid_time", {"key"}, 0), "Bad join key on table : No match"); } void DoRunMissingByKeyTest(const std::shared_ptr& l_schema, const std::shared_ptr& r_schema) { DoRunInvalidPlanTest(l_schema, r_schema, - AsofJoinNodeOptions("time", {"invalid_key"}, 0), + GetRepeatedOptions(2, "time", {"invalid_key"}, 0), "Bad join key on table : No match"); } void DoRunNestedOnKeyTest(const std::shared_ptr& l_schema, const std::shared_ptr& r_schema) { - DoRunInvalidPlanTest(l_schema, r_schema, AsofJoinNodeOptions({0, "time"}, {"key"}, 0), + DoRunInvalidPlanTest(l_schema, r_schema, GetRepeatedOptions(2, {0, "time"}, {"key"}, 0), "Bad join key on table : No match"); } void DoRunNestedByKeyTest(const std::shared_ptr& l_schema, const std::shared_ptr& r_schema) { DoRunInvalidPlanTest(l_schema, r_schema, - AsofJoinNodeOptions("time", {FieldRef{0, 1}}, 0), + GetRepeatedOptions(2, "time", {FieldRef{0, 1}}, 0), "Bad join key on table : No match"); } @@ -404,7 +413,7 @@ void DoRunUnorderedPlanTest(bool l_unordered, bool r_unordered, const std::shared_ptr& l_schema, const std::shared_ptr& r_schema) { DoRunUnorderedPlanTest(l_unordered, r_unordered, l_schema, r_schema, - AsofJoinNodeOptions("time", {"key"}, 1000), + GetRepeatedOptions(2, "time", {"key"}, 1000), "out-of-order on-key values"); } @@ -501,7 +510,7 @@ struct BasicTest { ASSERT_OK_AND_ASSIGN(exp_nokey_batches, MutateByKey(exp_nokey_batches, "key", "key2", true, true)); CheckRunOutput(l_batches, r0_batches, r1_batches, exp_nokey_batches, - AsofJoinNodeOptions("time", {"key2"}, tolerance)); + GetRepeatedOptions(3, "time", {"key2"}, tolerance)); }); } static void DoMutateNullKey(BasicTest& basic_tests) { basic_tests.RunMutateNullKey(); } @@ -514,7 +523,7 @@ struct BasicTest { ASSERT_OK_AND_ASSIGN(r1_batches, MutateByKey(r1_batches, "key", "key", false, false, true)); CheckRunOutput(l_batches, r0_batches, r1_batches, exp_emptykey_batches, - AsofJoinNodeOptions("time", {}, tolerance)); + GetRepeatedOptions(3, "time", {}, tolerance)); }); } static void DoMutateEmptyKey(BasicTest& basic_tests) { diff --git a/cpp/src/arrow/compute/exec/options.h b/cpp/src/arrow/compute/exec/options.h index 5daaf0584ae..bad5d5669a9 100644 --- a/cpp/src/arrow/compute/exec/options.h +++ b/cpp/src/arrow/compute/exec/options.h @@ -479,22 +479,35 @@ class ARROW_EXPORT HashJoinNodeOptions : public ExecNodeOptions { /// This node will output one row for each row in the left table. class ARROW_EXPORT AsofJoinNodeOptions : public ExecNodeOptions { public: - AsofJoinNodeOptions(FieldRef on_key, std::vector by_key, int64_t tolerance) - : on_key(std::move(on_key)), by_key(by_key), tolerance(tolerance) {} - - /// \brief "on" key for the join. + /// \brief Keys for one input table of the AsofJoin operation /// - /// All inputs tables must be sorted by the "on" key. Must be a single field of a common - /// type. Inexact match is used on the "on" key. i.e., a row is considered match iff - /// left_on - tolerance <= right_on <= left_on. - /// Currently, the "on" key must be of an integer, date, or timestamp type. - FieldRef on_key; - /// \brief "by" key for the join. + /// The keys must be consistent across the input tables: + /// Each "on" key must refer to a field of the same type and units across the tables. + /// Each "by" key must refer to a list of fields of the same types across the tables. + struct Keys { + /// \brief "on" key for the join. + /// + /// The input table must be sorted by the "on" key. Must be a single field of a common + /// type. Inexact match is used on the "on" key. i.e., a row is considered match iff + /// left_on - tolerance <= right_on <= left_on. + /// Currently, the "on" key must be of an integer, date, or timestamp type. + FieldRef on_key; + /// \brief "by" key for the join. + /// + /// The input table must have each field of the "by" key. Exact equality is used for + /// each field of the "by" key. + /// Currently, each field of the "by" key must be of an integer, date, timestamp, or + /// base-binary type. + std::vector by_key; + }; + + AsofJoinNodeOptions(std::vector input_keys, int64_t tolerance) + : input_keys(std::move(input_keys)), tolerance(tolerance) {} + + /// \brief AsofJoin keys per input table. /// - /// All input tables must have the "by" key. Exact equality - /// is used for the "by" key. - /// Currently, the "by" key must be of an integer, date, timestamp, or base-binary type - std::vector by_key; + /// See `Keys` for details. + std::vector input_keys; /// \brief Tolerance for inexact "on" key matching. Must be non-negative. /// /// The tolerance is interpreted in the same units as the "on" key. diff --git a/cpp/src/arrow/engine/CMakeLists.txt b/cpp/src/arrow/engine/CMakeLists.txt index f129394325e..4e5f8bb96b7 100644 --- a/cpp/src/arrow/engine/CMakeLists.txt +++ b/cpp/src/arrow/engine/CMakeLists.txt @@ -23,6 +23,7 @@ set(ARROW_SUBSTRAIT_SRCS substrait/expression_internal.cc substrait/extension_set.cc substrait/extension_types.cc + substrait/options.cc substrait/plan_internal.cc substrait/relation_internal.cc substrait/serde.cc diff --git a/cpp/src/arrow/engine/substrait/expression_internal.cc b/cpp/src/arrow/engine/substrait/expression_internal.cc index 415cd195bf6..23c2d5bc19f 100644 --- a/cpp/src/arrow/engine/substrait/expression_internal.cc +++ b/cpp/src/arrow/engine/substrait/expression_internal.cc @@ -68,6 +68,8 @@ using internal::ToChars; namespace engine { +namespace substrait = ::substrait; + namespace { Id NormalizeFunctionName(Id id) { diff --git a/cpp/src/arrow/engine/substrait/options.cc b/cpp/src/arrow/engine/substrait/options.cc new file mode 100644 index 00000000000..f7c8bf4713e --- /dev/null +++ b/cpp/src/arrow/engine/substrait/options.cc @@ -0,0 +1,113 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. +#include + +#include "arrow/engine/substrait/options.h" + +#include +#include "arrow/compute/exec/asof_join_node.h" +#include "arrow/compute/exec/options.h" +#include "arrow/engine/substrait/expression_internal.h" +#include "arrow/engine/substrait/relation_internal.h" +#include "substrait/extension_rels.pb.h" + +namespace arrow { +namespace engine { + +namespace substrait = ::substrait; + +class DefaultExtensionProvider : public ExtensionProvider { + public: + Result MakeRel(const std::vector& inputs, + const google::protobuf::Any& rel, + const ExtensionSet& ext_set) override { + if (rel.Is()) { + arrow::substrait::AsOfJoinRel as_of_join_rel; + rel.UnpackTo(&as_of_join_rel); + return MakeAsOfJoinRel(inputs, as_of_join_rel, ext_set); + } + return Status::NotImplemented("Unrecognized extension in Susbstrait plan: ", + rel.DebugString()); + } + + private: + Result MakeAsOfJoinRel( + const std::vector& inputs, + const arrow::substrait::AsOfJoinRel& as_of_join_rel, const ExtensionSet& ext_set) { + if (inputs.size() < 2) { + return Status::Invalid("substrait::AsOfJoinNode too few input tables: ", + inputs.size()); + } + if (static_cast(as_of_join_rel.keys_size()) != inputs.size()) { + return Status::Invalid("substrait::AsOfJoinNode mismatched number of inputs"); + } + + size_t n_input = inputs.size(), i = 0; + std::vector input_keys(n_input); + for (const auto& keys : as_of_join_rel.keys()) { + // on-key + if (!keys.has_on()) { + return Status::Invalid("substrait::AsOfJoinNode missing on-key for input ", i); + } + ARROW_ASSIGN_OR_RAISE(auto on_key_expr, FromProto(keys.on(), ext_set, {})); + if (on_key_expr.field_ref() == NULLPTR) { + return Status::NotImplemented( + "substrait::AsOfJoinNode non-field-ref on-key for input ", i); + } + const FieldRef& on_key = *on_key_expr.field_ref(); + + // by-key + std::vector by_key; + for (const auto& by_item : keys.by()) { + ARROW_ASSIGN_OR_RAISE(auto by_key_expr, FromProto(by_item, ext_set, {})); + if (by_key_expr.field_ref() == NULLPTR) { + return Status::NotImplemented( + "substrait::AsOfJoinNode non-field-ref by-key for input ", i); + } + by_key.push_back(*by_key_expr.field_ref()); + } + + input_keys[i] = {std::move(on_key), std::move(by_key)}; + ++i; + } + + // schema + int64_t tolerance = as_of_join_rel.tolerance(); + std::vector> input_schema(inputs.size()); + for (size_t i = 0; i < inputs.size(); i++) { + input_schema[i] = inputs[i].output_schema; + } + ARROW_ASSIGN_OR_RAISE(auto schema, + compute::asofjoin::MakeOutputSchema(input_schema, input_keys)); + compute::AsofJoinNodeOptions asofjoin_node_opts{std::move(input_keys), tolerance}; + + // declaration + std::vector input_decls(inputs.size()); + for (size_t i = 0; i < inputs.size(); i++) { + input_decls[i] = inputs[i].declaration; + } + return DeclarationInfo{ + compute::Declaration("asofjoin", input_decls, std::move(asofjoin_node_opts)), + std::move(schema)}; + } +}; + +std::shared_ptr ExtensionProvider::kDefaultExtensionProvider = + std::make_shared(); + +} // namespace engine +} // namespace arrow diff --git a/cpp/src/arrow/engine/substrait/options.h b/cpp/src/arrow/engine/substrait/options.h index 014842f4d8f..57f29f65630 100644 --- a/cpp/src/arrow/engine/substrait/options.h +++ b/cpp/src/arrow/engine/substrait/options.h @@ -23,7 +23,11 @@ #include #include +#include + #include "arrow/compute/type_fwd.h" +#include "arrow/engine/substrait/type_fwd.h" +#include "arrow/engine/substrait/visibility.h" #include "arrow/type_fwd.h" namespace arrow { @@ -32,7 +36,7 @@ namespace engine { /// How strictly to adhere to the input structure when converting between Substrait and /// Acero representations of a plan. This allows the user to trade conversion accuracy /// for performance and lenience. -enum class ConversionStrictness { +enum class ARROW_ENGINE_EXPORT ConversionStrictness { /// When a primitive is used at the input that doesn't have an exact match at the /// output, reject the conversion. This effectively asserts that there is no (known) /// information loss in the conversion, and that plans should either round-trip back and @@ -65,9 +69,18 @@ using NamedTableProvider = std::function(const std::vector&)>; static NamedTableProvider kDefaultNamedTableProvider; +class ARROW_ENGINE_EXPORT ExtensionProvider { + public: + static std::shared_ptr kDefaultExtensionProvider; + virtual ~ExtensionProvider() = default; + virtual Result MakeRel(const std::vector& inputs, + const google::protobuf::Any& rel, + const ExtensionSet& ext_set) = 0; +}; + /// Options that control the conversion between Substrait and Acero representations of a /// plan. -struct ConversionOptions { +struct ARROW_ENGINE_EXPORT ConversionOptions { /// \brief How strictly the converter should adhere to the structure of the input. ConversionStrictness strictness = ConversionStrictness::BEST_EFFORT; /// \brief A custom strategy to be used for providing named tables @@ -75,6 +88,8 @@ struct ConversionOptions { /// The default behavior will return an invalid status if the plan has any /// named table relations. NamedTableProvider named_table_provider = kDefaultNamedTableProvider; + std::shared_ptr extension_provider = + ExtensionProvider::kDefaultExtensionProvider; }; } // namespace engine diff --git a/cpp/src/arrow/engine/substrait/plan_internal.cc b/cpp/src/arrow/engine/substrait/plan_internal.cc index fff0f7563cc..0ab018e378e 100644 --- a/cpp/src/arrow/engine/substrait/plan_internal.cc +++ b/cpp/src/arrow/engine/substrait/plan_internal.cc @@ -42,6 +42,8 @@ using internal::checked_cast; namespace engine { +namespace substrait = ::substrait; + Status AddExtensionSetToPlan(const ExtensionSet& ext_set, substrait::Plan* plan) { plan->clear_extension_uris(); diff --git a/cpp/src/arrow/engine/substrait/plan_internal.h b/cpp/src/arrow/engine/substrait/plan_internal.h index 235bf1a6ce1..c2094ae1e61 100644 --- a/cpp/src/arrow/engine/substrait/plan_internal.h +++ b/cpp/src/arrow/engine/substrait/plan_internal.h @@ -33,6 +33,8 @@ namespace arrow { namespace engine { +namespace substrait = ::substrait; + /// \brief Replaces the extension information of a Substrait Plan message with the given /// extension set, such that the anchors defined therein can be used in the rest of the /// plan. diff --git a/cpp/src/arrow/engine/substrait/relation_internal.cc b/cpp/src/arrow/engine/substrait/relation_internal.cc index 744227a9339..391928dd73e 100644 --- a/cpp/src/arrow/engine/substrait/relation_internal.cc +++ b/cpp/src/arrow/engine/substrait/relation_internal.cc @@ -68,6 +68,8 @@ using internal::UriFromAbsolutePath; namespace engine { +namespace substrait = ::substrait; + struct EmitInfo { std::vector expressions; std::shared_ptr schema; @@ -685,6 +687,35 @@ Result FromProto(const substrait::Rel& rel, const ExtensionSet& return ProcessEmit(std::move(set), std::move(set_declaration), std::move(union_schema)); } + case substrait::Rel::RelTypeCase::kExtensionLeaf: { + const auto& ext = rel.extension_leaf(); + ARROW_ASSIGN_OR_RAISE( + auto ext_leaf_decl, + conversion_options.extension_provider->MakeRel({}, ext.detail(), ext_set)); + return ProcessEmit(ext, std::move(ext_leaf_decl), ext_leaf_decl.output_schema); + } + case substrait::Rel::RelTypeCase::kExtensionSingle: { + const auto& ext = rel.extension_single(); + ARROW_ASSIGN_OR_RAISE(DeclarationInfo input, + FromProto(ext.input(), ext_set, conversion_options)); + ARROW_ASSIGN_OR_RAISE( + auto ext_single_decl, + conversion_options.extension_provider->MakeRel({input}, ext.detail(), ext_set)); + return ProcessEmit(ext, std::move(ext_single_decl), ext_single_decl.output_schema); + } + case substrait::Rel::RelTypeCase::kExtensionMulti: { + const auto& ext = rel.extension_multi(); + std::vector inputs; + for (const auto& input : ext.inputs()) { + ARROW_ASSIGN_OR_RAISE(auto input_info, + FromProto(input, ext_set, conversion_options)); + inputs.push_back(std::move(input_info)); + } + ARROW_ASSIGN_OR_RAISE( + auto ext_multi_decl, + conversion_options.extension_provider->MakeRel(inputs, ext.detail(), ext_set)); + return ProcessEmit(ext, std::move(ext_multi_decl), ext_multi_decl.output_schema); + } default: break; diff --git a/cpp/src/arrow/engine/substrait/relation_internal.h b/cpp/src/arrow/engine/substrait/relation_internal.h index 17153f5365f..ee2848122d3 100644 --- a/cpp/src/arrow/engine/substrait/relation_internal.h +++ b/cpp/src/arrow/engine/substrait/relation_internal.h @@ -32,6 +32,14 @@ namespace arrow { namespace engine { +/// Information resulting from converting a Substrait relation. +struct ARROW_ENGINE_EXPORT DeclarationInfo { + /// The compute declaration produced thus far. + compute::Declaration declaration; + + std::shared_ptr output_schema; +}; + /// \brief Convert a Substrait Rel object to an Acero declaration ARROW_ENGINE_EXPORT Result FromProto(const substrait::Rel&, const ExtensionSet&, diff --git a/cpp/src/arrow/engine/substrait/serde.cc b/cpp/src/arrow/engine/substrait/serde.cc index ac5de90326e..8290f14caf3 100644 --- a/cpp/src/arrow/engine/substrait/serde.cc +++ b/cpp/src/arrow/engine/substrait/serde.cc @@ -47,6 +47,8 @@ namespace arrow { namespace engine { +namespace substrait = ::substrait; + Status ParseFromBufferImpl(const Buffer& buf, const std::string& full_name, google::protobuf::Message* message) { google::protobuf::io::ArrayInputStream buf_stream{buf.data(), diff --git a/cpp/src/arrow/engine/substrait/serde_test.cc b/cpp/src/arrow/engine/substrait/serde_test.cc index 435006a4e03..d302d753796 100644 --- a/cpp/src/arrow/engine/substrait/serde_test.cc +++ b/cpp/src/arrow/engine/substrait/serde_test.cc @@ -32,6 +32,7 @@ #include "arrow/compute/api_scalar.h" #include "arrow/compute/api_vector.h" #include "arrow/compute/exec.h" +#include "arrow/compute/exec/asof_join_node.h" #include "arrow/compute/exec/exec_plan.h" #include "arrow/compute/exec/expression.h" #include "arrow/compute/exec/expression_internal.h" @@ -3351,6 +3352,16 @@ TEST(Substrait, IsthmusPlan) { /*include_columns=*/{}, conversion_options); } +NamedTableProvider ProvideMadeTable( + std::function>(const std::vector&)> make) { + return [make](const std::vector& names) -> Result { + ARROW_ASSIGN_OR_RAISE(auto table, make(names)); + std::shared_ptr options = + std::make_shared(table); + return compute::Declaration("table_source", {}, options, "mock_source"); + }; +} + TEST(Substrait, ProjectWithMultiFieldExpressions) { auto dummy_schema = schema({field("A", int32()), field("B", int32()), field("C", int32())}); @@ -4000,5 +4011,194 @@ TEST(Substrait, SetRelationBasic) { &sort_options); } +TEST(Substrait, PlanWithAsOfJoinExtension) { + // This demos an extension relation + std::string substrait_json = R"({ + "extensionUris": [], + "extensions": [], + "relations": [{ + "root": { + "input": { + "extension_multi": { + "common": { + "emit": { + "outputMapping": [0, 1, 2, 3] + } + }, + "inputs": [ + { + "read": { + "common": { + "direct": { + } + }, + "baseSchema": { + "names": ["time", "key", "value1"], + "struct": { + "types": [ + { + "i32": { + "typeVariationReference": 0, + "nullability": "NULLABILITY_NULLABLE" + } + }, + { + "i32": { + "typeVariationReference": 0, + "nullability": "NULLABILITY_NULLABLE" + } + }, + { + "fp64": { + "typeVariationReference": 0, + "nullability": "NULLABILITY_NULLABLE" + } + } + ], + "typeVariationReference": 0, + "nullability": "NULLABILITY_REQUIRED" + } + }, + "namedTable": { + "names": ["T1"] + } + } + }, + { + "read": { + "common": { + "direct": { + } + }, + "baseSchema": { + "names": ["time", "key", "value2"], + "struct": { + "types": [ + { + "i32": { + "typeVariationReference": 0, + "nullability": "NULLABILITY_NULLABLE" + } + }, + { + "i32": { + "typeVariationReference": 0, + "nullability": "NULLABILITY_NULLABLE" + } + }, + { + "fp64": { + "typeVariationReference": 0, + "nullability": "NULLABILITY_NULLABLE" + } + } + ], + "typeVariationReference": 0, + "nullability": "NULLABILITY_REQUIRED" + } + }, + "namedTable": { + "names": ["T2"] + } + } + } + ], + "detail": { + "@type": "/arrow.substrait.AsOfJoinRel", + "keys" : [ + { + "on": { + "selection": { + "directReference": { + "structField": { + "field": 0, + } + }, + "rootReference": {} + } + }, + "by": [ + { + "selection": { + "directReference": { + "structField": { + "field": 1, + } + }, + "rootReference": {} + } + } + ] + }, + { + "on": { + "selection": { + "directReference": { + "structField": { + "field": 0, + } + }, + "rootReference": {} + } + }, + "by": [ + { + "selection": { + "directReference": { + "structField": { + "field": 1, + } + }, + "rootReference": {} + } + } + ] + } + ], + "tolerance": 1000 + } + } + }, + "names": ["time", "key", "value1", "value2"] + } + }], + "expectedTypeUrls": [] + })"; + + std::vector> input_schema = { + schema({field("time", int32()), field("key", int32()), field("value1", float64())}), + schema( + {field("time", int32()), field("key", int32()), field("value2", float64())})}; + NamedTableProvider table_provider = ProvideMadeTable( + [&input_schema]( + const std::vector& names) -> Result> { + if (names.size() != 1) { + return Status::Invalid("Multiple test table names"); + } + if (names[0] == "T1") { + return TableFromJSON(input_schema[0], + {"[[2, 1, 1.1], [4, 1, 2.1], [6, 2, 3.1]]"}); + } + if (names[0] == "T2") { + return TableFromJSON(input_schema[1], + {"[[1, 1, 1.2], [3, 2, 2.2], [5, 2, 3.2]]"}); + } + return Status::Invalid("Unknown test table name ", names[0]); + }); + ConversionOptions conversion_options; + conversion_options.named_table_provider = std::move(table_provider); + + ASSERT_OK_AND_ASSIGN(auto buf, internal::SubstraitFromJSON("Plan", substrait_json)); + + ASSERT_OK_AND_ASSIGN( + auto out_schema, + compute::asofjoin::MakeOutputSchema( + input_schema, {{FieldRef(0), {FieldRef(1)}}, {FieldRef(0), {FieldRef(1)}}})); + auto expected_table = TableFromJSON( + out_schema, {"[[2, 1, 1.1, 1.2], [4, 1, 2.1, 1.2], [6, 2, 3.1, 3.2]]"}); + CheckRoundTripResult(std::move(out_schema), std::move(expected_table), + *compute::default_exec_context(), buf, {}, conversion_options); +} + } // namespace engine } // namespace arrow diff --git a/cpp/src/arrow/engine/substrait/test_plan_builder.cc b/cpp/src/arrow/engine/substrait/test_plan_builder.cc index 62f4361a610..d2dbe77e8d6 100644 --- a/cpp/src/arrow/engine/substrait/test_plan_builder.cc +++ b/cpp/src/arrow/engine/substrait/test_plan_builder.cc @@ -36,6 +36,8 @@ namespace arrow { namespace engine { + +namespace substrait = ::substrait; namespace internal { static const ConversionOptions kPlanBuilderConversionOptions; diff --git a/cpp/src/arrow/engine/substrait/type_fwd.h b/cpp/src/arrow/engine/substrait/type_fwd.h index 235d9e82d1b..6089d3f747a 100644 --- a/cpp/src/arrow/engine/substrait/type_fwd.h +++ b/cpp/src/arrow/engine/substrait/type_fwd.h @@ -26,6 +26,7 @@ class ExtensionIdRegistry; class ExtensionSet; struct ConversionOptions; +struct DeclarationInfo; } // namespace engine } // namespace arrow diff --git a/cpp/src/arrow/engine/substrait/type_internal.cc b/cpp/src/arrow/engine/substrait/type_internal.cc index f56aa19a040..01f189d224c 100644 --- a/cpp/src/arrow/engine/substrait/type_internal.cc +++ b/cpp/src/arrow/engine/substrait/type_internal.cc @@ -44,7 +44,7 @@ namespace { template bool IsNullable(const TypeMessage& type) { // FIXME what can we do with NULLABILITY_UNSPECIFIED - return type.nullability() != ::substrait::Type::NULLABILITY_REQUIRED; + return type.nullability() != substrait::Type::NULLABILITY_REQUIRED; } template @@ -95,67 +95,67 @@ Result FieldsFromProto(int size, const Types& types, } // namespace Result, bool>> FromProto( - const ::substrait::Type& type, const ExtensionSet& ext_set, + const substrait::Type& type, const ExtensionSet& ext_set, const ConversionOptions& conversion_options) { switch (type.kind_case()) { - case ::substrait::Type::kBool: + case substrait::Type::kBool: return FromProtoImpl(type.bool_()); - case ::substrait::Type::kI8: + case substrait::Type::kI8: return FromProtoImpl(type.i8()); - case ::substrait::Type::kI16: + case substrait::Type::kI16: return FromProtoImpl(type.i16()); - case ::substrait::Type::kI32: + case substrait::Type::kI32: return FromProtoImpl(type.i32()); - case ::substrait::Type::kI64: + case substrait::Type::kI64: return FromProtoImpl(type.i64()); - case ::substrait::Type::kFp32: + case substrait::Type::kFp32: return FromProtoImpl(type.fp32()); - case ::substrait::Type::kFp64: + case substrait::Type::kFp64: return FromProtoImpl(type.fp64()); - case ::substrait::Type::kString: + case substrait::Type::kString: return FromProtoImpl(type.string()); - case ::substrait::Type::kBinary: + case substrait::Type::kBinary: return FromProtoImpl(type.binary()); - case ::substrait::Type::kTimestamp: + case substrait::Type::kTimestamp: return FromProtoImpl(type.timestamp(), TimeUnit::MICRO); - case ::substrait::Type::kTimestampTz: + case substrait::Type::kTimestampTz: return FromProtoImpl(type.timestamp_tz(), TimeUnit::MICRO, TimestampTzTimezoneString()); - case ::substrait::Type::kDate: + case substrait::Type::kDate: return FromProtoImpl(type.date()); - case ::substrait::Type::kTime: + case substrait::Type::kTime: return FromProtoImpl(type.time(), TimeUnit::MICRO); - case ::substrait::Type::kIntervalYear: + case substrait::Type::kIntervalYear: return FromProtoImpl(type.interval_year(), interval_year); - case ::substrait::Type::kIntervalDay: + case substrait::Type::kIntervalDay: return FromProtoImpl(type.interval_day(), interval_day); - case ::substrait::Type::kUuid: + case substrait::Type::kUuid: return FromProtoImpl(type.uuid(), uuid); - case ::substrait::Type::kFixedChar: + case substrait::Type::kFixedChar: return FromProtoImpl(type.fixed_char(), fixed_char, type.fixed_char().length()); - case ::substrait::Type::kVarchar: + case substrait::Type::kVarchar: return FromProtoImpl(type.varchar(), varchar, type.varchar().length()); - case ::substrait::Type::kFixedBinary: + case substrait::Type::kFixedBinary: return FromProtoImpl(type.fixed_binary(), type.fixed_binary().length()); - case ::substrait::Type::kDecimal: { + case substrait::Type::kDecimal: { const auto& decimal = type.decimal(); return FromProtoImpl(decimal, decimal.precision(), decimal.scale()); } - case ::substrait::Type::kStruct: { + case substrait::Type::kStruct: { const auto& struct_ = type.struct_(); ARROW_ASSIGN_OR_RAISE( @@ -166,7 +166,7 @@ Result, bool>> FromProto( return FromProtoImpl(struct_, std::move(fields)); } - case ::substrait::Type::kList: { + case substrait::Type::kList: { const auto& list = type.list(); if (!list.has_type()) { @@ -181,7 +181,7 @@ Result, bool>> FromProto( list, field("item", std::move(type_nullable.first), type_nullable.second)); } - case ::substrait::Type::kMap: { + case substrait::Type::kMap: { const auto& map = type.map(); static const std::array kMissing = {"key and value", "value", "key", @@ -207,7 +207,7 @@ Result, bool>> FromProto( field("value", std::move(value_nullable.first), value_nullable.second)); } - case ::substrait::Type::kUserDefined: { + case substrait::Type::kUserDefined: { const auto& user_defined = type.user_defined(); uint32_t anchor = user_defined.type_reference(); ARROW_ASSIGN_OR_RAISE(auto type_record, ext_set.DecodeType(anchor)); @@ -321,8 +321,7 @@ struct DataTypeToProtoImpl { for (const auto& field : t.fields()) { if (field->metadata() != nullptr) { - return Status::Invalid( - "::substrait::Type::Struct does not support field metadata"); + return Status::Invalid("substrait::Type::Struct does not support field metadata"); } ARROW_ASSIGN_OR_RAISE(auto type, ToProto(*field->type(), field->nullable(), ext_set_, conversion_options_)); @@ -386,8 +385,8 @@ struct DataTypeToProtoImpl { template Sub* SetWithThen(void (::substrait::Type::*set_allocated_sub)(Sub*)) { auto sub = std::make_unique(); - sub->set_nullability(nullable_ ? ::substrait::Type::NULLABILITY_NULLABLE - : ::substrait::Type::NULLABILITY_REQUIRED); + sub->set_nullability(nullable_ ? substrait::Type::NULLABILITY_NULLABLE + : substrait::Type::NULLABILITY_REQUIRED); auto out = sub.get(); (type_->*set_allocated_sub)(sub.release()); @@ -402,37 +401,37 @@ struct DataTypeToProtoImpl { template Status EncodeUserDefined(const T& t) { ARROW_ASSIGN_OR_RAISE(auto anchor, ext_set_->EncodeType(t)); - auto user_defined = std::make_unique<::substrait::Type::UserDefined>(); + auto user_defined = std::make_unique(); user_defined->set_type_reference(anchor); - user_defined->set_nullability(nullable_ ? ::substrait::Type::NULLABILITY_NULLABLE - : ::substrait::Type::NULLABILITY_REQUIRED); + user_defined->set_nullability(nullable_ ? substrait::Type::NULLABILITY_NULLABLE + : substrait::Type::NULLABILITY_REQUIRED); type_->set_allocated_user_defined(user_defined.release()); return Status::OK(); } Status NotImplemented(const DataType& t) { - return Status::NotImplemented("conversion to ::substrait::Type from ", t.ToString()); + return Status::NotImplemented("conversion to substrait::Type from ", t.ToString()); } Status operator()(const DataType& type) { return VisitTypeInline(type, this); } - ::substrait::Type* type_; + substrait::Type* type_; bool nullable_; ExtensionSet* ext_set_; const ConversionOptions& conversion_options_; }; } // namespace -Result> ToProto( +Result> ToProto( const DataType& type, bool nullable, ExtensionSet* ext_set, const ConversionOptions& conversion_options) { - auto out = std::make_unique<::substrait::Type>(); + auto out = std::make_unique(); RETURN_NOT_OK( (DataTypeToProtoImpl{out.get(), nullable, ext_set, conversion_options})(type)); return std::move(out); } -Result> FromProto(const ::substrait::NamedStruct& named_struct, +Result> FromProto(const substrait::NamedStruct& named_struct, const ExtensionSet& ext_set, const ConversionOptions& conversion_options) { if (!named_struct.has_struct_()) { @@ -476,7 +475,7 @@ void ToProtoGetDepthFirstNames(const FieldVector& fields, } } // namespace -Result> ToProto( +Result> ToProto( const Schema& schema, ExtensionSet* ext_set, const ConversionOptions& conversion_options) { if (conversion_options.strictness == ConversionStrictness::EXACT_ROUNDTRIP && @@ -484,20 +483,20 @@ Result> ToProto( return Status::Invalid("::substrait::NamedStruct does not support schema metadata"); } - auto named_struct = std::make_unique<::substrait::NamedStruct>(); + auto named_struct = std::make_unique(); auto names = named_struct->mutable_names(); names->Reserve(schema.num_fields()); ToProtoGetDepthFirstNames(schema.fields(), names); - auto struct_ = std::make_unique<::substrait::Type::Struct>(); + auto struct_ = std::make_unique(); auto types = struct_->mutable_types(); types->Reserve(schema.num_fields()); for (const auto& field : schema.fields()) { if (conversion_options.strictness == ConversionStrictness::EXACT_ROUNDTRIP && field->metadata() != nullptr) { - return Status::Invalid("::substrait::NamedStruct does not support field metadata"); + return Status::Invalid("substrait::NamedStruct does not support field metadata"); } ARROW_ASSIGN_OR_RAISE(auto type, ToProto(*field->type(), field->nullable(), ext_set, diff --git a/cpp/src/arrow/engine/substrait/type_internal.h b/cpp/src/arrow/engine/substrait/type_internal.h index 0d53028f493..b162e4dc2b2 100644 --- a/cpp/src/arrow/engine/substrait/type_internal.h +++ b/cpp/src/arrow/engine/substrait/type_internal.h @@ -33,6 +33,8 @@ namespace arrow { namespace engine { +namespace substrait = ::substrait; + ARROW_ENGINE_EXPORT Result, bool>> FromProto(const substrait::Type&, const ExtensionSet&, From ad04595494c199183c6f9ff4d06ef2b2915a3269 Mon Sep 17 00:00:00 2001 From: Yaron Gvili Date: Fri, 4 Nov 2022 08:05:40 -0400 Subject: [PATCH 2/9] remove options.h dependency on protobuf --- cpp/src/arrow/engine/substrait/options.cc | 5 +++ cpp/src/arrow/engine/substrait/options.h | 16 ++----- .../arrow/engine/substrait/options_internal.h | 44 +++++++++++++++++++ .../engine/substrait/relation_internal.cc | 1 + 4 files changed, 54 insertions(+), 12 deletions(-) create mode 100644 cpp/src/arrow/engine/substrait/options_internal.h diff --git a/cpp/src/arrow/engine/substrait/options.cc b/cpp/src/arrow/engine/substrait/options.cc index f7c8bf4713e..0fe7527ee83 100644 --- a/cpp/src/arrow/engine/substrait/options.cc +++ b/cpp/src/arrow/engine/substrait/options.cc @@ -22,6 +22,7 @@ #include "arrow/compute/exec/asof_join_node.h" #include "arrow/compute/exec/options.h" #include "arrow/engine/substrait/expression_internal.h" +#include "arrow/engine/substrait/options_internal.h" #include "arrow/engine/substrait/relation_internal.h" #include "substrait/extension_rels.pb.h" @@ -109,5 +110,9 @@ class DefaultExtensionProvider : public ExtensionProvider { std::shared_ptr ExtensionProvider::kDefaultExtensionProvider = std::make_shared(); +std::shared_ptr default_extension_provider() { + return ExtensionProvider::kDefaultExtensionProvider; +} + } // namespace engine } // namespace arrow diff --git a/cpp/src/arrow/engine/substrait/options.h b/cpp/src/arrow/engine/substrait/options.h index 57f29f65630..41d0792a8d2 100644 --- a/cpp/src/arrow/engine/substrait/options.h +++ b/cpp/src/arrow/engine/substrait/options.h @@ -23,8 +23,6 @@ #include #include -#include - #include "arrow/compute/type_fwd.h" #include "arrow/engine/substrait/type_fwd.h" #include "arrow/engine/substrait/visibility.h" @@ -69,14 +67,9 @@ using NamedTableProvider = std::function(const std::vector&)>; static NamedTableProvider kDefaultNamedTableProvider; -class ARROW_ENGINE_EXPORT ExtensionProvider { - public: - static std::shared_ptr kDefaultExtensionProvider; - virtual ~ExtensionProvider() = default; - virtual Result MakeRel(const std::vector& inputs, - const google::protobuf::Any& rel, - const ExtensionSet& ext_set) = 0; -}; +class ExtensionProvider; + +ARROW_ENGINE_EXPORT std::shared_ptr default_extension_provider(); /// Options that control the conversion between Substrait and Acero representations of a /// plan. @@ -88,8 +81,7 @@ struct ARROW_ENGINE_EXPORT ConversionOptions { /// The default behavior will return an invalid status if the plan has any /// named table relations. NamedTableProvider named_table_provider = kDefaultNamedTableProvider; - std::shared_ptr extension_provider = - ExtensionProvider::kDefaultExtensionProvider; + std::shared_ptr extension_provider = default_extension_provider(); }; } // namespace engine diff --git a/cpp/src/arrow/engine/substrait/options_internal.h b/cpp/src/arrow/engine/substrait/options_internal.h new file mode 100644 index 00000000000..0d186147a9a --- /dev/null +++ b/cpp/src/arrow/engine/substrait/options_internal.h @@ -0,0 +1,44 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +// This API is EXPERIMENTAL. + +#pragma once + +#include + +#include + +#include "arrow/compute/type_fwd.h" +#include "arrow/engine/substrait/type_fwd.h" +#include "arrow/engine/substrait/visibility.h" +#include "arrow/type_fwd.h" + +namespace arrow { +namespace engine { + +class ARROW_ENGINE_EXPORT ExtensionProvider { + public: + static std::shared_ptr kDefaultExtensionProvider; + virtual ~ExtensionProvider() = default; + virtual Result MakeRel(const std::vector& inputs, + const google::protobuf::Any& rel, + const ExtensionSet& ext_set) = 0; +}; + +} // namespace engine +} // namespace arrow diff --git a/cpp/src/arrow/engine/substrait/relation_internal.cc b/cpp/src/arrow/engine/substrait/relation_internal.cc index 391928dd73e..5e48000643e 100644 --- a/cpp/src/arrow/engine/substrait/relation_internal.cc +++ b/cpp/src/arrow/engine/substrait/relation_internal.cc @@ -44,6 +44,7 @@ #include "arrow/engine/substrait/expression_internal.h" #include "arrow/engine/substrait/extension_set.h" #include "arrow/engine/substrait/options.h" +#include "arrow/engine/substrait/options_internal.h" #include "arrow/engine/substrait/relation.h" #include "arrow/engine/substrait/type_internal.h" #include "arrow/engine/substrait/util.h" From 05d6aa6069de0254db3f66699293a3406adc5105 Mon Sep 17 00:00:00 2001 From: Yaron Gvili Date: Sun, 6 Nov 2022 08:30:02 -0500 Subject: [PATCH 3/9] requested changes --- cpp/cmake_modules/ThirdpartyToolchain.cmake | 4 ---- cpp/src/arrow/engine/substrait/plan_internal.cc | 2 +- 2 files changed, 1 insertion(+), 5 deletions(-) diff --git a/cpp/cmake_modules/ThirdpartyToolchain.cmake b/cpp/cmake_modules/ThirdpartyToolchain.cmake index 00e8a72166e..6a668be20a0 100644 --- a/cpp/cmake_modules/ThirdpartyToolchain.cmake +++ b/cpp/cmake_modules/ThirdpartyToolchain.cmake @@ -1747,8 +1747,6 @@ macro(build_substrait) set(SUBSTRAIT_PROTOS algebra extensions/extensions plan type) set(ARROW_SUBSTRAIT_PROTOS extension_rels) set(ARROW_SUBSTRAIT_PROTOS_DIR "${CMAKE_SOURCE_DIR}/proto") - message("SOURCE DIR IS ${SOURCE_DIR} AND ${CMAKE_SOURCE_DIR} AND ${ARROW_SUBSTRAIT_PROTOS_DIR}" - ) externalproject_add(substrait_ep ${EP_COMMON_OPTIONS} @@ -1800,8 +1798,6 @@ macro(build_substrait) list(APPEND SUBSTRAIT_SOURCES "${SUBSTRAIT_PROTO_GEN}.cc") endforeach() - message("SOURCE DIR2 IS ${SOURCE_DIR} AND ${CMAKE_SOURCE_DIR} AND ${ARROW_SUBSTRAIT_PROTOS_DIR}" - ) foreach(ARROW_SUBSTRAIT_PROTO ${ARROW_SUBSTRAIT_PROTOS}) set(ARROW_SUBSTRAIT_PROTO_GEN "${SUBSTRAIT_CPP_DIR}/substrait/${ARROW_SUBSTRAIT_PROTO}.pb") diff --git a/cpp/src/arrow/engine/substrait/plan_internal.cc b/cpp/src/arrow/engine/substrait/plan_internal.cc index 0ab018e378e..f787a0206e9 100644 --- a/cpp/src/arrow/engine/substrait/plan_internal.cc +++ b/cpp/src/arrow/engine/substrait/plan_internal.cc @@ -145,7 +145,7 @@ Result GetExtensionSetFromPlan(const substrait::Plan& plan, namespace { -// FIXME Is there some way to get these from the cmake files? +// TODO(ARROW-18145) Populate these from cmake files constexpr uint32_t kSubstraitMajorVersion = 0; constexpr uint32_t kSubstraitMinorVersion = 20; constexpr uint32_t kSubstraitPatchVersion = 0; From 498db94458eabb9669edbe682d81001746d725b5 Mon Sep 17 00:00:00 2001 From: Yaron Gvili Date: Sun, 27 Nov 2022 05:01:02 -0500 Subject: [PATCH 4/9] requested fixes --- cpp/proto/substrait/extension_rels.proto | 2 +- cpp/src/arrow/compute/exec/options.h | 6 +- .../engine/substrait/expression_internal.cc | 193 +++++++++--------- .../engine/substrait/expression_internal.h | 19 +- cpp/src/arrow/engine/substrait/options.cc | 20 +- .../arrow/engine/substrait/plan_internal.cc | 28 ++- .../arrow/engine/substrait/plan_internal.h | 10 +- .../engine/substrait/relation_internal.cc | 124 +++++------ .../engine/substrait/relation_internal.h | 4 +- cpp/src/arrow/engine/substrait/serde.cc | 23 +-- cpp/src/arrow/engine/substrait/serde_test.cc | 2 +- .../engine/substrait/test_plan_builder.cc | 101 +++++---- .../arrow/engine/substrait/type_internal.h | 15 +- 13 files changed, 273 insertions(+), 274 deletions(-) diff --git a/cpp/proto/substrait/extension_rels.proto b/cpp/proto/substrait/extension_rels.proto index 6f806d00e5b..ceed9f3e455 100644 --- a/cpp/proto/substrait/extension_rels.proto +++ b/cpp/proto/substrait/extension_rels.proto @@ -16,7 +16,7 @@ // under the License. syntax = "proto3"; -package arrow.substrait; +package arrow.substrait_ext; import "substrait/algebra.proto"; diff --git a/cpp/src/arrow/compute/exec/options.h b/cpp/src/arrow/compute/exec/options.h index bad5d5669a9..325e8e514d1 100644 --- a/cpp/src/arrow/compute/exec/options.h +++ b/cpp/src/arrow/compute/exec/options.h @@ -488,13 +488,13 @@ class ARROW_EXPORT AsofJoinNodeOptions : public ExecNodeOptions { /// \brief "on" key for the join. /// /// The input table must be sorted by the "on" key. Must be a single field of a common - /// type. Inexact match is used on the "on" key. i.e., a row is considered match iff + /// type. Inexact match is used on the "on" key. i.e., a row is considered a match iff /// left_on - tolerance <= right_on <= left_on. /// Currently, the "on" key must be of an integer, date, or timestamp type. FieldRef on_key; /// \brief "by" key for the join. /// - /// The input table must have each field of the "by" key. Exact equality is used for + /// Each input table must have each field of the "by" key. Exact equality is used for /// each field of the "by" key. /// Currently, each field of the "by" key must be of an integer, date, timestamp, or /// base-binary type. @@ -506,7 +506,7 @@ class ARROW_EXPORT AsofJoinNodeOptions : public ExecNodeOptions { /// \brief AsofJoin keys per input table. /// - /// See `Keys` for details. + /// \see `Keys` for details. std::vector input_keys; /// \brief Tolerance for inexact "on" key matching. Must be non-negative. /// diff --git a/cpp/src/arrow/engine/substrait/expression_internal.cc b/cpp/src/arrow/engine/substrait/expression_internal.cc index 23c2d5bc19f..38aa7e799ca 100644 --- a/cpp/src/arrow/engine/substrait/expression_internal.cc +++ b/cpp/src/arrow/engine/substrait/expression_internal.cc @@ -68,8 +68,6 @@ using internal::ToChars; namespace engine { -namespace substrait = ::substrait; - namespace { Id NormalizeFunctionName(Id id) { @@ -85,7 +83,7 @@ Id NormalizeFunctionName(Id id) { } // namespace -Status DecodeArg(const substrait::FunctionArgument& arg, int idx, SubstraitCall* call, +Status DecodeArg(const ::substrait::FunctionArgument& arg, int idx, SubstraitCall* call, const ExtensionSet& ext_set, const ConversionOptions& conversion_options) { if (!arg.enum_().empty()) { @@ -102,7 +100,7 @@ Status DecodeArg(const substrait::FunctionArgument& arg, int idx, SubstraitCall* return Status::OK(); } -Status DecodeOption(const substrait::FunctionOption& opt, SubstraitCall* call) { +Status DecodeOption(const ::substrait::FunctionOption& opt, SubstraitCall* call) { std::vector prefs; if (opt.preference_size() == 0) { return Status::Invalid("Invalid Substrait plan. The option ", opt.name(), @@ -116,7 +114,7 @@ Status DecodeOption(const substrait::FunctionOption& opt, SubstraitCall* call) { } Result DecodeScalarFunction( - Id id, const substrait::Expression::ScalarFunction& scalar_fn, + Id id, const ::substrait::Expression::ScalarFunction& scalar_fn, const ExtensionSet& ext_set, const ConversionOptions& conversion_options) { ARROW_ASSIGN_OR_RAISE(auto output_type_and_nullable, FromProto(scalar_fn.output_type(), ext_set, conversion_options)); @@ -131,22 +129,32 @@ Result DecodeScalarFunction( return std::move(call); } -Result FromProto(const substrait::AggregateFunction& func, bool is_hash, +std::string EnumToString(int value, const google::protobuf::EnumDescriptor* descriptor) { + const google::protobuf::EnumValueDescriptor* value_desc = + descriptor->FindValueByNumber(value); + if (value_desc == nullptr) { + return "unknown"; + } + return value_desc->name(); +} + +Result FromProto(const ::substrait::AggregateFunction& func, bool is_hash, const ExtensionSet& ext_set, const ConversionOptions& conversion_options) { - if (func.phase() != substrait::AggregationPhase::AGGREGATION_PHASE_INITIAL_TO_RESULT) { + if (func.phase() != + ::substrait::AggregationPhase::AGGREGATION_PHASE_INITIAL_TO_RESULT) { return Status::NotImplemented( "Unsupported aggregation phase '", - EnumToString(func.phase(), *substrait::AggregationPhase_descriptor()), + EnumToString(func.phase(), *::substrait::AggregationPhase_descriptor()), "'. Only INITIAL_TO_RESULT is supported"); } if (func.invocation() != - substrait::AggregateFunction::AggregationInvocation:: + ::substrait::AggregateFunction::AggregationInvocation:: AggregateFunction_AggregationInvocation_AGGREGATION_INVOCATION_ALL) { return Status::NotImplemented( "Unsupported aggregation invocation '", EnumToString(func.invocation(), - *substrait::AggregateFunction::AggregationInvocation_descriptor()), + *::substrait::AggregateFunction::AggregationInvocation_descriptor()), "'. Only AGGREGATION_INVOCATION_ALL is " "supported"); } @@ -166,17 +174,17 @@ Result FromProto(const substrait::AggregateFunction& func, bool i return std::move(call); } -Result FromProto(const substrait::Expression& expr, +Result FromProto(const ::substrait::Expression& expr, const ExtensionSet& ext_set, const ConversionOptions& conversion_options) { switch (expr.rex_type_case()) { - case substrait::Expression::kLiteral: { + case ::substrait::Expression::kLiteral: { ARROW_ASSIGN_OR_RAISE(auto datum, FromProto(expr.literal(), ext_set, conversion_options)); return compute::literal(std::move(datum)); } - case substrait::Expression::kSelection: { + case ::substrait::Expression::kSelection: { if (!expr.selection().has_direct_reference()) break; std::optional out; @@ -188,7 +196,7 @@ Result FromProto(const substrait::Expression& expr, const auto* ref = &expr.selection().direct_reference(); while (ref != nullptr) { switch (ref->reference_type_case()) { - case substrait::Expression::ReferenceSegment::kStructField: { + case ::substrait::Expression::ReferenceSegment::kStructField: { auto index = ref->struct_field().field(); if (!out) { // Root StructField (column selection) @@ -217,11 +225,11 @@ Result FromProto(const substrait::Expression& expr, } break; } - case substrait::Expression::ReferenceSegment::kListElement: { + case ::substrait::Expression::ReferenceSegment::kListElement: { if (!out) { // Root ListField (illegal) return Status::Invalid( - "substrait::ListElement cannot take a Relation as an argument"); + "::substrait::ListElement cannot take a Relation as an argument"); } // ListField on top of an arbitrary expression @@ -249,7 +257,7 @@ Result FromProto(const substrait::Expression& expr, break; } - case substrait::Expression::kIfThen: { + case ::substrait::Expression::kIfThen: { const auto& if_then = expr.if_then(); if (!if_then.has_else_()) break; if (if_then.ifs_size() == 0) break; @@ -289,7 +297,7 @@ Result FromProto(const substrait::Expression& expr, return compute::call("case_when", std::move(args)); } - case substrait::Expression::kScalarFunction: { + case ::substrait::Expression::kScalarFunction: { const auto& scalar_fn = expr.scalar_function(); ARROW_ASSIGN_OR_RAISE(Id function_id, @@ -322,7 +330,7 @@ Result FromProto(const substrait::Expression& expr, expr.DebugString()); } -Result FromProto(const substrait::Expression::Literal& lit, +Result FromProto(const ::substrait::Expression::Literal& lit, const ExtensionSet& ext_set, const ConversionOptions& conversion_options) { if (lit.nullable() && @@ -334,43 +342,43 @@ Result FromProto(const substrait::Expression::Literal& lit, } switch (lit.literal_type_case()) { - case substrait::Expression::Literal::kBoolean: + case ::substrait::Expression::Literal::kBoolean: return Datum(lit.boolean()); - case substrait::Expression::Literal::kI8: + case ::substrait::Expression::Literal::kI8: return Datum(static_cast(lit.i8())); - case substrait::Expression::Literal::kI16: + case ::substrait::Expression::Literal::kI16: return Datum(static_cast(lit.i16())); - case substrait::Expression::Literal::kI32: + case ::substrait::Expression::Literal::kI32: return Datum(static_cast(lit.i32())); - case substrait::Expression::Literal::kI64: + case ::substrait::Expression::Literal::kI64: return Datum(static_cast(lit.i64())); - case substrait::Expression::Literal::kFp32: + case ::substrait::Expression::Literal::kFp32: return Datum(lit.fp32()); - case substrait::Expression::Literal::kFp64: + case ::substrait::Expression::Literal::kFp64: return Datum(lit.fp64()); - case substrait::Expression::Literal::kString: + case ::substrait::Expression::Literal::kString: return Datum(lit.string()); - case substrait::Expression::Literal::kBinary: + case ::substrait::Expression::Literal::kBinary: return Datum(BinaryScalar(lit.binary())); - case substrait::Expression::Literal::kTimestamp: + case ::substrait::Expression::Literal::kTimestamp: return Datum( TimestampScalar(static_cast(lit.timestamp()), TimeUnit::MICRO)); - case substrait::Expression::Literal::kTimestampTz: + case ::substrait::Expression::Literal::kTimestampTz: return Datum(TimestampScalar(static_cast(lit.timestamp_tz()), TimeUnit::MICRO, TimestampTzTimezoneString())); - case substrait::Expression::Literal::kDate: + case ::substrait::Expression::Literal::kDate: return Datum(Date32Scalar(lit.date())); - case substrait::Expression::Literal::kTime: + case ::substrait::Expression::Literal::kTime: return Datum(Time64Scalar(lit.time(), TimeUnit::MICRO)); - case substrait::Expression::Literal::kIntervalYearToMonth: - case substrait::Expression::Literal::kIntervalDayToSecond: { + case ::substrait::Expression::Literal::kIntervalYearToMonth: + case ::substrait::Expression::Literal::kIntervalDayToSecond: { Int32Builder builder; std::shared_ptr type; if (lit.has_interval_year_to_month()) { @@ -387,23 +395,23 @@ Result FromProto(const substrait::Expression::Literal& lit, ExtensionScalar(FixedSizeListScalar(std::move(array)), std::move(type))); } - case substrait::Expression::Literal::kUuid: + case ::substrait::Expression::Literal::kUuid: return Datum(ExtensionScalar(FixedSizeBinaryScalar(lit.uuid()), uuid())); - case substrait::Expression::Literal::kFixedChar: + case ::substrait::Expression::Literal::kFixedChar: return Datum( ExtensionScalar(FixedSizeBinaryScalar(lit.fixed_char()), fixed_char(static_cast(lit.fixed_char().size())))); - case substrait::Expression::Literal::kVarChar: + case ::substrait::Expression::Literal::kVarChar: return Datum( ExtensionScalar(StringScalar(lit.var_char().value()), varchar(static_cast(lit.var_char().length())))); - case substrait::Expression::Literal::kFixedBinary: + case ::substrait::Expression::Literal::kFixedBinary: return Datum(FixedSizeBinaryScalar(lit.fixed_binary())); - case substrait::Expression::Literal::kDecimal: { + case ::substrait::Expression::Literal::kDecimal: { if (lit.decimal().value().size() != sizeof(Decimal128)) { return Status::Invalid("Decimal literal had ", lit.decimal().value().size(), " bytes (expected ", sizeof(Decimal128), ")"); @@ -420,7 +428,7 @@ Result FromProto(const substrait::Expression::Literal& lit, return Datum(Decimal128Scalar(value, std::move(type))); } - case substrait::Expression::Literal::kStruct: { + case ::substrait::Expression::Literal::kStruct: { const auto& struct_ = lit.struct_(); ScalarVector fields(struct_.fields_size()); @@ -440,12 +448,12 @@ Result FromProto(const substrait::Expression::Literal& lit, return Datum(std::move(scalar)); } - case substrait::Expression::Literal::kList: { + case ::substrait::Expression::Literal::kList: { const auto& list = lit.list(); if (list.values_size() == 0) { return Status::Invalid( - "substrait::Expression::Literal::List had no values; should have been an " - "substrait::Expression::Literal::EmptyList"); + "::substrait::Expression::Literal::List had no values; should have been an " + "::substrait::Expression::Literal::EmptyList"); } std::shared_ptr element_type; @@ -473,12 +481,12 @@ Result FromProto(const substrait::Expression::Literal& lit, return Datum(ListScalar(std::move(arr))); } - case substrait::Expression::Literal::kMap: { + case ::substrait::Expression::Literal::kMap: { const auto& map = lit.map(); if (map.key_values_size() == 0) { return Status::Invalid( - "substrait::Expression::Literal::Map had no values; should have been an " - "substrait::Expression::Literal::EmptyMap"); + "::substrait::Expression::Literal::Map had no values; should have been an " + "::substrait::Expression::Literal::EmptyMap"); } std::shared_ptr key_type, value_type; @@ -534,14 +542,14 @@ Result FromProto(const substrait::Expression::Literal& lit, return Datum(std::make_shared(std::move(kv_arr))); } - case substrait::Expression::Literal::kEmptyList: { + case ::substrait::Expression::Literal::kEmptyList: { ARROW_ASSIGN_OR_RAISE(auto type_nullable, FromProto(lit.empty_list().type(), ext_set, conversion_options)); ARROW_ASSIGN_OR_RAISE(auto values, MakeEmptyArray(type_nullable.first)); return ListScalar{std::move(values)}; } - case substrait::Expression::Literal::kEmptyMap: { + case ::substrait::Expression::Literal::kEmptyMap: { ARROW_ASSIGN_OR_RAISE( auto key_type_nullable, FromProto(lit.empty_map().key(), ext_set, conversion_options)); @@ -564,7 +572,7 @@ Result FromProto(const substrait::Expression::Literal& lit, return MapScalar{std::move(key_values)}; } - case substrait::Expression::Literal::kNull: { + case ::substrait::Expression::Literal::kNull: { ARROW_ASSIGN_OR_RAISE(auto type_nullable, FromProto(lit.null(), ext_set, conversion_options)); if (!type_nullable.second) { @@ -587,10 +595,10 @@ namespace { struct ScalarToProtoImpl { Status Visit(const NullScalar& s) { return NotImplemented(s); } - using Lit = substrait::Expression::Literal; + using Lit = ::substrait::Expression::Literal; template - Status Primitive(void (substrait::Expression::Literal::*set)(Arg), + Status Primitive(void (::substrait::Expression::Literal::*set)(Arg), const PrimitiveScalar& primitive_scalar) { (lit_->*set)(static_cast(primitive_scalar.value)); return Status::OK(); @@ -811,27 +819,27 @@ struct ScalarToProtoImpl { Status Visit(const MonthDayNanoIntervalScalar& s) { return NotImplemented(s); } Status NotImplemented(const Scalar& s) { - return Status::NotImplemented("conversion to substrait::Expression::Literal from ", + return Status::NotImplemented("conversion to ::substrait::Expression::Literal from ", s.ToString()); } Status operator()(const Scalar& scalar) { return VisitScalarInline(scalar, this); } - substrait::Expression::Literal* lit_; + ::substrait::Expression::Literal* lit_; ExtensionSet* ext_set_; const ConversionOptions& conversion_options_; }; } // namespace -Result> ToProto( +Result> ToProto( const Datum& datum, ExtensionSet* ext_set, const ConversionOptions& conversion_options) { if (!datum.is_scalar()) { return Status::NotImplemented("representing ", datum.ToString(), - " as a substrait::Expression::Literal"); + " as a ::substrait::Expression::Literal"); } - auto out = std::make_unique(); + auto out = std::make_unique<::substrait::Expression::Literal>(); if (datum.scalar()->is_valid) { RETURN_NOT_OK( @@ -846,11 +854,11 @@ Result> ToProto( } static Status AddChildToReferenceSegment( - substrait::Expression::ReferenceSegment& segment, - std::unique_ptr&& child) { + ::substrait::Expression::ReferenceSegment& segment, + std::unique_ptr<::substrait::Expression::ReferenceSegment>&& child) { auto status = Status::Invalid("Attempt to add child to incomplete reference segment"); switch (segment.reference_type_case()) { - case substrait::Expression::ReferenceSegment::kMapKey: { + case ::substrait::Expression::ReferenceSegment::kMapKey: { auto map_key = segment.mutable_map_key(); if (map_key->has_child()) { status = AddChildToReferenceSegment(*map_key->mutable_child(), std::move(child)); @@ -860,7 +868,7 @@ static Status AddChildToReferenceSegment( } break; } - case substrait::Expression::ReferenceSegment::kStructField: { + case ::substrait::Expression::ReferenceSegment::kStructField: { auto struct_field = segment.mutable_struct_field(); if (struct_field->has_child()) { status = @@ -871,7 +879,7 @@ static Status AddChildToReferenceSegment( } break; } - case substrait::Expression::ReferenceSegment::kListElement: { + case ::substrait::Expression::ReferenceSegment::kListElement: { auto list_element = segment.mutable_list_element(); if (list_element->has_child()) { status = @@ -890,9 +898,9 @@ static Status AddChildToReferenceSegment( // Indexes the given Substrait expression or root (if expr is empty) using the given // ReferenceSegment. -static Result> MakeDirectReference( - std::unique_ptr&& expr, - std::unique_ptr&& ref_segment) { +static Result> MakeDirectReference( + std::unique_ptr<::substrait::Expression>&& expr, + std::unique_ptr<::substrait::Expression::ReferenceSegment>&& ref_segment) { // If expr is already a selection expression, add the index to its index stack. if (expr && expr->has_selection() && expr->selection().has_direct_reference()) { auto selection = expr->mutable_selection(); @@ -903,67 +911,67 @@ static Result> MakeDirectReference( } } - auto selection = std::make_unique(); + auto selection = std::make_unique<::substrait::Expression::FieldReference>(); selection->set_allocated_direct_reference(ref_segment.release()); - if (expr && expr->rex_type_case() != substrait::Expression::REX_TYPE_NOT_SET) { + if (expr && expr->rex_type_case() != ::substrait::Expression::REX_TYPE_NOT_SET) { selection->set_allocated_expression(expr.release()); } else { selection->set_allocated_root_reference( - new substrait::Expression::FieldReference::RootReference()); + new ::substrait::Expression::FieldReference::RootReference()); } - auto out = std::make_unique(); + auto out = std::make_unique<::substrait::Expression>(); out->set_allocated_selection(selection.release()); return std::move(out); } // Indexes the given Substrait struct-typed expression or root (if expr is empty) using // the given field index. -static Result> MakeStructFieldReference( - std::unique_ptr&& expr, int field) { +static Result> MakeStructFieldReference( + std::unique_ptr<::substrait::Expression>&& expr, int field) { auto struct_field = - std::make_unique(); + std::make_unique<::substrait::Expression::ReferenceSegment::StructField>(); struct_field->set_field(field); - auto ref_segment = std::make_unique(); + auto ref_segment = std::make_unique<::substrait::Expression::ReferenceSegment>(); ref_segment->set_allocated_struct_field(struct_field.release()); return MakeDirectReference(std::move(expr), std::move(ref_segment)); } // Indexes the given Substrait list-typed expression using the given offset. -static Result> MakeListElementReference( - std::unique_ptr&& expr, int offset) { +static Result> MakeListElementReference( + std::unique_ptr<::substrait::Expression>&& expr, int offset) { auto list_element = - std::make_unique(); + std::make_unique<::substrait::Expression::ReferenceSegment::ListElement>(); list_element->set_offset(offset); - auto ref_segment = std::make_unique(); + auto ref_segment = std::make_unique<::substrait::Expression::ReferenceSegment>(); ref_segment->set_allocated_list_element(list_element.release()); return MakeDirectReference(std::move(expr), std::move(ref_segment)); } -Result> EncodeSubstraitCall( +Result> EncodeSubstraitCall( const SubstraitCall& call, ExtensionSet* ext_set, const ConversionOptions& conversion_options) { ARROW_ASSIGN_OR_RAISE(uint32_t anchor, ext_set->EncodeFunction(call.id())); - auto scalar_fn = std::make_unique(); + auto scalar_fn = std::make_unique<::substrait::Expression::ScalarFunction>(); scalar_fn->set_function_reference(anchor); ARROW_ASSIGN_OR_RAISE( - std::unique_ptr output_type, + std::unique_ptr<::substrait::Type> output_type, ToProto(*call.output_type(), call.output_nullable(), ext_set, conversion_options)); scalar_fn->set_allocated_output_type(output_type.release()); for (int i = 0; i < call.size(); i++) { - substrait::FunctionArgument* arg = scalar_fn->add_arguments(); + ::substrait::FunctionArgument* arg = scalar_fn->add_arguments(); if (call.HasEnumArg(i)) { ARROW_ASSIGN_OR_RAISE(std::string_view enum_val, call.GetEnumArg(i)); arg->set_enum_(std::string(enum_val)); } else if (call.HasValueArg(i)) { ARROW_ASSIGN_OR_RAISE(compute::Expression value_arg, call.GetValueArg(i)); - ARROW_ASSIGN_OR_RAISE(std::unique_ptr value_expr, + ARROW_ASSIGN_OR_RAISE(std::unique_ptr<::substrait::Expression> value_expr, ToProto(value_arg, ext_set, conversion_options)); arg->set_allocated_value(value_expr.release()); } else { @@ -974,14 +982,14 @@ Result> EncodeSubstraitCa return std::move(scalar_fn); } -Result> ToProto( +Result> ToProto( const compute::Expression& expr, ExtensionSet* ext_set, const ConversionOptions& conversion_options) { if (!expr.IsBound()) { return Status::Invalid("ToProto requires a bound Expression"); } - auto out = std::make_unique(); + auto out = std::make_unique<::substrait::Expression>(); if (auto datum = expr.literal()) { ARROW_ASSIGN_OR_RAISE(auto literal, ToProto(*datum, ext_set, conversion_options)); @@ -1006,11 +1014,11 @@ Result> ToProto( auto conditions = call->arguments[0].call(); if (conditions && conditions->function_name == "make_struct") { // catch the special case of calls convertible to IfThen - auto if_then_ = std::make_unique(); + auto if_then_ = std::make_unique<::substrait::Expression::IfThen>(); // don't try to convert argument 0 of the case_when; we have to convert the elements // of make_struct individually - std::vector> arguments( + std::vector> arguments( call->arguments.size() - 1); for (size_t i = 1; i < call->arguments.size(); ++i) { ARROW_ASSIGN_OR_RAISE(arguments[i - 1], @@ -1020,7 +1028,7 @@ Result> ToProto( for (size_t i = 0; i < conditions->arguments.size(); ++i) { ARROW_ASSIGN_OR_RAISE(auto cond_substrait, ToProto(conditions->arguments[i], ext_set, conversion_options)); - auto clause = std::make_unique(); + auto clause = std::make_unique<::substrait::Expression::IfThen::IfClause>(); clause->set_allocated_if_(cond_substrait.release()); clause->set_allocated_then(arguments[i].release()); if_then_->mutable_ifs()->AddAllocated(clause.release()); @@ -1035,7 +1043,7 @@ Result> ToProto( // the remaining function pattern matchers only convert the function itself, so we // should be able to convert all its arguments first here - std::vector> arguments(call->arguments.size()); + std::vector> arguments(call->arguments.size()); for (size_t i = 0; i < arguments.size(); ++i) { ARROW_ASSIGN_OR_RAISE(arguments[i], ToProto(call->arguments[i], ext_set, conversion_options)); @@ -1061,7 +1069,7 @@ Result> ToProto( if (arguments[0]->has_selection() && arguments[0]->selection().has_direct_reference()) { if (arguments[1]->has_literal() && arguments[1]->literal().literal_type_case() == - substrait::Expression::Literal::kI32) { + ::substrait::Expression::Literal::kI32) { return MakeListElementReference(std::move(arguments[0]), arguments[1]->literal().i32()); } @@ -1070,11 +1078,11 @@ Result> ToProto( if (call->function_name == "if_else") { // catch the special case of calls convertible to IfThen - auto if_clause = std::make_unique(); + auto if_clause = std::make_unique<::substrait::Expression::IfThen::IfClause>(); if_clause->set_allocated_if_(arguments[0].release()); if_clause->set_allocated_then(arguments[1].release()); - auto if_then = std::make_unique(); + auto if_then = std::make_unique<::substrait::Expression::IfThen>(); if_then->mutable_ifs()->AddAllocated(if_clause.release()); if_then->set_allocated_else_(arguments[2].release()); @@ -1087,8 +1095,9 @@ Result> ToProto( ExtensionIdRegistry::ArrowToSubstraitCall converter, ext_set->registry()->GetArrowToSubstraitCall(call->function_name)); ARROW_ASSIGN_OR_RAISE(SubstraitCall substrait_call, converter(*call)); - ARROW_ASSIGN_OR_RAISE(std::unique_ptr scalar_fn, - EncodeSubstraitCall(substrait_call, ext_set, conversion_options)); + ARROW_ASSIGN_OR_RAISE( + std::unique_ptr<::substrait::Expression::ScalarFunction> scalar_fn, + EncodeSubstraitCall(substrait_call, ext_set, conversion_options)); out->set_allocated_scalar_function(scalar_fn.release()); return std::move(out); } diff --git a/cpp/src/arrow/engine/substrait/expression_internal.h b/cpp/src/arrow/engine/substrait/expression_internal.h index e947537dd1e..65b3f41ddfc 100644 --- a/cpp/src/arrow/engine/substrait/expression_internal.h +++ b/cpp/src/arrow/engine/substrait/expression_internal.h @@ -32,28 +32,25 @@ namespace arrow { namespace engine { -class SubstraitCall; - ARROW_ENGINE_EXPORT -Result FromProto(const substrait::Expression&, const ExtensionSet&, +Result FromProto(const ::substrait::Expression&, const ExtensionSet&, const ConversionOptions&); ARROW_ENGINE_EXPORT -Result> ToProto(const compute::Expression&, - ExtensionSet*, - const ConversionOptions&); +Result> ToProto(const compute::Expression&, + ExtensionSet*, + const ConversionOptions&); ARROW_ENGINE_EXPORT -Result FromProto(const substrait::Expression::Literal&, const ExtensionSet&, +Result FromProto(const ::substrait::Expression::Literal&, const ExtensionSet&, const ConversionOptions&); ARROW_ENGINE_EXPORT -Result> ToProto(const Datum&, - ExtensionSet*, - const ConversionOptions&); +Result> ToProto( + const Datum&, ExtensionSet*, const ConversionOptions&); ARROW_ENGINE_EXPORT -Result FromProto(const substrait::AggregateFunction&, bool is_hash, +Result FromProto(const ::substrait::AggregateFunction&, bool is_hash, const ExtensionSet&, const ConversionOptions&); } // namespace engine diff --git a/cpp/src/arrow/engine/substrait/options.cc b/cpp/src/arrow/engine/substrait/options.cc index 0fe7527ee83..9dfd4d7856a 100644 --- a/cpp/src/arrow/engine/substrait/options.cc +++ b/cpp/src/arrow/engine/substrait/options.cc @@ -29,15 +29,13 @@ namespace arrow { namespace engine { -namespace substrait = ::substrait; - class DefaultExtensionProvider : public ExtensionProvider { public: Result MakeRel(const std::vector& inputs, const google::protobuf::Any& rel, const ExtensionSet& ext_set) override { - if (rel.Is()) { - arrow::substrait::AsOfJoinRel as_of_join_rel; + if (rel.Is()) { + arrow::substrait_ext::AsOfJoinRel as_of_join_rel; rel.UnpackTo(&as_of_join_rel); return MakeAsOfJoinRel(inputs, as_of_join_rel, ext_set); } @@ -48,13 +46,14 @@ class DefaultExtensionProvider : public ExtensionProvider { private: Result MakeAsOfJoinRel( const std::vector& inputs, - const arrow::substrait::AsOfJoinRel& as_of_join_rel, const ExtensionSet& ext_set) { + const arrow::substrait_ext::AsOfJoinRel& as_of_join_rel, + const ExtensionSet& ext_set) { if (inputs.size() < 2) { - return Status::Invalid("substrait::AsOfJoinNode too few input tables: ", + return Status::Invalid("substrait_ext::AsOfJoinNode too few input tables: ", inputs.size()); } if (static_cast(as_of_join_rel.keys_size()) != inputs.size()) { - return Status::Invalid("substrait::AsOfJoinNode mismatched number of inputs"); + return Status::Invalid("substrait_ext::AsOfJoinNode mismatched number of inputs"); } size_t n_input = inputs.size(), i = 0; @@ -62,12 +61,13 @@ class DefaultExtensionProvider : public ExtensionProvider { for (const auto& keys : as_of_join_rel.keys()) { // on-key if (!keys.has_on()) { - return Status::Invalid("substrait::AsOfJoinNode missing on-key for input ", i); + return Status::Invalid("substrait_ext::AsOfJoinNode missing on-key for input ", + i); } ARROW_ASSIGN_OR_RAISE(auto on_key_expr, FromProto(keys.on(), ext_set, {})); if (on_key_expr.field_ref() == NULLPTR) { return Status::NotImplemented( - "substrait::AsOfJoinNode non-field-ref on-key for input ", i); + "substrait_ext::AsOfJoinNode non-field-ref on-key for input ", i); } const FieldRef& on_key = *on_key_expr.field_ref(); @@ -77,7 +77,7 @@ class DefaultExtensionProvider : public ExtensionProvider { ARROW_ASSIGN_OR_RAISE(auto by_key_expr, FromProto(by_item, ext_set, {})); if (by_key_expr.field_ref() == NULLPTR) { return Status::NotImplemented( - "substrait::AsOfJoinNode non-field-ref by-key for input ", i); + "substrait_ext::AsOfJoinNode non-field-ref by-key for input ", i); } by_key.push_back(*by_key_expr.field_ref()); } diff --git a/cpp/src/arrow/engine/substrait/plan_internal.cc b/cpp/src/arrow/engine/substrait/plan_internal.cc index f787a0206e9..dfec4fa7336 100644 --- a/cpp/src/arrow/engine/substrait/plan_internal.cc +++ b/cpp/src/arrow/engine/substrait/plan_internal.cc @@ -42,9 +42,7 @@ using internal::checked_cast; namespace engine { -namespace substrait = ::substrait; - -Status AddExtensionSetToPlan(const ExtensionSet& ext_set, substrait::Plan* plan) { +Status AddExtensionSetToPlan(const ExtensionSet& ext_set, ::substrait::Plan* plan) { plan->clear_extension_uris(); std::unordered_map map; @@ -55,7 +53,7 @@ Status AddExtensionSetToPlan(const ExtensionSet& ext_set, substrait::Plan* plan) auto uri = ext_set.uris().at(anchor); if (uri.empty()) continue; - auto ext_uri = std::make_unique(); + auto ext_uri = std::make_unique<::substrait::extensions::SimpleExtensionURI>(); ext_uri->set_uri(std::string(uri)); ext_uri->set_extension_uri_anchor(anchor); uris->AddAllocated(ext_uri.release()); @@ -66,7 +64,7 @@ Status AddExtensionSetToPlan(const ExtensionSet& ext_set, substrait::Plan* plan) auto extensions = plan->mutable_extensions(); extensions->Reserve(static_cast(ext_set.num_types() + ext_set.num_functions())); - using ExtDecl = substrait::extensions::SimpleExtensionDeclaration; + using ExtDecl = ::substrait::extensions::SimpleExtensionDeclaration; for (uint32_t anchor = 0; anchor < ext_set.num_types(); ++anchor) { ARROW_ASSIGN_OR_RAISE(auto type_record, ext_set.DecodeType(anchor)); @@ -98,7 +96,7 @@ Status AddExtensionSetToPlan(const ExtensionSet& ext_set, substrait::Plan* plan) return Status::OK(); } -Result GetExtensionSetFromPlan(const substrait::Plan& plan, +Result GetExtensionSetFromPlan(const ::substrait::Plan& plan, const ConversionOptions& conversion_options, const ExtensionIdRegistry* registry) { if (registry == NULLPTR) { @@ -116,18 +114,18 @@ Result GetExtensionSetFromPlan(const substrait::Plan& plan, std::unordered_map type_ids, function_ids; for (const auto& ext : plan.extensions()) { switch (ext.mapping_type_case()) { - case substrait::extensions::SimpleExtensionDeclaration::kExtensionTypeVariation: { + case ::substrait::extensions::SimpleExtensionDeclaration::kExtensionTypeVariation: { return Status::NotImplemented("Type Variations are not yet implemented"); } - case substrait::extensions::SimpleExtensionDeclaration::kExtensionType: { + case ::substrait::extensions::SimpleExtensionDeclaration::kExtensionType: { const auto& type = ext.extension_type(); std::string_view uri = uris[type.extension_uri_reference()]; type_ids[type.type_anchor()] = Id{uri, type.name()}; break; } - case substrait::extensions::SimpleExtensionDeclaration::kExtensionFunction: { + case ::substrait::extensions::SimpleExtensionDeclaration::kExtensionFunction: { const auto& fn = ext.extension_function(); std::string_view uri = uris[fn.extension_uri_reference()]; function_ids[fn.function_anchor()] = Id{uri, fn.name()}; @@ -150,8 +148,8 @@ constexpr uint32_t kSubstraitMajorVersion = 0; constexpr uint32_t kSubstraitMinorVersion = 20; constexpr uint32_t kSubstraitPatchVersion = 0; -std::unique_ptr CreateVersion() { - auto version = std::make_unique(); +std::unique_ptr<::substrait::Version> CreateVersion() { + auto version = std::make_unique<::substrait::Version>(); version->set_major_number(kSubstraitMajorVersion); version->set_minor_number(kSubstraitMinorVersion); version->set_patch_number(kSubstraitPatchVersion); @@ -161,13 +159,13 @@ std::unique_ptr CreateVersion() { } // namespace -Result> PlanToProto( +Result> PlanToProto( const compute::Declaration& declr, ExtensionSet* ext_set, const ConversionOptions& conversion_options) { - auto subs_plan = std::make_unique(); + auto subs_plan = std::make_unique<::substrait::Plan>(); subs_plan->set_allocated_version(CreateVersion().release()); - auto plan_rel = std::make_unique(); - auto rel_root = std::make_unique(); + auto plan_rel = std::make_unique<::substrait::PlanRel>(); + auto rel_root = std::make_unique<::substrait::RelRoot>(); ARROW_ASSIGN_OR_RAISE(auto rel, ToProto(declr, ext_set, conversion_options)); rel_root->set_allocated_input(rel.release()); plan_rel->set_allocated_root(rel_root.release()); diff --git a/cpp/src/arrow/engine/substrait/plan_internal.h b/cpp/src/arrow/engine/substrait/plan_internal.h index c2094ae1e61..0299352bfc1 100644 --- a/cpp/src/arrow/engine/substrait/plan_internal.h +++ b/cpp/src/arrow/engine/substrait/plan_internal.h @@ -33,8 +33,6 @@ namespace arrow { namespace engine { -namespace substrait = ::substrait; - /// \brief Replaces the extension information of a Substrait Plan message with the given /// extension set, such that the anchors defined therein can be used in the rest of the /// plan. @@ -43,7 +41,7 @@ namespace substrait = ::substrait; /// \param[in,out] plan the Substrait plan message that is to be updated /// \return success or failure ARROW_ENGINE_EXPORT -Status AddExtensionSetToPlan(const ExtensionSet& ext_set, substrait::Plan* plan); +Status AddExtensionSetToPlan(const ExtensionSet& ext_set, ::substrait::Plan* plan); /// \brief Interprets the extension information of a Substrait Plan message into an /// ExtensionSet. @@ -55,10 +53,10 @@ Status AddExtensionSetToPlan(const ExtensionSet& ext_set, substrait::Plan* plan) /// correspond to Substrait's URI/name pairs ARROW_ENGINE_EXPORT Result GetExtensionSetFromPlan( - const substrait::Plan& plan, const ConversionOptions& conversion_options, + const ::substrait::Plan& plan, const ConversionOptions& conversion_options, const ExtensionIdRegistry* registry = default_extension_id_registry()); -/// \brief Serialize a declaration into a substrait::Plan. +/// \brief Serialize a declaration into a ::substrait::Plan. /// /// Note that, this is a part of a roundtripping test API and not /// designed for use in production @@ -66,7 +64,7 @@ Result GetExtensionSetFromPlan( /// \param[in, out] ext_set the extension set to be updated /// \param[in] conversion_options options to control serialization behavior /// \return the serialized plan -ARROW_ENGINE_EXPORT Result> PlanToProto( +ARROW_ENGINE_EXPORT Result> PlanToProto( const compute::Declaration& declr, ExtensionSet* ext_set, const ConversionOptions& conversion_options = {}); diff --git a/cpp/src/arrow/engine/substrait/relation_internal.cc b/cpp/src/arrow/engine/substrait/relation_internal.cc index 5e48000643e..b9b5016d24c 100644 --- a/cpp/src/arrow/engine/substrait/relation_internal.cc +++ b/cpp/src/arrow/engine/substrait/relation_internal.cc @@ -69,8 +69,6 @@ using internal::UriFromAbsolutePath; namespace engine { -namespace substrait = ::substrait; - struct EmitInfo { std::vector expressions; std::shared_ptr schema; @@ -100,9 +98,9 @@ Result ProcessEmit(const RelMessage& rel, const std::shared_ptr& schema) { if (rel.has_common()) { switch (rel.common().emit_kind_case()) { - case substrait::RelCommon::EmitKindCase::kDirect: + case ::substrait::RelCommon::EmitKindCase::kDirect: return no_emit_declr; - case substrait::RelCommon::EmitKindCase::kEmit: { + case ::substrait::RelCommon::EmitKindCase::kEmit: { ARROW_ASSIGN_OR_RAISE(auto emit_info, GetEmitInfo(rel, schema)); return DeclarationInfo{ compute::Declaration::Sequence( @@ -125,10 +123,10 @@ Status CheckRelCommon(const RelMessage& rel, if (rel.has_common()) { if (rel.common().has_hint() && conversion_options.strictness == ConversionStrictness::EXACT_ROUNDTRIP) { - return Status::NotImplemented("substrait::RelCommon::Hint"); + return Status::NotImplemented("::substrait::RelCommon::Hint"); } if (rel.common().has_advanced_extension()) { - return Status::NotImplemented("substrait::RelCommon::advanced_extension"); + return Status::NotImplemented("::substrait::RelCommon::advanced_extension"); } } if (rel.has_advanced_extension()) { @@ -155,7 +153,8 @@ Status DiscoverFilesFromDir(const std::shared_ptr& local_fs return Status::OK(); } -Result FromProto(const substrait::Rel& rel, const ExtensionSet& ext_set, +Result FromProto(const ::substrait::Rel& rel, + const ExtensionSet& ext_set, const ConversionOptions& conversion_options) { static bool dataset_init = false; if (!dataset_init) { @@ -164,7 +163,7 @@ Result FromProto(const substrait::Rel& rel, const ExtensionSet& } switch (rel.rel_type_case()) { - case substrait::Rel::RelTypeCase::kRead: { + case ::substrait::Rel::RelTypeCase::kRead: { const auto& read = rel.read(); RETURN_NOT_OK(CheckRelCommon(read, conversion_options)); @@ -181,7 +180,7 @@ Result FromProto(const substrait::Rel& rel, const ExtensionSet& } if (read.has_projection()) { - return Status::NotImplemented("substrait::ReadRel::projection"); + return Status::NotImplemented("::substrait::ReadRel::projection"); } if (read.has_named_table()) { @@ -197,7 +196,7 @@ Result FromProto(const substrait::Rel& rel, const ExtensionSet& const NamedTableProvider& named_table_provider = conversion_options.named_table_provider; - const substrait::ReadRel::NamedTable& named_table = read.named_table(); + const ::substrait::ReadRel::NamedTable& named_table = read.named_table(); std::vector table_names(named_table.names().begin(), named_table.names().end()); ARROW_ASSIGN_OR_RAISE(compute::Declaration source_decl, @@ -214,12 +213,12 @@ Result FromProto(const substrait::Rel& rel, const ExtensionSet& if (!read.has_local_files()) { return Status::NotImplemented( - "substrait::ReadRel with read_type other than LocalFiles"); + "::substrait::ReadRel with read_type other than LocalFiles"); } if (read.local_files().has_advanced_extension()) { return Status::NotImplemented( - "substrait::ReadRel::LocalFiles::advanced_extension"); + "::substrait::ReadRel::LocalFiles::advanced_extension"); } std::shared_ptr format; @@ -230,31 +229,32 @@ Result FromProto(const substrait::Rel& rel, const ExtensionSet& // Validate properties of the `FileOrFiles` item if (item.partition_index() != 0) { return Status::NotImplemented( - "non-default substrait::ReadRel::LocalFiles::FileOrFiles::partition_index"); + "non-default " + "::substrait::ReadRel::LocalFiles::FileOrFiles::partition_index"); } if (item.start() != 0) { return Status::NotImplemented( - "non-default substrait::ReadRel::LocalFiles::FileOrFiles::start offset"); + "non-default ::substrait::ReadRel::LocalFiles::FileOrFiles::start offset"); } if (item.length() != 0) { return Status::NotImplemented( - "non-default substrait::ReadRel::LocalFiles::FileOrFiles::length"); + "non-default ::substrait::ReadRel::LocalFiles::FileOrFiles::length"); } // Extract and parse the read relation's source URI ::arrow::internal::Uri item_uri; switch (item.path_type_case()) { - case substrait::ReadRel::LocalFiles::FileOrFiles::kUriPath: + case ::substrait::ReadRel::LocalFiles::FileOrFiles::kUriPath: RETURN_NOT_OK(item_uri.Parse(item.uri_path())); break; - case substrait::ReadRel::LocalFiles::FileOrFiles::kUriFile: + case ::substrait::ReadRel::LocalFiles::FileOrFiles::kUriFile: RETURN_NOT_OK(item_uri.Parse(item.uri_file())); break; - case substrait::ReadRel::LocalFiles::FileOrFiles::kUriFolder: + case ::substrait::ReadRel::LocalFiles::FileOrFiles::kUriFolder: RETURN_NOT_OK(item_uri.Parse(item.uri_folder())); break; @@ -265,49 +265,49 @@ Result FromProto(const substrait::Rel& rel, const ExtensionSet& // Validate the URI before processing if (!item_uri.is_file_scheme()) { - return Status::NotImplemented("substrait::ReadRel::LocalFiles item (", + return Status::NotImplemented("::substrait::ReadRel::LocalFiles item (", item_uri.ToString(), ") does not have file scheme (file:///)"); } if (item_uri.port() != -1) { - return Status::NotImplemented("substrait::ReadRel::LocalFiles item (", + return Status::NotImplemented("::substrait::ReadRel::LocalFiles item (", item_uri.ToString(), ") should not have a port number in path"); } if (!item_uri.query_string().empty()) { - return Status::NotImplemented("substrait::ReadRel::LocalFiles item (", + return Status::NotImplemented("::substrait::ReadRel::LocalFiles item (", item_uri.ToString(), ") should not have a query string in path"); } switch (item.file_format_case()) { - case substrait::ReadRel::LocalFiles::FileOrFiles::kParquet: + case ::substrait::ReadRel::LocalFiles::FileOrFiles::kParquet: format = std::make_shared(); break; - case substrait::ReadRel::LocalFiles::FileOrFiles::kArrow: + case ::substrait::ReadRel::LocalFiles::FileOrFiles::kArrow: format = std::make_shared(); break; default: return Status::NotImplemented( "unsupported file format ", - "(see substrait::ReadRel::LocalFiles::FileOrFiles::file_format)"); + "(see ::substrait::ReadRel::LocalFiles::FileOrFiles::file_format)"); } // Handle the URI as appropriate switch (item.path_type_case()) { - case substrait::ReadRel::LocalFiles::FileOrFiles::kUriFile: { + case ::substrait::ReadRel::LocalFiles::FileOrFiles::kUriFile: { files.emplace_back(item_uri.path(), fs::FileType::File); break; } - case substrait::ReadRel::LocalFiles::FileOrFiles::kUriFolder: { + case ::substrait::ReadRel::LocalFiles::FileOrFiles::kUriFolder: { RETURN_NOT_OK(DiscoverFilesFromDir(filesystem, item_uri.path(), &files)); break; } - case substrait::ReadRel::LocalFiles::FileOrFiles::kUriPath: { + case ::substrait::ReadRel::LocalFiles::FileOrFiles::kUriPath: { ARROW_ASSIGN_OR_RAISE(auto file_info, filesystem->GetFileInfo(item_uri.path())); @@ -330,7 +330,7 @@ Result FromProto(const substrait::Rel& rel, const ExtensionSet& break; } - case substrait::ReadRel::LocalFiles::FileOrFiles::kUriPathGlob: { + case ::substrait::ReadRel::LocalFiles::FileOrFiles::kUriPathGlob: { ARROW_ASSIGN_OR_RAISE(auto globbed_files, fs::internal::GlobFiles(filesystem, item_uri.path())); std::move(globbed_files.begin(), globbed_files.end(), @@ -358,18 +358,18 @@ Result FromProto(const substrait::Rel& rel, const ExtensionSet& std::move(base_schema)); } - case substrait::Rel::RelTypeCase::kFilter: { + case ::substrait::Rel::RelTypeCase::kFilter: { const auto& filter = rel.filter(); RETURN_NOT_OK(CheckRelCommon(filter, conversion_options)); if (!filter.has_input()) { - return Status::Invalid("substrait::FilterRel with no input relation"); + return Status::Invalid("::substrait::FilterRel with no input relation"); } ARROW_ASSIGN_OR_RAISE(auto input, FromProto(filter.input(), ext_set, conversion_options)); if (!filter.has_condition()) { - return Status::Invalid("substrait::FilterRel with no condition expression"); + return Status::Invalid("::substrait::FilterRel with no condition expression"); } ARROW_ASSIGN_OR_RAISE(auto condition, FromProto(filter.condition(), ext_set, conversion_options)); @@ -384,11 +384,11 @@ Result FromProto(const substrait::Rel& rel, const ExtensionSet& input.output_schema); } - case substrait::Rel::RelTypeCase::kProject: { + case ::substrait::Rel::RelTypeCase::kProject: { const auto& project = rel.project(); RETURN_NOT_OK(CheckRelCommon(project, conversion_options)); if (!project.has_input()) { - return Status::Invalid("substrait::ProjectRel with no input relation"); + return Status::Invalid("::substrait::ProjectRel with no input relation"); } ARROW_ASSIGN_OR_RAISE(auto input, FromProto(project.input(), ext_set, conversion_options)); @@ -437,38 +437,38 @@ Result FromProto(const substrait::Rel& rel, const ExtensionSet& std::move(project_schema)); } - case substrait::Rel::RelTypeCase::kJoin: { + case ::substrait::Rel::RelTypeCase::kJoin: { const auto& join = rel.join(); RETURN_NOT_OK(CheckRelCommon(join, conversion_options)); if (!join.has_left()) { - return Status::Invalid("substrait::JoinRel with no left relation"); + return Status::Invalid("::substrait::JoinRel with no left relation"); } if (!join.has_right()) { - return Status::Invalid("substrait::JoinRel with no right relation"); + return Status::Invalid("::substrait::JoinRel with no right relation"); } compute::JoinType join_type; switch (join.type()) { - case substrait::JoinRel::JOIN_TYPE_UNSPECIFIED: + case ::substrait::JoinRel::JOIN_TYPE_UNSPECIFIED: return Status::NotImplemented("Unspecified join type is not supported"); - case substrait::JoinRel::JOIN_TYPE_INNER: + case ::substrait::JoinRel::JOIN_TYPE_INNER: join_type = compute::JoinType::INNER; break; - case substrait::JoinRel::JOIN_TYPE_OUTER: + case ::substrait::JoinRel::JOIN_TYPE_OUTER: join_type = compute::JoinType::FULL_OUTER; break; - case substrait::JoinRel::JOIN_TYPE_LEFT: + case ::substrait::JoinRel::JOIN_TYPE_LEFT: join_type = compute::JoinType::LEFT_OUTER; break; - case substrait::JoinRel::JOIN_TYPE_RIGHT: + case ::substrait::JoinRel::JOIN_TYPE_RIGHT: join_type = compute::JoinType::RIGHT_OUTER; break; - case substrait::JoinRel::JOIN_TYPE_SEMI: + case ::substrait::JoinRel::JOIN_TYPE_SEMI: join_type = compute::JoinType::LEFT_SEMI; break; - case substrait::JoinRel::JOIN_TYPE_ANTI: + case ::substrait::JoinRel::JOIN_TYPE_ANTI: join_type = compute::JoinType::LEFT_ANTI; break; default: @@ -481,7 +481,7 @@ Result FromProto(const substrait::Rel& rel, const ExtensionSet& FromProto(join.right(), ext_set, conversion_options)); if (!join.has_expression()) { - return Status::Invalid("substrait::JoinRel with no expression"); + return Status::Invalid("::substrait::JoinRel with no expression"); } ARROW_ASSIGN_OR_RAISE(auto expression, @@ -543,12 +543,12 @@ Result FromProto(const substrait::Rel& rel, const ExtensionSet& return ProcessEmit(std::move(join), std::move(join_declaration), std::move(join_schema)); } - case substrait::Rel::RelTypeCase::kAggregate: { + case ::substrait::Rel::RelTypeCase::kAggregate: { const auto& aggregate = rel.aggregate(); RETURN_NOT_OK(CheckRelCommon(aggregate, conversion_options)); if (!aggregate.has_input()) { - return Status::Invalid("substrait::AggregateRel with no input relation"); + return Status::Invalid("::substrait::AggregateRel with no input relation"); } ARROW_ASSIGN_OR_RAISE(auto input, @@ -566,7 +566,7 @@ Result FromProto(const substrait::Rel& rel, const ExtensionSet& std::vector key_field_ids; std::vector keys; if (aggregate.groupings_size() > 0) { - const substrait::AggregateRel::Grouping& group = aggregate.groupings(0); + const ::substrait::AggregateRel::Grouping& group = aggregate.groupings(0); int grouping_expr_size = group.grouping_expressions_size(); keys.reserve(grouping_expr_size); key_field_ids.reserve(grouping_expr_size); @@ -620,7 +620,7 @@ Result FromProto(const substrait::Rel& rel, const ExtensionSet& aggregates.push_back(std::move(arrow_agg)); } else { - return Status::Invalid("substrait::AggregateFunction not provided"); + return Status::Invalid("::substrait::AggregateFunction not provided"); } } FieldVector output_fields; @@ -688,14 +688,14 @@ Result FromProto(const substrait::Rel& rel, const ExtensionSet& return ProcessEmit(std::move(set), std::move(set_declaration), std::move(union_schema)); } - case substrait::Rel::RelTypeCase::kExtensionLeaf: { + case ::substrait::Rel::RelTypeCase::kExtensionLeaf: { const auto& ext = rel.extension_leaf(); ARROW_ASSIGN_OR_RAISE( auto ext_leaf_decl, conversion_options.extension_provider->MakeRel({}, ext.detail(), ext_set)); return ProcessEmit(ext, std::move(ext_leaf_decl), ext_leaf_decl.output_schema); } - case substrait::Rel::RelTypeCase::kExtensionSingle: { + case ::substrait::Rel::RelTypeCase::kExtensionSingle: { const auto& ext = rel.extension_single(); ARROW_ASSIGN_OR_RAISE(DeclarationInfo input, FromProto(ext.input(), ext_set, conversion_options)); @@ -704,7 +704,7 @@ Result FromProto(const substrait::Rel& rel, const ExtensionSet& conversion_options.extension_provider->MakeRel({input}, ext.detail(), ext_set)); return ProcessEmit(ext, std::move(ext_single_decl), ext_single_decl.output_schema); } - case substrait::Rel::RelTypeCase::kExtensionMulti: { + case ::substrait::Rel::RelTypeCase::kExtensionMulti: { const auto& ext = rel.extension_multi(); std::vector inputs; for (const auto& input : ext.inputs()) { @@ -778,7 +778,7 @@ Result> NamedTableRelationConverter( Result> ScanRelationConverter( const std::shared_ptr& schema, const compute::Declaration& declaration, ExtensionSet* ext_set, const ConversionOptions& conversion_options) { - auto read_rel = std::make_unique(); + auto read_rel = std::make_unique<::substrait::ReadRel>(); const auto& scan_node_options = checked_cast(*declaration.options); auto dataset = @@ -793,22 +793,22 @@ Result> ScanRelationConverter( read_rel->set_allocated_base_schema(named_struct.release()); // set local files - auto read_rel_lfs = std::make_unique(); + auto read_rel_lfs = std::make_unique<::substrait::ReadRel::LocalFiles>(); for (const auto& file : dataset->files()) { auto read_rel_lfs_ffs = - std::make_unique(); + std::make_unique<::substrait::ReadRel::LocalFiles::FileOrFiles>(); read_rel_lfs_ffs->set_uri_path(UriFromAbsolutePath(file)); // set file format auto format_type_name = dataset->format()->type_name(); if (format_type_name == "parquet") { read_rel_lfs_ffs->set_allocated_parquet( - new substrait::ReadRel::LocalFiles::FileOrFiles::ParquetReadOptions()); + new ::substrait::ReadRel::LocalFiles::FileOrFiles::ParquetReadOptions()); } else if (format_type_name == "ipc") { read_rel_lfs_ffs->set_allocated_arrow( - new substrait::ReadRel::LocalFiles::FileOrFiles::ArrowReadOptions()); + new ::substrait::ReadRel::LocalFiles::FileOrFiles::ArrowReadOptions()); } else if (format_type_name == "orc") { read_rel_lfs_ffs->set_allocated_orc( - new substrait::ReadRel::LocalFiles::FileOrFiles::OrcReadOptions()); + new ::substrait::ReadRel::LocalFiles::FileOrFiles::OrcReadOptions()); } else { return Status::NotImplemented("Unsupported file type: ", format_type_name); } @@ -818,10 +818,10 @@ Result> ScanRelationConverter( return std::move(read_rel); } -Result> FilterRelationConverter( +Result> FilterRelationConverter( const std::shared_ptr& schema, const compute::Declaration& declaration, ExtensionSet* ext_set, const ConversionOptions& conversion_options) { - auto filter_rel = std::make_unique(); + auto filter_rel = std::make_unique<::substrait::FilterRel>(); const auto& filter_node_options = checked_cast(*(declaration.options)); @@ -852,7 +852,7 @@ Result> FilterRelationConverter( Status SerializeAndCombineRelations(const compute::Declaration& declaration, ExtensionSet* ext_set, - std::unique_ptr* rel, + std::unique_ptr<::substrait::Rel>* rel, const ConversionOptions& conversion_options) { const auto& factory_name = declaration.factory_name; ARROW_ASSIGN_OR_RAISE(auto schema, ExtractSchemaToBind(declaration)); @@ -889,10 +889,10 @@ Status SerializeAndCombineRelations(const compute::Declaration& declaration, return Status::OK(); } -Result> ToProto( +Result> ToProto( const compute::Declaration& declr, ExtensionSet* ext_set, const ConversionOptions& conversion_options) { - auto rel = std::make_unique(); + auto rel = std::make_unique<::substrait::Rel>(); RETURN_NOT_OK(SerializeAndCombineRelations(declr, ext_set, &rel, conversion_options)); return std::move(rel); } diff --git a/cpp/src/arrow/engine/substrait/relation_internal.h b/cpp/src/arrow/engine/substrait/relation_internal.h index ee2848122d3..ab63f5ed7f6 100644 --- a/cpp/src/arrow/engine/substrait/relation_internal.h +++ b/cpp/src/arrow/engine/substrait/relation_internal.h @@ -42,7 +42,7 @@ struct ARROW_ENGINE_EXPORT DeclarationInfo { /// \brief Convert a Substrait Rel object to an Acero declaration ARROW_ENGINE_EXPORT -Result FromProto(const substrait::Rel&, const ExtensionSet&, +Result FromProto(const ::substrait::Rel&, const ExtensionSet&, const ConversionOptions&); /// \brief Convert an Acero Declaration to a Substrait Rel @@ -51,7 +51,7 @@ Result FromProto(const substrait::Rel&, const ExtensionSet&, /// the ExecNode or ExecPlan are not used in this context as Declaration /// is preferred in the Substrait space rather than internal components of /// Acero execution engine. -ARROW_ENGINE_EXPORT Result> ToProto( +ARROW_ENGINE_EXPORT Result> ToProto( const compute::Declaration&, ExtensionSet*, const ConversionOptions&); } // namespace engine diff --git a/cpp/src/arrow/engine/substrait/serde.cc b/cpp/src/arrow/engine/substrait/serde.cc index 8290f14caf3..7c55e9b1ec7 100644 --- a/cpp/src/arrow/engine/substrait/serde.cc +++ b/cpp/src/arrow/engine/substrait/serde.cc @@ -47,8 +47,6 @@ namespace arrow { namespace engine { -namespace substrait = ::substrait; - Status ParseFromBufferImpl(const Buffer& buf, const std::string& full_name, google::protobuf::Message* message) { google::protobuf::io::ArrayInputStream buf_stream{buf.data(), @@ -88,7 +86,7 @@ Result> SerializeRelation( Result DeserializeRelation( const Buffer& buf, const ExtensionSet& ext_set, const ConversionOptions& conversion_options) { - ARROW_ASSIGN_OR_RAISE(auto rel, ParseFromBuffer(buf)); + ARROW_ASSIGN_OR_RAISE(auto rel, ParseFromBuffer<::substrait::Rel>(buf)); ARROW_ASSIGN_OR_RAISE(auto decl_info, FromProto(rel, ext_set, conversion_options)); return std::move(decl_info.declaration); } @@ -137,7 +135,7 @@ Result> DeserializePlans( const Buffer& buf, DeclarationFactory declaration_factory, const ExtensionIdRegistry* registry, ExtensionSet* ext_set_out, const ConversionOptions& conversion_options) { - ARROW_ASSIGN_OR_RAISE(auto plan, ParseFromBuffer(buf)); + ARROW_ASSIGN_OR_RAISE(auto plan, ParseFromBuffer<::substrait::Plan>(buf)); if (plan.version().major_number() < kMinimumMajorVersion && plan.version().minor_number() < kMinimumMinorVersion) { @@ -149,7 +147,7 @@ Result> DeserializePlans( GetExtensionSetFromPlan(plan, conversion_options, registry)); std::vector sink_decls; - for (const substrait::PlanRel& plan_rel : plan.relations()) { + for (const ::substrait::PlanRel& plan_rel : plan.relations()) { ARROW_ASSIGN_OR_RAISE( auto decl_info, FromProto(plan_rel.has_root() ? plan_rel.root().input() : plan_rel.rel(), ext_set, @@ -248,7 +246,8 @@ Result> DeserializePlan( Result> DeserializeSchema( const Buffer& buf, const ExtensionSet& ext_set, const ConversionOptions& conversion_options) { - ARROW_ASSIGN_OR_RAISE(auto named_struct, ParseFromBuffer(buf)); + ARROW_ASSIGN_OR_RAISE(auto named_struct, + ParseFromBuffer<::substrait::NamedStruct>(buf)); return FromProto(named_struct, ext_set, conversion_options); } @@ -263,7 +262,7 @@ Result> SerializeSchema( Result> DeserializeType( const Buffer& buf, const ExtensionSet& ext_set, const ConversionOptions& conversion_options) { - ARROW_ASSIGN_OR_RAISE(auto type, ParseFromBuffer(buf)); + ARROW_ASSIGN_OR_RAISE(auto type, ParseFromBuffer<::substrait::Type>(buf)); ARROW_ASSIGN_OR_RAISE(auto type_nullable, FromProto(type, ext_set, conversion_options)); return std::move(type_nullable.first); } @@ -280,7 +279,7 @@ Result> SerializeType( Result DeserializeExpression( const Buffer& buf, const ExtensionSet& ext_set, const ConversionOptions& conversion_options) { - ARROW_ASSIGN_OR_RAISE(auto expr, ParseFromBuffer(buf)); + ARROW_ASSIGN_OR_RAISE(auto expr, ParseFromBuffer<::substrait::Expression>(buf)); return FromProto(expr, ext_set, conversion_options); } @@ -318,11 +317,11 @@ static Status CheckMessagesEquivalent(const Buffer& l_buf, const Buffer& r_buf) Status CheckMessagesEquivalent(std::string_view message_name, const Buffer& l_buf, const Buffer& r_buf) { if (message_name == "Type") { - return CheckMessagesEquivalent(l_buf, r_buf); + return CheckMessagesEquivalent<::substrait::Type>(l_buf, r_buf); } if (message_name == "NamedStruct") { - return CheckMessagesEquivalent(l_buf, r_buf); + return CheckMessagesEquivalent<::substrait::NamedStruct>(l_buf, r_buf); } if (message_name == "Schema") { @@ -332,11 +331,11 @@ Status CheckMessagesEquivalent(std::string_view message_name, const Buffer& l_bu } if (message_name == "Expression") { - return CheckMessagesEquivalent(l_buf, r_buf); + return CheckMessagesEquivalent<::substrait::Expression>(l_buf, r_buf); } if (message_name == "Rel") { - return CheckMessagesEquivalent(l_buf, r_buf); + return CheckMessagesEquivalent<::substrait::Rel>(l_buf, r_buf); } if (message_name == "Relation") { diff --git a/cpp/src/arrow/engine/substrait/serde_test.cc b/cpp/src/arrow/engine/substrait/serde_test.cc index d302d753796..80d08926799 100644 --- a/cpp/src/arrow/engine/substrait/serde_test.cc +++ b/cpp/src/arrow/engine/substrait/serde_test.cc @@ -4104,7 +4104,7 @@ TEST(Substrait, PlanWithAsOfJoinExtension) { } ], "detail": { - "@type": "/arrow.substrait.AsOfJoinRel", + "@type": "/arrow.substrait_ext.AsOfJoinRel", "keys" : [ { "on": { diff --git a/cpp/src/arrow/engine/substrait/test_plan_builder.cc b/cpp/src/arrow/engine/substrait/test_plan_builder.cc index d2dbe77e8d6..0e28e49b7af 100644 --- a/cpp/src/arrow/engine/substrait/test_plan_builder.cc +++ b/cpp/src/arrow/engine/substrait/test_plan_builder.cc @@ -37,59 +37,58 @@ namespace arrow { namespace engine { -namespace substrait = ::substrait; namespace internal { static const ConversionOptions kPlanBuilderConversionOptions; -Result> CreateRead(const Table& table, - ExtensionSet* ext_set) { - auto read = std::make_unique(); +Result> CreateRead(const Table& table, + ExtensionSet* ext_set) { + auto read = std::make_unique<::substrait::ReadRel>(); - ARROW_ASSIGN_OR_RAISE(std::unique_ptr schema, + ARROW_ASSIGN_OR_RAISE(std::unique_ptr<::substrait::NamedStruct> schema, ToProto(*table.schema(), ext_set, kPlanBuilderConversionOptions)); read->set_allocated_base_schema(schema.release()); - auto named_table = std::make_unique(); + auto named_table = std::make_unique<::substrait::ReadRel::NamedTable>(); named_table->add_names("test"); read->set_allocated_named_table(named_table.release()); return read; } -void CreateDirectReference(int32_t index, substrait::Expression* expr) { - auto reference = std::make_unique(); - auto reference_segment = std::make_unique(); +void CreateDirectReference(int32_t index, ::substrait::Expression* expr) { + auto reference = std::make_unique<::substrait::Expression::FieldReference>(); + auto reference_segment = std::make_unique<::substrait::Expression::ReferenceSegment>(); auto struct_field = - std::make_unique(); + std::make_unique<::substrait::Expression::ReferenceSegment::StructField>(); struct_field->set_field(index); reference_segment->set_allocated_struct_field(struct_field.release()); reference->set_allocated_direct_reference(reference_segment.release()); auto root_reference = - std::make_unique(); + std::make_unique<::substrait::Expression::FieldReference::RootReference>(); reference->set_allocated_root_reference(root_reference.release()); expr->set_allocated_selection(reference.release()); } -Result> CreateProject( +Result> CreateProject( Id function_id, const std::vector& arguments, const std::unordered_map> options, const std::vector>& arg_types, const DataType& output_type, ExtensionSet* ext_set) { - auto project = std::make_unique(); + auto project = std::make_unique<::substrait::ProjectRel>(); - auto call = std::make_unique(); + auto call = std::make_unique<::substrait::Expression::ScalarFunction>(); ARROW_ASSIGN_OR_RAISE(uint32_t function_anchor, ext_set->EncodeFunction(function_id)); call->set_function_reference(function_anchor); std::size_t arg_index = 0; std::size_t table_arg_index = 0; for (const std::shared_ptr& arg_type : arg_types) { - substrait::FunctionArgument* argument = call->add_arguments(); + ::substrait::FunctionArgument* argument = call->add_arguments(); if (arg_type) { // If it has a type then it's a reference to the input table - auto expression = std::make_unique(); + auto expression = std::make_unique<::substrait::Expression>(); CreateDirectReference(static_cast(table_arg_index++), expression.get()); argument->set_allocated_value(expression.release()); } else { @@ -100,7 +99,7 @@ Result> CreateProject( arg_index++; } for (const auto& opt : options) { - substrait::FunctionOption* option = call->add_options(); + ::substrait::FunctionOption* option = call->add_options(); option->set_name(opt.first); for (const std::string& pref : opt.second) { option->add_preference(pref); @@ -108,49 +107,49 @@ Result> CreateProject( } ARROW_ASSIGN_OR_RAISE( - std::unique_ptr output_type_substrait, + std::unique_ptr<::substrait::Type> output_type_substrait, ToProto(output_type, /*nullable=*/true, ext_set, kPlanBuilderConversionOptions)); call->set_allocated_output_type(output_type_substrait.release()); - substrait::Expression* call_expression = project->add_expressions(); + ::substrait::Expression* call_expression = project->add_expressions(); call_expression->set_allocated_scalar_function(call.release()); return project; } -Result> CreateAgg(Id function_id, - const std::vector& keys, - int arg_idx, - const DataType& output_type, - ExtensionSet* ext_set) { - auto agg = std::make_unique(); +Result> CreateAgg(Id function_id, + const std::vector& keys, + int arg_idx, + const DataType& output_type, + ExtensionSet* ext_set) { + auto agg = std::make_unique<::substrait::AggregateRel>(); if (!keys.empty()) { - substrait::AggregateRel::Grouping* grouping = agg->add_groupings(); + ::substrait::AggregateRel::Grouping* grouping = agg->add_groupings(); for (int key : keys) { - substrait::Expression* key_expr = grouping->add_grouping_expressions(); + ::substrait::Expression* key_expr = grouping->add_grouping_expressions(); CreateDirectReference(key, key_expr); } } - substrait::AggregateRel::Measure* measure_wrapper = agg->add_measures(); - auto agg_func = std::make_unique(); + ::substrait::AggregateRel::Measure* measure_wrapper = agg->add_measures(); + auto agg_func = std::make_unique<::substrait::AggregateFunction>(); ARROW_ASSIGN_OR_RAISE(uint32_t function_anchor, ext_set->EncodeFunction(function_id)); agg_func->set_function_reference(function_anchor); - substrait::FunctionArgument* arg = agg_func->add_arguments(); - auto arg_expr = std::make_unique(); + ::substrait::FunctionArgument* arg = agg_func->add_arguments(); + auto arg_expr = std::make_unique<::substrait::Expression>(); CreateDirectReference(arg_idx, arg_expr.get()); arg->set_allocated_value(arg_expr.release()); - agg_func->set_phase(substrait::AggregationPhase::AGGREGATION_PHASE_INITIAL_TO_RESULT); + agg_func->set_phase(::substrait::AggregationPhase::AGGREGATION_PHASE_INITIAL_TO_RESULT); agg_func->set_invocation( - substrait::AggregateFunction::AggregationInvocation:: + ::substrait::AggregateFunction::AggregationInvocation:: AggregateFunction_AggregationInvocation_AGGREGATION_INVOCATION_ALL); ARROW_ASSIGN_OR_RAISE( - std::unique_ptr output_type_substrait, + std::unique_ptr<::substrait::Type> output_type_substrait, ToProto(output_type, /*nullable=*/true, ext_set, kPlanBuilderConversionOptions)); agg_func->set_allocated_output_type(output_type_substrait.release()); measure_wrapper->set_allocated_measure(agg_func.release()); @@ -158,8 +157,8 @@ Result> CreateAgg(Id function_id, return agg; } -std::unique_ptr CreateTestVersion() { - auto version = std::make_unique(); +std::unique_ptr<::substrait::Version> CreateTestVersion() { + auto version = std::make_unique<::substrait::Version>(); version->set_major_number(std::numeric_limits::max()); version->set_minor_number(std::numeric_limits::max()); version->set_patch_number(std::numeric_limits::max()); @@ -167,13 +166,13 @@ std::unique_ptr CreateTestVersion() { return version; } -Result> CreatePlan(std::unique_ptr root, - ExtensionSet* ext_set) { - auto plan = std::make_unique(); +Result> CreatePlan( + std::unique_ptr<::substrait::Rel> root, ExtensionSet* ext_set) { + auto plan = std::make_unique<::substrait::Plan>(); plan->set_allocated_version(CreateTestVersion().release()); - substrait::PlanRel* plan_rel = plan->add_relations(); - auto rel_root = std::make_unique(); + ::substrait::PlanRel* plan_rel = plan->add_relations(); + auto rel_root = std::make_unique<::substrait::RelRoot>(); rel_root->set_allocated_input(root.release()); plan_rel->set_allocated_root(rel_root.release()); @@ -188,20 +187,20 @@ Result> CreateScanProjectSubstrait( const std::vector>& data_types, const DataType& output_type) { ExtensionSet ext_set; - ARROW_ASSIGN_OR_RAISE(std::unique_ptr read, + ARROW_ASSIGN_OR_RAISE(std::unique_ptr<::substrait::ReadRel> read, CreateRead(*input_table, &ext_set)); ARROW_ASSIGN_OR_RAISE( - std::unique_ptr project, + std::unique_ptr<::substrait::ProjectRel> project, CreateProject(function_id, arguments, options, data_types, output_type, &ext_set)); - auto read_rel = std::make_unique(); + auto read_rel = std::make_unique<::substrait::Rel>(); read_rel->set_allocated_read(read.release()); project->set_allocated_input(read_rel.release()); - auto project_rel = std::make_unique(); + auto project_rel = std::make_unique<::substrait::Rel>(); project_rel->set_allocated_project(project.release()); - ARROW_ASSIGN_OR_RAISE(std::unique_ptr plan, + ARROW_ASSIGN_OR_RAISE(std::unique_ptr<::substrait::Plan> plan, CreatePlan(std::move(project_rel), &ext_set)); return Buffer::FromString(plan->SerializeAsString()); } @@ -211,19 +210,19 @@ Result> CreateScanAggSubstrait( const std::vector& key_idxs, int arg_idx, const DataType& output_type) { ExtensionSet ext_set; - ARROW_ASSIGN_OR_RAISE(std::unique_ptr read, + ARROW_ASSIGN_OR_RAISE(std::unique_ptr<::substrait::ReadRel> read, CreateRead(*input_table, &ext_set)); - ARROW_ASSIGN_OR_RAISE(std::unique_ptr agg, + ARROW_ASSIGN_OR_RAISE(std::unique_ptr<::substrait::AggregateRel> agg, CreateAgg(function_id, key_idxs, arg_idx, output_type, &ext_set)); - auto read_rel = std::make_unique(); + auto read_rel = std::make_unique<::substrait::Rel>(); read_rel->set_allocated_read(read.release()); agg->set_allocated_input(read_rel.release()); - auto agg_rel = std::make_unique(); + auto agg_rel = std::make_unique<::substrait::Rel>(); agg_rel->set_allocated_aggregate(agg.release()); - ARROW_ASSIGN_OR_RAISE(std::unique_ptr plan, + ARROW_ASSIGN_OR_RAISE(std::unique_ptr<::substrait::Plan> plan, CreatePlan(std::move(agg_rel), &ext_set)); return Buffer::FromString(plan->SerializeAsString()); } diff --git a/cpp/src/arrow/engine/substrait/type_internal.h b/cpp/src/arrow/engine/substrait/type_internal.h index b162e4dc2b2..33fdf1d0cc3 100644 --- a/cpp/src/arrow/engine/substrait/type_internal.h +++ b/cpp/src/arrow/engine/substrait/type_internal.h @@ -33,24 +33,23 @@ namespace arrow { namespace engine { -namespace substrait = ::substrait; - ARROW_ENGINE_EXPORT -Result, bool>> FromProto(const substrait::Type&, +Result, bool>> FromProto(const ::substrait::Type&, const ExtensionSet&, const ConversionOptions&); ARROW_ENGINE_EXPORT -Result> ToProto(const DataType&, bool nullable, - ExtensionSet*, const ConversionOptions&); +Result> ToProto(const DataType&, bool nullable, + ExtensionSet*, + const ConversionOptions&); ARROW_ENGINE_EXPORT -Result> FromProto(const substrait::NamedStruct&, +Result> FromProto(const ::substrait::NamedStruct&, const ExtensionSet&, const ConversionOptions&); ARROW_ENGINE_EXPORT -Result> ToProto(const Schema&, ExtensionSet*, - const ConversionOptions&); +Result> ToProto(const Schema&, ExtensionSet*, + const ConversionOptions&); inline std::string TimestampTzTimezoneString() { return "UTC"; } From 149380c93e252d54c2c2a59bf23afb5b47b55aad Mon Sep 17 00:00:00 2001 From: Weston Pace Date: Fri, 9 Dec 2022 18:21:44 -0800 Subject: [PATCH 5/9] Cleaning up ::substrait that was introduced again during rebase --- cpp/cmake_modules/ThirdpartyToolchain.cmake | 7 - .../engine/substrait/expression_internal.cc | 182 +++++++++--------- .../engine/substrait/expression_internal.h | 19 +- .../arrow/engine/substrait/plan_internal.cc | 26 +-- .../arrow/engine/substrait/plan_internal.h | 8 +- .../engine/substrait/relation_internal.cc | 121 ++++++------ .../engine/substrait/relation_internal.h | 12 +- cpp/src/arrow/engine/substrait/serde.cc | 21 +- .../engine/substrait/test_plan_builder.cc | 101 +++++----- .../arrow/engine/substrait/type_internal.cc | 56 +++--- .../arrow/engine/substrait/type_internal.h | 13 +- 11 files changed, 272 insertions(+), 294 deletions(-) diff --git a/cpp/cmake_modules/ThirdpartyToolchain.cmake b/cpp/cmake_modules/ThirdpartyToolchain.cmake index 6a668be20a0..3eda538fb2e 100644 --- a/cpp/cmake_modules/ThirdpartyToolchain.cmake +++ b/cpp/cmake_modules/ThirdpartyToolchain.cmake @@ -657,13 +657,6 @@ else() "${THIRDPARTY_MIRROR_URL}/snappy-${ARROW_SNAPPY_BUILD_VERSION}.tar.gz") endif() -# Remove these two lines once https://github.com/substrait-io/substrait/pull/342 merges -set(ENV{ARROW_SUBSTRAIT_URL} - "https://github.com/substrait-io/substrait/archive/e59008b6b202f8af06c2266991161b1e45cb056a.tar.gz" -) -set(ARROW_SUBSTRAIT_BUILD_SHA256_CHECKSUM - "f64629cb377fcc62c9d3e8fe69fa6a4cf326f34d756e03db84843c5cce8d04cd") - if(DEFINED ENV{ARROW_SUBSTRAIT_URL}) set(SUBSTRAIT_SOURCE_URL "$ENV{ARROW_SUBSTRAIT_URL}") else() diff --git a/cpp/src/arrow/engine/substrait/expression_internal.cc b/cpp/src/arrow/engine/substrait/expression_internal.cc index 38aa7e799ca..6caddd1cb53 100644 --- a/cpp/src/arrow/engine/substrait/expression_internal.cc +++ b/cpp/src/arrow/engine/substrait/expression_internal.cc @@ -83,7 +83,7 @@ Id NormalizeFunctionName(Id id) { } // namespace -Status DecodeArg(const ::substrait::FunctionArgument& arg, int idx, SubstraitCall* call, +Status DecodeArg(const substrait::FunctionArgument& arg, int idx, SubstraitCall* call, const ExtensionSet& ext_set, const ConversionOptions& conversion_options) { if (!arg.enum_().empty()) { @@ -100,7 +100,7 @@ Status DecodeArg(const ::substrait::FunctionArgument& arg, int idx, SubstraitCal return Status::OK(); } -Status DecodeOption(const ::substrait::FunctionOption& opt, SubstraitCall* call) { +Status DecodeOption(const substrait::FunctionOption& opt, SubstraitCall* call) { std::vector prefs; if (opt.preference_size() == 0) { return Status::Invalid("Invalid Substrait plan. The option ", opt.name(), @@ -114,7 +114,7 @@ Status DecodeOption(const ::substrait::FunctionOption& opt, SubstraitCall* call) } Result DecodeScalarFunction( - Id id, const ::substrait::Expression::ScalarFunction& scalar_fn, + Id id, const substrait::Expression::ScalarFunction& scalar_fn, const ExtensionSet& ext_set, const ConversionOptions& conversion_options) { ARROW_ASSIGN_OR_RAISE(auto output_type_and_nullable, FromProto(scalar_fn.output_type(), ext_set, conversion_options)); @@ -138,23 +138,22 @@ std::string EnumToString(int value, const google::protobuf::EnumDescriptor* desc return value_desc->name(); } -Result FromProto(const ::substrait::AggregateFunction& func, bool is_hash, +Result FromProto(const substrait::AggregateFunction& func, bool is_hash, const ExtensionSet& ext_set, const ConversionOptions& conversion_options) { - if (func.phase() != - ::substrait::AggregationPhase::AGGREGATION_PHASE_INITIAL_TO_RESULT) { + if (func.phase() != substrait::AggregationPhase::AGGREGATION_PHASE_INITIAL_TO_RESULT) { return Status::NotImplemented( "Unsupported aggregation phase '", - EnumToString(func.phase(), *::substrait::AggregationPhase_descriptor()), + EnumToString(func.phase(), *substrait::AggregationPhase_descriptor()), "'. Only INITIAL_TO_RESULT is supported"); } if (func.invocation() != - ::substrait::AggregateFunction::AggregationInvocation:: + substrait::AggregateFunction::AggregationInvocation:: AggregateFunction_AggregationInvocation_AGGREGATION_INVOCATION_ALL) { return Status::NotImplemented( "Unsupported aggregation invocation '", EnumToString(func.invocation(), - *::substrait::AggregateFunction::AggregationInvocation_descriptor()), + *substrait::AggregateFunction::AggregationInvocation_descriptor()), "'. Only AGGREGATION_INVOCATION_ALL is " "supported"); } @@ -174,17 +173,17 @@ Result FromProto(const ::substrait::AggregateFunction& func, bool return std::move(call); } -Result FromProto(const ::substrait::Expression& expr, +Result FromProto(const substrait::Expression& expr, const ExtensionSet& ext_set, const ConversionOptions& conversion_options) { switch (expr.rex_type_case()) { - case ::substrait::Expression::kLiteral: { + case substrait::Expression::kLiteral: { ARROW_ASSIGN_OR_RAISE(auto datum, FromProto(expr.literal(), ext_set, conversion_options)); return compute::literal(std::move(datum)); } - case ::substrait::Expression::kSelection: { + case substrait::Expression::kSelection: { if (!expr.selection().has_direct_reference()) break; std::optional out; @@ -196,7 +195,7 @@ Result FromProto(const ::substrait::Expression& expr, const auto* ref = &expr.selection().direct_reference(); while (ref != nullptr) { switch (ref->reference_type_case()) { - case ::substrait::Expression::ReferenceSegment::kStructField: { + case substrait::Expression::ReferenceSegment::kStructField: { auto index = ref->struct_field().field(); if (!out) { // Root StructField (column selection) @@ -225,11 +224,11 @@ Result FromProto(const ::substrait::Expression& expr, } break; } - case ::substrait::Expression::ReferenceSegment::kListElement: { + case substrait::Expression::ReferenceSegment::kListElement: { if (!out) { // Root ListField (illegal) return Status::Invalid( - "::substrait::ListElement cannot take a Relation as an argument"); + "substrait::ListElement cannot take a Relation as an argument"); } // ListField on top of an arbitrary expression @@ -257,7 +256,7 @@ Result FromProto(const ::substrait::Expression& expr, break; } - case ::substrait::Expression::kIfThen: { + case substrait::Expression::kIfThen: { const auto& if_then = expr.if_then(); if (!if_then.has_else_()) break; if (if_then.ifs_size() == 0) break; @@ -297,7 +296,7 @@ Result FromProto(const ::substrait::Expression& expr, return compute::call("case_when", std::move(args)); } - case ::substrait::Expression::kScalarFunction: { + case substrait::Expression::kScalarFunction: { const auto& scalar_fn = expr.scalar_function(); ARROW_ASSIGN_OR_RAISE(Id function_id, @@ -330,7 +329,7 @@ Result FromProto(const ::substrait::Expression& expr, expr.DebugString()); } -Result FromProto(const ::substrait::Expression::Literal& lit, +Result FromProto(const substrait::Expression::Literal& lit, const ExtensionSet& ext_set, const ConversionOptions& conversion_options) { if (lit.nullable() && @@ -342,43 +341,43 @@ Result FromProto(const ::substrait::Expression::Literal& lit, } switch (lit.literal_type_case()) { - case ::substrait::Expression::Literal::kBoolean: + case substrait::Expression::Literal::kBoolean: return Datum(lit.boolean()); - case ::substrait::Expression::Literal::kI8: + case substrait::Expression::Literal::kI8: return Datum(static_cast(lit.i8())); - case ::substrait::Expression::Literal::kI16: + case substrait::Expression::Literal::kI16: return Datum(static_cast(lit.i16())); - case ::substrait::Expression::Literal::kI32: + case substrait::Expression::Literal::kI32: return Datum(static_cast(lit.i32())); - case ::substrait::Expression::Literal::kI64: + case substrait::Expression::Literal::kI64: return Datum(static_cast(lit.i64())); - case ::substrait::Expression::Literal::kFp32: + case substrait::Expression::Literal::kFp32: return Datum(lit.fp32()); - case ::substrait::Expression::Literal::kFp64: + case substrait::Expression::Literal::kFp64: return Datum(lit.fp64()); - case ::substrait::Expression::Literal::kString: + case substrait::Expression::Literal::kString: return Datum(lit.string()); - case ::substrait::Expression::Literal::kBinary: + case substrait::Expression::Literal::kBinary: return Datum(BinaryScalar(lit.binary())); - case ::substrait::Expression::Literal::kTimestamp: + case substrait::Expression::Literal::kTimestamp: return Datum( TimestampScalar(static_cast(lit.timestamp()), TimeUnit::MICRO)); - case ::substrait::Expression::Literal::kTimestampTz: + case substrait::Expression::Literal::kTimestampTz: return Datum(TimestampScalar(static_cast(lit.timestamp_tz()), TimeUnit::MICRO, TimestampTzTimezoneString())); - case ::substrait::Expression::Literal::kDate: + case substrait::Expression::Literal::kDate: return Datum(Date32Scalar(lit.date())); - case ::substrait::Expression::Literal::kTime: + case substrait::Expression::Literal::kTime: return Datum(Time64Scalar(lit.time(), TimeUnit::MICRO)); - case ::substrait::Expression::Literal::kIntervalYearToMonth: - case ::substrait::Expression::Literal::kIntervalDayToSecond: { + case substrait::Expression::Literal::kIntervalYearToMonth: + case substrait::Expression::Literal::kIntervalDayToSecond: { Int32Builder builder; std::shared_ptr type; if (lit.has_interval_year_to_month()) { @@ -395,23 +394,23 @@ Result FromProto(const ::substrait::Expression::Literal& lit, ExtensionScalar(FixedSizeListScalar(std::move(array)), std::move(type))); } - case ::substrait::Expression::Literal::kUuid: + case substrait::Expression::Literal::kUuid: return Datum(ExtensionScalar(FixedSizeBinaryScalar(lit.uuid()), uuid())); - case ::substrait::Expression::Literal::kFixedChar: + case substrait::Expression::Literal::kFixedChar: return Datum( ExtensionScalar(FixedSizeBinaryScalar(lit.fixed_char()), fixed_char(static_cast(lit.fixed_char().size())))); - case ::substrait::Expression::Literal::kVarChar: + case substrait::Expression::Literal::kVarChar: return Datum( ExtensionScalar(StringScalar(lit.var_char().value()), varchar(static_cast(lit.var_char().length())))); - case ::substrait::Expression::Literal::kFixedBinary: + case substrait::Expression::Literal::kFixedBinary: return Datum(FixedSizeBinaryScalar(lit.fixed_binary())); - case ::substrait::Expression::Literal::kDecimal: { + case substrait::Expression::Literal::kDecimal: { if (lit.decimal().value().size() != sizeof(Decimal128)) { return Status::Invalid("Decimal literal had ", lit.decimal().value().size(), " bytes (expected ", sizeof(Decimal128), ")"); @@ -428,7 +427,7 @@ Result FromProto(const ::substrait::Expression::Literal& lit, return Datum(Decimal128Scalar(value, std::move(type))); } - case ::substrait::Expression::Literal::kStruct: { + case substrait::Expression::Literal::kStruct: { const auto& struct_ = lit.struct_(); ScalarVector fields(struct_.fields_size()); @@ -448,12 +447,12 @@ Result FromProto(const ::substrait::Expression::Literal& lit, return Datum(std::move(scalar)); } - case ::substrait::Expression::Literal::kList: { + case substrait::Expression::Literal::kList: { const auto& list = lit.list(); if (list.values_size() == 0) { return Status::Invalid( - "::substrait::Expression::Literal::List had no values; should have been an " - "::substrait::Expression::Literal::EmptyList"); + "substrait::Expression::Literal::List had no values; should have been an " + "substrait::Expression::Literal::EmptyList"); } std::shared_ptr element_type; @@ -481,12 +480,12 @@ Result FromProto(const ::substrait::Expression::Literal& lit, return Datum(ListScalar(std::move(arr))); } - case ::substrait::Expression::Literal::kMap: { + case substrait::Expression::Literal::kMap: { const auto& map = lit.map(); if (map.key_values_size() == 0) { return Status::Invalid( - "::substrait::Expression::Literal::Map had no values; should have been an " - "::substrait::Expression::Literal::EmptyMap"); + "substrait::Expression::Literal::Map had no values; should have been an " + "substrait::Expression::Literal::EmptyMap"); } std::shared_ptr key_type, value_type; @@ -542,14 +541,14 @@ Result FromProto(const ::substrait::Expression::Literal& lit, return Datum(std::make_shared(std::move(kv_arr))); } - case ::substrait::Expression::Literal::kEmptyList: { + case substrait::Expression::Literal::kEmptyList: { ARROW_ASSIGN_OR_RAISE(auto type_nullable, FromProto(lit.empty_list().type(), ext_set, conversion_options)); ARROW_ASSIGN_OR_RAISE(auto values, MakeEmptyArray(type_nullable.first)); return ListScalar{std::move(values)}; } - case ::substrait::Expression::Literal::kEmptyMap: { + case substrait::Expression::Literal::kEmptyMap: { ARROW_ASSIGN_OR_RAISE( auto key_type_nullable, FromProto(lit.empty_map().key(), ext_set, conversion_options)); @@ -572,7 +571,7 @@ Result FromProto(const ::substrait::Expression::Literal& lit, return MapScalar{std::move(key_values)}; } - case ::substrait::Expression::Literal::kNull: { + case substrait::Expression::Literal::kNull: { ARROW_ASSIGN_OR_RAISE(auto type_nullable, FromProto(lit.null(), ext_set, conversion_options)); if (!type_nullable.second) { @@ -595,10 +594,10 @@ namespace { struct ScalarToProtoImpl { Status Visit(const NullScalar& s) { return NotImplemented(s); } - using Lit = ::substrait::Expression::Literal; + using Lit = substrait::Expression::Literal; template - Status Primitive(void (::substrait::Expression::Literal::*set)(Arg), + Status Primitive(void (substrait::Expression::Literal::*set)(Arg), const PrimitiveScalar& primitive_scalar) { (lit_->*set)(static_cast(primitive_scalar.value)); return Status::OK(); @@ -819,27 +818,27 @@ struct ScalarToProtoImpl { Status Visit(const MonthDayNanoIntervalScalar& s) { return NotImplemented(s); } Status NotImplemented(const Scalar& s) { - return Status::NotImplemented("conversion to ::substrait::Expression::Literal from ", + return Status::NotImplemented("conversion to substrait::Expression::Literal from ", s.ToString()); } Status operator()(const Scalar& scalar) { return VisitScalarInline(scalar, this); } - ::substrait::Expression::Literal* lit_; + substrait::Expression::Literal* lit_; ExtensionSet* ext_set_; const ConversionOptions& conversion_options_; }; } // namespace -Result> ToProto( +Result> ToProto( const Datum& datum, ExtensionSet* ext_set, const ConversionOptions& conversion_options) { if (!datum.is_scalar()) { return Status::NotImplemented("representing ", datum.ToString(), - " as a ::substrait::Expression::Literal"); + " as a substrait::Expression::Literal"); } - auto out = std::make_unique<::substrait::Expression::Literal>(); + auto out = std::make_unique(); if (datum.scalar()->is_valid) { RETURN_NOT_OK( @@ -854,11 +853,11 @@ Result> ToProto( } static Status AddChildToReferenceSegment( - ::substrait::Expression::ReferenceSegment& segment, - std::unique_ptr<::substrait::Expression::ReferenceSegment>&& child) { + substrait::Expression::ReferenceSegment& segment, + std::unique_ptr&& child) { auto status = Status::Invalid("Attempt to add child to incomplete reference segment"); switch (segment.reference_type_case()) { - case ::substrait::Expression::ReferenceSegment::kMapKey: { + case substrait::Expression::ReferenceSegment::kMapKey: { auto map_key = segment.mutable_map_key(); if (map_key->has_child()) { status = AddChildToReferenceSegment(*map_key->mutable_child(), std::move(child)); @@ -868,7 +867,7 @@ static Status AddChildToReferenceSegment( } break; } - case ::substrait::Expression::ReferenceSegment::kStructField: { + case substrait::Expression::ReferenceSegment::kStructField: { auto struct_field = segment.mutable_struct_field(); if (struct_field->has_child()) { status = @@ -879,7 +878,7 @@ static Status AddChildToReferenceSegment( } break; } - case ::substrait::Expression::ReferenceSegment::kListElement: { + case substrait::Expression::ReferenceSegment::kListElement: { auto list_element = segment.mutable_list_element(); if (list_element->has_child()) { status = @@ -898,9 +897,9 @@ static Status AddChildToReferenceSegment( // Indexes the given Substrait expression or root (if expr is empty) using the given // ReferenceSegment. -static Result> MakeDirectReference( - std::unique_ptr<::substrait::Expression>&& expr, - std::unique_ptr<::substrait::Expression::ReferenceSegment>&& ref_segment) { +static Result> MakeDirectReference( + std::unique_ptr&& expr, + std::unique_ptr&& ref_segment) { // If expr is already a selection expression, add the index to its index stack. if (expr && expr->has_selection() && expr->selection().has_direct_reference()) { auto selection = expr->mutable_selection(); @@ -911,67 +910,67 @@ static Result> MakeDirectReference( } } - auto selection = std::make_unique<::substrait::Expression::FieldReference>(); + auto selection = std::make_unique(); selection->set_allocated_direct_reference(ref_segment.release()); - if (expr && expr->rex_type_case() != ::substrait::Expression::REX_TYPE_NOT_SET) { + if (expr && expr->rex_type_case() != substrait::Expression::REX_TYPE_NOT_SET) { selection->set_allocated_expression(expr.release()); } else { selection->set_allocated_root_reference( - new ::substrait::Expression::FieldReference::RootReference()); + new substrait::Expression::FieldReference::RootReference()); } - auto out = std::make_unique<::substrait::Expression>(); + auto out = std::make_unique(); out->set_allocated_selection(selection.release()); return std::move(out); } // Indexes the given Substrait struct-typed expression or root (if expr is empty) using // the given field index. -static Result> MakeStructFieldReference( - std::unique_ptr<::substrait::Expression>&& expr, int field) { +static Result> MakeStructFieldReference( + std::unique_ptr&& expr, int field) { auto struct_field = - std::make_unique<::substrait::Expression::ReferenceSegment::StructField>(); + std::make_unique(); struct_field->set_field(field); - auto ref_segment = std::make_unique<::substrait::Expression::ReferenceSegment>(); + auto ref_segment = std::make_unique(); ref_segment->set_allocated_struct_field(struct_field.release()); return MakeDirectReference(std::move(expr), std::move(ref_segment)); } // Indexes the given Substrait list-typed expression using the given offset. -static Result> MakeListElementReference( - std::unique_ptr<::substrait::Expression>&& expr, int offset) { +static Result> MakeListElementReference( + std::unique_ptr&& expr, int offset) { auto list_element = - std::make_unique<::substrait::Expression::ReferenceSegment::ListElement>(); + std::make_unique(); list_element->set_offset(offset); - auto ref_segment = std::make_unique<::substrait::Expression::ReferenceSegment>(); + auto ref_segment = std::make_unique(); ref_segment->set_allocated_list_element(list_element.release()); return MakeDirectReference(std::move(expr), std::move(ref_segment)); } -Result> EncodeSubstraitCall( +Result> EncodeSubstraitCall( const SubstraitCall& call, ExtensionSet* ext_set, const ConversionOptions& conversion_options) { ARROW_ASSIGN_OR_RAISE(uint32_t anchor, ext_set->EncodeFunction(call.id())); - auto scalar_fn = std::make_unique<::substrait::Expression::ScalarFunction>(); + auto scalar_fn = std::make_unique(); scalar_fn->set_function_reference(anchor); ARROW_ASSIGN_OR_RAISE( - std::unique_ptr<::substrait::Type> output_type, + std::unique_ptr output_type, ToProto(*call.output_type(), call.output_nullable(), ext_set, conversion_options)); scalar_fn->set_allocated_output_type(output_type.release()); for (int i = 0; i < call.size(); i++) { - ::substrait::FunctionArgument* arg = scalar_fn->add_arguments(); + substrait::FunctionArgument* arg = scalar_fn->add_arguments(); if (call.HasEnumArg(i)) { ARROW_ASSIGN_OR_RAISE(std::string_view enum_val, call.GetEnumArg(i)); arg->set_enum_(std::string(enum_val)); } else if (call.HasValueArg(i)) { ARROW_ASSIGN_OR_RAISE(compute::Expression value_arg, call.GetValueArg(i)); - ARROW_ASSIGN_OR_RAISE(std::unique_ptr<::substrait::Expression> value_expr, + ARROW_ASSIGN_OR_RAISE(std::unique_ptr value_expr, ToProto(value_arg, ext_set, conversion_options)); arg->set_allocated_value(value_expr.release()); } else { @@ -982,14 +981,14 @@ Result> EncodeSubstrait return std::move(scalar_fn); } -Result> ToProto( +Result> ToProto( const compute::Expression& expr, ExtensionSet* ext_set, const ConversionOptions& conversion_options) { if (!expr.IsBound()) { return Status::Invalid("ToProto requires a bound Expression"); } - auto out = std::make_unique<::substrait::Expression>(); + auto out = std::make_unique(); if (auto datum = expr.literal()) { ARROW_ASSIGN_OR_RAISE(auto literal, ToProto(*datum, ext_set, conversion_options)); @@ -1014,11 +1013,11 @@ Result> ToProto( auto conditions = call->arguments[0].call(); if (conditions && conditions->function_name == "make_struct") { // catch the special case of calls convertible to IfThen - auto if_then_ = std::make_unique<::substrait::Expression::IfThen>(); + auto if_then_ = std::make_unique(); // don't try to convert argument 0 of the case_when; we have to convert the elements // of make_struct individually - std::vector> arguments( + std::vector> arguments( call->arguments.size() - 1); for (size_t i = 1; i < call->arguments.size(); ++i) { ARROW_ASSIGN_OR_RAISE(arguments[i - 1], @@ -1028,7 +1027,7 @@ Result> ToProto( for (size_t i = 0; i < conditions->arguments.size(); ++i) { ARROW_ASSIGN_OR_RAISE(auto cond_substrait, ToProto(conditions->arguments[i], ext_set, conversion_options)); - auto clause = std::make_unique<::substrait::Expression::IfThen::IfClause>(); + auto clause = std::make_unique(); clause->set_allocated_if_(cond_substrait.release()); clause->set_allocated_then(arguments[i].release()); if_then_->mutable_ifs()->AddAllocated(clause.release()); @@ -1043,7 +1042,7 @@ Result> ToProto( // the remaining function pattern matchers only convert the function itself, so we // should be able to convert all its arguments first here - std::vector> arguments(call->arguments.size()); + std::vector> arguments(call->arguments.size()); for (size_t i = 0; i < arguments.size(); ++i) { ARROW_ASSIGN_OR_RAISE(arguments[i], ToProto(call->arguments[i], ext_set, conversion_options)); @@ -1069,7 +1068,7 @@ Result> ToProto( if (arguments[0]->has_selection() && arguments[0]->selection().has_direct_reference()) { if (arguments[1]->has_literal() && arguments[1]->literal().literal_type_case() == - ::substrait::Expression::Literal::kI32) { + substrait::Expression::Literal::kI32) { return MakeListElementReference(std::move(arguments[0]), arguments[1]->literal().i32()); } @@ -1078,11 +1077,11 @@ Result> ToProto( if (call->function_name == "if_else") { // catch the special case of calls convertible to IfThen - auto if_clause = std::make_unique<::substrait::Expression::IfThen::IfClause>(); + auto if_clause = std::make_unique(); if_clause->set_allocated_if_(arguments[0].release()); if_clause->set_allocated_then(arguments[1].release()); - auto if_then = std::make_unique<::substrait::Expression::IfThen>(); + auto if_then = std::make_unique(); if_then->mutable_ifs()->AddAllocated(if_clause.release()); if_then->set_allocated_else_(arguments[2].release()); @@ -1095,9 +1094,8 @@ Result> ToProto( ExtensionIdRegistry::ArrowToSubstraitCall converter, ext_set->registry()->GetArrowToSubstraitCall(call->function_name)); ARROW_ASSIGN_OR_RAISE(SubstraitCall substrait_call, converter(*call)); - ARROW_ASSIGN_OR_RAISE( - std::unique_ptr<::substrait::Expression::ScalarFunction> scalar_fn, - EncodeSubstraitCall(substrait_call, ext_set, conversion_options)); + ARROW_ASSIGN_OR_RAISE(std::unique_ptr scalar_fn, + EncodeSubstraitCall(substrait_call, ext_set, conversion_options)); out->set_allocated_scalar_function(scalar_fn.release()); return std::move(out); } diff --git a/cpp/src/arrow/engine/substrait/expression_internal.h b/cpp/src/arrow/engine/substrait/expression_internal.h index 65b3f41ddfc..e947537dd1e 100644 --- a/cpp/src/arrow/engine/substrait/expression_internal.h +++ b/cpp/src/arrow/engine/substrait/expression_internal.h @@ -32,25 +32,28 @@ namespace arrow { namespace engine { +class SubstraitCall; + ARROW_ENGINE_EXPORT -Result FromProto(const ::substrait::Expression&, const ExtensionSet&, +Result FromProto(const substrait::Expression&, const ExtensionSet&, const ConversionOptions&); ARROW_ENGINE_EXPORT -Result> ToProto(const compute::Expression&, - ExtensionSet*, - const ConversionOptions&); +Result> ToProto(const compute::Expression&, + ExtensionSet*, + const ConversionOptions&); ARROW_ENGINE_EXPORT -Result FromProto(const ::substrait::Expression::Literal&, const ExtensionSet&, +Result FromProto(const substrait::Expression::Literal&, const ExtensionSet&, const ConversionOptions&); ARROW_ENGINE_EXPORT -Result> ToProto( - const Datum&, ExtensionSet*, const ConversionOptions&); +Result> ToProto(const Datum&, + ExtensionSet*, + const ConversionOptions&); ARROW_ENGINE_EXPORT -Result FromProto(const ::substrait::AggregateFunction&, bool is_hash, +Result FromProto(const substrait::AggregateFunction&, bool is_hash, const ExtensionSet&, const ConversionOptions&); } // namespace engine diff --git a/cpp/src/arrow/engine/substrait/plan_internal.cc b/cpp/src/arrow/engine/substrait/plan_internal.cc index dfec4fa7336..6d12c19fcd7 100644 --- a/cpp/src/arrow/engine/substrait/plan_internal.cc +++ b/cpp/src/arrow/engine/substrait/plan_internal.cc @@ -42,7 +42,7 @@ using internal::checked_cast; namespace engine { -Status AddExtensionSetToPlan(const ExtensionSet& ext_set, ::substrait::Plan* plan) { +Status AddExtensionSetToPlan(const ExtensionSet& ext_set, substrait::Plan* plan) { plan->clear_extension_uris(); std::unordered_map map; @@ -53,7 +53,7 @@ Status AddExtensionSetToPlan(const ExtensionSet& ext_set, ::substrait::Plan* pla auto uri = ext_set.uris().at(anchor); if (uri.empty()) continue; - auto ext_uri = std::make_unique<::substrait::extensions::SimpleExtensionURI>(); + auto ext_uri = std::make_unique(); ext_uri->set_uri(std::string(uri)); ext_uri->set_extension_uri_anchor(anchor); uris->AddAllocated(ext_uri.release()); @@ -64,7 +64,7 @@ Status AddExtensionSetToPlan(const ExtensionSet& ext_set, ::substrait::Plan* pla auto extensions = plan->mutable_extensions(); extensions->Reserve(static_cast(ext_set.num_types() + ext_set.num_functions())); - using ExtDecl = ::substrait::extensions::SimpleExtensionDeclaration; + using ExtDecl = substrait::extensions::SimpleExtensionDeclaration; for (uint32_t anchor = 0; anchor < ext_set.num_types(); ++anchor) { ARROW_ASSIGN_OR_RAISE(auto type_record, ext_set.DecodeType(anchor)); @@ -96,7 +96,7 @@ Status AddExtensionSetToPlan(const ExtensionSet& ext_set, ::substrait::Plan* pla return Status::OK(); } -Result GetExtensionSetFromPlan(const ::substrait::Plan& plan, +Result GetExtensionSetFromPlan(const substrait::Plan& plan, const ConversionOptions& conversion_options, const ExtensionIdRegistry* registry) { if (registry == NULLPTR) { @@ -114,18 +114,18 @@ Result GetExtensionSetFromPlan(const ::substrait::Plan& plan, std::unordered_map type_ids, function_ids; for (const auto& ext : plan.extensions()) { switch (ext.mapping_type_case()) { - case ::substrait::extensions::SimpleExtensionDeclaration::kExtensionTypeVariation: { + case substrait::extensions::SimpleExtensionDeclaration::kExtensionTypeVariation: { return Status::NotImplemented("Type Variations are not yet implemented"); } - case ::substrait::extensions::SimpleExtensionDeclaration::kExtensionType: { + case substrait::extensions::SimpleExtensionDeclaration::kExtensionType: { const auto& type = ext.extension_type(); std::string_view uri = uris[type.extension_uri_reference()]; type_ids[type.type_anchor()] = Id{uri, type.name()}; break; } - case ::substrait::extensions::SimpleExtensionDeclaration::kExtensionFunction: { + case substrait::extensions::SimpleExtensionDeclaration::kExtensionFunction: { const auto& fn = ext.extension_function(); std::string_view uri = uris[fn.extension_uri_reference()]; function_ids[fn.function_anchor()] = Id{uri, fn.name()}; @@ -148,8 +148,8 @@ constexpr uint32_t kSubstraitMajorVersion = 0; constexpr uint32_t kSubstraitMinorVersion = 20; constexpr uint32_t kSubstraitPatchVersion = 0; -std::unique_ptr<::substrait::Version> CreateVersion() { - auto version = std::make_unique<::substrait::Version>(); +std::unique_ptr CreateVersion() { + auto version = std::make_unique(); version->set_major_number(kSubstraitMajorVersion); version->set_minor_number(kSubstraitMinorVersion); version->set_patch_number(kSubstraitPatchVersion); @@ -159,13 +159,13 @@ std::unique_ptr<::substrait::Version> CreateVersion() { } // namespace -Result> PlanToProto( +Result> PlanToProto( const compute::Declaration& declr, ExtensionSet* ext_set, const ConversionOptions& conversion_options) { - auto subs_plan = std::make_unique<::substrait::Plan>(); + auto subs_plan = std::make_unique(); subs_plan->set_allocated_version(CreateVersion().release()); - auto plan_rel = std::make_unique<::substrait::PlanRel>(); - auto rel_root = std::make_unique<::substrait::RelRoot>(); + auto plan_rel = std::make_unique(); + auto rel_root = std::make_unique(); ARROW_ASSIGN_OR_RAISE(auto rel, ToProto(declr, ext_set, conversion_options)); rel_root->set_allocated_input(rel.release()); plan_rel->set_allocated_root(rel_root.release()); diff --git a/cpp/src/arrow/engine/substrait/plan_internal.h b/cpp/src/arrow/engine/substrait/plan_internal.h index 0299352bfc1..235bf1a6ce1 100644 --- a/cpp/src/arrow/engine/substrait/plan_internal.h +++ b/cpp/src/arrow/engine/substrait/plan_internal.h @@ -41,7 +41,7 @@ namespace engine { /// \param[in,out] plan the Substrait plan message that is to be updated /// \return success or failure ARROW_ENGINE_EXPORT -Status AddExtensionSetToPlan(const ExtensionSet& ext_set, ::substrait::Plan* plan); +Status AddExtensionSetToPlan(const ExtensionSet& ext_set, substrait::Plan* plan); /// \brief Interprets the extension information of a Substrait Plan message into an /// ExtensionSet. @@ -53,10 +53,10 @@ Status AddExtensionSetToPlan(const ExtensionSet& ext_set, ::substrait::Plan* pla /// correspond to Substrait's URI/name pairs ARROW_ENGINE_EXPORT Result GetExtensionSetFromPlan( - const ::substrait::Plan& plan, const ConversionOptions& conversion_options, + const substrait::Plan& plan, const ConversionOptions& conversion_options, const ExtensionIdRegistry* registry = default_extension_id_registry()); -/// \brief Serialize a declaration into a ::substrait::Plan. +/// \brief Serialize a declaration into a substrait::Plan. /// /// Note that, this is a part of a roundtripping test API and not /// designed for use in production @@ -64,7 +64,7 @@ Result GetExtensionSetFromPlan( /// \param[in, out] ext_set the extension set to be updated /// \param[in] conversion_options options to control serialization behavior /// \return the serialized plan -ARROW_ENGINE_EXPORT Result> PlanToProto( +ARROW_ENGINE_EXPORT Result> PlanToProto( const compute::Declaration& declr, ExtensionSet* ext_set, const ConversionOptions& conversion_options = {}); diff --git a/cpp/src/arrow/engine/substrait/relation_internal.cc b/cpp/src/arrow/engine/substrait/relation_internal.cc index b9b5016d24c..0faeaec554f 100644 --- a/cpp/src/arrow/engine/substrait/relation_internal.cc +++ b/cpp/src/arrow/engine/substrait/relation_internal.cc @@ -98,9 +98,9 @@ Result ProcessEmit(const RelMessage& rel, const std::shared_ptr& schema) { if (rel.has_common()) { switch (rel.common().emit_kind_case()) { - case ::substrait::RelCommon::EmitKindCase::kDirect: + case substrait::RelCommon::EmitKindCase::kDirect: return no_emit_declr; - case ::substrait::RelCommon::EmitKindCase::kEmit: { + case substrait::RelCommon::EmitKindCase::kEmit: { ARROW_ASSIGN_OR_RAISE(auto emit_info, GetEmitInfo(rel, schema)); return DeclarationInfo{ compute::Declaration::Sequence( @@ -123,10 +123,10 @@ Status CheckRelCommon(const RelMessage& rel, if (rel.has_common()) { if (rel.common().has_hint() && conversion_options.strictness == ConversionStrictness::EXACT_ROUNDTRIP) { - return Status::NotImplemented("::substrait::RelCommon::Hint"); + return Status::NotImplemented("substrait::RelCommon::Hint"); } if (rel.common().has_advanced_extension()) { - return Status::NotImplemented("::substrait::RelCommon::advanced_extension"); + return Status::NotImplemented("substrait::RelCommon::advanced_extension"); } } if (rel.has_advanced_extension()) { @@ -153,8 +153,7 @@ Status DiscoverFilesFromDir(const std::shared_ptr& local_fs return Status::OK(); } -Result FromProto(const ::substrait::Rel& rel, - const ExtensionSet& ext_set, +Result FromProto(const substrait::Rel& rel, const ExtensionSet& ext_set, const ConversionOptions& conversion_options) { static bool dataset_init = false; if (!dataset_init) { @@ -163,7 +162,7 @@ Result FromProto(const ::substrait::Rel& rel, } switch (rel.rel_type_case()) { - case ::substrait::Rel::RelTypeCase::kRead: { + case substrait::Rel::RelTypeCase::kRead: { const auto& read = rel.read(); RETURN_NOT_OK(CheckRelCommon(read, conversion_options)); @@ -180,7 +179,7 @@ Result FromProto(const ::substrait::Rel& rel, } if (read.has_projection()) { - return Status::NotImplemented("::substrait::ReadRel::projection"); + return Status::NotImplemented("substrait::ReadRel::projection"); } if (read.has_named_table()) { @@ -196,7 +195,7 @@ Result FromProto(const ::substrait::Rel& rel, const NamedTableProvider& named_table_provider = conversion_options.named_table_provider; - const ::substrait::ReadRel::NamedTable& named_table = read.named_table(); + const substrait::ReadRel::NamedTable& named_table = read.named_table(); std::vector table_names(named_table.names().begin(), named_table.names().end()); ARROW_ASSIGN_OR_RAISE(compute::Declaration source_decl, @@ -213,12 +212,12 @@ Result FromProto(const ::substrait::Rel& rel, if (!read.has_local_files()) { return Status::NotImplemented( - "::substrait::ReadRel with read_type other than LocalFiles"); + "substrait::ReadRel with read_type other than LocalFiles"); } if (read.local_files().has_advanced_extension()) { return Status::NotImplemented( - "::substrait::ReadRel::LocalFiles::advanced_extension"); + "substrait::ReadRel::LocalFiles::advanced_extension"); } std::shared_ptr format; @@ -230,31 +229,31 @@ Result FromProto(const ::substrait::Rel& rel, if (item.partition_index() != 0) { return Status::NotImplemented( "non-default " - "::substrait::ReadRel::LocalFiles::FileOrFiles::partition_index"); + "substrait::ReadRel::LocalFiles::FileOrFiles::partition_index"); } if (item.start() != 0) { return Status::NotImplemented( - "non-default ::substrait::ReadRel::LocalFiles::FileOrFiles::start offset"); + "non-default substrait::ReadRel::LocalFiles::FileOrFiles::start offset"); } if (item.length() != 0) { return Status::NotImplemented( - "non-default ::substrait::ReadRel::LocalFiles::FileOrFiles::length"); + "non-default substrait::ReadRel::LocalFiles::FileOrFiles::length"); } // Extract and parse the read relation's source URI ::arrow::internal::Uri item_uri; switch (item.path_type_case()) { - case ::substrait::ReadRel::LocalFiles::FileOrFiles::kUriPath: + case substrait::ReadRel::LocalFiles::FileOrFiles::kUriPath: RETURN_NOT_OK(item_uri.Parse(item.uri_path())); break; - case ::substrait::ReadRel::LocalFiles::FileOrFiles::kUriFile: + case substrait::ReadRel::LocalFiles::FileOrFiles::kUriFile: RETURN_NOT_OK(item_uri.Parse(item.uri_file())); break; - case ::substrait::ReadRel::LocalFiles::FileOrFiles::kUriFolder: + case substrait::ReadRel::LocalFiles::FileOrFiles::kUriFolder: RETURN_NOT_OK(item_uri.Parse(item.uri_folder())); break; @@ -265,49 +264,49 @@ Result FromProto(const ::substrait::Rel& rel, // Validate the URI before processing if (!item_uri.is_file_scheme()) { - return Status::NotImplemented("::substrait::ReadRel::LocalFiles item (", + return Status::NotImplemented("substrait::ReadRel::LocalFiles item (", item_uri.ToString(), ") does not have file scheme (file:///)"); } if (item_uri.port() != -1) { - return Status::NotImplemented("::substrait::ReadRel::LocalFiles item (", + return Status::NotImplemented("substrait::ReadRel::LocalFiles item (", item_uri.ToString(), ") should not have a port number in path"); } if (!item_uri.query_string().empty()) { - return Status::NotImplemented("::substrait::ReadRel::LocalFiles item (", + return Status::NotImplemented("substrait::ReadRel::LocalFiles item (", item_uri.ToString(), ") should not have a query string in path"); } switch (item.file_format_case()) { - case ::substrait::ReadRel::LocalFiles::FileOrFiles::kParquet: + case substrait::ReadRel::LocalFiles::FileOrFiles::kParquet: format = std::make_shared(); break; - case ::substrait::ReadRel::LocalFiles::FileOrFiles::kArrow: + case substrait::ReadRel::LocalFiles::FileOrFiles::kArrow: format = std::make_shared(); break; default: return Status::NotImplemented( "unsupported file format ", - "(see ::substrait::ReadRel::LocalFiles::FileOrFiles::file_format)"); + "(see substrait::ReadRel::LocalFiles::FileOrFiles::file_format)"); } // Handle the URI as appropriate switch (item.path_type_case()) { - case ::substrait::ReadRel::LocalFiles::FileOrFiles::kUriFile: { + case substrait::ReadRel::LocalFiles::FileOrFiles::kUriFile: { files.emplace_back(item_uri.path(), fs::FileType::File); break; } - case ::substrait::ReadRel::LocalFiles::FileOrFiles::kUriFolder: { + case substrait::ReadRel::LocalFiles::FileOrFiles::kUriFolder: { RETURN_NOT_OK(DiscoverFilesFromDir(filesystem, item_uri.path(), &files)); break; } - case ::substrait::ReadRel::LocalFiles::FileOrFiles::kUriPath: { + case substrait::ReadRel::LocalFiles::FileOrFiles::kUriPath: { ARROW_ASSIGN_OR_RAISE(auto file_info, filesystem->GetFileInfo(item_uri.path())); @@ -330,7 +329,7 @@ Result FromProto(const ::substrait::Rel& rel, break; } - case ::substrait::ReadRel::LocalFiles::FileOrFiles::kUriPathGlob: { + case substrait::ReadRel::LocalFiles::FileOrFiles::kUriPathGlob: { ARROW_ASSIGN_OR_RAISE(auto globbed_files, fs::internal::GlobFiles(filesystem, item_uri.path())); std::move(globbed_files.begin(), globbed_files.end(), @@ -358,18 +357,18 @@ Result FromProto(const ::substrait::Rel& rel, std::move(base_schema)); } - case ::substrait::Rel::RelTypeCase::kFilter: { + case substrait::Rel::RelTypeCase::kFilter: { const auto& filter = rel.filter(); RETURN_NOT_OK(CheckRelCommon(filter, conversion_options)); if (!filter.has_input()) { - return Status::Invalid("::substrait::FilterRel with no input relation"); + return Status::Invalid("substrait::FilterRel with no input relation"); } ARROW_ASSIGN_OR_RAISE(auto input, FromProto(filter.input(), ext_set, conversion_options)); if (!filter.has_condition()) { - return Status::Invalid("::substrait::FilterRel with no condition expression"); + return Status::Invalid("substrait::FilterRel with no condition expression"); } ARROW_ASSIGN_OR_RAISE(auto condition, FromProto(filter.condition(), ext_set, conversion_options)); @@ -384,11 +383,11 @@ Result FromProto(const ::substrait::Rel& rel, input.output_schema); } - case ::substrait::Rel::RelTypeCase::kProject: { + case substrait::Rel::RelTypeCase::kProject: { const auto& project = rel.project(); RETURN_NOT_OK(CheckRelCommon(project, conversion_options)); if (!project.has_input()) { - return Status::Invalid("::substrait::ProjectRel with no input relation"); + return Status::Invalid("substrait::ProjectRel with no input relation"); } ARROW_ASSIGN_OR_RAISE(auto input, FromProto(project.input(), ext_set, conversion_options)); @@ -437,38 +436,38 @@ Result FromProto(const ::substrait::Rel& rel, std::move(project_schema)); } - case ::substrait::Rel::RelTypeCase::kJoin: { + case substrait::Rel::RelTypeCase::kJoin: { const auto& join = rel.join(); RETURN_NOT_OK(CheckRelCommon(join, conversion_options)); if (!join.has_left()) { - return Status::Invalid("::substrait::JoinRel with no left relation"); + return Status::Invalid("substrait::JoinRel with no left relation"); } if (!join.has_right()) { - return Status::Invalid("::substrait::JoinRel with no right relation"); + return Status::Invalid("substrait::JoinRel with no right relation"); } compute::JoinType join_type; switch (join.type()) { - case ::substrait::JoinRel::JOIN_TYPE_UNSPECIFIED: + case substrait::JoinRel::JOIN_TYPE_UNSPECIFIED: return Status::NotImplemented("Unspecified join type is not supported"); - case ::substrait::JoinRel::JOIN_TYPE_INNER: + case substrait::JoinRel::JOIN_TYPE_INNER: join_type = compute::JoinType::INNER; break; - case ::substrait::JoinRel::JOIN_TYPE_OUTER: + case substrait::JoinRel::JOIN_TYPE_OUTER: join_type = compute::JoinType::FULL_OUTER; break; - case ::substrait::JoinRel::JOIN_TYPE_LEFT: + case substrait::JoinRel::JOIN_TYPE_LEFT: join_type = compute::JoinType::LEFT_OUTER; break; - case ::substrait::JoinRel::JOIN_TYPE_RIGHT: + case substrait::JoinRel::JOIN_TYPE_RIGHT: join_type = compute::JoinType::RIGHT_OUTER; break; - case ::substrait::JoinRel::JOIN_TYPE_SEMI: + case substrait::JoinRel::JOIN_TYPE_SEMI: join_type = compute::JoinType::LEFT_SEMI; break; - case ::substrait::JoinRel::JOIN_TYPE_ANTI: + case substrait::JoinRel::JOIN_TYPE_ANTI: join_type = compute::JoinType::LEFT_ANTI; break; default: @@ -481,7 +480,7 @@ Result FromProto(const ::substrait::Rel& rel, FromProto(join.right(), ext_set, conversion_options)); if (!join.has_expression()) { - return Status::Invalid("::substrait::JoinRel with no expression"); + return Status::Invalid("substrait::JoinRel with no expression"); } ARROW_ASSIGN_OR_RAISE(auto expression, @@ -543,12 +542,12 @@ Result FromProto(const ::substrait::Rel& rel, return ProcessEmit(std::move(join), std::move(join_declaration), std::move(join_schema)); } - case ::substrait::Rel::RelTypeCase::kAggregate: { + case substrait::Rel::RelTypeCase::kAggregate: { const auto& aggregate = rel.aggregate(); RETURN_NOT_OK(CheckRelCommon(aggregate, conversion_options)); if (!aggregate.has_input()) { - return Status::Invalid("::substrait::AggregateRel with no input relation"); + return Status::Invalid("substrait::AggregateRel with no input relation"); } ARROW_ASSIGN_OR_RAISE(auto input, @@ -566,7 +565,7 @@ Result FromProto(const ::substrait::Rel& rel, std::vector key_field_ids; std::vector keys; if (aggregate.groupings_size() > 0) { - const ::substrait::AggregateRel::Grouping& group = aggregate.groupings(0); + const substrait::AggregateRel::Grouping& group = aggregate.groupings(0); int grouping_expr_size = group.grouping_expressions_size(); keys.reserve(grouping_expr_size); key_field_ids.reserve(grouping_expr_size); @@ -620,7 +619,7 @@ Result FromProto(const ::substrait::Rel& rel, aggregates.push_back(std::move(arrow_agg)); } else { - return Status::Invalid("::substrait::AggregateFunction not provided"); + return Status::Invalid("substrait::AggregateFunction not provided"); } } FieldVector output_fields; @@ -688,14 +687,14 @@ Result FromProto(const ::substrait::Rel& rel, return ProcessEmit(std::move(set), std::move(set_declaration), std::move(union_schema)); } - case ::substrait::Rel::RelTypeCase::kExtensionLeaf: { + case substrait::Rel::RelTypeCase::kExtensionLeaf: { const auto& ext = rel.extension_leaf(); ARROW_ASSIGN_OR_RAISE( auto ext_leaf_decl, conversion_options.extension_provider->MakeRel({}, ext.detail(), ext_set)); return ProcessEmit(ext, std::move(ext_leaf_decl), ext_leaf_decl.output_schema); } - case ::substrait::Rel::RelTypeCase::kExtensionSingle: { + case substrait::Rel::RelTypeCase::kExtensionSingle: { const auto& ext = rel.extension_single(); ARROW_ASSIGN_OR_RAISE(DeclarationInfo input, FromProto(ext.input(), ext_set, conversion_options)); @@ -704,7 +703,7 @@ Result FromProto(const ::substrait::Rel& rel, conversion_options.extension_provider->MakeRel({input}, ext.detail(), ext_set)); return ProcessEmit(ext, std::move(ext_single_decl), ext_single_decl.output_schema); } - case ::substrait::Rel::RelTypeCase::kExtensionMulti: { + case substrait::Rel::RelTypeCase::kExtensionMulti: { const auto& ext = rel.extension_multi(); std::vector inputs; for (const auto& input : ext.inputs()) { @@ -778,7 +777,7 @@ Result> NamedTableRelationConverter( Result> ScanRelationConverter( const std::shared_ptr& schema, const compute::Declaration& declaration, ExtensionSet* ext_set, const ConversionOptions& conversion_options) { - auto read_rel = std::make_unique<::substrait::ReadRel>(); + auto read_rel = std::make_unique(); const auto& scan_node_options = checked_cast(*declaration.options); auto dataset = @@ -793,22 +792,22 @@ Result> ScanRelationConverter( read_rel->set_allocated_base_schema(named_struct.release()); // set local files - auto read_rel_lfs = std::make_unique<::substrait::ReadRel::LocalFiles>(); + auto read_rel_lfs = std::make_unique(); for (const auto& file : dataset->files()) { auto read_rel_lfs_ffs = - std::make_unique<::substrait::ReadRel::LocalFiles::FileOrFiles>(); + std::make_unique(); read_rel_lfs_ffs->set_uri_path(UriFromAbsolutePath(file)); // set file format auto format_type_name = dataset->format()->type_name(); if (format_type_name == "parquet") { read_rel_lfs_ffs->set_allocated_parquet( - new ::substrait::ReadRel::LocalFiles::FileOrFiles::ParquetReadOptions()); + new substrait::ReadRel::LocalFiles::FileOrFiles::ParquetReadOptions()); } else if (format_type_name == "ipc") { read_rel_lfs_ffs->set_allocated_arrow( - new ::substrait::ReadRel::LocalFiles::FileOrFiles::ArrowReadOptions()); + new substrait::ReadRel::LocalFiles::FileOrFiles::ArrowReadOptions()); } else if (format_type_name == "orc") { read_rel_lfs_ffs->set_allocated_orc( - new ::substrait::ReadRel::LocalFiles::FileOrFiles::OrcReadOptions()); + new substrait::ReadRel::LocalFiles::FileOrFiles::OrcReadOptions()); } else { return Status::NotImplemented("Unsupported file type: ", format_type_name); } @@ -818,10 +817,10 @@ Result> ScanRelationConverter( return std::move(read_rel); } -Result> FilterRelationConverter( +Result> FilterRelationConverter( const std::shared_ptr& schema, const compute::Declaration& declaration, ExtensionSet* ext_set, const ConversionOptions& conversion_options) { - auto filter_rel = std::make_unique<::substrait::FilterRel>(); + auto filter_rel = std::make_unique(); const auto& filter_node_options = checked_cast(*(declaration.options)); @@ -852,7 +851,7 @@ Result> FilterRelationConverter( Status SerializeAndCombineRelations(const compute::Declaration& declaration, ExtensionSet* ext_set, - std::unique_ptr<::substrait::Rel>* rel, + std::unique_ptr* rel, const ConversionOptions& conversion_options) { const auto& factory_name = declaration.factory_name; ARROW_ASSIGN_OR_RAISE(auto schema, ExtractSchemaToBind(declaration)); @@ -889,10 +888,10 @@ Status SerializeAndCombineRelations(const compute::Declaration& declaration, return Status::OK(); } -Result> ToProto( +Result> ToProto( const compute::Declaration& declr, ExtensionSet* ext_set, const ConversionOptions& conversion_options) { - auto rel = std::make_unique<::substrait::Rel>(); + auto rel = std::make_unique(); RETURN_NOT_OK(SerializeAndCombineRelations(declr, ext_set, &rel, conversion_options)); return std::move(rel); } diff --git a/cpp/src/arrow/engine/substrait/relation_internal.h b/cpp/src/arrow/engine/substrait/relation_internal.h index ab63f5ed7f6..17153f5365f 100644 --- a/cpp/src/arrow/engine/substrait/relation_internal.h +++ b/cpp/src/arrow/engine/substrait/relation_internal.h @@ -32,17 +32,9 @@ namespace arrow { namespace engine { -/// Information resulting from converting a Substrait relation. -struct ARROW_ENGINE_EXPORT DeclarationInfo { - /// The compute declaration produced thus far. - compute::Declaration declaration; - - std::shared_ptr output_schema; -}; - /// \brief Convert a Substrait Rel object to an Acero declaration ARROW_ENGINE_EXPORT -Result FromProto(const ::substrait::Rel&, const ExtensionSet&, +Result FromProto(const substrait::Rel&, const ExtensionSet&, const ConversionOptions&); /// \brief Convert an Acero Declaration to a Substrait Rel @@ -51,7 +43,7 @@ Result FromProto(const ::substrait::Rel&, const ExtensionSet&, /// the ExecNode or ExecPlan are not used in this context as Declaration /// is preferred in the Substrait space rather than internal components of /// Acero execution engine. -ARROW_ENGINE_EXPORT Result> ToProto( +ARROW_ENGINE_EXPORT Result> ToProto( const compute::Declaration&, ExtensionSet*, const ConversionOptions&); } // namespace engine diff --git a/cpp/src/arrow/engine/substrait/serde.cc b/cpp/src/arrow/engine/substrait/serde.cc index 7c55e9b1ec7..ac5de90326e 100644 --- a/cpp/src/arrow/engine/substrait/serde.cc +++ b/cpp/src/arrow/engine/substrait/serde.cc @@ -86,7 +86,7 @@ Result> SerializeRelation( Result DeserializeRelation( const Buffer& buf, const ExtensionSet& ext_set, const ConversionOptions& conversion_options) { - ARROW_ASSIGN_OR_RAISE(auto rel, ParseFromBuffer<::substrait::Rel>(buf)); + ARROW_ASSIGN_OR_RAISE(auto rel, ParseFromBuffer(buf)); ARROW_ASSIGN_OR_RAISE(auto decl_info, FromProto(rel, ext_set, conversion_options)); return std::move(decl_info.declaration); } @@ -135,7 +135,7 @@ Result> DeserializePlans( const Buffer& buf, DeclarationFactory declaration_factory, const ExtensionIdRegistry* registry, ExtensionSet* ext_set_out, const ConversionOptions& conversion_options) { - ARROW_ASSIGN_OR_RAISE(auto plan, ParseFromBuffer<::substrait::Plan>(buf)); + ARROW_ASSIGN_OR_RAISE(auto plan, ParseFromBuffer(buf)); if (plan.version().major_number() < kMinimumMajorVersion && plan.version().minor_number() < kMinimumMinorVersion) { @@ -147,7 +147,7 @@ Result> DeserializePlans( GetExtensionSetFromPlan(plan, conversion_options, registry)); std::vector sink_decls; - for (const ::substrait::PlanRel& plan_rel : plan.relations()) { + for (const substrait::PlanRel& plan_rel : plan.relations()) { ARROW_ASSIGN_OR_RAISE( auto decl_info, FromProto(plan_rel.has_root() ? plan_rel.root().input() : plan_rel.rel(), ext_set, @@ -246,8 +246,7 @@ Result> DeserializePlan( Result> DeserializeSchema( const Buffer& buf, const ExtensionSet& ext_set, const ConversionOptions& conversion_options) { - ARROW_ASSIGN_OR_RAISE(auto named_struct, - ParseFromBuffer<::substrait::NamedStruct>(buf)); + ARROW_ASSIGN_OR_RAISE(auto named_struct, ParseFromBuffer(buf)); return FromProto(named_struct, ext_set, conversion_options); } @@ -262,7 +261,7 @@ Result> SerializeSchema( Result> DeserializeType( const Buffer& buf, const ExtensionSet& ext_set, const ConversionOptions& conversion_options) { - ARROW_ASSIGN_OR_RAISE(auto type, ParseFromBuffer<::substrait::Type>(buf)); + ARROW_ASSIGN_OR_RAISE(auto type, ParseFromBuffer(buf)); ARROW_ASSIGN_OR_RAISE(auto type_nullable, FromProto(type, ext_set, conversion_options)); return std::move(type_nullable.first); } @@ -279,7 +278,7 @@ Result> SerializeType( Result DeserializeExpression( const Buffer& buf, const ExtensionSet& ext_set, const ConversionOptions& conversion_options) { - ARROW_ASSIGN_OR_RAISE(auto expr, ParseFromBuffer<::substrait::Expression>(buf)); + ARROW_ASSIGN_OR_RAISE(auto expr, ParseFromBuffer(buf)); return FromProto(expr, ext_set, conversion_options); } @@ -317,11 +316,11 @@ static Status CheckMessagesEquivalent(const Buffer& l_buf, const Buffer& r_buf) Status CheckMessagesEquivalent(std::string_view message_name, const Buffer& l_buf, const Buffer& r_buf) { if (message_name == "Type") { - return CheckMessagesEquivalent<::substrait::Type>(l_buf, r_buf); + return CheckMessagesEquivalent(l_buf, r_buf); } if (message_name == "NamedStruct") { - return CheckMessagesEquivalent<::substrait::NamedStruct>(l_buf, r_buf); + return CheckMessagesEquivalent(l_buf, r_buf); } if (message_name == "Schema") { @@ -331,11 +330,11 @@ Status CheckMessagesEquivalent(std::string_view message_name, const Buffer& l_bu } if (message_name == "Expression") { - return CheckMessagesEquivalent<::substrait::Expression>(l_buf, r_buf); + return CheckMessagesEquivalent(l_buf, r_buf); } if (message_name == "Rel") { - return CheckMessagesEquivalent<::substrait::Rel>(l_buf, r_buf); + return CheckMessagesEquivalent(l_buf, r_buf); } if (message_name == "Relation") { diff --git a/cpp/src/arrow/engine/substrait/test_plan_builder.cc b/cpp/src/arrow/engine/substrait/test_plan_builder.cc index 0e28e49b7af..62f4361a610 100644 --- a/cpp/src/arrow/engine/substrait/test_plan_builder.cc +++ b/cpp/src/arrow/engine/substrait/test_plan_builder.cc @@ -36,59 +36,58 @@ namespace arrow { namespace engine { - namespace internal { static const ConversionOptions kPlanBuilderConversionOptions; -Result> CreateRead(const Table& table, - ExtensionSet* ext_set) { - auto read = std::make_unique<::substrait::ReadRel>(); +Result> CreateRead(const Table& table, + ExtensionSet* ext_set) { + auto read = std::make_unique(); - ARROW_ASSIGN_OR_RAISE(std::unique_ptr<::substrait::NamedStruct> schema, + ARROW_ASSIGN_OR_RAISE(std::unique_ptr schema, ToProto(*table.schema(), ext_set, kPlanBuilderConversionOptions)); read->set_allocated_base_schema(schema.release()); - auto named_table = std::make_unique<::substrait::ReadRel::NamedTable>(); + auto named_table = std::make_unique(); named_table->add_names("test"); read->set_allocated_named_table(named_table.release()); return read; } -void CreateDirectReference(int32_t index, ::substrait::Expression* expr) { - auto reference = std::make_unique<::substrait::Expression::FieldReference>(); - auto reference_segment = std::make_unique<::substrait::Expression::ReferenceSegment>(); +void CreateDirectReference(int32_t index, substrait::Expression* expr) { + auto reference = std::make_unique(); + auto reference_segment = std::make_unique(); auto struct_field = - std::make_unique<::substrait::Expression::ReferenceSegment::StructField>(); + std::make_unique(); struct_field->set_field(index); reference_segment->set_allocated_struct_field(struct_field.release()); reference->set_allocated_direct_reference(reference_segment.release()); auto root_reference = - std::make_unique<::substrait::Expression::FieldReference::RootReference>(); + std::make_unique(); reference->set_allocated_root_reference(root_reference.release()); expr->set_allocated_selection(reference.release()); } -Result> CreateProject( +Result> CreateProject( Id function_id, const std::vector& arguments, const std::unordered_map> options, const std::vector>& arg_types, const DataType& output_type, ExtensionSet* ext_set) { - auto project = std::make_unique<::substrait::ProjectRel>(); + auto project = std::make_unique(); - auto call = std::make_unique<::substrait::Expression::ScalarFunction>(); + auto call = std::make_unique(); ARROW_ASSIGN_OR_RAISE(uint32_t function_anchor, ext_set->EncodeFunction(function_id)); call->set_function_reference(function_anchor); std::size_t arg_index = 0; std::size_t table_arg_index = 0; for (const std::shared_ptr& arg_type : arg_types) { - ::substrait::FunctionArgument* argument = call->add_arguments(); + substrait::FunctionArgument* argument = call->add_arguments(); if (arg_type) { // If it has a type then it's a reference to the input table - auto expression = std::make_unique<::substrait::Expression>(); + auto expression = std::make_unique(); CreateDirectReference(static_cast(table_arg_index++), expression.get()); argument->set_allocated_value(expression.release()); } else { @@ -99,7 +98,7 @@ Result> CreateProject( arg_index++; } for (const auto& opt : options) { - ::substrait::FunctionOption* option = call->add_options(); + substrait::FunctionOption* option = call->add_options(); option->set_name(opt.first); for (const std::string& pref : opt.second) { option->add_preference(pref); @@ -107,49 +106,49 @@ Result> CreateProject( } ARROW_ASSIGN_OR_RAISE( - std::unique_ptr<::substrait::Type> output_type_substrait, + std::unique_ptr output_type_substrait, ToProto(output_type, /*nullable=*/true, ext_set, kPlanBuilderConversionOptions)); call->set_allocated_output_type(output_type_substrait.release()); - ::substrait::Expression* call_expression = project->add_expressions(); + substrait::Expression* call_expression = project->add_expressions(); call_expression->set_allocated_scalar_function(call.release()); return project; } -Result> CreateAgg(Id function_id, - const std::vector& keys, - int arg_idx, - const DataType& output_type, - ExtensionSet* ext_set) { - auto agg = std::make_unique<::substrait::AggregateRel>(); +Result> CreateAgg(Id function_id, + const std::vector& keys, + int arg_idx, + const DataType& output_type, + ExtensionSet* ext_set) { + auto agg = std::make_unique(); if (!keys.empty()) { - ::substrait::AggregateRel::Grouping* grouping = agg->add_groupings(); + substrait::AggregateRel::Grouping* grouping = agg->add_groupings(); for (int key : keys) { - ::substrait::Expression* key_expr = grouping->add_grouping_expressions(); + substrait::Expression* key_expr = grouping->add_grouping_expressions(); CreateDirectReference(key, key_expr); } } - ::substrait::AggregateRel::Measure* measure_wrapper = agg->add_measures(); - auto agg_func = std::make_unique<::substrait::AggregateFunction>(); + substrait::AggregateRel::Measure* measure_wrapper = agg->add_measures(); + auto agg_func = std::make_unique(); ARROW_ASSIGN_OR_RAISE(uint32_t function_anchor, ext_set->EncodeFunction(function_id)); agg_func->set_function_reference(function_anchor); - ::substrait::FunctionArgument* arg = agg_func->add_arguments(); - auto arg_expr = std::make_unique<::substrait::Expression>(); + substrait::FunctionArgument* arg = agg_func->add_arguments(); + auto arg_expr = std::make_unique(); CreateDirectReference(arg_idx, arg_expr.get()); arg->set_allocated_value(arg_expr.release()); - agg_func->set_phase(::substrait::AggregationPhase::AGGREGATION_PHASE_INITIAL_TO_RESULT); + agg_func->set_phase(substrait::AggregationPhase::AGGREGATION_PHASE_INITIAL_TO_RESULT); agg_func->set_invocation( - ::substrait::AggregateFunction::AggregationInvocation:: + substrait::AggregateFunction::AggregationInvocation:: AggregateFunction_AggregationInvocation_AGGREGATION_INVOCATION_ALL); ARROW_ASSIGN_OR_RAISE( - std::unique_ptr<::substrait::Type> output_type_substrait, + std::unique_ptr output_type_substrait, ToProto(output_type, /*nullable=*/true, ext_set, kPlanBuilderConversionOptions)); agg_func->set_allocated_output_type(output_type_substrait.release()); measure_wrapper->set_allocated_measure(agg_func.release()); @@ -157,8 +156,8 @@ Result> CreateAgg(Id function_id, return agg; } -std::unique_ptr<::substrait::Version> CreateTestVersion() { - auto version = std::make_unique<::substrait::Version>(); +std::unique_ptr CreateTestVersion() { + auto version = std::make_unique(); version->set_major_number(std::numeric_limits::max()); version->set_minor_number(std::numeric_limits::max()); version->set_patch_number(std::numeric_limits::max()); @@ -166,13 +165,13 @@ std::unique_ptr<::substrait::Version> CreateTestVersion() { return version; } -Result> CreatePlan( - std::unique_ptr<::substrait::Rel> root, ExtensionSet* ext_set) { - auto plan = std::make_unique<::substrait::Plan>(); +Result> CreatePlan(std::unique_ptr root, + ExtensionSet* ext_set) { + auto plan = std::make_unique(); plan->set_allocated_version(CreateTestVersion().release()); - ::substrait::PlanRel* plan_rel = plan->add_relations(); - auto rel_root = std::make_unique<::substrait::RelRoot>(); + substrait::PlanRel* plan_rel = plan->add_relations(); + auto rel_root = std::make_unique(); rel_root->set_allocated_input(root.release()); plan_rel->set_allocated_root(rel_root.release()); @@ -187,20 +186,20 @@ Result> CreateScanProjectSubstrait( const std::vector>& data_types, const DataType& output_type) { ExtensionSet ext_set; - ARROW_ASSIGN_OR_RAISE(std::unique_ptr<::substrait::ReadRel> read, + ARROW_ASSIGN_OR_RAISE(std::unique_ptr read, CreateRead(*input_table, &ext_set)); ARROW_ASSIGN_OR_RAISE( - std::unique_ptr<::substrait::ProjectRel> project, + std::unique_ptr project, CreateProject(function_id, arguments, options, data_types, output_type, &ext_set)); - auto read_rel = std::make_unique<::substrait::Rel>(); + auto read_rel = std::make_unique(); read_rel->set_allocated_read(read.release()); project->set_allocated_input(read_rel.release()); - auto project_rel = std::make_unique<::substrait::Rel>(); + auto project_rel = std::make_unique(); project_rel->set_allocated_project(project.release()); - ARROW_ASSIGN_OR_RAISE(std::unique_ptr<::substrait::Plan> plan, + ARROW_ASSIGN_OR_RAISE(std::unique_ptr plan, CreatePlan(std::move(project_rel), &ext_set)); return Buffer::FromString(plan->SerializeAsString()); } @@ -210,19 +209,19 @@ Result> CreateScanAggSubstrait( const std::vector& key_idxs, int arg_idx, const DataType& output_type) { ExtensionSet ext_set; - ARROW_ASSIGN_OR_RAISE(std::unique_ptr<::substrait::ReadRel> read, + ARROW_ASSIGN_OR_RAISE(std::unique_ptr read, CreateRead(*input_table, &ext_set)); - ARROW_ASSIGN_OR_RAISE(std::unique_ptr<::substrait::AggregateRel> agg, + ARROW_ASSIGN_OR_RAISE(std::unique_ptr agg, CreateAgg(function_id, key_idxs, arg_idx, output_type, &ext_set)); - auto read_rel = std::make_unique<::substrait::Rel>(); + auto read_rel = std::make_unique(); read_rel->set_allocated_read(read.release()); agg->set_allocated_input(read_rel.release()); - auto agg_rel = std::make_unique<::substrait::Rel>(); + auto agg_rel = std::make_unique(); agg_rel->set_allocated_aggregate(agg.release()); - ARROW_ASSIGN_OR_RAISE(std::unique_ptr<::substrait::Plan> plan, + ARROW_ASSIGN_OR_RAISE(std::unique_ptr plan, CreatePlan(std::move(agg_rel), &ext_set)); return Buffer::FromString(plan->SerializeAsString()); } diff --git a/cpp/src/arrow/engine/substrait/type_internal.cc b/cpp/src/arrow/engine/substrait/type_internal.cc index 01f189d224c..fad49b822b4 100644 --- a/cpp/src/arrow/engine/substrait/type_internal.cc +++ b/cpp/src/arrow/engine/substrait/type_internal.cc @@ -228,20 +228,18 @@ struct DataTypeToProtoImpl { Status Visit(const NullType& t) { return EncodeUserDefined(t); } Status Visit(const BooleanType& t) { - return SetWith(&::substrait::Type::set_allocated_bool_); + return SetWith(&substrait::Type::set_allocated_bool_); } - Status Visit(const Int8Type& t) { - return SetWith(&::substrait::Type::set_allocated_i8); - } + Status Visit(const Int8Type& t) { return SetWith(&substrait::Type::set_allocated_i8); } Status Visit(const Int16Type& t) { - return SetWith(&::substrait::Type::set_allocated_i16); + return SetWith(&substrait::Type::set_allocated_i16); } Status Visit(const Int32Type& t) { - return SetWith(&::substrait::Type::set_allocated_i32); + return SetWith(&substrait::Type::set_allocated_i32); } Status Visit(const Int64Type& t) { - return SetWith(&::substrait::Type::set_allocated_i64); + return SetWith(&substrait::Type::set_allocated_i64); } Status Visit(const UInt8Type& t) { return EncodeUserDefined(t); } @@ -251,27 +249,26 @@ struct DataTypeToProtoImpl { Status Visit(const HalfFloatType& t) { return EncodeUserDefined(t); } Status Visit(const FloatType& t) { - return SetWith(&::substrait::Type::set_allocated_fp32); + return SetWith(&substrait::Type::set_allocated_fp32); } Status Visit(const DoubleType& t) { - return SetWith(&::substrait::Type::set_allocated_fp64); + return SetWith(&substrait::Type::set_allocated_fp64); } Status Visit(const StringType& t) { - return SetWith(&::substrait::Type::set_allocated_string); + return SetWith(&substrait::Type::set_allocated_string); } Status Visit(const BinaryType& t) { - return SetWith(&::substrait::Type::set_allocated_binary); + return SetWith(&substrait::Type::set_allocated_binary); } Status Visit(const FixedSizeBinaryType& t) { - SetWithThen(&::substrait::Type::set_allocated_fixed_binary) - ->set_length(t.byte_width()); + SetWithThen(&substrait::Type::set_allocated_fixed_binary)->set_length(t.byte_width()); return Status::OK(); } Status Visit(const Date32Type& t) { - return SetWith(&::substrait::Type::set_allocated_date); + return SetWith(&substrait::Type::set_allocated_date); } Status Visit(const Date64Type& t) { return NotImplemented(t); } @@ -279,10 +276,10 @@ struct DataTypeToProtoImpl { if (t.unit() != TimeUnit::MICRO) return NotImplemented(t); if (t.timezone() == "") { - return SetWith(&::substrait::Type::set_allocated_timestamp); + return SetWith(&substrait::Type::set_allocated_timestamp); } if (t.timezone() == TimestampTzTimezoneString()) { - return SetWith(&::substrait::Type::set_allocated_timestamp_tz); + return SetWith(&substrait::Type::set_allocated_timestamp_tz); } return NotImplemented(t); @@ -291,14 +288,14 @@ struct DataTypeToProtoImpl { Status Visit(const Time32Type& t) { return NotImplemented(t); } Status Visit(const Time64Type& t) { if (t.unit() != TimeUnit::MICRO) return NotImplemented(t); - return SetWith(&::substrait::Type::set_allocated_time); + return SetWith(&substrait::Type::set_allocated_time); } Status Visit(const MonthIntervalType& t) { return EncodeUserDefined(t); } Status Visit(const DayTimeIntervalType& t) { return EncodeUserDefined(t); } Status Visit(const Decimal128Type& t) { - auto dec = SetWithThen(&::substrait::Type::set_allocated_decimal); + auto dec = SetWithThen(&substrait::Type::set_allocated_decimal); dec->set_precision(t.precision()); dec->set_scale(t.scale()); return Status::OK(); @@ -309,13 +306,12 @@ struct DataTypeToProtoImpl { // FIXME assert default field name; custom ones won't roundtrip ARROW_ASSIGN_OR_RAISE(auto type, ToProto(*t.value_type(), t.value_field()->nullable(), ext_set_, conversion_options_)); - SetWithThen(&::substrait::Type::set_allocated_list) - ->set_allocated_type(type.release()); + SetWithThen(&substrait::Type::set_allocated_list)->set_allocated_type(type.release()); return Status::OK(); } Status Visit(const StructType& t) { - auto types = SetWithThen(&::substrait::Type::set_allocated_struct_)->mutable_types(); + auto types = SetWithThen(&substrait::Type::set_allocated_struct_)->mutable_types(); types->Reserve(t.num_fields()); @@ -336,7 +332,7 @@ struct DataTypeToProtoImpl { Status Visit(const MapType& t) { // FIXME assert default field names; custom ones won't roundtrip - auto map = SetWithThen(&::substrait::Type::set_allocated_map); + auto map = SetWithThen(&substrait::Type::set_allocated_map); ARROW_ASSIGN_OR_RAISE(auto key, ToProto(*t.key_type(), /*nullable=*/false, ext_set_, conversion_options_)); @@ -351,25 +347,25 @@ struct DataTypeToProtoImpl { Status Visit(const ExtensionType& t) { if (UnwrapUuid(t)) { - return SetWith(&::substrait::Type::set_allocated_uuid); + return SetWith(&substrait::Type::set_allocated_uuid); } if (auto length = UnwrapFixedChar(t)) { - SetWithThen(&::substrait::Type::set_allocated_fixed_char)->set_length(*length); + SetWithThen(&substrait::Type::set_allocated_fixed_char)->set_length(*length); return Status::OK(); } if (auto length = UnwrapVarChar(t)) { - SetWithThen(&::substrait::Type::set_allocated_varchar)->set_length(*length); + SetWithThen(&substrait::Type::set_allocated_varchar)->set_length(*length); return Status::OK(); } if (UnwrapIntervalYear(t)) { - return SetWith(&::substrait::Type::set_allocated_interval_year); + return SetWith(&substrait::Type::set_allocated_interval_year); } if (UnwrapIntervalDay(t)) { - return SetWith(&::substrait::Type::set_allocated_interval_day); + return SetWith(&substrait::Type::set_allocated_interval_day); } return NotImplemented(t); @@ -383,7 +379,7 @@ struct DataTypeToProtoImpl { Status Visit(const MonthDayNanoIntervalType& t) { return EncodeUserDefined(t); } template - Sub* SetWithThen(void (::substrait::Type::*set_allocated_sub)(Sub*)) { + Sub* SetWithThen(void (substrait::Type::*set_allocated_sub)(Sub*)) { auto sub = std::make_unique(); sub->set_nullability(nullable_ ? substrait::Type::NULLABILITY_NULLABLE : substrait::Type::NULLABILITY_REQUIRED); @@ -394,7 +390,7 @@ struct DataTypeToProtoImpl { } template - Status SetWith(void (::substrait::Type::*set_allocated_sub)(Sub*)) { + Status SetWith(void (substrait::Type::*set_allocated_sub)(Sub*)) { return SetWithThen(set_allocated_sub), Status::OK(); } @@ -480,7 +476,7 @@ Result> ToProto( const ConversionOptions& conversion_options) { if (conversion_options.strictness == ConversionStrictness::EXACT_ROUNDTRIP && schema.metadata() != nullptr) { - return Status::Invalid("::substrait::NamedStruct does not support schema metadata"); + return Status::Invalid("substrait::NamedStruct does not support schema metadata"); } auto named_struct = std::make_unique(); diff --git a/cpp/src/arrow/engine/substrait/type_internal.h b/cpp/src/arrow/engine/substrait/type_internal.h index 33fdf1d0cc3..0d53028f493 100644 --- a/cpp/src/arrow/engine/substrait/type_internal.h +++ b/cpp/src/arrow/engine/substrait/type_internal.h @@ -34,22 +34,21 @@ namespace arrow { namespace engine { ARROW_ENGINE_EXPORT -Result, bool>> FromProto(const ::substrait::Type&, +Result, bool>> FromProto(const substrait::Type&, const ExtensionSet&, const ConversionOptions&); ARROW_ENGINE_EXPORT -Result> ToProto(const DataType&, bool nullable, - ExtensionSet*, - const ConversionOptions&); +Result> ToProto(const DataType&, bool nullable, + ExtensionSet*, const ConversionOptions&); ARROW_ENGINE_EXPORT -Result> FromProto(const ::substrait::NamedStruct&, +Result> FromProto(const substrait::NamedStruct&, const ExtensionSet&, const ConversionOptions&); ARROW_ENGINE_EXPORT -Result> ToProto(const Schema&, ExtensionSet*, - const ConversionOptions&); +Result> ToProto(const Schema&, ExtensionSet*, + const ConversionOptions&); inline std::string TimestampTzTimezoneString() { return "UTC"; } From bb30d538aa8285f9f1585c5ec5781995cfcb244e Mon Sep 17 00:00:00 2001 From: Weston Pace Date: Thu, 15 Dec 2022 15:36:34 -0800 Subject: [PATCH 6/9] Fix test error introduced by rebase. Add in logic to try and ensure the asof join node is not marked finished from the process thread. Doesn't currently work because executor can be null. --- cpp/src/arrow/engine/substrait/serde_test.cc | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/cpp/src/arrow/engine/substrait/serde_test.cc b/cpp/src/arrow/engine/substrait/serde_test.cc index 80d08926799..e4c32c8d13f 100644 --- a/cpp/src/arrow/engine/substrait/serde_test.cc +++ b/cpp/src/arrow/engine/substrait/serde_test.cc @@ -3692,6 +3692,10 @@ TEST(Substrait, ReadRelWithGlobFiles) { #ifdef _WIN32 GTEST_SKIP() << "ARROW-16392: Substrait File URI not supported for Windows"; #endif +<<<<<<< HEAD +======= + compute::ExecContext exec_context; +>>>>>>> 4b4f26a0b (Fix test error introduced by rebase. Add in logic to try and ensure the asof join node is not marked finished from the process thread. Doesn't currently work because executor can be null.) arrow::dataset::internal::Initialize(); auto dummy_schema = @@ -4196,8 +4200,8 @@ TEST(Substrait, PlanWithAsOfJoinExtension) { input_schema, {{FieldRef(0), {FieldRef(1)}}, {FieldRef(0), {FieldRef(1)}}})); auto expected_table = TableFromJSON( out_schema, {"[[2, 1, 1.1, 1.2], [4, 1, 2.1, 1.2], [6, 2, 3.1, 3.2]]"}); - CheckRoundTripResult(std::move(out_schema), std::move(expected_table), - *compute::default_exec_context(), buf, {}, conversion_options); + CheckRoundTripResult(std::move(expected_table), *compute::default_exec_context(), buf, + {}, conversion_options); } } // namespace engine From e7cc46bf560d18c4e8a8934095e4a631c3891038 Mon Sep 17 00:00:00 2001 From: Weston Pace Date: Thu, 22 Dec 2022 09:54:19 -0800 Subject: [PATCH 7/9] Small workaround for a bug related to ARROW-15732. We have to use an executor for asof-join --- cpp/src/arrow/engine/substrait/serde_test.cc | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/cpp/src/arrow/engine/substrait/serde_test.cc b/cpp/src/arrow/engine/substrait/serde_test.cc index e4c32c8d13f..4f5fb571a04 100644 --- a/cpp/src/arrow/engine/substrait/serde_test.cc +++ b/cpp/src/arrow/engine/substrait/serde_test.cc @@ -4200,8 +4200,11 @@ TEST(Substrait, PlanWithAsOfJoinExtension) { input_schema, {{FieldRef(0), {FieldRef(1)}}, {FieldRef(0), {FieldRef(1)}}})); auto expected_table = TableFromJSON( out_schema, {"[[2, 1, 1.1, 1.2], [4, 1, 2.1, 1.2], [6, 2, 3.1, 3.2]]"}); - CheckRoundTripResult(std::move(expected_table), *compute::default_exec_context(), buf, - {}, conversion_options); + // TODO(ARROW-15732) asof join currently requires a threaded exec context but it should + // not (and we can move back to the default exec context) after ARROW-15732 merges + compute::ExecContext exec_ctx(default_memory_pool(), + ::arrow::internal::GetCpuThreadPool()); + CheckRoundTripResult(std::move(expected_table), exec_ctx, buf, {}, conversion_options); } } // namespace engine From 4462a3c87e7cbd9765f47b8a51d0a70d367ee1ad Mon Sep 17 00:00:00 2001 From: Weston Pace Date: Mon, 26 Dec 2022 07:21:33 -0800 Subject: [PATCH 8/9] clang format --- cpp/src/arrow/engine/substrait/serde_test.cc | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/cpp/src/arrow/engine/substrait/serde_test.cc b/cpp/src/arrow/engine/substrait/serde_test.cc index 4f5fb571a04..78524cf1fca 100644 --- a/cpp/src/arrow/engine/substrait/serde_test.cc +++ b/cpp/src/arrow/engine/substrait/serde_test.cc @@ -4200,8 +4200,8 @@ TEST(Substrait, PlanWithAsOfJoinExtension) { input_schema, {{FieldRef(0), {FieldRef(1)}}, {FieldRef(0), {FieldRef(1)}}})); auto expected_table = TableFromJSON( out_schema, {"[[2, 1, 1.1, 1.2], [4, 1, 2.1, 1.2], [6, 2, 3.1, 3.2]]"}); - // TODO(ARROW-15732) asof join currently requires a threaded exec context but it should - // not (and we can move back to the default exec context) after ARROW-15732 merges + // TODO(ARROW-15732) asof join currently requires a threaded exec context but it should + // not (and we can move back to the default exec context) after ARROW-15732 merges compute::ExecContext exec_ctx(default_memory_pool(), ::arrow::internal::GetCpuThreadPool()); CheckRoundTripResult(std::move(expected_table), exec_ctx, buf, {}, conversion_options); From ddaef9dda341775873b0a3400e09fa2e42d58e20 Mon Sep 17 00:00:00 2001 From: Weston Pace Date: Fri, 30 Dec 2022 11:24:50 -0800 Subject: [PATCH 9/9] Minor fix after rebase --- cpp/src/arrow/engine/substrait/serde_test.cc | 10 +--------- 1 file changed, 1 insertion(+), 9 deletions(-) diff --git a/cpp/src/arrow/engine/substrait/serde_test.cc b/cpp/src/arrow/engine/substrait/serde_test.cc index 78524cf1fca..eee2ed868a6 100644 --- a/cpp/src/arrow/engine/substrait/serde_test.cc +++ b/cpp/src/arrow/engine/substrait/serde_test.cc @@ -3692,10 +3692,6 @@ TEST(Substrait, ReadRelWithGlobFiles) { #ifdef _WIN32 GTEST_SKIP() << "ARROW-16392: Substrait File URI not supported for Windows"; #endif -<<<<<<< HEAD -======= - compute::ExecContext exec_context; ->>>>>>> 4b4f26a0b (Fix test error introduced by rebase. Add in logic to try and ensure the asof join node is not marked finished from the process thread. Doesn't currently work because executor can be null.) arrow::dataset::internal::Initialize(); auto dummy_schema = @@ -4200,11 +4196,7 @@ TEST(Substrait, PlanWithAsOfJoinExtension) { input_schema, {{FieldRef(0), {FieldRef(1)}}, {FieldRef(0), {FieldRef(1)}}})); auto expected_table = TableFromJSON( out_schema, {"[[2, 1, 1.1, 1.2], [4, 1, 2.1, 1.2], [6, 2, 3.1, 3.2]]"}); - // TODO(ARROW-15732) asof join currently requires a threaded exec context but it should - // not (and we can move back to the default exec context) after ARROW-15732 merges - compute::ExecContext exec_ctx(default_memory_pool(), - ::arrow::internal::GetCpuThreadPool()); - CheckRoundTripResult(std::move(expected_table), exec_ctx, buf, {}, conversion_options); + CheckRoundTripResult(std::move(expected_table), buf, {}, conversion_options); } } // namespace engine