diff --git a/cpp/cmake_modules/ThirdpartyToolchain.cmake b/cpp/cmake_modules/ThirdpartyToolchain.cmake index d0c8c600d55..3eda538fb2e 100644 --- a/cpp/cmake_modules/ThirdpartyToolchain.cmake +++ b/cpp/cmake_modules/ThirdpartyToolchain.cmake @@ -1738,6 +1738,8 @@ macro(build_substrait) # Note: not all protos in Substrait actually matter to plan # consumption. No need to build the ones we don't need. set(SUBSTRAIT_PROTOS algebra extensions/extensions plan type) + set(ARROW_SUBSTRAIT_PROTOS extension_rels) + set(ARROW_SUBSTRAIT_PROTOS_DIR "${CMAKE_SOURCE_DIR}/proto") externalproject_add(substrait_ep ${EP_COMMON_OPTIONS} @@ -1789,6 +1791,27 @@ macro(build_substrait) list(APPEND SUBSTRAIT_SOURCES "${SUBSTRAIT_PROTO_GEN}.cc") endforeach() + foreach(ARROW_SUBSTRAIT_PROTO ${ARROW_SUBSTRAIT_PROTOS}) + set(ARROW_SUBSTRAIT_PROTO_GEN + "${SUBSTRAIT_CPP_DIR}/substrait/${ARROW_SUBSTRAIT_PROTO}.pb") + foreach(EXT h cc) + set_source_files_properties("${ARROW_SUBSTRAIT_PROTO_GEN}.${EXT}" + PROPERTIES COMPILE_OPTIONS + "${SUBSTRAIT_SUPPRESSED_FLAGS}" + GENERATED TRUE + SKIP_UNITY_BUILD_INCLUSION TRUE) + list(APPEND SUBSTRAIT_PROTO_GEN_ALL "${ARROW_SUBSTRAIT_PROTO_GEN}.${EXT}") + endforeach() + add_custom_command(OUTPUT "${ARROW_SUBSTRAIT_PROTO_GEN}.cc" + "${ARROW_SUBSTRAIT_PROTO_GEN}.h" + COMMAND ${ARROW_PROTOBUF_PROTOC} "-I${SUBSTRAIT_LOCAL_DIR}/proto" + "-I${ARROW_SUBSTRAIT_PROTOS_DIR}" + "--cpp_out=${SUBSTRAIT_CPP_DIR}" + "${ARROW_SUBSTRAIT_PROTOS_DIR}/substrait/${ARROW_SUBSTRAIT_PROTO}.proto" + DEPENDS ${PROTO_DEPENDS} substrait_ep) + + list(APPEND SUBSTRAIT_SOURCES "${ARROW_SUBSTRAIT_PROTO_GEN}.cc") + endforeach() add_custom_target(substrait_gen ALL DEPENDS ${SUBSTRAIT_PROTO_GEN_ALL}) diff --git a/cpp/proto/substrait/extension_rels.proto b/cpp/proto/substrait/extension_rels.proto new file mode 100644 index 00000000000..ceed9f3e455 --- /dev/null +++ b/cpp/proto/substrait/extension_rels.proto @@ -0,0 +1,44 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. +syntax = "proto3"; + +package arrow.substrait_ext; + +import "substrait/algebra.proto"; + +option csharp_namespace = "Arrow.Substrait"; +option go_package = "github.com/apache/arrow/substrait"; +option java_multiple_files = true; +option java_package = "io.arrow.substrait"; + +// As-Of-Join relation +message AsOfJoinRel { + // One key per input relation, each key describing how to join the corresponding input + repeated AsOfJoinKey keys = 1; + + // As-Of tolerance, in units of the on-key + int64 tolerance = 2; + + // As-Of-Join key + message AsOfJoinKey { + // A field reference defining the on-key + .substrait.Expression on = 1; + + // A set of field references defining the by-key + repeated .substrait.Expression by = 2; + } +} diff --git a/cpp/src/arrow/compute/exec/asof_join_benchmark.cc b/cpp/src/arrow/compute/exec/asof_join_benchmark.cc index 5890e10c206..366508e34b8 100644 --- a/cpp/src/arrow/compute/exec/asof_join_benchmark.cc +++ b/cpp/src/arrow/compute/exec/asof_join_benchmark.cc @@ -101,9 +101,19 @@ static void TableJoinOverhead(benchmark::State& state, benchmark::Counter(static_cast(default_memory_pool()->max_memory())); } +AsofJoinNodeOptions GetRepeatedOptions(size_t repeat, FieldRef on_key, + std::vector by_key, int64_t tolerance) { + std::vector input_keys(repeat); + for (size_t i = 0; i < repeat; i++) { + input_keys[i] = {on_key, by_key}; + } + return AsofJoinNodeOptions(input_keys, tolerance); +} + static void AsOfJoinOverhead(benchmark::State& state) { int64_t tolerance = 0; - AsofJoinNodeOptions options = AsofJoinNodeOptions(kTimeCol, {kKeyCol}, tolerance); + AsofJoinNodeOptions options = + GetRepeatedOptions(int(state.range(4)), kTimeCol, {kKeyCol}, tolerance); TableJoinOverhead( state, TableGenerationProperties{int(state.range(0)), int(state.range(1)), diff --git a/cpp/src/arrow/compute/exec/asof_join_node.cc b/cpp/src/arrow/compute/exec/asof_join_node.cc index 83bbf5df4ca..d071c0ce7f4 100644 --- a/cpp/src/arrow/compute/exec/asof_join_node.cc +++ b/cpp/src/arrow/compute/exec/asof_join_node.cc @@ -15,6 +15,8 @@ // specific language governing permissions and limitations // under the License. +#include "arrow/compute/exec/asof_join_node.h" + #include #include #include @@ -997,17 +999,16 @@ class AsofJoinNode : public ExecNode { } static arrow::Result> MakeOutputSchema( - const std::vector& inputs, + const std::vector> input_schema, const std::vector& indices_of_on_key, const std::vector>& indices_of_by_key) { std::vector> fields; - size_t n_by = indices_of_by_key[0].size(); + size_t n_by = indices_of_by_key.size() == 0 ? 0 : indices_of_by_key[0].size(); const DataType* on_key_type = NULLPTR; std::vector by_key_type(n_by, NULLPTR); // Take all non-key, non-time RHS fields - for (size_t j = 0; j < inputs.size(); ++j) { - const auto& input_schema = inputs[j]->output_schema(); + for (size_t j = 0; j < input_schema.size(); ++j) { const auto& on_field_ix = indices_of_on_key[j]; const auto& by_field_ix = indices_of_by_key[j]; @@ -1015,10 +1016,10 @@ class AsofJoinNode : public ExecNode { return Status::Invalid("Missing join key on table ", j); } - const auto& on_field = input_schema->fields()[on_field_ix]; + const auto& on_field = input_schema[j]->fields()[on_field_ix]; std::vector by_field(n_by); for (size_t k = 0; k < n_by; k++) { - by_field[k] = input_schema->fields()[by_field_ix[k]].get(); + by_field[k] = input_schema[j]->fields()[by_field_ix[k]].get(); } if (on_key_type == NULLPTR) { @@ -1038,8 +1039,8 @@ class AsofJoinNode : public ExecNode { } } - for (int i = 0; i < input_schema->num_fields(); ++i) { - const auto field = input_schema->field(i); + for (int i = 0; i < input_schema[j]->num_fields(); ++i) { + const auto field = input_schema[j]->field(i); if (i == on_field_ix) { ARROW_RETURN_NOT_OK(is_valid_on_field(field)); // Only add on field from the left table @@ -1076,6 +1077,56 @@ class AsofJoinNode : public ExecNode { return match.indices()[0]; } + static Result GetByKeySize( + const std::vector& input_keys) { + size_t n_by = 0; + for (size_t i = 0; i < input_keys.size(); ++i) { + const auto& by_key = input_keys[i].by_key; + if (i == 0) { + n_by = by_key.size(); + } else if (n_by != by_key.size()) { + return Status::Invalid("inconsistent size of by-key across inputs"); + } + } + return n_by; + } + + static Result> GetIndicesOfOnKey( + const std::vector>& input_schema, + const std::vector& input_keys) { + if (input_schema.size() != input_keys.size()) { + return Status::Invalid("mismatching number of input schema and keys"); + } + size_t n_input = input_schema.size(); + std::vector indices_of_on_key(n_input); + for (size_t i = 0; i < n_input; ++i) { + const auto& on_key = input_keys[i].on_key; + ARROW_ASSIGN_OR_RAISE(indices_of_on_key[i], + FindColIndex(*input_schema[i], on_key, "on")); + } + return indices_of_on_key; + } + + static Result>> GetIndicesOfByKey( + const std::vector>& input_schema, + const std::vector& input_keys) { + if (input_schema.size() != input_keys.size()) { + return Status::Invalid("mismatching number of input schema and keys"); + } + ARROW_ASSIGN_OR_RAISE(size_t n_by, GetByKeySize(input_keys)); + size_t n_input = input_schema.size(); + std::vector> indices_of_by_key( + n_input, std::vector(n_by)); + for (size_t i = 0; i < n_input; ++i) { + for (size_t k = 0; k < n_by; k++) { + const auto& by_key = input_keys[i].by_key; + ARROW_ASSIGN_OR_RAISE(indices_of_by_key[i][k], + FindColIndex(*input_schema[i], by_key[k], "by")); + } + } + return indices_of_by_key; + } + static arrow::Result Make(ExecPlan* plan, std::vector inputs, const ExecNodeOptions& options) { DCHECK_GE(inputs.size(), 2) << "Must have at least two inputs"; @@ -1086,24 +1137,21 @@ class AsofJoinNode : public ExecNode { join_options.tolerance); } - size_t n_input = inputs.size(), n_by = join_options.by_key.size(); + ARROW_ASSIGN_OR_RAISE(size_t n_by, GetByKeySize(join_options.input_keys)); + size_t n_input = inputs.size(); std::vector input_labels(n_input); - std::vector indices_of_on_key(n_input); - std::vector> indices_of_by_key( - n_input, std::vector(n_by)); + std::vector> input_schema(n_input); for (size_t i = 0; i < n_input; ++i) { input_labels[i] = i == 0 ? "left" : "right_" + ToChars(i); - const Schema& input_schema = *inputs[i]->output_schema(); - ARROW_ASSIGN_OR_RAISE(indices_of_on_key[i], - FindColIndex(input_schema, join_options.on_key, "on")); - for (size_t k = 0; k < n_by; k++) { - ARROW_ASSIGN_OR_RAISE(indices_of_by_key[i][k], - FindColIndex(input_schema, join_options.by_key[k], "by")); - } + input_schema[i] = inputs[i]->output_schema(); } - - ARROW_ASSIGN_OR_RAISE(std::shared_ptr output_schema, - MakeOutputSchema(inputs, indices_of_on_key, indices_of_by_key)); + ARROW_ASSIGN_OR_RAISE(std::vector indices_of_on_key, + GetIndicesOfOnKey(input_schema, join_options.input_keys)); + ARROW_ASSIGN_OR_RAISE(std::vector> indices_of_by_key, + GetIndicesOfByKey(input_schema, join_options.input_keys)); + ARROW_ASSIGN_OR_RAISE( + std::shared_ptr output_schema, + MakeOutputSchema(input_schema, indices_of_on_key, indices_of_by_key)); std::vector> key_hashers; for (size_t i = 0; i < n_input; i++) { @@ -1213,5 +1261,20 @@ void RegisterAsofJoinNode(ExecFactoryRegistry* registry) { } } // namespace internal +namespace asofjoin { + +Result> MakeOutputSchema( + const std::vector>& input_schema, + const std::vector& input_keys) { + ARROW_ASSIGN_OR_RAISE(std::vector indices_of_on_key, + AsofJoinNode::GetIndicesOfOnKey(input_schema, input_keys)); + ARROW_ASSIGN_OR_RAISE(std::vector> indices_of_by_key, + AsofJoinNode::GetIndicesOfByKey(input_schema, input_keys)); + return AsofJoinNode::MakeOutputSchema(input_schema, indices_of_on_key, + indices_of_by_key); +} + +} // namespace asofjoin + } // namespace compute } // namespace arrow diff --git a/cpp/src/arrow/compute/exec/asof_join_node.h b/cpp/src/arrow/compute/exec/asof_join_node.h new file mode 100644 index 00000000000..27777090d3d --- /dev/null +++ b/cpp/src/arrow/compute/exec/asof_join_node.h @@ -0,0 +1,37 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#include + +#include "arrow/compute/exec.h" +#include "arrow/compute/exec/options.h" +#include "arrow/type.h" +#include "arrow/util/visibility.h" + +namespace arrow { +namespace compute { +namespace asofjoin { + +using AsofJoinKeys = AsofJoinNodeOptions::Keys; + +ARROW_EXPORT Result> MakeOutputSchema( + const std::vector>& input_schema, + const std::vector& input_keys); + +} // namespace asofjoin +} // namespace compute +} // namespace arrow diff --git a/cpp/src/arrow/compute/exec/asof_join_node_test.cc b/cpp/src/arrow/compute/exec/asof_join_node_test.cc index c865f9f38f8..e30e8420956 100644 --- a/cpp/src/arrow/compute/exec/asof_join_node_test.cc +++ b/cpp/src/arrow/compute/exec/asof_join_node_test.cc @@ -142,6 +142,15 @@ void BuildZeroBaseBinaryArray(std::shared_ptr& empty, int64_t length) { ASSERT_OK(builder.Finish(&empty)); } +AsofJoinNodeOptions GetRepeatedOptions(size_t repeat, FieldRef on_key, + std::vector by_key, int64_t tolerance) { + std::vector input_keys(repeat); + for (size_t i = 0; i < repeat; i++) { + input_keys[i] = {on_key, by_key}; + } + return AsofJoinNodeOptions(input_keys, tolerance); +} + // mutates by copying from_key into to_key and changing from_key to zero Result MutateByKey(BatchesWithSchema& batches, std::string from_key, std::string to_key, bool replace_key = false, @@ -248,7 +257,7 @@ void CheckRunOutput(const BatchesWithSchema& l_batches, const BatchesWithSchema& r1_batches, const BatchesWithSchema& exp_batches, \ const FieldRef time, by_key_type key, const int64_t tolerance) { \ CheckRunOutput(l_batches, r0_batches, r1_batches, exp_batches, \ - AsofJoinNodeOptions(time, {key}, tolerance)); \ + GetRepeatedOptions(3, time, {key}, tolerance)); \ } EXPAND_BY_KEY_TYPE(CHECK_RUN_OUTPUT) @@ -300,7 +309,7 @@ void DoRunInvalidPlanTest(const std::shared_ptr& l_schema, const std::shared_ptr& r_schema, int64_t tolerance, const std::string& expected_error_str) { DoRunInvalidPlanTest(l_schema, r_schema, - AsofJoinNodeOptions("time", {"key"}, tolerance), + GetRepeatedOptions(2, "time", {"key"}, tolerance), expected_error_str); } @@ -323,27 +332,27 @@ void DoRunMissingKeysTest(const std::shared_ptr& l_schema, void DoRunMissingOnKeyTest(const std::shared_ptr& l_schema, const std::shared_ptr& r_schema) { DoRunInvalidPlanTest(l_schema, r_schema, - AsofJoinNodeOptions("invalid_time", {"key"}, 0), + GetRepeatedOptions(2, "invalid_time", {"key"}, 0), "Bad join key on table : No match"); } void DoRunMissingByKeyTest(const std::shared_ptr& l_schema, const std::shared_ptr& r_schema) { DoRunInvalidPlanTest(l_schema, r_schema, - AsofJoinNodeOptions("time", {"invalid_key"}, 0), + GetRepeatedOptions(2, "time", {"invalid_key"}, 0), "Bad join key on table : No match"); } void DoRunNestedOnKeyTest(const std::shared_ptr& l_schema, const std::shared_ptr& r_schema) { - DoRunInvalidPlanTest(l_schema, r_schema, AsofJoinNodeOptions({0, "time"}, {"key"}, 0), + DoRunInvalidPlanTest(l_schema, r_schema, GetRepeatedOptions(2, {0, "time"}, {"key"}, 0), "Bad join key on table : No match"); } void DoRunNestedByKeyTest(const std::shared_ptr& l_schema, const std::shared_ptr& r_schema) { DoRunInvalidPlanTest(l_schema, r_schema, - AsofJoinNodeOptions("time", {FieldRef{0, 1}}, 0), + GetRepeatedOptions(2, "time", {FieldRef{0, 1}}, 0), "Bad join key on table : No match"); } @@ -404,7 +413,7 @@ void DoRunUnorderedPlanTest(bool l_unordered, bool r_unordered, const std::shared_ptr& l_schema, const std::shared_ptr& r_schema) { DoRunUnorderedPlanTest(l_unordered, r_unordered, l_schema, r_schema, - AsofJoinNodeOptions("time", {"key"}, 1000), + GetRepeatedOptions(2, "time", {"key"}, 1000), "out-of-order on-key values"); } @@ -501,7 +510,7 @@ struct BasicTest { ASSERT_OK_AND_ASSIGN(exp_nokey_batches, MutateByKey(exp_nokey_batches, "key", "key2", true, true)); CheckRunOutput(l_batches, r0_batches, r1_batches, exp_nokey_batches, - AsofJoinNodeOptions("time", {"key2"}, tolerance)); + GetRepeatedOptions(3, "time", {"key2"}, tolerance)); }); } static void DoMutateNullKey(BasicTest& basic_tests) { basic_tests.RunMutateNullKey(); } @@ -514,7 +523,7 @@ struct BasicTest { ASSERT_OK_AND_ASSIGN(r1_batches, MutateByKey(r1_batches, "key", "key", false, false, true)); CheckRunOutput(l_batches, r0_batches, r1_batches, exp_emptykey_batches, - AsofJoinNodeOptions("time", {}, tolerance)); + GetRepeatedOptions(3, "time", {}, tolerance)); }); } static void DoMutateEmptyKey(BasicTest& basic_tests) { diff --git a/cpp/src/arrow/compute/exec/options.h b/cpp/src/arrow/compute/exec/options.h index 5daaf0584ae..325e8e514d1 100644 --- a/cpp/src/arrow/compute/exec/options.h +++ b/cpp/src/arrow/compute/exec/options.h @@ -479,22 +479,35 @@ class ARROW_EXPORT HashJoinNodeOptions : public ExecNodeOptions { /// This node will output one row for each row in the left table. class ARROW_EXPORT AsofJoinNodeOptions : public ExecNodeOptions { public: - AsofJoinNodeOptions(FieldRef on_key, std::vector by_key, int64_t tolerance) - : on_key(std::move(on_key)), by_key(by_key), tolerance(tolerance) {} - - /// \brief "on" key for the join. + /// \brief Keys for one input table of the AsofJoin operation /// - /// All inputs tables must be sorted by the "on" key. Must be a single field of a common - /// type. Inexact match is used on the "on" key. i.e., a row is considered match iff - /// left_on - tolerance <= right_on <= left_on. - /// Currently, the "on" key must be of an integer, date, or timestamp type. - FieldRef on_key; - /// \brief "by" key for the join. + /// The keys must be consistent across the input tables: + /// Each "on" key must refer to a field of the same type and units across the tables. + /// Each "by" key must refer to a list of fields of the same types across the tables. + struct Keys { + /// \brief "on" key for the join. + /// + /// The input table must be sorted by the "on" key. Must be a single field of a common + /// type. Inexact match is used on the "on" key. i.e., a row is considered a match iff + /// left_on - tolerance <= right_on <= left_on. + /// Currently, the "on" key must be of an integer, date, or timestamp type. + FieldRef on_key; + /// \brief "by" key for the join. + /// + /// Each input table must have each field of the "by" key. Exact equality is used for + /// each field of the "by" key. + /// Currently, each field of the "by" key must be of an integer, date, timestamp, or + /// base-binary type. + std::vector by_key; + }; + + AsofJoinNodeOptions(std::vector input_keys, int64_t tolerance) + : input_keys(std::move(input_keys)), tolerance(tolerance) {} + + /// \brief AsofJoin keys per input table. /// - /// All input tables must have the "by" key. Exact equality - /// is used for the "by" key. - /// Currently, the "by" key must be of an integer, date, timestamp, or base-binary type - std::vector by_key; + /// \see `Keys` for details. + std::vector input_keys; /// \brief Tolerance for inexact "on" key matching. Must be non-negative. /// /// The tolerance is interpreted in the same units as the "on" key. diff --git a/cpp/src/arrow/engine/CMakeLists.txt b/cpp/src/arrow/engine/CMakeLists.txt index f129394325e..4e5f8bb96b7 100644 --- a/cpp/src/arrow/engine/CMakeLists.txt +++ b/cpp/src/arrow/engine/CMakeLists.txt @@ -23,6 +23,7 @@ set(ARROW_SUBSTRAIT_SRCS substrait/expression_internal.cc substrait/extension_set.cc substrait/extension_types.cc + substrait/options.cc substrait/plan_internal.cc substrait/relation_internal.cc substrait/serde.cc diff --git a/cpp/src/arrow/engine/substrait/expression_internal.cc b/cpp/src/arrow/engine/substrait/expression_internal.cc index 415cd195bf6..6caddd1cb53 100644 --- a/cpp/src/arrow/engine/substrait/expression_internal.cc +++ b/cpp/src/arrow/engine/substrait/expression_internal.cc @@ -129,6 +129,15 @@ Result DecodeScalarFunction( return std::move(call); } +std::string EnumToString(int value, const google::protobuf::EnumDescriptor* descriptor) { + const google::protobuf::EnumValueDescriptor* value_desc = + descriptor->FindValueByNumber(value); + if (value_desc == nullptr) { + return "unknown"; + } + return value_desc->name(); +} + Result FromProto(const substrait::AggregateFunction& func, bool is_hash, const ExtensionSet& ext_set, const ConversionOptions& conversion_options) { diff --git a/cpp/src/arrow/engine/substrait/options.cc b/cpp/src/arrow/engine/substrait/options.cc new file mode 100644 index 00000000000..9dfd4d7856a --- /dev/null +++ b/cpp/src/arrow/engine/substrait/options.cc @@ -0,0 +1,118 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. +#include + +#include "arrow/engine/substrait/options.h" + +#include +#include "arrow/compute/exec/asof_join_node.h" +#include "arrow/compute/exec/options.h" +#include "arrow/engine/substrait/expression_internal.h" +#include "arrow/engine/substrait/options_internal.h" +#include "arrow/engine/substrait/relation_internal.h" +#include "substrait/extension_rels.pb.h" + +namespace arrow { +namespace engine { + +class DefaultExtensionProvider : public ExtensionProvider { + public: + Result MakeRel(const std::vector& inputs, + const google::protobuf::Any& rel, + const ExtensionSet& ext_set) override { + if (rel.Is()) { + arrow::substrait_ext::AsOfJoinRel as_of_join_rel; + rel.UnpackTo(&as_of_join_rel); + return MakeAsOfJoinRel(inputs, as_of_join_rel, ext_set); + } + return Status::NotImplemented("Unrecognized extension in Susbstrait plan: ", + rel.DebugString()); + } + + private: + Result MakeAsOfJoinRel( + const std::vector& inputs, + const arrow::substrait_ext::AsOfJoinRel& as_of_join_rel, + const ExtensionSet& ext_set) { + if (inputs.size() < 2) { + return Status::Invalid("substrait_ext::AsOfJoinNode too few input tables: ", + inputs.size()); + } + if (static_cast(as_of_join_rel.keys_size()) != inputs.size()) { + return Status::Invalid("substrait_ext::AsOfJoinNode mismatched number of inputs"); + } + + size_t n_input = inputs.size(), i = 0; + std::vector input_keys(n_input); + for (const auto& keys : as_of_join_rel.keys()) { + // on-key + if (!keys.has_on()) { + return Status::Invalid("substrait_ext::AsOfJoinNode missing on-key for input ", + i); + } + ARROW_ASSIGN_OR_RAISE(auto on_key_expr, FromProto(keys.on(), ext_set, {})); + if (on_key_expr.field_ref() == NULLPTR) { + return Status::NotImplemented( + "substrait_ext::AsOfJoinNode non-field-ref on-key for input ", i); + } + const FieldRef& on_key = *on_key_expr.field_ref(); + + // by-key + std::vector by_key; + for (const auto& by_item : keys.by()) { + ARROW_ASSIGN_OR_RAISE(auto by_key_expr, FromProto(by_item, ext_set, {})); + if (by_key_expr.field_ref() == NULLPTR) { + return Status::NotImplemented( + "substrait_ext::AsOfJoinNode non-field-ref by-key for input ", i); + } + by_key.push_back(*by_key_expr.field_ref()); + } + + input_keys[i] = {std::move(on_key), std::move(by_key)}; + ++i; + } + + // schema + int64_t tolerance = as_of_join_rel.tolerance(); + std::vector> input_schema(inputs.size()); + for (size_t i = 0; i < inputs.size(); i++) { + input_schema[i] = inputs[i].output_schema; + } + ARROW_ASSIGN_OR_RAISE(auto schema, + compute::asofjoin::MakeOutputSchema(input_schema, input_keys)); + compute::AsofJoinNodeOptions asofjoin_node_opts{std::move(input_keys), tolerance}; + + // declaration + std::vector input_decls(inputs.size()); + for (size_t i = 0; i < inputs.size(); i++) { + input_decls[i] = inputs[i].declaration; + } + return DeclarationInfo{ + compute::Declaration("asofjoin", input_decls, std::move(asofjoin_node_opts)), + std::move(schema)}; + } +}; + +std::shared_ptr ExtensionProvider::kDefaultExtensionProvider = + std::make_shared(); + +std::shared_ptr default_extension_provider() { + return ExtensionProvider::kDefaultExtensionProvider; +} + +} // namespace engine +} // namespace arrow diff --git a/cpp/src/arrow/engine/substrait/options.h b/cpp/src/arrow/engine/substrait/options.h index 014842f4d8f..41d0792a8d2 100644 --- a/cpp/src/arrow/engine/substrait/options.h +++ b/cpp/src/arrow/engine/substrait/options.h @@ -24,6 +24,8 @@ #include #include "arrow/compute/type_fwd.h" +#include "arrow/engine/substrait/type_fwd.h" +#include "arrow/engine/substrait/visibility.h" #include "arrow/type_fwd.h" namespace arrow { @@ -32,7 +34,7 @@ namespace engine { /// How strictly to adhere to the input structure when converting between Substrait and /// Acero representations of a plan. This allows the user to trade conversion accuracy /// for performance and lenience. -enum class ConversionStrictness { +enum class ARROW_ENGINE_EXPORT ConversionStrictness { /// When a primitive is used at the input that doesn't have an exact match at the /// output, reject the conversion. This effectively asserts that there is no (known) /// information loss in the conversion, and that plans should either round-trip back and @@ -65,9 +67,13 @@ using NamedTableProvider = std::function(const std::vector&)>; static NamedTableProvider kDefaultNamedTableProvider; +class ExtensionProvider; + +ARROW_ENGINE_EXPORT std::shared_ptr default_extension_provider(); + /// Options that control the conversion between Substrait and Acero representations of a /// plan. -struct ConversionOptions { +struct ARROW_ENGINE_EXPORT ConversionOptions { /// \brief How strictly the converter should adhere to the structure of the input. ConversionStrictness strictness = ConversionStrictness::BEST_EFFORT; /// \brief A custom strategy to be used for providing named tables @@ -75,6 +81,7 @@ struct ConversionOptions { /// The default behavior will return an invalid status if the plan has any /// named table relations. NamedTableProvider named_table_provider = kDefaultNamedTableProvider; + std::shared_ptr extension_provider = default_extension_provider(); }; } // namespace engine diff --git a/cpp/src/arrow/engine/substrait/options_internal.h b/cpp/src/arrow/engine/substrait/options_internal.h new file mode 100644 index 00000000000..0d186147a9a --- /dev/null +++ b/cpp/src/arrow/engine/substrait/options_internal.h @@ -0,0 +1,44 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +// This API is EXPERIMENTAL. + +#pragma once + +#include + +#include + +#include "arrow/compute/type_fwd.h" +#include "arrow/engine/substrait/type_fwd.h" +#include "arrow/engine/substrait/visibility.h" +#include "arrow/type_fwd.h" + +namespace arrow { +namespace engine { + +class ARROW_ENGINE_EXPORT ExtensionProvider { + public: + static std::shared_ptr kDefaultExtensionProvider; + virtual ~ExtensionProvider() = default; + virtual Result MakeRel(const std::vector& inputs, + const google::protobuf::Any& rel, + const ExtensionSet& ext_set) = 0; +}; + +} // namespace engine +} // namespace arrow diff --git a/cpp/src/arrow/engine/substrait/plan_internal.cc b/cpp/src/arrow/engine/substrait/plan_internal.cc index fff0f7563cc..6d12c19fcd7 100644 --- a/cpp/src/arrow/engine/substrait/plan_internal.cc +++ b/cpp/src/arrow/engine/substrait/plan_internal.cc @@ -143,7 +143,7 @@ Result GetExtensionSetFromPlan(const substrait::Plan& plan, namespace { -// FIXME Is there some way to get these from the cmake files? +// TODO(ARROW-18145) Populate these from cmake files constexpr uint32_t kSubstraitMajorVersion = 0; constexpr uint32_t kSubstraitMinorVersion = 20; constexpr uint32_t kSubstraitPatchVersion = 0; diff --git a/cpp/src/arrow/engine/substrait/relation_internal.cc b/cpp/src/arrow/engine/substrait/relation_internal.cc index 744227a9339..0faeaec554f 100644 --- a/cpp/src/arrow/engine/substrait/relation_internal.cc +++ b/cpp/src/arrow/engine/substrait/relation_internal.cc @@ -44,6 +44,7 @@ #include "arrow/engine/substrait/expression_internal.h" #include "arrow/engine/substrait/extension_set.h" #include "arrow/engine/substrait/options.h" +#include "arrow/engine/substrait/options_internal.h" #include "arrow/engine/substrait/relation.h" #include "arrow/engine/substrait/type_internal.h" #include "arrow/engine/substrait/util.h" @@ -227,7 +228,8 @@ Result FromProto(const substrait::Rel& rel, const ExtensionSet& // Validate properties of the `FileOrFiles` item if (item.partition_index() != 0) { return Status::NotImplemented( - "non-default substrait::ReadRel::LocalFiles::FileOrFiles::partition_index"); + "non-default " + "substrait::ReadRel::LocalFiles::FileOrFiles::partition_index"); } if (item.start() != 0) { @@ -685,6 +687,35 @@ Result FromProto(const substrait::Rel& rel, const ExtensionSet& return ProcessEmit(std::move(set), std::move(set_declaration), std::move(union_schema)); } + case substrait::Rel::RelTypeCase::kExtensionLeaf: { + const auto& ext = rel.extension_leaf(); + ARROW_ASSIGN_OR_RAISE( + auto ext_leaf_decl, + conversion_options.extension_provider->MakeRel({}, ext.detail(), ext_set)); + return ProcessEmit(ext, std::move(ext_leaf_decl), ext_leaf_decl.output_schema); + } + case substrait::Rel::RelTypeCase::kExtensionSingle: { + const auto& ext = rel.extension_single(); + ARROW_ASSIGN_OR_RAISE(DeclarationInfo input, + FromProto(ext.input(), ext_set, conversion_options)); + ARROW_ASSIGN_OR_RAISE( + auto ext_single_decl, + conversion_options.extension_provider->MakeRel({input}, ext.detail(), ext_set)); + return ProcessEmit(ext, std::move(ext_single_decl), ext_single_decl.output_schema); + } + case substrait::Rel::RelTypeCase::kExtensionMulti: { + const auto& ext = rel.extension_multi(); + std::vector inputs; + for (const auto& input : ext.inputs()) { + ARROW_ASSIGN_OR_RAISE(auto input_info, + FromProto(input, ext_set, conversion_options)); + inputs.push_back(std::move(input_info)); + } + ARROW_ASSIGN_OR_RAISE( + auto ext_multi_decl, + conversion_options.extension_provider->MakeRel(inputs, ext.detail(), ext_set)); + return ProcessEmit(ext, std::move(ext_multi_decl), ext_multi_decl.output_schema); + } default: break; diff --git a/cpp/src/arrow/engine/substrait/serde_test.cc b/cpp/src/arrow/engine/substrait/serde_test.cc index 435006a4e03..eee2ed868a6 100644 --- a/cpp/src/arrow/engine/substrait/serde_test.cc +++ b/cpp/src/arrow/engine/substrait/serde_test.cc @@ -32,6 +32,7 @@ #include "arrow/compute/api_scalar.h" #include "arrow/compute/api_vector.h" #include "arrow/compute/exec.h" +#include "arrow/compute/exec/asof_join_node.h" #include "arrow/compute/exec/exec_plan.h" #include "arrow/compute/exec/expression.h" #include "arrow/compute/exec/expression_internal.h" @@ -3351,6 +3352,16 @@ TEST(Substrait, IsthmusPlan) { /*include_columns=*/{}, conversion_options); } +NamedTableProvider ProvideMadeTable( + std::function>(const std::vector&)> make) { + return [make](const std::vector& names) -> Result { + ARROW_ASSIGN_OR_RAISE(auto table, make(names)); + std::shared_ptr options = + std::make_shared(table); + return compute::Declaration("table_source", {}, options, "mock_source"); + }; +} + TEST(Substrait, ProjectWithMultiFieldExpressions) { auto dummy_schema = schema({field("A", int32()), field("B", int32()), field("C", int32())}); @@ -4000,5 +4011,193 @@ TEST(Substrait, SetRelationBasic) { &sort_options); } +TEST(Substrait, PlanWithAsOfJoinExtension) { + // This demos an extension relation + std::string substrait_json = R"({ + "extensionUris": [], + "extensions": [], + "relations": [{ + "root": { + "input": { + "extension_multi": { + "common": { + "emit": { + "outputMapping": [0, 1, 2, 3] + } + }, + "inputs": [ + { + "read": { + "common": { + "direct": { + } + }, + "baseSchema": { + "names": ["time", "key", "value1"], + "struct": { + "types": [ + { + "i32": { + "typeVariationReference": 0, + "nullability": "NULLABILITY_NULLABLE" + } + }, + { + "i32": { + "typeVariationReference": 0, + "nullability": "NULLABILITY_NULLABLE" + } + }, + { + "fp64": { + "typeVariationReference": 0, + "nullability": "NULLABILITY_NULLABLE" + } + } + ], + "typeVariationReference": 0, + "nullability": "NULLABILITY_REQUIRED" + } + }, + "namedTable": { + "names": ["T1"] + } + } + }, + { + "read": { + "common": { + "direct": { + } + }, + "baseSchema": { + "names": ["time", "key", "value2"], + "struct": { + "types": [ + { + "i32": { + "typeVariationReference": 0, + "nullability": "NULLABILITY_NULLABLE" + } + }, + { + "i32": { + "typeVariationReference": 0, + "nullability": "NULLABILITY_NULLABLE" + } + }, + { + "fp64": { + "typeVariationReference": 0, + "nullability": "NULLABILITY_NULLABLE" + } + } + ], + "typeVariationReference": 0, + "nullability": "NULLABILITY_REQUIRED" + } + }, + "namedTable": { + "names": ["T2"] + } + } + } + ], + "detail": { + "@type": "/arrow.substrait_ext.AsOfJoinRel", + "keys" : [ + { + "on": { + "selection": { + "directReference": { + "structField": { + "field": 0, + } + }, + "rootReference": {} + } + }, + "by": [ + { + "selection": { + "directReference": { + "structField": { + "field": 1, + } + }, + "rootReference": {} + } + } + ] + }, + { + "on": { + "selection": { + "directReference": { + "structField": { + "field": 0, + } + }, + "rootReference": {} + } + }, + "by": [ + { + "selection": { + "directReference": { + "structField": { + "field": 1, + } + }, + "rootReference": {} + } + } + ] + } + ], + "tolerance": 1000 + } + } + }, + "names": ["time", "key", "value1", "value2"] + } + }], + "expectedTypeUrls": [] + })"; + + std::vector> input_schema = { + schema({field("time", int32()), field("key", int32()), field("value1", float64())}), + schema( + {field("time", int32()), field("key", int32()), field("value2", float64())})}; + NamedTableProvider table_provider = ProvideMadeTable( + [&input_schema]( + const std::vector& names) -> Result> { + if (names.size() != 1) { + return Status::Invalid("Multiple test table names"); + } + if (names[0] == "T1") { + return TableFromJSON(input_schema[0], + {"[[2, 1, 1.1], [4, 1, 2.1], [6, 2, 3.1]]"}); + } + if (names[0] == "T2") { + return TableFromJSON(input_schema[1], + {"[[1, 1, 1.2], [3, 2, 2.2], [5, 2, 3.2]]"}); + } + return Status::Invalid("Unknown test table name ", names[0]); + }); + ConversionOptions conversion_options; + conversion_options.named_table_provider = std::move(table_provider); + + ASSERT_OK_AND_ASSIGN(auto buf, internal::SubstraitFromJSON("Plan", substrait_json)); + + ASSERT_OK_AND_ASSIGN( + auto out_schema, + compute::asofjoin::MakeOutputSchema( + input_schema, {{FieldRef(0), {FieldRef(1)}}, {FieldRef(0), {FieldRef(1)}}})); + auto expected_table = TableFromJSON( + out_schema, {"[[2, 1, 1.1, 1.2], [4, 1, 2.1, 1.2], [6, 2, 3.1, 3.2]]"}); + CheckRoundTripResult(std::move(expected_table), buf, {}, conversion_options); +} + } // namespace engine } // namespace arrow diff --git a/cpp/src/arrow/engine/substrait/type_fwd.h b/cpp/src/arrow/engine/substrait/type_fwd.h index 235d9e82d1b..6089d3f747a 100644 --- a/cpp/src/arrow/engine/substrait/type_fwd.h +++ b/cpp/src/arrow/engine/substrait/type_fwd.h @@ -26,6 +26,7 @@ class ExtensionIdRegistry; class ExtensionSet; struct ConversionOptions; +struct DeclarationInfo; } // namespace engine } // namespace arrow diff --git a/cpp/src/arrow/engine/substrait/type_internal.cc b/cpp/src/arrow/engine/substrait/type_internal.cc index f56aa19a040..fad49b822b4 100644 --- a/cpp/src/arrow/engine/substrait/type_internal.cc +++ b/cpp/src/arrow/engine/substrait/type_internal.cc @@ -44,7 +44,7 @@ namespace { template bool IsNullable(const TypeMessage& type) { // FIXME what can we do with NULLABILITY_UNSPECIFIED - return type.nullability() != ::substrait::Type::NULLABILITY_REQUIRED; + return type.nullability() != substrait::Type::NULLABILITY_REQUIRED; } template @@ -95,67 +95,67 @@ Result FieldsFromProto(int size, const Types& types, } // namespace Result, bool>> FromProto( - const ::substrait::Type& type, const ExtensionSet& ext_set, + const substrait::Type& type, const ExtensionSet& ext_set, const ConversionOptions& conversion_options) { switch (type.kind_case()) { - case ::substrait::Type::kBool: + case substrait::Type::kBool: return FromProtoImpl(type.bool_()); - case ::substrait::Type::kI8: + case substrait::Type::kI8: return FromProtoImpl(type.i8()); - case ::substrait::Type::kI16: + case substrait::Type::kI16: return FromProtoImpl(type.i16()); - case ::substrait::Type::kI32: + case substrait::Type::kI32: return FromProtoImpl(type.i32()); - case ::substrait::Type::kI64: + case substrait::Type::kI64: return FromProtoImpl(type.i64()); - case ::substrait::Type::kFp32: + case substrait::Type::kFp32: return FromProtoImpl(type.fp32()); - case ::substrait::Type::kFp64: + case substrait::Type::kFp64: return FromProtoImpl(type.fp64()); - case ::substrait::Type::kString: + case substrait::Type::kString: return FromProtoImpl(type.string()); - case ::substrait::Type::kBinary: + case substrait::Type::kBinary: return FromProtoImpl(type.binary()); - case ::substrait::Type::kTimestamp: + case substrait::Type::kTimestamp: return FromProtoImpl(type.timestamp(), TimeUnit::MICRO); - case ::substrait::Type::kTimestampTz: + case substrait::Type::kTimestampTz: return FromProtoImpl(type.timestamp_tz(), TimeUnit::MICRO, TimestampTzTimezoneString()); - case ::substrait::Type::kDate: + case substrait::Type::kDate: return FromProtoImpl(type.date()); - case ::substrait::Type::kTime: + case substrait::Type::kTime: return FromProtoImpl(type.time(), TimeUnit::MICRO); - case ::substrait::Type::kIntervalYear: + case substrait::Type::kIntervalYear: return FromProtoImpl(type.interval_year(), interval_year); - case ::substrait::Type::kIntervalDay: + case substrait::Type::kIntervalDay: return FromProtoImpl(type.interval_day(), interval_day); - case ::substrait::Type::kUuid: + case substrait::Type::kUuid: return FromProtoImpl(type.uuid(), uuid); - case ::substrait::Type::kFixedChar: + case substrait::Type::kFixedChar: return FromProtoImpl(type.fixed_char(), fixed_char, type.fixed_char().length()); - case ::substrait::Type::kVarchar: + case substrait::Type::kVarchar: return FromProtoImpl(type.varchar(), varchar, type.varchar().length()); - case ::substrait::Type::kFixedBinary: + case substrait::Type::kFixedBinary: return FromProtoImpl(type.fixed_binary(), type.fixed_binary().length()); - case ::substrait::Type::kDecimal: { + case substrait::Type::kDecimal: { const auto& decimal = type.decimal(); return FromProtoImpl(decimal, decimal.precision(), decimal.scale()); } - case ::substrait::Type::kStruct: { + case substrait::Type::kStruct: { const auto& struct_ = type.struct_(); ARROW_ASSIGN_OR_RAISE( @@ -166,7 +166,7 @@ Result, bool>> FromProto( return FromProtoImpl(struct_, std::move(fields)); } - case ::substrait::Type::kList: { + case substrait::Type::kList: { const auto& list = type.list(); if (!list.has_type()) { @@ -181,7 +181,7 @@ Result, bool>> FromProto( list, field("item", std::move(type_nullable.first), type_nullable.second)); } - case ::substrait::Type::kMap: { + case substrait::Type::kMap: { const auto& map = type.map(); static const std::array kMissing = {"key and value", "value", "key", @@ -207,7 +207,7 @@ Result, bool>> FromProto( field("value", std::move(value_nullable.first), value_nullable.second)); } - case ::substrait::Type::kUserDefined: { + case substrait::Type::kUserDefined: { const auto& user_defined = type.user_defined(); uint32_t anchor = user_defined.type_reference(); ARROW_ASSIGN_OR_RAISE(auto type_record, ext_set.DecodeType(anchor)); @@ -228,20 +228,18 @@ struct DataTypeToProtoImpl { Status Visit(const NullType& t) { return EncodeUserDefined(t); } Status Visit(const BooleanType& t) { - return SetWith(&::substrait::Type::set_allocated_bool_); + return SetWith(&substrait::Type::set_allocated_bool_); } - Status Visit(const Int8Type& t) { - return SetWith(&::substrait::Type::set_allocated_i8); - } + Status Visit(const Int8Type& t) { return SetWith(&substrait::Type::set_allocated_i8); } Status Visit(const Int16Type& t) { - return SetWith(&::substrait::Type::set_allocated_i16); + return SetWith(&substrait::Type::set_allocated_i16); } Status Visit(const Int32Type& t) { - return SetWith(&::substrait::Type::set_allocated_i32); + return SetWith(&substrait::Type::set_allocated_i32); } Status Visit(const Int64Type& t) { - return SetWith(&::substrait::Type::set_allocated_i64); + return SetWith(&substrait::Type::set_allocated_i64); } Status Visit(const UInt8Type& t) { return EncodeUserDefined(t); } @@ -251,27 +249,26 @@ struct DataTypeToProtoImpl { Status Visit(const HalfFloatType& t) { return EncodeUserDefined(t); } Status Visit(const FloatType& t) { - return SetWith(&::substrait::Type::set_allocated_fp32); + return SetWith(&substrait::Type::set_allocated_fp32); } Status Visit(const DoubleType& t) { - return SetWith(&::substrait::Type::set_allocated_fp64); + return SetWith(&substrait::Type::set_allocated_fp64); } Status Visit(const StringType& t) { - return SetWith(&::substrait::Type::set_allocated_string); + return SetWith(&substrait::Type::set_allocated_string); } Status Visit(const BinaryType& t) { - return SetWith(&::substrait::Type::set_allocated_binary); + return SetWith(&substrait::Type::set_allocated_binary); } Status Visit(const FixedSizeBinaryType& t) { - SetWithThen(&::substrait::Type::set_allocated_fixed_binary) - ->set_length(t.byte_width()); + SetWithThen(&substrait::Type::set_allocated_fixed_binary)->set_length(t.byte_width()); return Status::OK(); } Status Visit(const Date32Type& t) { - return SetWith(&::substrait::Type::set_allocated_date); + return SetWith(&substrait::Type::set_allocated_date); } Status Visit(const Date64Type& t) { return NotImplemented(t); } @@ -279,10 +276,10 @@ struct DataTypeToProtoImpl { if (t.unit() != TimeUnit::MICRO) return NotImplemented(t); if (t.timezone() == "") { - return SetWith(&::substrait::Type::set_allocated_timestamp); + return SetWith(&substrait::Type::set_allocated_timestamp); } if (t.timezone() == TimestampTzTimezoneString()) { - return SetWith(&::substrait::Type::set_allocated_timestamp_tz); + return SetWith(&substrait::Type::set_allocated_timestamp_tz); } return NotImplemented(t); @@ -291,14 +288,14 @@ struct DataTypeToProtoImpl { Status Visit(const Time32Type& t) { return NotImplemented(t); } Status Visit(const Time64Type& t) { if (t.unit() != TimeUnit::MICRO) return NotImplemented(t); - return SetWith(&::substrait::Type::set_allocated_time); + return SetWith(&substrait::Type::set_allocated_time); } Status Visit(const MonthIntervalType& t) { return EncodeUserDefined(t); } Status Visit(const DayTimeIntervalType& t) { return EncodeUserDefined(t); } Status Visit(const Decimal128Type& t) { - auto dec = SetWithThen(&::substrait::Type::set_allocated_decimal); + auto dec = SetWithThen(&substrait::Type::set_allocated_decimal); dec->set_precision(t.precision()); dec->set_scale(t.scale()); return Status::OK(); @@ -309,20 +306,18 @@ struct DataTypeToProtoImpl { // FIXME assert default field name; custom ones won't roundtrip ARROW_ASSIGN_OR_RAISE(auto type, ToProto(*t.value_type(), t.value_field()->nullable(), ext_set_, conversion_options_)); - SetWithThen(&::substrait::Type::set_allocated_list) - ->set_allocated_type(type.release()); + SetWithThen(&substrait::Type::set_allocated_list)->set_allocated_type(type.release()); return Status::OK(); } Status Visit(const StructType& t) { - auto types = SetWithThen(&::substrait::Type::set_allocated_struct_)->mutable_types(); + auto types = SetWithThen(&substrait::Type::set_allocated_struct_)->mutable_types(); types->Reserve(t.num_fields()); for (const auto& field : t.fields()) { if (field->metadata() != nullptr) { - return Status::Invalid( - "::substrait::Type::Struct does not support field metadata"); + return Status::Invalid("substrait::Type::Struct does not support field metadata"); } ARROW_ASSIGN_OR_RAISE(auto type, ToProto(*field->type(), field->nullable(), ext_set_, conversion_options_)); @@ -337,7 +332,7 @@ struct DataTypeToProtoImpl { Status Visit(const MapType& t) { // FIXME assert default field names; custom ones won't roundtrip - auto map = SetWithThen(&::substrait::Type::set_allocated_map); + auto map = SetWithThen(&substrait::Type::set_allocated_map); ARROW_ASSIGN_OR_RAISE(auto key, ToProto(*t.key_type(), /*nullable=*/false, ext_set_, conversion_options_)); @@ -352,25 +347,25 @@ struct DataTypeToProtoImpl { Status Visit(const ExtensionType& t) { if (UnwrapUuid(t)) { - return SetWith(&::substrait::Type::set_allocated_uuid); + return SetWith(&substrait::Type::set_allocated_uuid); } if (auto length = UnwrapFixedChar(t)) { - SetWithThen(&::substrait::Type::set_allocated_fixed_char)->set_length(*length); + SetWithThen(&substrait::Type::set_allocated_fixed_char)->set_length(*length); return Status::OK(); } if (auto length = UnwrapVarChar(t)) { - SetWithThen(&::substrait::Type::set_allocated_varchar)->set_length(*length); + SetWithThen(&substrait::Type::set_allocated_varchar)->set_length(*length); return Status::OK(); } if (UnwrapIntervalYear(t)) { - return SetWith(&::substrait::Type::set_allocated_interval_year); + return SetWith(&substrait::Type::set_allocated_interval_year); } if (UnwrapIntervalDay(t)) { - return SetWith(&::substrait::Type::set_allocated_interval_day); + return SetWith(&substrait::Type::set_allocated_interval_day); } return NotImplemented(t); @@ -384,10 +379,10 @@ struct DataTypeToProtoImpl { Status Visit(const MonthDayNanoIntervalType& t) { return EncodeUserDefined(t); } template - Sub* SetWithThen(void (::substrait::Type::*set_allocated_sub)(Sub*)) { + Sub* SetWithThen(void (substrait::Type::*set_allocated_sub)(Sub*)) { auto sub = std::make_unique(); - sub->set_nullability(nullable_ ? ::substrait::Type::NULLABILITY_NULLABLE - : ::substrait::Type::NULLABILITY_REQUIRED); + sub->set_nullability(nullable_ ? substrait::Type::NULLABILITY_NULLABLE + : substrait::Type::NULLABILITY_REQUIRED); auto out = sub.get(); (type_->*set_allocated_sub)(sub.release()); @@ -395,44 +390,44 @@ struct DataTypeToProtoImpl { } template - Status SetWith(void (::substrait::Type::*set_allocated_sub)(Sub*)) { + Status SetWith(void (substrait::Type::*set_allocated_sub)(Sub*)) { return SetWithThen(set_allocated_sub), Status::OK(); } template Status EncodeUserDefined(const T& t) { ARROW_ASSIGN_OR_RAISE(auto anchor, ext_set_->EncodeType(t)); - auto user_defined = std::make_unique<::substrait::Type::UserDefined>(); + auto user_defined = std::make_unique(); user_defined->set_type_reference(anchor); - user_defined->set_nullability(nullable_ ? ::substrait::Type::NULLABILITY_NULLABLE - : ::substrait::Type::NULLABILITY_REQUIRED); + user_defined->set_nullability(nullable_ ? substrait::Type::NULLABILITY_NULLABLE + : substrait::Type::NULLABILITY_REQUIRED); type_->set_allocated_user_defined(user_defined.release()); return Status::OK(); } Status NotImplemented(const DataType& t) { - return Status::NotImplemented("conversion to ::substrait::Type from ", t.ToString()); + return Status::NotImplemented("conversion to substrait::Type from ", t.ToString()); } Status operator()(const DataType& type) { return VisitTypeInline(type, this); } - ::substrait::Type* type_; + substrait::Type* type_; bool nullable_; ExtensionSet* ext_set_; const ConversionOptions& conversion_options_; }; } // namespace -Result> ToProto( +Result> ToProto( const DataType& type, bool nullable, ExtensionSet* ext_set, const ConversionOptions& conversion_options) { - auto out = std::make_unique<::substrait::Type>(); + auto out = std::make_unique(); RETURN_NOT_OK( (DataTypeToProtoImpl{out.get(), nullable, ext_set, conversion_options})(type)); return std::move(out); } -Result> FromProto(const ::substrait::NamedStruct& named_struct, +Result> FromProto(const substrait::NamedStruct& named_struct, const ExtensionSet& ext_set, const ConversionOptions& conversion_options) { if (!named_struct.has_struct_()) { @@ -476,28 +471,28 @@ void ToProtoGetDepthFirstNames(const FieldVector& fields, } } // namespace -Result> ToProto( +Result> ToProto( const Schema& schema, ExtensionSet* ext_set, const ConversionOptions& conversion_options) { if (conversion_options.strictness == ConversionStrictness::EXACT_ROUNDTRIP && schema.metadata() != nullptr) { - return Status::Invalid("::substrait::NamedStruct does not support schema metadata"); + return Status::Invalid("substrait::NamedStruct does not support schema metadata"); } - auto named_struct = std::make_unique<::substrait::NamedStruct>(); + auto named_struct = std::make_unique(); auto names = named_struct->mutable_names(); names->Reserve(schema.num_fields()); ToProtoGetDepthFirstNames(schema.fields(), names); - auto struct_ = std::make_unique<::substrait::Type::Struct>(); + auto struct_ = std::make_unique(); auto types = struct_->mutable_types(); types->Reserve(schema.num_fields()); for (const auto& field : schema.fields()) { if (conversion_options.strictness == ConversionStrictness::EXACT_ROUNDTRIP && field->metadata() != nullptr) { - return Status::Invalid("::substrait::NamedStruct does not support field metadata"); + return Status::Invalid("substrait::NamedStruct does not support field metadata"); } ARROW_ASSIGN_OR_RAISE(auto type, ToProto(*field->type(), field->nullable(), ext_set,