diff --git a/.travis.yml b/.travis.yml index 11985571d30..71cfe17844c 100644 --- a/.travis.yml +++ b/.travis.yml @@ -99,6 +99,7 @@ jobs: -e ARROW_GCS=OFF -e ARROW_MIMALLOC=OFF -e ARROW_ORC=OFF + -e ARROW_ENGINE=OFF -e ARROW_PARQUET=OFF -e ARROW_S3=OFF -e CMAKE_UNITY_BUILD=ON diff --git a/cpp/CMakeLists.txt b/cpp/CMakeLists.txt index fd7027c30eb..dfef91aeb8d 100644 --- a/cpp/CMakeLists.txt +++ b/cpp/CMakeLists.txt @@ -351,7 +351,9 @@ if(ARROW_CUDA endif() if(ARROW_ENGINE) + set(ARROW_PARQUET ON) set(ARROW_COMPUTE ON) + set(ARROW_DATASET ON) endif() if(ARROW_SKYHOOK) diff --git a/cpp/cmake_modules/DefineOptions.cmake b/cpp/cmake_modules/DefineOptions.cmake index 0a43ec18f60..30b1d0e075b 100644 --- a/cpp/cmake_modules/DefineOptions.cmake +++ b/cpp/cmake_modules/DefineOptions.cmake @@ -225,7 +225,7 @@ if("${CMAKE_SOURCE_DIR}" STREQUAL "${CMAKE_CURRENT_SOURCE_DIR}") define_option(ARROW_DATASET "Build the Arrow Dataset Modules" OFF) - define_option(ARROW_ENGINE "Build the Arrow Execution Engine" OFF) + define_option(ARROW_ENGINE "Build the Arrow Query Engine Module" OFF) define_option(ARROW_FILESYSTEM "Build the Arrow Filesystem Layer" OFF) @@ -478,6 +478,16 @@ advised that if this is enabled 'install' will fail silently on components;\ that have not been built" OFF) + set(ARROW_SUBSTRAIT_REPO_DEFAULT "https://github.com/substrait-io/substrait") + define_option_string(ARROW_SUBSTRAIT_REPO + "Custom git repository URL for downloading Substrait sources.;\ +See also ARROW_SUBSTRAIT_TAG" "${ARROW_SUBSTRAIT_REPO_DEFAULT}") + + set(ARROW_SUBSTRAIT_TAG_DEFAULT "e1b4c04a1b518912f4c4065b16a1b2c0ac8e14cf") + define_option_string(ARROW_SUBSTRAIT_TAG + "Custom git hash/tag/branch for Substrait repository.;\ +See also ARROW_SUBSTRAIT_REPO" "${ARROW_SUBSTRAIT_TAG_DEFAULT}") + option(ARROW_BUILD_CONFIG_SUMMARY_JSON "Summarize build configuration in a JSON file" ON) endif() diff --git a/cpp/cmake_modules/FindArrowEngine.cmake b/cpp/cmake_modules/FindArrowEngine.cmake new file mode 100644 index 00000000000..3ee09e0de3d --- /dev/null +++ b/cpp/cmake_modules/FindArrowEngine.cmake @@ -0,0 +1,88 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +# - Find Arrow Engine (arrow/engine/api.h, libarrow_engine.a, libarrow_engine.so) +# +# This module requires Arrow from which it uses +# arrow_find_package() +# +# This module defines +# ARROW_ENGINE_FOUND, whether Arrow Engine has been found +# ARROW_ENGINE_IMPORT_LIB, +# path to libarrow_engine's import library (Windows only) +# ARROW_ENGINE_INCLUDE_DIR, directory containing headers +# ARROW_ENGINE_LIB_DIR, directory containing Arrow Engine libraries +# ARROW_ENGINE_SHARED_LIB, path to libarrow_engine's shared library +# ARROW_ENGINE_STATIC_LIB, path to libarrow_engine.a + +if(DEFINED ARROW_ENGINE_FOUND) + return() +endif() + +set(find_package_arguments) +if(${CMAKE_FIND_PACKAGE_NAME}_FIND_VERSION) + list(APPEND find_package_arguments "${${CMAKE_FIND_PACKAGE_NAME}_FIND_VERSION}") +endif() +if(${CMAKE_FIND_PACKAGE_NAME}_FIND_REQUIRED) + list(APPEND find_package_arguments REQUIRED) +endif() +if(${CMAKE_FIND_PACKAGE_NAME}_FIND_QUIETLY) + list(APPEND find_package_arguments QUIET) +endif() +find_package(Arrow ${find_package_arguments}) +find_package(Parquet ${find_package_arguments}) + +if(ARROW_FOUND AND PARQUET_FOUND) + arrow_find_package(ARROW_ENGINE + "${ARROW_HOME}" + arrow_engine + arrow/engine/api.h + ArrowEngine + arrow-engine) + if(NOT ARROW_ENGINE_VERSION) + set(ARROW_ENGINE_VERSION "${ARROW_VERSION}") + endif() +endif() + +if("${ARROW_ENGINE_VERSION}" VERSION_EQUAL "${ARROW_VERSION}") + set(ARROW_ENGINE_VERSION_MATCH TRUE) +else() + set(ARROW_ENGINE_VERSION_MATCH FALSE) +endif() + +mark_as_advanced(ARROW_ENGINE_IMPORT_LIB + ARROW_ENGINE_INCLUDE_DIR + ARROW_ENGINE_LIBS + ARROW_ENGINE_LIB_DIR + ARROW_ENGINE_SHARED_IMP_LIB + ARROW_ENGINE_SHARED_LIB + ARROW_ENGINE_STATIC_LIB + ARROW_ENGINE_VERSION + ARROW_ENGINE_VERSION_MATCH) + +find_package_handle_standard_args( + ArrowEngine + REQUIRED_VARS ARROW_ENGINE_INCLUDE_DIR ARROW_ENGINE_LIB_DIR ARROW_ENGINE_VERSION_MATCH + VERSION_VAR ARROW_ENGINE_VERSION) +set(ARROW_ENGINE_FOUND ${ArrowEngine_FOUND}) + +if(ArrowEngine_FOUND AND NOT ArrowEngine_FIND_QUIETLY) + message(STATUS "Found the Arrow Engine by ${ARROW_ENGINE_FIND_APPROACH}") + message(STATUS "Found the Arrow Engine shared library: ${ARROW_ENGINE_SHARED_LIB}") + message(STATUS "Found the Arrow Engine import library: ${ARROW_ENGINE_IMPORT_LIB}") + message(STATUS "Found the Arrow Engine static library: ${ARROW_ENGINE_STATIC_LIB}") +endif() diff --git a/cpp/cmake_modules/ThirdpartyToolchain.cmake b/cpp/cmake_modules/ThirdpartyToolchain.cmake index d85e511c1e7..54cd6a7a815 100644 --- a/cpp/cmake_modules/ThirdpartyToolchain.cmake +++ b/cpp/cmake_modules/ThirdpartyToolchain.cmake @@ -309,7 +309,8 @@ endif() if(ARROW_ORC OR ARROW_FLIGHT - OR ARROW_GANDIVA) + OR ARROW_GANDIVA + OR ARROW_ENGINE) set(ARROW_WITH_PROTOBUF ON) endif() @@ -1425,6 +1426,11 @@ macro(build_protobuf) set(PROTOBUF_VENDORED TRUE) set(PROTOBUF_PREFIX "${CMAKE_CURRENT_BINARY_DIR}/protobuf_ep-install") set(PROTOBUF_INCLUDE_DIR "${PROTOBUF_PREFIX}/include") + # This flag is based on what the user initially requested but if + # we've fallen back to building protobuf we always build it statically + # so we need to reset the flag so that we can link against it correctly + # later. + set(Protobuf_USE_STATIC_LIBS ON) # Newer protobuf releases always have a lib prefix independent from CMAKE_STATIC_LIBRARY_PREFIX set(PROTOBUF_STATIC_LIB "${PROTOBUF_PREFIX}/lib/libprotobuf${CMAKE_STATIC_LIBRARY_SUFFIX}") @@ -1531,7 +1537,7 @@ if(ARROW_WITH_PROTOBUF) PC_PACKAGE_NAMES protobuf) - if(ARROW_PROTOBUF_USE_SHARED AND MSVC_TOOLCHAIN) + if(NOT Protobuf_USE_STATIC_LIBS AND MSVC_TOOLCHAIN) add_definitions(-DPROTOBUF_USE_DLLS) endif() diff --git a/cpp/examples/arrow/CMakeLists.txt b/cpp/examples/arrow/CMakeLists.txt index e46cc7a6fe5..838d7b982c9 100644 --- a/cpp/examples/arrow/CMakeLists.txt +++ b/cpp/examples/arrow/CMakeLists.txt @@ -21,6 +21,10 @@ if(ARROW_COMPUTE) add_arrow_example(compute_register_example) endif() +if(ARROW_ENGINE) + add_arrow_example(engine_substrait_consumption EXTRA_LINK_LIBS arrow_engine_shared) +endif() + if(ARROW_COMPUTE AND ARROW_CSV) add_arrow_example(compute_and_write_csv_example) endif() diff --git a/cpp/examples/arrow/engine_substrait_consumption.cc b/cpp/examples/arrow/engine_substrait_consumption.cc new file mode 100644 index 00000000000..b0109b36888 --- /dev/null +++ b/cpp/examples/arrow/engine_substrait_consumption.cc @@ -0,0 +1,186 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#include +#include +#include + +#include +#include +#include +#include + +namespace eng = arrow::engine; +namespace cp = arrow::compute; + +#define ABORT_ON_FAILURE(expr) \ + do { \ + arrow::Status status_ = (expr); \ + if (!status_.ok()) { \ + std::cerr << status_.message() << std::endl; \ + abort(); \ + } \ + } while (0); + +class IgnoringConsumer : public cp::SinkNodeConsumer { + public: + explicit IgnoringConsumer(size_t tag) : tag_{tag} {} + + arrow::Status Consume(cp::ExecBatch batch) override { + // Consume a batch of data + // (just print its row count to stdout) + std::cout << "-" << tag_ << " consumed " << batch.length << " rows" << std::endl; + return arrow::Status::OK(); + } + + arrow::Future<> Finish() override { + // Signal to the consumer that the last batch has been delivered + // (we don't do any real work in this consumer so mark it finished immediately) + // + // The returned future should only finish when all outstanding tasks have completed + // (after this method is called Consume is guaranteed not to be called again) + std::cout << "-" << tag_ << " finished" << std::endl; + return arrow::Future<>::MakeFinished(); + } + + private: + // A unique label for instances to help distinguish logging output if a plan has + // multiple sinks + // + // In this example, this is set to the zero-based index of the relation tree in the plan + size_t tag_; +}; + +arrow::Future> GetSubstraitFromServer( + const std::string& filename) { + // Emulate server interaction by parsing hard coded JSON + std::string substrait_json = R"({ + "relations": [ + {"rel": { + "read": { + "base_schema": { + "struct": { + "types": [ {"i64": {}}, {"bool": {}} ] + }, + "names": ["i", "b"] + }, + "local_files": { + "items": [ + { + "uri_file": "file://FILENAME_PLACEHOLDER", + "format": "FILE_FORMAT_PARQUET" + } + ] + } + } + }} + ], + "extension_uris": [ + { + "extension_uri_anchor": 7, + "uri": "https://github.com/apache/arrow/blob/master/format/substrait/extension_types.yaml" + } + ], + "extensions": [ + {"extension_type": { + "extension_uri_reference": 7, + "type_anchor": 42, + "name": "null" + }}, + {"extension_type_variation": { + "extension_uri_reference": 7, + "type_variation_anchor": 23, + "name": "u8" + }}, + {"extension_function": { + "extension_uri_reference": 7, + "function_anchor": 42, + "name": "add" + }} + ] + })"; + std::string filename_placeholder = "FILENAME_PLACEHOLDER"; + substrait_json.replace(substrait_json.find(filename_placeholder), + filename_placeholder.size(), filename); + return eng::internal::SubstraitFromJSON("Plan", substrait_json); +} + +int main(int argc, char** argv) { + if (argc < 2) { + std::cout << "Please specify a parquet file to scan" << std::endl; + // Fake pass for CI + return EXIT_SUCCESS; + } + + // Plans arrive at the consumer serialized in a Buffer, using the binary protobuf + // serialization of a substrait Plan + auto maybe_serialized_plan = GetSubstraitFromServer(argv[1]).result(); + ABORT_ON_FAILURE(maybe_serialized_plan.status()); + std::shared_ptr serialized_plan = + std::move(maybe_serialized_plan).ValueOrDie(); + + // Print the received plan to stdout as JSON + arrow::Result maybe_plan_json = + eng::internal::SubstraitToJSON("Plan", *serialized_plan); + ABORT_ON_FAILURE(maybe_plan_json.status()); + std::cout << std::string(50, '#') << " received substrait::Plan:" << std::endl; + std::cout << maybe_plan_json.ValueOrDie() << std::endl; + + // The data sink(s) for plans is/are implicit in substrait plans, but explicit in + // Arrow. Therefore, deserializing a plan requires a factory for consumers: each + // time the root of a substrait relation tree is deserialized, an Arrow consumer is + // constructed into which its batches will be piped. + std::vector> consumers; + std::function()> consumer_factory = [&] { + // All batches produced by the plan will be fed into IgnoringConsumers: + auto tag = consumers.size(); + consumers.emplace_back(new IgnoringConsumer{tag}); + return consumers.back(); + }; + + // Deserialize each relation tree in the substrait plan to an Arrow compute Declaration + arrow::Result> maybe_decls = + eng::DeserializePlan(*serialized_plan, consumer_factory); + ABORT_ON_FAILURE(maybe_decls.status()); + std::vector decls = std::move(maybe_decls).ValueOrDie(); + + // It's safe to drop the serialized plan; we don't leave references to its memory + serialized_plan.reset(); + + // Construct an empty plan (note: configure Function registry and ThreadPool here) + arrow::Result> maybe_plan = cp::ExecPlan::Make(); + ABORT_ON_FAILURE(maybe_plan.status()); + std::shared_ptr plan = std::move(maybe_plan).ValueOrDie(); + + // Add decls to plan (note: configure ExecNode registry before this point) + for (const cp::Declaration& decl : decls) { + ABORT_ON_FAILURE(decl.AddToPlan(plan.get()).status()); + } + + // Validate the plan and print it to stdout + ABORT_ON_FAILURE(plan->Validate()); + std::cout << std::string(50, '#') << " produced arrow::ExecPlan:" << std::endl; + std::cout << plan->ToString() << std::endl; + + // Start the plan... + std::cout << std::string(50, '#') << " consuming batches:" << std::endl; + ABORT_ON_FAILURE(plan->StartProducing()); + + // ... and wait for it to finish + ABORT_ON_FAILURE(plan->finished().status()); + return EXIT_SUCCESS; +} diff --git a/cpp/src/arrow/CMakeLists.txt b/cpp/src/arrow/CMakeLists.txt index b984bc10425..e4421345189 100644 --- a/cpp/src/arrow/CMakeLists.txt +++ b/cpp/src/arrow/CMakeLists.txt @@ -722,6 +722,10 @@ if(ARROW_COMPUTE) add_subdirectory(compute) endif() +if(ARROW_ENGINE) + add_subdirectory(engine) +endif() + if(ARROW_CUDA) add_subdirectory(gpu) endif() diff --git a/cpp/src/arrow/array/array_base.cc b/cpp/src/arrow/array/array_base.cc index c31a7b72483..11b6b1630c5 100644 --- a/cpp/src/arrow/array/array_base.cc +++ b/cpp/src/arrow/array/array_base.cc @@ -282,6 +282,8 @@ std::string Array::ToString() const { return ss.str(); } +void PrintTo(const Array& x, std::ostream* os) { *os << x.ToString(); } + Result> Array::View( const std::shared_ptr& out_type) const { ARROW_ASSIGN_OR_RAISE(std::shared_ptr result, diff --git a/cpp/src/arrow/array/array_base.h b/cpp/src/arrow/array/array_base.h index b6b769cf033..c17daad48fa 100644 --- a/cpp/src/arrow/array/array_base.h +++ b/cpp/src/arrow/array/array_base.h @@ -187,10 +187,11 @@ class ARROW_EXPORT Array { Status ValidateFull() const; protected: - Array() : null_bitmap_data_(NULLPTR) {} + Array() = default; + ARROW_DEFAULT_MOVE_AND_ASSIGN(Array); std::shared_ptr data_; - const uint8_t* null_bitmap_data_; + const uint8_t* null_bitmap_data_ = NULLPTR; /// Protected method for constructors void SetData(const std::shared_ptr& data) { @@ -204,6 +205,8 @@ class ARROW_EXPORT Array { private: ARROW_DISALLOW_COPY_AND_ASSIGN(Array); + + ARROW_EXPORT friend void PrintTo(const Array& x, std::ostream* os); }; static inline std::ostream& operator<<(std::ostream& os, const Array& x) { diff --git a/cpp/src/arrow/array/builder_base.h b/cpp/src/arrow/array/builder_base.h index a513bf0f4ab..4b0e0147cea 100644 --- a/cpp/src/arrow/array/builder_base.h +++ b/cpp/src/arrow/array/builder_base.h @@ -28,6 +28,7 @@ #include "arrow/array/array_primitive.h" #include "arrow/buffer.h" #include "arrow/buffer_builder.h" +#include "arrow/result.h" #include "arrow/status.h" #include "arrow/type_fwd.h" #include "arrow/util/macros.h" @@ -286,6 +287,13 @@ ARROW_EXPORT Status MakeBuilder(MemoryPool* pool, const std::shared_ptr& type, std::unique_ptr* out); +inline Result> MakeBuilder( + const std::shared_ptr& type, MemoryPool* pool = default_memory_pool()) { + std::unique_ptr out; + ARROW_RETURN_NOT_OK(MakeBuilder(pool, type, &out)); + return std::move(out); +} + /// \brief Construct an empty ArrayBuilder corresponding to the data /// type, where any top-level or nested dictionary builders return the /// exact index type specified by the type. @@ -293,6 +301,13 @@ ARROW_EXPORT Status MakeBuilderExactIndex(MemoryPool* pool, const std::shared_ptr& type, std::unique_ptr* out); +inline Result> MakeBuilderExactIndex( + const std::shared_ptr& type, MemoryPool* pool = default_memory_pool()) { + std::unique_ptr out; + ARROW_RETURN_NOT_OK(MakeBuilderExactIndex(pool, type, &out)); + return std::move(out); +} + /// \brief Construct an empty DictionaryBuilder initialized optionally /// with a pre-existing dictionary /// \param[in] pool the MemoryPool to use for allocations @@ -304,4 +319,12 @@ Status MakeDictionaryBuilder(MemoryPool* pool, const std::shared_ptr& const std::shared_ptr& dictionary, std::unique_ptr* out); +inline Result> MakeDictionaryBuilder( + const std::shared_ptr& type, const std::shared_ptr& dictionary, + MemoryPool* pool = default_memory_pool()) { + std::unique_ptr out; + ARROW_RETURN_NOT_OK(MakeDictionaryBuilder(pool, type, dictionary, &out)); + return std::move(out); +} + } // namespace arrow diff --git a/cpp/src/arrow/compute/exec/expression_internal.h b/cpp/src/arrow/compute/exec/expression_internal.h index dc38924d932..f8c686d2c81 100644 --- a/cpp/src/arrow/compute/exec/expression_internal.h +++ b/cpp/src/arrow/compute/exec/expression_internal.h @@ -29,9 +29,6 @@ #include "arrow/util/logging.h" namespace arrow { - -using internal::checked_cast; - namespace compute { struct KnownFieldValues { @@ -213,7 +210,7 @@ struct Comparison { inline const compute::CastOptions* GetCastOptions(const Expression::Call& call) { if (call.function_name != "cast") return nullptr; - return checked_cast(call.options.get()); + return ::arrow::internal::checked_cast(call.options.get()); } inline bool IsSetLookup(const std::string& function) { @@ -223,7 +220,8 @@ inline bool IsSetLookup(const std::string& function) { inline const compute::MakeStructOptions* GetMakeStructOptions( const Expression::Call& call) { if (call.function_name != "make_struct") return nullptr; - return checked_cast(call.options.get()); + return ::arrow::internal::checked_cast( + call.options.get()); } /// A helper for unboxing an Expression composed of associative function calls. @@ -281,7 +279,8 @@ inline Result> GetFunction( return exec_context->func_registry()->GetFunction(call.function_name); } // XXX this special case is strange; why not make "cast" a ScalarFunction? - const auto& to_type = checked_cast(*call.options).to_type; + const auto& to_type = + ::arrow::internal::checked_cast(*call.options).to_type; return compute::GetCastFunction(to_type); } diff --git a/cpp/src/arrow/csv/column_decoder_test.cc b/cpp/src/arrow/csv/column_decoder_test.cc index c8b96e04696..ebac7a3da2f 100644 --- a/cpp/src/arrow/csv/column_decoder_test.cc +++ b/cpp/src/arrow/csv/column_decoder_test.cc @@ -22,6 +22,7 @@ #include +#include "arrow/array/array_base.h" #include "arrow/csv/column_decoder.h" #include "arrow/csv/options.h" #include "arrow/csv/test_common.h" diff --git a/cpp/src/arrow/dataset/scanner.cc b/cpp/src/arrow/dataset/scanner.cc index c99316f764a..7ecfb3ead46 100644 --- a/cpp/src/arrow/dataset/scanner.cc +++ b/cpp/src/arrow/dataset/scanner.cc @@ -766,6 +766,17 @@ Result MakeScanNode(compute::ExecPlan* plan, scan_options->filter.Bind(*dataset->schema())); } + // If no projection schema is specified we will use a default projection. In + // general we should not be able to get here if using the ScannerBuilder but + // it is possible to get here if scan_options is used directly. To be cleaned up + // in ARROW-12311 + if (!scan_options->projected_schema) { + ARROW_ASSIGN_OR_RAISE(auto projection_descr, + ProjectionDescr::Default(*dataset->schema())); + scan_options->projected_schema = std::move(projection_descr.schema); + scan_options->projection = projection_descr.expression; + } + if (!scan_options->projection.IsBound()) { auto fields = dataset->schema()->fields(); for (const auto& aug_field : kAugmentedFields) { diff --git a/cpp/src/arrow/datum.h b/cpp/src/arrow/datum.h index 514b4247316..bce53decd4a 100644 --- a/cpp/src/arrow/datum.h +++ b/cpp/src/arrow/datum.h @@ -149,9 +149,17 @@ struct ARROW_EXPORT Datum { template ::value, bool IsScalar = std::is_base_of::value, typename = enable_if_t> - Datum(const std::shared_ptr& value) // NOLINT implicit conversion + Datum(std::shared_ptr value) // NOLINT implicit conversion : Datum(std::shared_ptr::type>( - value)) {} + std::move(value))) {} + + // Cast from subtypes of Array or Scalar to Datum + template ::type, + bool IsArray = std::is_base_of::value, + bool IsScalar = std::is_base_of::value, + typename = enable_if_t> + Datum(T&& value) // NOLINT implicit conversion + : Datum(std::make_shared(std::forward(value))) {} // Convenience constructors explicit Datum(bool value); diff --git a/cpp/src/arrow/engine/ArrowEngineConfig.cmake.in b/cpp/src/arrow/engine/ArrowEngineConfig.cmake.in new file mode 100644 index 00000000000..8fafcda3864 --- /dev/null +++ b/cpp/src/arrow/engine/ArrowEngineConfig.cmake.in @@ -0,0 +1,38 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# +# This config sets the following variables in your project:: +# +# ArrowEngine_FOUND - true if Arrow Engine found on the system +# +# This config sets the following targets in your project:: +# +# arrow_engine_shared - for linked as shared library if shared library is built +# arrow_engine_static - for linked as static library if static library is built + +@PACKAGE_INIT@ + +include(CMakeFindDependencyMacro) +find_dependency(Arrow) +find_dependency(ArrowDataset) +find_dependency(Parquet) + +# Load targets only once. If we load targets multiple times, CMake reports +# already existent target error. +if(NOT (TARGET arrow_engine_shared OR TARGET arrow_engine_static)) + include("${CMAKE_CURRENT_LIST_DIR}/ArrowEngineTargets.cmake") +endif() diff --git a/cpp/src/arrow/engine/CMakeLists.txt b/cpp/src/arrow/engine/CMakeLists.txt new file mode 100644 index 00000000000..0f00a6600f4 --- /dev/null +++ b/cpp/src/arrow/engine/CMakeLists.txt @@ -0,0 +1,143 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +add_custom_target(arrow_engine) + +arrow_install_all_headers("arrow/engine") + +set(ARROW_ENGINE_LINK_LIBS ${ARROW_PROTOBUF_LIBPROTOBUF}) + +#if(WIN32) +# list(APPEND ARROW_ENGINE_LINK_LIBS ws2_32.lib) +#endif() + +set(ARROW_ENGINE_SRCS + substrait/expression_internal.cc + substrait/extension_set.cc + substrait/extension_types.cc + substrait/serde.cc + substrait/plan_internal.cc + substrait/relation_internal.cc + substrait/type_internal.cc) + +set(SUBSTRAIT_LOCAL_DIR "${CMAKE_CURRENT_BINARY_DIR}/substrait") +set(SUBSTRAIT_GEN_DIR "${CMAKE_CURRENT_BINARY_DIR}/generated") +set(SUBSTRAIT_PROTOS + capabilities + expression + extensions/extensions + function + parameterized_types + plan + relations + type + type_expressions) + +externalproject_add(substrait_ep + GIT_REPOSITORY "${ARROW_SUBSTRAIT_REPO}" + GIT_TAG "${ARROW_SUBSTRAIT_TAG}" + SOURCE_DIR "${SUBSTRAIT_LOCAL_DIR}" + CONFIGURE_COMMAND "" + BUILD_COMMAND "" + INSTALL_COMMAND "") + +set(SUBSTRAIT_SUPPRESSED_WARNINGS) +if(MSVC) + # Protobuf generated files trigger some spurious warnings on MSVC. + + # Implicit conversion from uint64_t to uint32_t: + list(APPEND SUBSTRAIT_SUPPRESSED_WARNINGS "/wd4244") + + # Missing dll-interface: + list(APPEND SUBSTRAIT_SUPPRESSED_WARNINGS "/wd4251") +endif() + +set(SUBSTRAIT_PROTO_GEN_ALL) +foreach(SUBSTRAIT_PROTO ${SUBSTRAIT_PROTOS}) + set(SUBSTRAIT_PROTO_GEN "${SUBSTRAIT_GEN_DIR}/substrait/${SUBSTRAIT_PROTO}.pb") + + foreach(EXT h cc) + set_source_files_properties("${SUBSTRAIT_PROTO_GEN}.${EXT}" + PROPERTIES COMPILE_OPTIONS + "${SUBSTRAIT_SUPPRESSED_WARNINGS}" + GENERATED TRUE + SKIP_UNITY_BUILD_INCLUSION TRUE) + add_custom_command(OUTPUT "${SUBSTRAIT_PROTO_GEN}.${EXT}" + COMMAND ${ARROW_PROTOBUF_PROTOC} "-I${SUBSTRAIT_LOCAL_DIR}/proto" + "--cpp_out=${SUBSTRAIT_GEN_DIR}" + "${SUBSTRAIT_LOCAL_DIR}/proto/substrait/${SUBSTRAIT_PROTO}.proto" + DEPENDS ${PROTO_DEPENDS} substrait_ep) + list(APPEND SUBSTRAIT_PROTO_GEN_ALL "${SUBSTRAIT_PROTO_GEN}.${EXT}") + endforeach() + + list(APPEND ARROW_ENGINE_SRCS "${SUBSTRAIT_PROTO_GEN}.cc") +endforeach() + +add_custom_target(substrait_gen ALL DEPENDS ${SUBSTRAIT_PROTO_GEN_ALL}) + +find_package(Git) +add_custom_target(substrait_gen_verify + COMMENT "Verifying that generated substrait accessors are consistent with \ + ARROW_SUBSTRAIT_REPO_AND_TAG='${ARROW_SUBSTRAIT_REPO_AND_TAG}'" + COMMAND ${GIT_EXECUTABLE} diff --exit-code ${SUBSTRAIT_GEN_DIR} + DEPENDS substrait_gen_clear + DEPENDS substrait_gen) + +add_arrow_lib(arrow_engine + CMAKE_PACKAGE_NAME + ArrowEngine + PKG_CONFIG_NAME + arrow-engine + OUTPUTS + ARROW_ENGINE_LIBRARIES + SOURCES + ${ARROW_ENGINE_SRCS} + PRECOMPILED_HEADERS + "$<$:arrow/engine/pch.h>" + SHARED_LINK_FLAGS + ${ARROW_VERSION_SCRIPT_FLAGS} # Defined in cpp/arrow/CMakeLists.txt + SHARED_LINK_LIBS + arrow_shared + arrow_dataset_shared + ${ARROW_ENGINE_LINK_LIBS} + STATIC_LINK_LIBS + arrow_static + arrow_dataset_static + ${ARROW_ENGINE_LINK_LIBS} + PRIVATE_INCLUDES + ${SUBSTRAIT_GEN_DIR}) + +foreach(LIB_TARGET ${ARROW_ENGINE_LIBRARIES}) + target_compile_definitions(${LIB_TARGET} PRIVATE ARROW_ENGINE_EXPORTING) +endforeach() + +set(ARROW_ENGINE_TEST_LINK_LIBS ${ARROW_ENGINE_LINK_lIBS} ${ARROW_TEST_LINK_LIBS}) +if(ARROW_TEST_LINKAGE STREQUAL "static") + list(APPEND ARROW_ENGINE_TEST_LINK_LIBS arrow_engine_static) +else() + list(APPEND ARROW_ENGINE_TEST_LINK_LIBS arrow_engine_shared) +endif() + +add_arrow_test(substrait_test + SOURCES + substrait/serde_test.cc + EXTRA_LINK_LIBS + ${ARROW_ENGINE_TEST_LINK_LIBS} + PREFIX + "arrow-engine" + LABELS + "arrow_engine") diff --git a/cpp/src/arrow/engine/api.h b/cpp/src/arrow/engine/api.h new file mode 100644 index 00000000000..de996e4d264 --- /dev/null +++ b/cpp/src/arrow/engine/api.h @@ -0,0 +1,23 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +// This API is EXPERIMENTAL. + +#pragma once + +#include "arrow/engine/substrait/extension_types.h" +#include "arrow/engine/substrait/serde.h" diff --git a/cpp/src/arrow/engine/arrow-engine.pc.in b/cpp/src/arrow/engine/arrow-engine.pc.in new file mode 100644 index 00000000000..90fba82a8e9 --- /dev/null +++ b/cpp/src/arrow/engine/arrow-engine.pc.in @@ -0,0 +1,25 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +libdir=@CMAKE_INSTALL_FULL_LIBDIR@ +includedir=@CMAKE_INSTALL_FULL_INCLUDEDIR@ + +Name: Apache Arrow Engine +Description: Apache Arrow's Query Engine. +Version: @ARROW_VERSION@ +Requires: arrow +Libs: -L${libdir} -larrow_engine diff --git a/cpp/src/arrow/engine/pch.h b/cpp/src/arrow/engine/pch.h new file mode 100644 index 00000000000..ddb4c120f2a --- /dev/null +++ b/cpp/src/arrow/engine/pch.h @@ -0,0 +1,23 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +// Often-used headers, for precompiling. +// If updating this header, please make sure you check compilation speed +// before checking in. Adding headers which are not used extremely often +// may incur a slowdown, since it makes the precompiled header heavier to load. + +#include "arrow/pch.h" diff --git a/cpp/src/arrow/engine/simple_extension_type_internal.h b/cpp/src/arrow/engine/simple_extension_type_internal.h new file mode 100644 index 00000000000..b177425a9a9 --- /dev/null +++ b/cpp/src/arrow/engine/simple_extension_type_internal.h @@ -0,0 +1,196 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#pragma once + +#include +#include +#include +#include + +#include "arrow/extension_type.h" +#include "arrow/util/logging.h" +#include "arrow/util/optional.h" +#include "arrow/util/reflection_internal.h" +#include "arrow/util/string.h" + +namespace arrow { +namespace engine { + +/// \brief A helper class for creating simple extension types +/// +/// Extension types can be parameterized by flat structs +/// +/// Each item in the struct will be serialized and deserialized using +/// the STL insertion and extraction operators (i.e. << and >>). +/// +/// Note: The serialization is a very barebones JSON-like format and +/// probably shouldn't be hand-edited + +template GetStorage(const Params&)> +class SimpleExtensionType : public ExtensionType { + public: + explicit SimpleExtensionType(std::shared_ptr storage_type, Params params = {}) + : ExtensionType(std::move(storage_type)), params_(std::move(params)) {} + + static std::shared_ptr Make(Params params) { + auto storage_type = GetStorage(params); + return std::make_shared(std::move(storage_type), + std::move(params)); + } + + /// \brief Returns the parameters object for the type + /// + /// If the type is not an instance of this extension type then nullptr will be returned + static const Params* GetIf(const DataType& type) { + if (type.id() != Type::EXTENSION) return nullptr; + + const auto& ext_type = ::arrow::internal::checked_cast(type); + if (ext_type.extension_name() != kExtensionName) return nullptr; + + return &::arrow::internal::checked_cast(type).params_; + } + + std::string extension_name() const override { return kExtensionName.to_string(); } + + std::string ToString() const override { return "extension<" + this->Serialize() + ">"; } + + /// \brief A comparator which returns true iff all parameter properties are equal + struct ExtensionEqualsImpl { + ExtensionEqualsImpl(const Params& l, const Params& r) : left_(l), right_(r) { + kProperties->ForEach(*this); + } + + template + void operator()(const Property& prop, size_t i) { + equal_ &= prop.get(left_) == prop.get(right_); + } + + const Params& left_; + const Params& right_; + bool equal_ = true; + }; + + bool ExtensionEquals(const ExtensionType& other) const override { + if (kExtensionName != other.extension_name()) return false; + const auto& other_params = static_cast(other).params_; + return ExtensionEqualsImpl(params_, other_params).equal_; + } + + std::shared_ptr MakeArray(std::shared_ptr data) const override { + DCHECK_EQ(data->type->id(), Type::EXTENSION); + DCHECK_EQ(static_cast(*data->type).extension_name(), + kExtensionName); + return std::make_shared(data); + } + + struct DeserializeImpl { + explicit DeserializeImpl(util::string_view repr) { + Init(kExtensionName, repr, kProperties->size()); + kProperties->ForEach(*this); + } + + void Fail() { params_ = util::nullopt; } + + void Init(util::string_view class_name, util::string_view repr, + size_t num_properties) { + if (!repr.starts_with(class_name)) return Fail(); + + repr = repr.substr(class_name.size()); + if (repr.empty()) return Fail(); + if (repr.front() != '{') return Fail(); + if (repr.back() != '}') return Fail(); + + repr = repr.substr(1, repr.size() - 2); + members_ = ::arrow::internal::SplitString(repr, ','); + if (members_.size() != num_properties) return Fail(); + } + + template + void operator()(const Property& prop, size_t i) { + if (!params_) return; + + auto first_colon = members_[i].find_first_of(':'); + if (first_colon == util::string_view::npos) return Fail(); + + auto name = members_[i].substr(0, first_colon); + if (name != prop.name()) return Fail(); + + auto value_repr = members_[i].substr(first_colon + 1); + typename Property::Type value; + try { + std::stringstream ss(value_repr.to_string()); + ss >> value; + if (!ss.eof()) return Fail(); + } catch (...) { + return Fail(); + } + prop.set(&*params_, std::move(value)); + } + + util::optional params_; + std::vector members_; + }; + Result> Deserialize( + std::shared_ptr storage_type, + const std::string& serialized) const override { + if (auto params = DeserializeImpl(serialized).params_) { + if (!storage_type->Equals(GetStorage(*params))) { + return Status::Invalid("Invalid storage type for ", kExtensionName, ": ", + storage_type->ToString(), " (expected ", + GetStorage(*params)->ToString(), ")"); + } + + return std::make_shared(std::move(storage_type), + std::move(*params)); + } + + return Status::Invalid("Could not parse parameters for extension type ", + extension_name(), " from ", serialized); + } + + struct SerializeImpl { + explicit SerializeImpl(const Params& params) + : params_(params), members_(kProperties->size()) { + kProperties->ForEach(*this); + } + + template + void operator()(const Property& prop, size_t i) { + std::stringstream ss; + ss << prop.name() << ":" << prop.get(params_); + members_[i] = ss.str(); + } + + std::string Finish() { + return kExtensionName.to_string() + "{" + + ::arrow::internal::JoinStrings(members_, ",") + "}"; + } + + const Params& params_; + std::vector members_; + }; + std::string Serialize() const override { return SerializeImpl(params_).Finish(); } + + private: + Params params_; +}; + +} // namespace engine +} // namespace arrow diff --git a/cpp/src/arrow/engine/substrait/expression_internal.cc b/cpp/src/arrow/engine/substrait/expression_internal.cc new file mode 100644 index 00000000000..686ef5d5572 --- /dev/null +++ b/cpp/src/arrow/engine/substrait/expression_internal.cc @@ -0,0 +1,896 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +// This API is EXPERIMENTAL. + +#include "arrow/engine/substrait/expression_internal.h" + +#include + +#include "arrow/builder.h" +#include "arrow/compute/exec/expression.h" +#include "arrow/compute/exec/expression_internal.h" +#include "arrow/engine/substrait/extension_types.h" +#include "arrow/engine/substrait/type_internal.h" +#include "arrow/result.h" +#include "arrow/status.h" +#include "arrow/util/make_unique.h" +#include "arrow/visit_scalar_inline.h" + +namespace arrow { + +using internal::checked_cast; + +namespace engine { + +namespace internal { +using ::arrow::internal::make_unique; +} // namespace internal + +Result FromProto(const substrait::Expression& expr, + const ExtensionSet& ext_set) { + switch (expr.rex_type_case()) { + case substrait::Expression::kLiteral: { + ARROW_ASSIGN_OR_RAISE(auto datum, FromProto(expr.literal(), ext_set)); + return compute::literal(std::move(datum)); + } + + case substrait::Expression::kSelection: { + if (!expr.selection().has_direct_reference()) break; + + util::optional out; + if (expr.selection().has_expression()) { + ARROW_ASSIGN_OR_RAISE(out, FromProto(expr.selection().expression(), ext_set)); + } + + const auto* ref = &expr.selection().direct_reference(); + while (ref != nullptr) { + switch (ref->reference_type_case()) { + case substrait::Expression::ReferenceSegment::kStructField: { + auto index = ref->struct_field().field(); + if (!out) { + // Root StructField (column selection) + out = compute::field_ref(FieldRef(index)); + } else if (auto out_ref = out->field_ref()) { + // Nested StructFields on the root (selection of struct-typed column + // combined with selecting struct fields) + out = compute::field_ref(FieldRef(*out_ref, index)); + } else if (out->call() && out->call()->function_name == "struct_field") { + // Nested StructFields on top of an arbitrary expression + std::static_pointer_cast( + out->call()->options) + ->indices.push_back(index); + } else { + // First StructField on top of an arbitrary expression + out = compute::call("struct_field", {std::move(*out)}, + arrow::compute::StructFieldOptions({index})); + } + + // Segment handled, continue with child segment (if any) + if (ref->struct_field().has_child()) { + ref = &ref->struct_field().child(); + } else { + ref = nullptr; + } + break; + } + case substrait::Expression::ReferenceSegment::kListElement: { + if (!out) { + // Root ListField (illegal) + return Status::Invalid( + "substrait::ListElement cannot take a Relation as an argument"); + } + + // ListField on top of an arbitrary expression + out = compute::call( + "list_element", + {std::move(*out), compute::literal(ref->list_element().offset())}); + + // Segment handled, continue with child segment (if any) + if (ref->list_element().has_child()) { + ref = &ref->list_element().child(); + } else { + ref = nullptr; + } + break; + } + default: + // Unimplemented construct, break out of loop + out.reset(); + ref = nullptr; + } + } + if (out) { + return *std::move(out); + } + break; + } + + case substrait::Expression::kIfThen: { + const auto& if_then = expr.if_then(); + if (!if_then.has_else_()) break; + if (if_then.ifs_size() == 0) break; + + if (if_then.ifs_size() == 1) { + ARROW_ASSIGN_OR_RAISE(auto if_, FromProto(if_then.ifs(0).if_(), ext_set)); + ARROW_ASSIGN_OR_RAISE(auto then, FromProto(if_then.ifs(0).then(), ext_set)); + ARROW_ASSIGN_OR_RAISE(auto else_, FromProto(if_then.else_(), ext_set)); + return compute::call("if_else", + {std::move(if_), std::move(then), std::move(else_)}); + } + + std::vector conditions, args; + std::vector condition_names; + conditions.reserve(if_then.ifs_size()); + condition_names.reserve(if_then.ifs_size()); + size_t name_counter = 0; + args.reserve(if_then.ifs_size() + 2); + args.emplace_back(); + for (const auto& if_ : if_then.ifs()) { + ARROW_ASSIGN_OR_RAISE(auto compute_if, FromProto(if_.if_(), ext_set)); + ARROW_ASSIGN_OR_RAISE(auto compute_then, FromProto(if_.then(), ext_set)); + conditions.emplace_back(std::move(compute_if)); + args.emplace_back(std::move(compute_then)); + condition_names.emplace_back("cond" + std::to_string(++name_counter)); + } + ARROW_ASSIGN_OR_RAISE(auto compute_else, FromProto(if_then.else_(), ext_set)); + args.emplace_back(std::move(compute_else)); + args[0] = compute::call("make_struct", std::move(conditions), + compute::MakeStructOptions(condition_names)); + return compute::call("case_when", std::move(args)); + } + + case substrait::Expression::kScalarFunction: { + const auto& scalar_fn = expr.scalar_function(); + + ARROW_ASSIGN_OR_RAISE(auto decoded_function, + ext_set.DecodeFunction(scalar_fn.function_reference())); + + std::vector arguments(scalar_fn.args_size()); + for (int i = 0; i < scalar_fn.args_size(); ++i) { + ARROW_ASSIGN_OR_RAISE(arguments[i], FromProto(scalar_fn.args(i), ext_set)); + } + + return compute::call(decoded_function.name.to_string(), std::move(arguments)); + } + + default: + break; + } + + return Status::NotImplemented( + "conversion to arrow::compute::Expression from Substrait expression ", + expr.DebugString()); +} + +Result FromProto(const substrait::Expression::Literal& lit, + const ExtensionSet& ext_set) { + if (lit.nullable()) { + // FIXME not sure how this field should be interpreted and there's no way to round + // trip it through arrow + return Status::Invalid( + "Nullable Literals - Literal.nullable must be left at the default"); + } + + switch (lit.literal_type_case()) { + case substrait::Expression::Literal::kBoolean: + return Datum(lit.boolean()); + + case substrait::Expression::Literal::kI8: + return Datum(static_cast(lit.i8())); + case substrait::Expression::Literal::kI16: + return Datum(static_cast(lit.i16())); + case substrait::Expression::Literal::kI32: + return Datum(static_cast(lit.i32())); + case substrait::Expression::Literal::kI64: + return Datum(static_cast(lit.i64())); + + case substrait::Expression::Literal::kFp32: + return Datum(lit.fp32()); + case substrait::Expression::Literal::kFp64: + return Datum(lit.fp64()); + + case substrait::Expression::Literal::kString: + return Datum(lit.string()); + case substrait::Expression::Literal::kBinary: + return Datum(BinaryScalar(lit.binary())); + + case substrait::Expression::Literal::kTimestamp: + return Datum( + TimestampScalar(static_cast(lit.timestamp()), TimeUnit::MICRO)); + + case substrait::Expression::Literal::kTimestampTz: + return Datum(TimestampScalar(static_cast(lit.timestamp_tz()), + TimeUnit::MICRO, TimestampTzTimezoneString())); + + case substrait::Expression::Literal::kDate: + return Datum(Date32Scalar(lit.date())); + case substrait::Expression::Literal::kTime: + return Datum(Time64Scalar(lit.time(), TimeUnit::MICRO)); + + case substrait::Expression::Literal::kIntervalYearToMonth: + case substrait::Expression::Literal::kIntervalDayToSecond: { + Int32Builder builder; + std::shared_ptr type; + if (lit.has_interval_year_to_month()) { + RETURN_NOT_OK(builder.Append(lit.interval_year_to_month().years())); + RETURN_NOT_OK(builder.Append(lit.interval_year_to_month().months())); + type = interval_year(); + } else { + RETURN_NOT_OK(builder.Append(lit.interval_day_to_second().days())); + RETURN_NOT_OK(builder.Append(lit.interval_day_to_second().seconds())); + type = interval_day(); + } + ARROW_ASSIGN_OR_RAISE(auto array, builder.Finish()); + return Datum( + ExtensionScalar(FixedSizeListScalar(std::move(array)), std::move(type))); + } + + case substrait::Expression::Literal::kUuid: + return Datum(ExtensionScalar(FixedSizeBinaryScalar(lit.uuid()), uuid())); + + case substrait::Expression::Literal::kFixedChar: + return Datum( + ExtensionScalar(FixedSizeBinaryScalar(lit.fixed_char()), + fixed_char(static_cast(lit.fixed_char().size())))); + + case substrait::Expression::Literal::kVarChar: + return Datum( + ExtensionScalar(StringScalar(lit.var_char().value()), + varchar(static_cast(lit.var_char().length())))); + + case substrait::Expression::Literal::kFixedBinary: + return Datum(FixedSizeBinaryScalar(lit.fixed_binary())); + + case substrait::Expression::Literal::kDecimal: { + if (lit.decimal().value().size() != sizeof(Decimal128)) { + return Status::Invalid("Decimal literal had ", lit.decimal().value().size(), + " bytes (expected ", sizeof(Decimal128), ")"); + } + + Decimal128 value; + std::memcpy(value.mutable_native_endian_bytes(), lit.decimal().value().data(), + sizeof(Decimal128)); +#if !ARROW_LITTLE_ENDIAN + std::reverse(value.mutable_native_endian_bytes(), + value.mutable_native_endian_bytes() + sizeof(Decimal128)); +#endif + auto type = decimal128(lit.decimal().precision(), lit.decimal().scale()); + return Datum(Decimal128Scalar(value, std::move(type))); + } + + case substrait::Expression::Literal::kStruct: { + const auto& struct_ = lit.struct_(); + + ScalarVector fields(struct_.fields_size()); + for (int i = 0; i < struct_.fields_size(); ++i) { + ARROW_ASSIGN_OR_RAISE(auto field, FromProto(struct_.fields(i), ext_set)); + DCHECK(field.is_scalar()); + fields[i] = field.scalar(); + } + + // Note that Substrait struct types don't have field names, but Arrow does, so we + // just use empty strings for them. + std::vector field_names(fields.size(), ""); + + ARROW_ASSIGN_OR_RAISE( + auto scalar, StructScalar::Make(std::move(fields), std::move(field_names))); + return Datum(std::move(scalar)); + } + + case substrait::Expression::Literal::kList: { + const auto& list = lit.list(); + if (list.values_size() == 0) { + return Status::Invalid( + "substrait::Expression::Literal::List had no values; should have been an " + "substrait::Expression::Literal::EmptyList"); + } + + std::shared_ptr element_type; + + ScalarVector values(list.values_size()); + for (int i = 0; i < list.values_size(); ++i) { + ARROW_ASSIGN_OR_RAISE(auto value, FromProto(list.values(i), ext_set)); + DCHECK(value.is_scalar()); + values[i] = value.scalar(); + if (element_type) { + if (!value.type()->Equals(*element_type)) { + return Status::Invalid( + list.DebugString(), + " has a value whose type doesn't match the other list values"); + } + } else { + element_type = value.type(); + } + } + + ARROW_ASSIGN_OR_RAISE(auto builder, MakeBuilder(element_type)); + RETURN_NOT_OK(builder->AppendScalars(values)); + ARROW_ASSIGN_OR_RAISE(auto arr, builder->Finish()); + return Datum(ListScalar(std::move(arr))); + } + + case substrait::Expression::Literal::kMap: { + const auto& map = lit.map(); + if (map.key_values_size() == 0) { + return Status::Invalid( + "substrait::Expression::Literal::Map had no values; should have been an " + "substrait::Expression::Literal::EmptyMap"); + } + + std::shared_ptr key_type, value_type; + ScalarVector keys(map.key_values_size()), values(map.key_values_size()); + for (int i = 0; i < map.key_values_size(); ++i) { + const auto& kv = map.key_values(i); + + static const std::array kMissing = {"key and value", "value", + "key", nullptr}; + if (auto missing = kMissing[kv.has_key() + kv.has_value() * 2]) { + return Status::Invalid("While converting to MapScalar encountered missing ", + missing, " in ", map.DebugString()); + } + ARROW_ASSIGN_OR_RAISE(auto key, FromProto(kv.key(), ext_set)); + ARROW_ASSIGN_OR_RAISE(auto value, FromProto(kv.value(), ext_set)); + + DCHECK(key.is_scalar()); + DCHECK(value.is_scalar()); + + keys[i] = key.scalar(); + values[i] = value.scalar(); + + if (key_type) { + if (!key.type()->Equals(*key_type)) { + return Status::Invalid(map.DebugString(), + " has a key whose type doesn't match key_type"); + } + } else { + key_type = value.type(); + } + + if (value_type) { + if (!value.type()->Equals(*value_type)) { + return Status::Invalid(map.DebugString(), + " has a value whose type doesn't match value_type"); + } + } else { + value_type = value.type(); + } + } + + ARROW_ASSIGN_OR_RAISE(auto key_builder, MakeBuilder(key_type)); + ARROW_ASSIGN_OR_RAISE(auto value_builder, MakeBuilder(value_type)); + RETURN_NOT_OK(key_builder->AppendScalars(keys)); + RETURN_NOT_OK(value_builder->AppendScalars(values)); + ARROW_ASSIGN_OR_RAISE(auto key_arr, key_builder->Finish()); + ARROW_ASSIGN_OR_RAISE(auto value_arr, value_builder->Finish()); + ARROW_ASSIGN_OR_RAISE( + auto kv_arr, + StructArray::Make(ArrayVector{std::move(key_arr), std::move(value_arr)}, + std::vector{"key", "value"})); + return Datum(std::make_shared(std::move(kv_arr))); + } + + case substrait::Expression::Literal::kEmptyList: { + ARROW_ASSIGN_OR_RAISE(auto type_nullable, + FromProto(lit.empty_list().type(), ext_set)); + ARROW_ASSIGN_OR_RAISE(auto values, MakeEmptyArray(type_nullable.first)); + return ListScalar{std::move(values)}; + } + + case substrait::Expression::Literal::kEmptyMap: { + ARROW_ASSIGN_OR_RAISE(auto key_type_nullable, + FromProto(lit.empty_map().key(), ext_set)); + ARROW_ASSIGN_OR_RAISE(auto keys, + MakeEmptyArray(std::move(key_type_nullable.first))); + + ARROW_ASSIGN_OR_RAISE(auto value_type_nullable, + FromProto(lit.empty_map().value(), ext_set)); + ARROW_ASSIGN_OR_RAISE(auto values, + MakeEmptyArray(std::move(value_type_nullable.first))); + + auto map_type = std::make_shared(keys->type(), values->type()); + ARROW_ASSIGN_OR_RAISE( + auto key_values, + StructArray::Make( + {std::move(keys), std::move(values)}, + checked_cast(*map_type).value_type()->fields())); + + return MapScalar{std::move(key_values)}; + } + + case substrait::Expression::Literal::kNull: { + ARROW_ASSIGN_OR_RAISE(auto type_nullable, FromProto(lit.null(), ext_set)); + if (!type_nullable.second) { + return Status::Invalid("Substrait null literal ", lit.DebugString(), + " is of non-nullable type"); + } + + return Datum(MakeNullScalar(std::move(type_nullable.first))); + } + + default: + break; + } + + return Status::NotImplemented("conversion to arrow::Datum from Substrait literal ", + lit.DebugString()); +} + +namespace { +struct ScalarToProtoImpl { + Status Visit(const NullScalar& s) { return NotImplemented(s); } + + using Lit = substrait::Expression::Literal; + + template + Status Primitive(void (substrait::Expression::Literal::*set)(Arg), + const PrimitiveScalar& primitive_scalar) { + (lit_->*set)(static_cast(primitive_scalar.value)); + return Status::OK(); + } + + template + Status FromBuffer(void (substrait::Expression::Literal::*set)(std::string&&), + const ScalarWithBufferValue& scalar_with_buffer) { + (lit_->*set)(scalar_with_buffer.value->ToString()); + return Status::OK(); + } + + Status Visit(const BooleanScalar& s) { return Primitive(&Lit::set_boolean, s); } + + Status Visit(const Int8Scalar& s) { return Primitive(&Lit::set_i8, s); } + Status Visit(const Int16Scalar& s) { return Primitive(&Lit::set_i16, s); } + Status Visit(const Int32Scalar& s) { return Primitive(&Lit::set_i32, s); } + Status Visit(const Int64Scalar& s) { return Primitive(&Lit::set_i64, s); } + + Status Visit(const UInt8Scalar& s) { return NotImplemented(s); } + Status Visit(const UInt16Scalar& s) { return NotImplemented(s); } + Status Visit(const UInt32Scalar& s) { return NotImplemented(s); } + Status Visit(const UInt64Scalar& s) { return NotImplemented(s); } + + Status Visit(const HalfFloatScalar& s) { return NotImplemented(s); } + Status Visit(const FloatScalar& s) { return Primitive(&Lit::set_fp32, s); } + Status Visit(const DoubleScalar& s) { return Primitive(&Lit::set_fp64, s); } + + Status Visit(const StringScalar& s) { return FromBuffer(&Lit::set_string, s); } + Status Visit(const BinaryScalar& s) { return FromBuffer(&Lit::set_binary, s); } + + Status Visit(const FixedSizeBinaryScalar& s) { + return FromBuffer(&Lit::set_fixed_binary, s); + } + + Status Visit(const Date32Scalar& s) { return Primitive(&Lit::set_date, s); } + Status Visit(const Date64Scalar& s) { return NotImplemented(s); } + + Status Visit(const TimestampScalar& s) { + const auto& t = checked_cast(*s.type); + + if (t.unit() != TimeUnit::MICRO) return NotImplemented(s); + + if (t.timezone() == "") return Primitive(&Lit::set_timestamp, s); + + if (t.timezone() == TimestampTzTimezoneString()) { + return Primitive(&Lit::set_timestamp_tz, s); + } + + return NotImplemented(s); + } + + Status Visit(const Time32Scalar& s) { return NotImplemented(s); } + Status Visit(const Time64Scalar& s) { + if (checked_cast(*s.type).unit() != TimeUnit::MICRO) { + return NotImplemented(s); + } + return Primitive(&Lit::set_time, s); + } + + Status Visit(const MonthIntervalScalar& s) { return NotImplemented(s); } + Status Visit(const DayTimeIntervalScalar& s) { return NotImplemented(s); } + + Status Visit(const Decimal128Scalar& s) { + auto decimal = internal::make_unique(); + + auto decimal_type = checked_cast(s.type.get()); + decimal->set_precision(decimal_type->precision()); + decimal->set_scale(decimal_type->scale()); + + decimal->set_value(reinterpret_cast(s.value.native_endian_bytes()), + sizeof(Decimal128)); +#if !ARROW_LITTLE_ENDIAN + std::reverse(decimal->mutable_value()->begin(), decimal->mutable_value()->end()); +#endif + lit_->set_allocated_decimal(decimal.release()); + return Status::OK(); + } + + Status Visit(const Decimal256Scalar& s) { return NotImplemented(s); } + + Status Visit(const ListScalar& s) { + if (s.value->length() == 0) { + ARROW_ASSIGN_OR_RAISE(auto list_type, + ToProto(*s.type, /*nullable=*/true, ext_set_)); + lit_->set_allocated_empty_list(list_type->release_list()); + return Status::OK(); + } + + lit_->set_allocated_list(new Lit::List()); + + const auto& list_type = checked_cast(*s.type); + ARROW_ASSIGN_OR_RAISE( + auto element_type, + ToProto(*list_type.value_type(), list_type.value_field()->nullable(), ext_set_)); + + auto values = lit_->mutable_list()->mutable_values(); + values->Reserve(static_cast(s.value->length())); + + for (int64_t i = 0; i < s.value->length(); ++i) { + ARROW_ASSIGN_OR_RAISE(Datum list_element, s.value->GetScalar(i)); + ARROW_ASSIGN_OR_RAISE(auto lit, ToProto(list_element, ext_set_)); + values->AddAllocated(lit.release()); + } + return Status::OK(); + } + + Status Visit(const StructScalar& s) { + lit_->set_allocated_struct_(new Lit::Struct()); + + auto fields = lit_->mutable_struct_()->mutable_fields(); + fields->Reserve(static_cast(s.value.size())); + + for (Datum field : s.value) { + ARROW_ASSIGN_OR_RAISE(auto lit, ToProto(field, ext_set_)); + fields->AddAllocated(lit.release()); + } + return Status::OK(); + } + + Status Visit(const SparseUnionScalar& s) { return NotImplemented(s); } + Status Visit(const DenseUnionScalar& s) { return NotImplemented(s); } + Status Visit(const DictionaryScalar& s) { return NotImplemented(s); } + + Status Visit(const MapScalar& s) { + if (s.value->length() == 0) { + ARROW_ASSIGN_OR_RAISE(auto map_type, ToProto(*s.type, /*nullable=*/true, ext_set_)); + lit_->set_allocated_empty_map(map_type->release_map()); + return Status::OK(); + } + + lit_->set_allocated_map(new Lit::Map()); + + const auto& kv_arr = checked_cast(*s.value); + + auto key_values = lit_->mutable_map()->mutable_key_values(); + key_values->Reserve(static_cast(kv_arr.length())); + + for (int64_t i = 0; i < s.value->length(); ++i) { + auto kv = internal::make_unique(); + + ARROW_ASSIGN_OR_RAISE(Datum key_scalar, kv_arr.field(0)->GetScalar(i)); + ARROW_ASSIGN_OR_RAISE(auto key, ToProto(key_scalar, ext_set_)); + kv->set_allocated_key(key.release()); + + ARROW_ASSIGN_OR_RAISE(Datum value_scalar, kv_arr.field(1)->GetScalar(i)); + ARROW_ASSIGN_OR_RAISE(auto value, ToProto(value_scalar, ext_set_)); + kv->set_allocated_value(value.release()); + + key_values->AddAllocated(kv.release()); + } + return Status::OK(); + } + + Status Visit(const ExtensionScalar& s) { + if (UnwrapUuid(*s.type)) { + return FromBuffer(&Lit::set_uuid, + checked_cast(*s.value)); + } + + if (UnwrapFixedChar(*s.type)) { + return FromBuffer(&Lit::set_fixed_char, + checked_cast(*s.value)); + } + + if (auto length = UnwrapVarChar(*s.type)) { + auto var_char = internal::make_unique(); + var_char->set_length(*length); + var_char->set_value(checked_cast(*s.value).value->ToString()); + + lit_->set_allocated_var_char(var_char.release()); + return Status::OK(); + } + + auto GetPairOfInts = [&] { + const auto& array = *checked_cast(*s.value).value; + auto ints = checked_cast(array).raw_values(); + return std::make_pair(ints[0], ints[1]); + }; + + if (UnwrapIntervalYear(*s.type)) { + auto interval_year = internal::make_unique(); + interval_year->set_years(GetPairOfInts().first); + interval_year->set_months(GetPairOfInts().second); + + lit_->set_allocated_interval_year_to_month(interval_year.release()); + return Status::OK(); + } + + if (UnwrapIntervalDay(*s.type)) { + auto interval_day = internal::make_unique(); + interval_day->set_days(GetPairOfInts().first); + interval_day->set_seconds(GetPairOfInts().second); + + lit_->set_allocated_interval_day_to_second(interval_day.release()); + return Status::OK(); + } + + return NotImplemented(s); + } + + Status Visit(const FixedSizeListScalar& s) { return NotImplemented(s); } + Status Visit(const DurationScalar& s) { return NotImplemented(s); } + Status Visit(const LargeStringScalar& s) { return NotImplemented(s); } + Status Visit(const LargeBinaryScalar& s) { return NotImplemented(s); } + Status Visit(const LargeListScalar& s) { return NotImplemented(s); } + Status Visit(const MonthDayNanoIntervalScalar& s) { return NotImplemented(s); } + + Status NotImplemented(const Scalar& s) { + return Status::NotImplemented("conversion to substrait::Expression::Literal from ", + s.ToString()); + } + + Status operator()(const Scalar& scalar) { return VisitScalarInline(scalar, this); } + + substrait::Expression::Literal* lit_; + ExtensionSet* ext_set_; +}; +} // namespace + +Result> ToProto(const Datum& datum, + ExtensionSet* ext_set) { + if (!datum.is_scalar()) { + return Status::NotImplemented("representing ", datum.ToString(), + " as a substrait::Expression::Literal"); + } + + auto out = internal::make_unique(); + + if (datum.scalar()->is_valid) { + RETURN_NOT_OK((ScalarToProtoImpl{out.get(), ext_set})(*datum.scalar())); + } else { + ARROW_ASSIGN_OR_RAISE(auto type, ToProto(*datum.type(), /*nullable=*/true, ext_set)); + out->set_allocated_null(type.release()); + } + + return std::move(out); +} + +static Status AddChildToReferenceSegment( + substrait::Expression::ReferenceSegment& segment, + std::unique_ptr&& child) { + auto status = Status::Invalid("Attempt to add child to incomplete reference segment"); + switch (segment.reference_type_case()) { + case substrait::Expression::ReferenceSegment::kMapKey: { + auto map_key = segment.mutable_map_key(); + if (map_key->has_child()) { + status = AddChildToReferenceSegment(*map_key->mutable_child(), std::move(child)); + } else { + map_key->set_allocated_child(child.release()); + status = Status::OK(); + } + break; + } + case substrait::Expression::ReferenceSegment::kStructField: { + auto struct_field = segment.mutable_struct_field(); + if (struct_field->has_child()) { + status = + AddChildToReferenceSegment(*struct_field->mutable_child(), std::move(child)); + } else { + struct_field->set_allocated_child(child.release()); + status = Status::OK(); + } + break; + } + case substrait::Expression::ReferenceSegment::kListElement: { + auto list_element = segment.mutable_list_element(); + if (list_element->has_child()) { + status = + AddChildToReferenceSegment(*list_element->mutable_child(), std::move(child)); + } else { + list_element->set_allocated_child(child.release()); + status = Status::OK(); + } + break; + } + default: + break; + } + return status; +} + +// Indexes the given Substrait expression or root (if expr is empty) using the given +// ReferenceSegment. +static Result> MakeDirectReference( + std::unique_ptr&& expr, + std::unique_ptr&& ref_segment) { + // If expr is already a selection expression, add the index to its index stack. + if (expr && expr->has_selection() && expr->selection().has_direct_reference()) { + auto selection = expr->mutable_selection(); + auto root_ref_segment = selection->mutable_direct_reference(); + auto status = AddChildToReferenceSegment(*root_ref_segment, std::move(ref_segment)); + if (status.ok()) { + return std::move(expr); + } + } + + auto selection = internal::make_unique(); + selection->set_allocated_direct_reference(ref_segment.release()); + + if (expr && expr->rex_type_case() != substrait::Expression::REX_TYPE_NOT_SET) { + selection->set_allocated_expression(expr.release()); + } else { + selection->set_allocated_root_reference( + new substrait::Expression::FieldReference::RootReference()); + } + + auto out = internal::make_unique(); + out->set_allocated_selection(selection.release()); + return std::move(out); +} + +// Indexes the given Substrait struct-typed expression or root (if expr is empty) using +// the given field index. +static Result> MakeStructFieldReference( + std::unique_ptr&& expr, int field) { + auto struct_field = + internal::make_unique(); + struct_field->set_field(field); + + auto ref_segment = internal::make_unique(); + ref_segment->set_allocated_struct_field(struct_field.release()); + + return MakeDirectReference(std::move(expr), std::move(ref_segment)); +} + +// Indexes the given Substrait list-typed expression using the given offset. +static Result> MakeListElementReference( + std::unique_ptr&& expr, int offset) { + auto list_element = + internal::make_unique(); + list_element->set_offset(offset); + + auto ref_segment = internal::make_unique(); + ref_segment->set_allocated_list_element(list_element.release()); + + return MakeDirectReference(std::move(expr), std::move(ref_segment)); +} + +Result> ToProto(const compute::Expression& expr, + ExtensionSet* ext_set) { + if (!expr.IsBound()) { + return Status::Invalid("ToProto requires a bound Expression"); + } + + auto out = internal::make_unique(); + + if (auto datum = expr.literal()) { + ARROW_ASSIGN_OR_RAISE(auto literal, ToProto(*datum, ext_set)); + out->set_allocated_literal(literal.release()); + return std::move(out); + } + + if (auto param = expr.parameter()) { + // Special case of a nested StructField + DCHECK(!param->indices.empty()); + + for (int index : param->indices) { + ARROW_ASSIGN_OR_RAISE(out, MakeStructFieldReference(std::move(out), index)); + } + + return std::move(out); + } + + auto call = CallNotNull(expr); + + if (call->function_name == "case_when") { + auto conditions = call->arguments[0].call(); + if (conditions && conditions->function_name == "make_struct") { + // catch the special case of calls convertible to IfThen + auto if_then_ = internal::make_unique(); + + // don't try to convert argument 0 of the case_when; we have to convert the elements + // of make_struct individually + std::vector> arguments( + call->arguments.size() - 1); + for (size_t i = 1; i < call->arguments.size(); ++i) { + ARROW_ASSIGN_OR_RAISE(arguments[i - 1], ToProto(call->arguments[i], ext_set)); + } + + for (size_t i = 0; i < conditions->arguments.size(); ++i) { + ARROW_ASSIGN_OR_RAISE(auto cond_substrait, + ToProto(conditions->arguments[i], ext_set)); + auto clause = internal::make_unique(); + clause->set_allocated_if_(cond_substrait.release()); + clause->set_allocated_then(arguments[i].release()); + if_then_->mutable_ifs()->AddAllocated(clause.release()); + } + + if_then_->set_allocated_else_(arguments.back().release()); + + out->set_allocated_if_then(if_then_.release()); + return std::move(out); + } + } + + // the remaining function pattern matchers only convert the function itself, so we + // should be able to convert all its arguments first here + std::vector> arguments(call->arguments.size()); + for (size_t i = 0; i < arguments.size(); ++i) { + ARROW_ASSIGN_OR_RAISE(arguments[i], ToProto(call->arguments[i], ext_set)); + } + + if (call->function_name == "struct_field") { + // catch the special case of calls convertible to a StructField + out = std::move(arguments[0]); + for (int index : + checked_cast(*call->options) + .indices) { + ARROW_ASSIGN_OR_RAISE(out, MakeStructFieldReference(std::move(out), index)); + } + + return std::move(out); + } + + if (call->function_name == "list_element") { + // catch the special case of calls convertible to a ListElement + if (arguments[0]->has_selection() && + arguments[0]->selection().has_direct_reference()) { + if (arguments[1]->has_literal() && arguments[1]->literal().has_i32()) { + return MakeListElementReference(std::move(arguments[0]), + arguments[1]->literal().i32()); + } + } + } + + if (call->function_name == "if_else") { + // catch the special case of calls convertible to IfThen + auto if_clause = internal::make_unique(); + if_clause->set_allocated_if_(arguments[0].release()); + if_clause->set_allocated_then(arguments[1].release()); + + auto if_then = internal::make_unique(); + if_then->mutable_ifs()->AddAllocated(if_clause.release()); + if_then->set_allocated_else_(arguments[2].release()); + + out->set_allocated_if_then(if_then.release()); + return std::move(out); + } + + // other expression types dive into extensions immediately + ARROW_ASSIGN_OR_RAISE(auto anchor, ext_set->EncodeFunction(call->function_name)); + + auto scalar_fn = internal::make_unique(); + scalar_fn->set_function_reference(anchor); + scalar_fn->mutable_args()->Reserve(static_cast(arguments.size())); + for (auto& arg : arguments) { + scalar_fn->mutable_args()->AddAllocated(arg.release()); + } + + out->set_allocated_scalar_function(scalar_fn.release()); + return std::move(out); +} + +} // namespace engine +} // namespace arrow diff --git a/cpp/src/arrow/engine/substrait/expression_internal.h b/cpp/src/arrow/engine/substrait/expression_internal.h new file mode 100644 index 00000000000..e491fa674cf --- /dev/null +++ b/cpp/src/arrow/engine/substrait/expression_internal.h @@ -0,0 +1,49 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +// This API is EXPERIMENTAL. + +#pragma once + +#include + +#include "arrow/compute/type_fwd.h" +#include "arrow/engine/substrait/extension_set.h" +#include "arrow/engine/visibility.h" +#include "arrow/type_fwd.h" + +#include "substrait/expression.pb.h" // IWYU pragma: export + +namespace arrow { +namespace engine { + +ARROW_ENGINE_EXPORT +Result FromProto(const substrait::Expression&, const ExtensionSet&); + +ARROW_ENGINE_EXPORT +Result> ToProto(const compute::Expression&, + ExtensionSet*); + +ARROW_ENGINE_EXPORT +Result FromProto(const substrait::Expression::Literal&, const ExtensionSet&); + +ARROW_ENGINE_EXPORT +Result> ToProto(const Datum&, + ExtensionSet*); + +} // namespace engine +} // namespace arrow diff --git a/cpp/src/arrow/engine/substrait/extension_set.cc b/cpp/src/arrow/engine/substrait/extension_set.cc new file mode 100644 index 00000000000..fe43ab28799 --- /dev/null +++ b/cpp/src/arrow/engine/substrait/extension_set.cc @@ -0,0 +1,367 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#include "arrow/engine/substrait/extension_set.h" + +#include +#include + +#include "arrow/util/hash_util.h" +#include "arrow/util/hashing.h" +#include "arrow/util/string_view.h" + +namespace arrow { +namespace engine { +namespace { + +struct TypePtrHashEq { + template + size_t operator()(const Ptr& type) const { + return type->Hash(); + } + + template + bool operator()(const Ptr& l, const Ptr& r) const { + return *l == *r; + } +}; + +struct IdHashEq { + using Id = ExtensionSet::Id; + + size_t operator()(Id id) const { + constexpr ::arrow::internal::StringViewHash hash = {}; + auto out = static_cast(hash(id.uri)); + ::arrow::internal::hash_combine(out, hash(id.name)); + return out; + } + + bool operator()(Id l, Id r) const { return l.uri == r.uri && l.name == r.name; } +}; + +} // namespace + +// A builder used when creating a Substrait plan from an Arrow execution plan. In +// that situation we do not have a set of anchor values already defined so we keep +// a map of what Ids we have seen. +struct ExtensionSet::Impl { + void AddUri(util::string_view uri, ExtensionSet* self) { + if (uris_.find(uri) != uris_.end()) return; + + self->uris_.push_back(uri); + uris_.insert(self->uris_.back()); // lookup helper's keys should reference memory + // owned by this ExtensionSet + } + + Status CheckHasUri(util::string_view uri) { + if (uris_.find(uri) != uris_.end()) return Status::OK(); + + return Status::Invalid( + "Uri ", uri, + " was referenced by an extension but was not declared in the ExtensionSet."); + } + + uint32_t EncodeType(ExtensionIdRegistry::TypeRecord type_record, ExtensionSet* self) { + // note: at this point we're guaranteed to have an Id which points to memory owned by + // the set's registry. + AddUri(type_record.id.uri, self); + auto it_success = + types_.emplace(type_record.id, static_cast(types_.size())); + + if (it_success.second) { + self->types_.push_back( + {type_record.id, type_record.type, type_record.is_variation}); + } + + return it_success.first->second; + } + + uint32_t EncodeFunction(Id id, util::string_view function_name, ExtensionSet* self) { + // note: at this point we're guaranteed to have an Id which points to memory owned by + // the set's registry. + AddUri(id.uri, self); + auto it_success = functions_.emplace(id, static_cast(functions_.size())); + + if (it_success.second) { + self->functions_.push_back({id, function_name}); + } + + return it_success.first->second; + } + + std::unordered_set uris_; + std::unordered_map types_, functions_; +}; + +ExtensionSet::ExtensionSet(ExtensionIdRegistry* registry) + : registry_(registry), impl_(new Impl(), [](Impl* impl) { delete impl; }) {} + +Result ExtensionSet::Make(std::vector uris, + std::vector type_ids, + std::vector type_is_variation, + std::vector function_ids, + ExtensionIdRegistry* registry) { + ExtensionSet set; + set.registry_ = registry; + + // TODO(bkietz) move this into the registry as registry->OwnUris(&uris) or so + std::unordered_set + uris_owned_by_registry; + for (util::string_view uri : registry->Uris()) { + uris_owned_by_registry.insert(uri); + } + + for (auto& uri : uris) { + if (uri.empty()) continue; + auto it = uris_owned_by_registry.find(uri); + if (it == uris_owned_by_registry.end()) { + return Status::KeyError("Uri '", uri, "' not found in registry"); + } + uri = *it; // Ensure uris point into the registry's memory + set.impl_->AddUri(*it, &set); + } + + if (type_ids.size() != type_is_variation.size()) { + return Status::Invalid("Received ", type_ids.size(), " type ids but a ", + type_is_variation.size(), "-long is_variation vector"); + } + + set.types_.resize(type_ids.size()); + + for (size_t i = 0; i < type_ids.size(); ++i) { + if (type_ids[i].empty()) continue; + RETURN_NOT_OK(set.impl_->CheckHasUri(type_ids[i].uri)); + + if (auto rec = registry->GetType(type_ids[i], type_is_variation[i])) { + set.types_[i] = {rec->id, rec->type, rec->is_variation}; + continue; + } + return Status::Invalid("Type", (type_is_variation[i] ? " variation" : ""), " ", + type_ids[i].uri, "#", type_ids[i].name, " not found"); + } + + set.functions_.resize(function_ids.size()); + + for (size_t i = 0; i < function_ids.size(); ++i) { + if (function_ids[i].empty()) continue; + RETURN_NOT_OK(set.impl_->CheckHasUri(function_ids[i].uri)); + + if (auto rec = registry->GetFunction(function_ids[i])) { + set.functions_[i] = {rec->id, rec->function_name}; + continue; + } + return Status::Invalid("Function ", function_ids[i].uri, "#", type_ids[i].name, + " not found"); + } + + set.uris_ = std::move(uris); + + return std::move(set); +} + +Result ExtensionSet::DecodeType(uint32_t anchor) const { + if (anchor >= types_.size() || types_[anchor].id.empty()) { + return Status::Invalid("User defined type reference ", anchor, + " did not have a corresponding anchor in the extension set"); + } + return types_[anchor]; +} + +Result ExtensionSet::EncodeType(const DataType& type) { + if (auto rec = registry_->GetType(type)) { + return impl_->EncodeType(*rec, this); + } + return Status::KeyError("type ", type.ToString(), " not found in the registry"); +} + +Result ExtensionSet::DecodeFunction(uint32_t anchor) const { + if (anchor >= functions_.size() || functions_[anchor].id.empty()) { + return Status::Invalid("User defined function reference ", anchor, + " did not have a corresponding anchor in the extension set"); + } + return functions_[anchor]; +} + +Result ExtensionSet::EncodeFunction(util::string_view function_name) { + if (auto rec = registry_->GetFunction(function_name)) { + return impl_->EncodeFunction(rec->id, rec->function_name, this); + } + return Status::KeyError("function ", function_name, " not found in the registry"); +} + +template +const int* GetIndex(const KeyToIndex& key_to_index, const Key& key) { + auto it = key_to_index.find(key); + if (it == key_to_index.end()) return nullptr; + return &it->second; +} + +ExtensionIdRegistry* default_extension_id_registry() { + static struct Impl : ExtensionIdRegistry { + Impl() { + struct TypeName { + std::shared_ptr type; + util::string_view name; + }; + + // The type (variation) mappings listed below need to be kept in sync + // with the YAML at substrait/format/extension_types.yaml manually; + // see ARROW-15535. + for (TypeName e : { + TypeName{uint8(), "u8"}, + TypeName{uint16(), "u16"}, + TypeName{uint32(), "u32"}, + TypeName{uint64(), "u64"}, + TypeName{float16(), "fp16"}, + }) { + DCHECK_OK(RegisterType({kArrowExtTypesUri, e.name}, std::move(e.type), + /*is_variation=*/true)); + } + + for (TypeName e : { + TypeName{null(), "null"}, + TypeName{month_interval(), "interval_month"}, + TypeName{day_time_interval(), "interval_day_milli"}, + TypeName{month_day_nano_interval(), "interval_month_day_nano"}, + }) { + DCHECK_OK(RegisterType({kArrowExtTypesUri, e.name}, std::move(e.type), + /*is_variation=*/false)); + } + + // TODO: this is just a placeholder right now. We'll need a YAML file for + // all functions (and prototypes) that Arrow provides that are relevant + // for Substrait, and include mappings for all of them here. See + // ARROW-15535. + for (util::string_view name : { + "add", + }) { + DCHECK_OK(RegisterFunction({kArrowExtTypesUri, name}, name.to_string())); + } + } + + std::vector Uris() const override { + return {uris_.begin(), uris_.end()}; + } + + util::optional GetType(const DataType& type) const override { + if (auto index = GetIndex(type_to_index_, &type)) { + return TypeRecord{type_ids_[*index], types_[*index], type_is_variation_[*index]}; + } + return {}; + } + + util::optional GetType(Id id, bool is_variation) const override { + if (auto index = + GetIndex(is_variation ? variation_id_to_index_ : id_to_index_, id)) { + return TypeRecord{type_ids_[*index], types_[*index], type_is_variation_[*index]}; + } + return {}; + } + + Status RegisterType(Id id, std::shared_ptr type, + bool is_variation) override { + DCHECK_EQ(type_ids_.size(), types_.size()); + DCHECK_EQ(type_ids_.size(), type_is_variation_.size()); + + Id copied_id{*uris_.emplace(id.uri.to_string()).first, + *names_.emplace(id.name.to_string()).first}; + + auto index = static_cast(type_ids_.size()); + + auto* id_to_index = is_variation ? &variation_id_to_index_ : &id_to_index_; + auto it_success = id_to_index->emplace(copied_id, index); + + if (!it_success.second) { + return Status::Invalid("Type id was already registered"); + } + + if (!type_to_index_.emplace(type.get(), index).second) { + id_to_index->erase(it_success.first); + return Status::Invalid("Type was already registered"); + } + + type_ids_.push_back(copied_id); + types_.push_back(std::move(type)); + type_is_variation_.push_back(is_variation); + return Status::OK(); + } + + util::optional GetFunction( + util::string_view arrow_function_name) const override { + if (auto index = GetIndex(function_name_to_index_, arrow_function_name)) { + return FunctionRecord{function_ids_[*index], *function_name_ptrs_[*index]}; + } + return {}; + } + + util::optional GetFunction(Id id) const override { + if (auto index = GetIndex(function_id_to_index_, id)) { + return FunctionRecord{function_ids_[*index], *function_name_ptrs_[*index]}; + } + return {}; + } + + Status RegisterFunction(Id id, std::string arrow_function_name) override { + DCHECK_EQ(function_ids_.size(), function_name_ptrs_.size()); + + Id copied_id{*uris_.emplace(id.uri.to_string()).first, + *names_.emplace(id.name.to_string()).first}; + + const std::string& copied_function_name{ + *function_names_.emplace(std::move(arrow_function_name)).first}; + + auto index = static_cast(function_ids_.size()); + + auto it_success = function_id_to_index_.emplace(copied_id, index); + + if (!it_success.second) { + return Status::Invalid("Function id was already registered"); + } + + if (!function_name_to_index_.emplace(copied_function_name, index).second) { + function_id_to_index_.erase(it_success.first); + return Status::Invalid("Function name was already registered"); + } + + function_name_ptrs_.push_back(&copied_function_name); + function_ids_.push_back(copied_id); + return Status::OK(); + } + + // owning storage of uris, names, (arrow::)function_names, types + // note that storing strings like this is safe since references into an + // unordered_set are not invalidated on insertion + std::unordered_set uris_, names_, function_names_; + DataTypeVector types_; + std::vector type_is_variation_; + + // non-owning lookup helpers + std::vector type_ids_, function_ids_; + std::unordered_map id_to_index_, variation_id_to_index_; + std::unordered_map type_to_index_; + + std::vector function_name_ptrs_; + std::unordered_map function_id_to_index_; + std::unordered_map + function_name_to_index_; + } impl_; + + return &impl_; +} + +} // namespace engine +} // namespace arrow diff --git a/cpp/src/arrow/engine/substrait/extension_set.h b/cpp/src/arrow/engine/substrait/extension_set.h new file mode 100644 index 00000000000..2eb44822375 --- /dev/null +++ b/cpp/src/arrow/engine/substrait/extension_set.h @@ -0,0 +1,240 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +// This API is EXPERIMENTAL. + +#pragma once + +#include + +#include "arrow/engine/visibility.h" +#include "arrow/type_fwd.h" +#include "arrow/util/optional.h" +#include "arrow/util/string_view.h" + +namespace arrow { +namespace engine { + +/// Substrait identifies functions and custom data types using a (uri, name) pair. +/// +/// This registry is a bidirectional mapping between Substrait IDs and their corresponding +/// Arrow counterparts (arrow::DataType and function names in a function registry) +/// +/// Substrait extension types and variations must be registered with their corresponding +/// arrow::DataType before they can be used! +/// +/// Conceptually this can be thought of as two pairs of `unordered_map`s. One pair to +/// go back and forth between Substrait ID and arrow::DataType and another pair to go +/// back and forth between Substrait ID and Arrow function names. +/// +/// Unlike an ExtensionSet this registry is not created automatically when consuming +/// Substrait plans and must be configured ahead of time (although there is a default +/// instance). +class ARROW_ENGINE_EXPORT ExtensionIdRegistry { + public: + /// All uris registered in this ExtensionIdRegistry + virtual std::vector Uris() const = 0; + + struct Id { + util::string_view uri, name; + + bool empty() const { return uri.empty() && name.empty(); } + }; + + /// \brief A mapping between a Substrait ID and an arrow::DataType + struct TypeRecord { + Id id; + const std::shared_ptr& type; + bool is_variation; + }; + virtual util::optional GetType(const DataType&) const = 0; + virtual util::optional GetType(Id, bool is_variation) const = 0; + virtual Status RegisterType(Id, std::shared_ptr, bool is_variation) = 0; + + /// \brief A mapping between a Substrait ID and an Arrow function + /// + /// Note: At the moment we identify functions solely by the name + /// of the function in the function registry. + /// + /// TODO(ARROW-15582) some functions will not be simple enough to convert without access + /// to their arguments/options. For example is_in embeds the set in options rather than + /// using an argument: + /// is_in(x, SetLookupOptions(set)) <-> (k...Uri, "is_in")(x, set) + /// + /// ... for another example, depending on the value of the first argument to + /// substrait::add it either corresponds to arrow::add or arrow::add_checked + struct FunctionRecord { + Id id; + const std::string& function_name; + }; + virtual util::optional GetFunction(Id) const = 0; + virtual util::optional GetFunction( + util::string_view arrow_function_name) const = 0; + virtual Status RegisterFunction(Id, std::string arrow_function_name) = 0; +}; + +constexpr util::string_view kArrowExtTypesUri = + "https://github.com/apache/arrow/blob/master/format/substrait/" + "extension_types.yaml"; + +/// A default registry with all supported functions and data types registered +/// +/// Note: Function support is currently very minimal, see ARROW-15538 +ARROW_ENGINE_EXPORT ExtensionIdRegistry* default_extension_id_registry(); + +/// \brief A set of extensions used within a plan +/// +/// Each time an extension is used within a Substrait plan the extension +/// must be included in an extension set that is defined at the root of the +/// plan. +/// +/// The plan refers to a specific extension using an "anchor" which is an +/// arbitrary integer invented by the producer that has no meaning beyond a +/// plan but which should be consistent within a plan. +/// +/// To support serialization and deserialization this type serves as a +/// bidirectional map between Substrait ID and "anchor"s. +/// +/// When deserializing a Substrait plan the extension set should be extracted +/// after the plan has been converted from Protobuf and before the plan +/// is converted to an execution plan. +/// +/// The extension set can be kept and reused during serialization if a perfect +/// round trip is required. If serialization is not needed or round tripping +/// is not required then the extension set can be safely discarded after the +/// plan has been converted into an execution plan. +/// +/// When converting an execution plan into a Substrait plan an extension set +/// can be automatically generated or a previously generated extension set can +/// be used. +/// +/// ExtensionSet does not own strings; it only refers to strings in an +/// ExtensionIdRegistry. +class ARROW_ENGINE_EXPORT ExtensionSet { + public: + using Id = ExtensionIdRegistry::Id; + + struct FunctionRecord { + Id id; + util::string_view name; + }; + + struct TypeRecord { + Id id; + std::shared_ptr type; + bool is_variation; + }; + + /// Construct an empty ExtensionSet to be populated during serialization. + explicit ExtensionSet(ExtensionIdRegistry* = default_extension_id_registry()); + ARROW_DEFAULT_MOVE_AND_ASSIGN(ExtensionSet); + + /// Construct an ExtensionSet with explicit extension ids for efficient referencing + /// during deserialization. Note that input vectors need not be densely packed; an empty + /// (default constructed) Id may be used as a placeholder to indicate an unused + /// _anchor/_reference. This factory will be used to wrap the extensions declared in a + /// substrait::Plan before deserializing the plan's relations. + /// + /// Views will be replaced with equivalent views pointing to memory owned by the + /// registry. + /// + /// Note: This is an advanced operation. The order of the ids, types, and functions + /// must match the anchor numbers chosen for a plan. + /// + /// An extension set should instead be created using + /// arrow::engine::GetExtensionSetFromPlan + static Result Make( + std::vector uris, std::vector type_ids, + std::vector type_is_variation, std::vector function_ids, + ExtensionIdRegistry* = default_extension_id_registry()); + + // index in these vectors == value of _anchor/_reference fields + /// TODO(ARROW-15583) this assumes that _anchor/_references won't be huge, which is not + /// guaranteed. Could it be? + const std::vector& uris() const { return uris_; } + + /// \brief Returns a data type given an anchor + /// + /// This is used when converting a Substrait plan to an Arrow execution plan. + /// + /// If the anchor does not exist in this extension set an error will be returned. + Result DecodeType(uint32_t anchor) const; + + /// \brief Returns the number of custom type records in this extension set + /// + /// Note: the types are currently stored as a sparse vector, so this may return a value + /// larger than the actual number of types. This behavior may change in the future; see + /// ARROW-15583. + std::size_t num_types() const { return types_.size(); } + + /// \brief Lookup the anchor for a given type + /// + /// This operation is used when converting an Arrow execution plan to a Substrait plan. + /// If the type has been previously encoded then the same anchor value will returned. + /// + /// If the type has not been previously encoded then a new anchor value will be created. + /// + /// If the type does not exist in the extension id registry then an error will be + /// returned. + /// + /// \return An anchor that can be used to refer to the type within a plan + Result EncodeType(const DataType& type); + + /// \brief Returns a function given an anchor + /// + /// This is used when converting a Substrait plan to an Arrow execution plan. + /// + /// If the anchor does not exist in this extension set an error will be returned. + Result DecodeFunction(uint32_t anchor) const; + + /// \brief Lookup the anchor for a given function + /// + /// This operation is used when converting an Arrow execution plan to a Substrait plan. + /// If the function has been previously encoded then the same anchor value will be + /// returned. + /// + /// If the function has not been previously encoded then a new anchor value will be + /// created. + /// + /// If the function name is not in the extension id registry then an error will be + /// returned. + /// + /// \return An anchor that can be used to refer to the function within a plan + Result EncodeFunction(util::string_view function_name); + + /// \brief Returns the number of custom functions in this extension set + /// + /// Note: the functions are currently stored as a sparse vector, so this may return a + /// value larger than the actual number of functions. This behavior may change in the + /// future; see ARROW-15583. + std::size_t num_functions() const { return functions_.size(); } + + private: + ExtensionIdRegistry* registry_; + /// The subset of extension registry URIs referenced by this extension set + std::vector uris_; + std::vector types_; + + std::vector functions_; + + // pimpl pattern to hide lookup details + struct Impl; + std::unique_ptr impl_; +}; + +} // namespace engine +} // namespace arrow diff --git a/cpp/src/arrow/engine/substrait/extension_types.cc b/cpp/src/arrow/engine/substrait/extension_types.cc new file mode 100644 index 00000000000..b8fd191b3fd --- /dev/null +++ b/cpp/src/arrow/engine/substrait/extension_types.cc @@ -0,0 +1,147 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#include "arrow/engine/substrait/extension_types.h" + +#include "arrow/engine/simple_extension_type_internal.h" +#include "arrow/util/hashing.h" +#include "arrow/util/string_view.h" + +namespace arrow { + +using internal::DataMember; +using internal::MakeProperties; + +namespace engine { +namespace { + +constexpr util::string_view kUuidExtensionName = "uuid"; +struct UuidExtensionParams {}; +std::shared_ptr UuidGetStorage(const UuidExtensionParams&) { + return fixed_size_binary(16); +} +static auto kUuidExtensionParamsProperties = MakeProperties(); + +using UuidType = SimpleExtensionType; + +constexpr util::string_view kFixedCharExtensionName = "fixed_char"; +struct FixedCharExtensionParams { + int32_t length; +}; +std::shared_ptr FixedCharGetStorage(const FixedCharExtensionParams& params) { + return fixed_size_binary(params.length); +} +static auto kFixedCharExtensionParamsProperties = + MakeProperties(DataMember("length", &FixedCharExtensionParams::length)); + +using FixedCharType = + SimpleExtensionType; + +constexpr util::string_view kVarCharExtensionName = "varchar"; +struct VarCharExtensionParams { + int32_t length; +}; +std::shared_ptr VarCharGetStorage(const VarCharExtensionParams&) { + return utf8(); +} +static auto kVarCharExtensionParamsProperties = + MakeProperties(DataMember("length", &VarCharExtensionParams::length)); + +using VarCharType = + SimpleExtensionType; + +constexpr util::string_view kIntervalYearExtensionName = "interval_year"; +struct IntervalYearExtensionParams {}; +std::shared_ptr IntervalYearGetStorage(const IntervalYearExtensionParams&) { + return fixed_size_list(int32(), 2); +} +static auto kIntervalYearExtensionParamsProperties = MakeProperties(); + +using IntervalYearType = + SimpleExtensionType; + +constexpr util::string_view kIntervalDayExtensionName = "interval_day"; +struct IntervalDayExtensionParams {}; +std::shared_ptr IntervalDayGetStorage(const IntervalDayExtensionParams&) { + return fixed_size_list(int32(), 2); +} +static auto kIntervalDayExtensionParamsProperties = MakeProperties(); + +using IntervalDayType = + SimpleExtensionType; + +} // namespace + +std::shared_ptr uuid() { return UuidType::Make({}); } + +std::shared_ptr fixed_char(int32_t length) { + return FixedCharType::Make({length}); +} + +std::shared_ptr varchar(int32_t length) { return VarCharType::Make({length}); } + +std::shared_ptr interval_year() { return IntervalYearType::Make({}); } + +std::shared_ptr interval_day() { return IntervalDayType::Make({}); } + +bool UnwrapUuid(const DataType& t) { + if (UuidType::GetIf(t)) { + return true; + } + return false; +} + +util::optional UnwrapFixedChar(const DataType& t) { + if (auto params = FixedCharType::GetIf(t)) { + return params->length; + } + return util::nullopt; +} + +util::optional UnwrapVarChar(const DataType& t) { + if (auto params = VarCharType::GetIf(t)) { + return params->length; + } + return util::nullopt; +} + +bool UnwrapIntervalYear(const DataType& t) { + if (IntervalYearType::GetIf(t)) { + return true; + } + return false; +} + +bool UnwrapIntervalDay(const DataType& t) { + if (IntervalDayType::GetIf(t)) { + return true; + } + return false; +} + +} // namespace engine +} // namespace arrow diff --git a/cpp/src/arrow/engine/substrait/extension_types.h b/cpp/src/arrow/engine/substrait/extension_types.h new file mode 100644 index 00000000000..e689e94722e --- /dev/null +++ b/cpp/src/arrow/engine/substrait/extension_types.h @@ -0,0 +1,82 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +// This API is EXPERIMENTAL. + +#pragma once + +#include + +#include "arrow/buffer.h" +#include "arrow/compute/function.h" +#include "arrow/engine/visibility.h" +#include "arrow/type_fwd.h" +#include "arrow/util/optional.h" +#include "arrow/util/string_view.h" + +namespace arrow { +namespace engine { + +// arrow::ExtensionTypes are provided to wrap uuid, fixed_char, varchar, interval_year, +// and interval_day which are first-class types in substrait but do not appear in +// the arrow type system. +// +// Note that these are not automatically registered with arrow::RegisterExtensionType(), +// which means among other things that serialization of these types to IPC would fail. + +/// fixed_size_binary(16) for storing Universally Unique IDentifiers +ARROW_ENGINE_EXPORT +std::shared_ptr uuid(); + +/// fixed_size_binary(length) constrained to contain only valid UTF-8 +ARROW_ENGINE_EXPORT +std::shared_ptr fixed_char(int32_t length); + +/// utf8() constrained to be shorter than `length` +ARROW_ENGINE_EXPORT +std::shared_ptr varchar(int32_t length); + +/// fixed_size_list(int32(), 2) storing a number of [years, months] +ARROW_ENGINE_EXPORT +std::shared_ptr interval_year(); + +/// fixed_size_list(int32(), 2) storing a number of [days, seconds] +ARROW_ENGINE_EXPORT +std::shared_ptr interval_day(); + +/// Return true if t is Uuid, otherwise false +ARROW_ENGINE_EXPORT +bool UnwrapUuid(const DataType&); + +/// Return FixedChar length if t is FixedChar, otherwise nullopt +ARROW_ENGINE_EXPORT +util::optional UnwrapFixedChar(const DataType&); + +/// Return Varchar (max) length if t is VarChar, otherwise nullopt +ARROW_ENGINE_EXPORT +util::optional UnwrapVarChar(const DataType& t); + +/// Return true if t is IntervalYear, otherwise false +ARROW_ENGINE_EXPORT +bool UnwrapIntervalYear(const DataType&); + +/// Return true if t is IntervalDay, otherwise false +ARROW_ENGINE_EXPORT +bool UnwrapIntervalDay(const DataType&); + +} // namespace engine +} // namespace arrow diff --git a/cpp/src/arrow/engine/substrait/plan_internal.cc b/cpp/src/arrow/engine/substrait/plan_internal.cc new file mode 100644 index 00000000000..8ffbcc005da --- /dev/null +++ b/cpp/src/arrow/engine/substrait/plan_internal.cc @@ -0,0 +1,161 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#include "arrow/engine/substrait/plan_internal.h" + +#include "arrow/result.h" +#include "arrow/util/hashing.h" +#include "arrow/util/logging.h" +#include "arrow/util/make_unique.h" +#include "arrow/util/unreachable.h" + +namespace arrow { + +using internal::checked_cast; + +namespace engine { + +namespace internal { +using ::arrow::internal::make_unique; +} // namespace internal + +Status AddExtensionSetToPlan(const ExtensionSet& ext_set, substrait::Plan* plan) { + plan->clear_extension_uris(); + + std::unordered_map map; + + auto uris = plan->mutable_extension_uris(); + uris->Reserve(static_cast(ext_set.uris().size())); + for (uint32_t anchor = 0; anchor < ext_set.uris().size(); ++anchor) { + auto uri = ext_set.uris()[anchor]; + if (uri.empty()) continue; + + auto ext_uri = internal::make_unique(); + ext_uri->set_uri(uri.to_string()); + ext_uri->set_extension_uri_anchor(anchor); + uris->AddAllocated(ext_uri.release()); + + map[uri] = anchor; + } + + auto extensions = plan->mutable_extensions(); + extensions->Reserve(static_cast(ext_set.num_types() + ext_set.num_functions())); + + using ExtDecl = substrait::extensions::SimpleExtensionDeclaration; + + for (uint32_t anchor = 0; anchor < ext_set.num_types(); ++anchor) { + ARROW_ASSIGN_OR_RAISE(auto type_record, ext_set.DecodeType(anchor)); + if (type_record.id.empty()) continue; + + auto ext_decl = internal::make_unique(); + + if (type_record.is_variation) { + auto type_var = internal::make_unique(); + type_var->set_extension_uri_reference(map[type_record.id.uri]); + type_var->set_type_variation_anchor(anchor); + type_var->set_name(type_record.id.name.to_string()); + ext_decl->set_allocated_extension_type_variation(type_var.release()); + } else { + auto type = internal::make_unique(); + type->set_extension_uri_reference(map[type_record.id.uri]); + type->set_type_anchor(anchor); + type->set_name(type_record.id.name.to_string()); + ext_decl->set_allocated_extension_type(type.release()); + } + + extensions->AddAllocated(ext_decl.release()); + } + + for (uint32_t anchor = 0; anchor < ext_set.num_functions(); ++anchor) { + ARROW_ASSIGN_OR_RAISE(auto function_record, ext_set.DecodeFunction(anchor)); + if (function_record.id.empty()) continue; + + auto fn = internal::make_unique(); + fn->set_extension_uri_reference(map[function_record.id.uri]); + fn->set_function_anchor(anchor); + fn->set_name(function_record.id.name.to_string()); + + auto ext_decl = internal::make_unique(); + ext_decl->set_allocated_extension_function(fn.release()); + extensions->AddAllocated(ext_decl.release()); + } + + return Status::OK(); +} + +namespace { +template +void SetElement(size_t i, const Element& element, std::vector* vector) { + DCHECK_LE(i, 1 << 20); + if (i >= vector->size()) { + vector->resize(i + 1); + } + (*vector)[i] = static_cast(element); +} +} // namespace + +Result GetExtensionSetFromPlan(const substrait::Plan& plan, + ExtensionIdRegistry* registry) { + std::vector uris; + for (const auto& uri : plan.extension_uris()) { + SetElement(uri.extension_uri_anchor(), uri.uri(), &uris); + } + + // NOTE: it's acceptable to use views to memory owned by plan; ExtensionSet::Make + // will only store views to memory owned by registry. + + using Id = ExtensionSet::Id; + + std::vector type_ids, function_ids; + std::vector type_is_variation; + for (const auto& ext : plan.extensions()) { + switch (ext.mapping_type_case()) { + case substrait::extensions::SimpleExtensionDeclaration::kExtensionTypeVariation: { + const auto& type_var = ext.extension_type_variation(); + util::string_view uri = uris[type_var.extension_uri_reference()]; + SetElement(type_var.type_variation_anchor(), Id{uri, type_var.name()}, &type_ids); + SetElement(type_var.type_variation_anchor(), true, &type_is_variation); + break; + } + + case substrait::extensions::SimpleExtensionDeclaration::kExtensionType: { + const auto& type = ext.extension_type(); + util::string_view uri = uris[type.extension_uri_reference()]; + SetElement(type.type_anchor(), Id{uri, type.name()}, &type_ids); + SetElement(type.type_anchor(), false, &type_is_variation); + break; + } + + case substrait::extensions::SimpleExtensionDeclaration::kExtensionFunction: { + const auto& fn = ext.extension_function(); + util::string_view uri = uris[fn.extension_uri_reference()]; + SetElement(fn.function_anchor(), Id{uri, fn.name()}, &function_ids); + break; + } + + default: + Unreachable(); + } + } + + return ExtensionSet::Make(std::move(uris), std::move(type_ids), + std::move(type_is_variation), std::move(function_ids), + registry); +} + +} // namespace engine +} // namespace arrow diff --git a/cpp/src/arrow/engine/substrait/plan_internal.h b/cpp/src/arrow/engine/substrait/plan_internal.h new file mode 100644 index 00000000000..0ab06ece1ce --- /dev/null +++ b/cpp/src/arrow/engine/substrait/plan_internal.h @@ -0,0 +1,55 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +// This API is EXPERIMENTAL. + +#pragma once + +#include "arrow/engine/substrait/extension_set.h" +#include "arrow/engine/visibility.h" +#include "arrow/type_fwd.h" + +#include "substrait/plan.pb.h" // IWYU pragma: export + +namespace arrow { +namespace engine { + +/// \brief Replaces the extension information of a Substrait Plan message with the given +/// extension set, such that the anchors defined therein can be used in the rest of the +/// plan. +/// +/// \param[in] ext_set the extension set to copy the extension information from +/// \param[in,out] plan the Substrait plan message that is to be updated +/// \return success or failure +ARROW_ENGINE_EXPORT +Status AddExtensionSetToPlan(const ExtensionSet& ext_set, substrait::Plan* plan); + +/// \brief Interprets the extension information of a Substrait Plan message into an +/// ExtensionSet. +/// +/// Note that the extension registry is not currently mutated, but may be in the future. +/// +/// \param[in] plan the plan message to take the information from +/// \param[in,out] registry registry defining which Arrow types and compute functions +/// correspond to Substrait's URI/name pairs +ARROW_ENGINE_EXPORT +Result GetExtensionSetFromPlan( + const substrait::Plan& plan, + ExtensionIdRegistry* registry = default_extension_id_registry()); + +} // namespace engine +} // namespace arrow diff --git a/cpp/src/arrow/engine/substrait/relation_internal.cc b/cpp/src/arrow/engine/substrait/relation_internal.cc new file mode 100644 index 00000000000..ae2244c87f5 --- /dev/null +++ b/cpp/src/arrow/engine/substrait/relation_internal.cc @@ -0,0 +1,193 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#include "arrow/engine/substrait/relation_internal.h" + +#include "arrow/compute/api_scalar.h" +#include "arrow/compute/exec/options.h" +#include "arrow/dataset/file_parquet.h" +#include "arrow/dataset/plan.h" +#include "arrow/dataset/scanner.h" +#include "arrow/engine/substrait/expression_internal.h" +#include "arrow/engine/substrait/type_internal.h" +#include "arrow/filesystem/localfs.h" + +namespace arrow { +namespace engine { + +template +Status CheckRelCommon(const RelMessage& rel) { + if (rel.has_common()) { + if (rel.common().has_emit()) { + return Status::NotImplemented("substrait::RelCommon::Emit"); + } + if (rel.common().has_hint()) { + return Status::NotImplemented("substrait::RelCommon::Hint"); + } + if (rel.common().has_advanced_extension()) { + return Status::NotImplemented("substrait::RelCommon::advanced_extension"); + } + } + if (rel.has_advanced_extension()) { + return Status::NotImplemented("substrait AdvancedExtensions"); + } + return Status::OK(); +} + +Result FromProto(const substrait::Rel& rel, + const ExtensionSet& ext_set) { + static bool dataset_init = false; + if (!dataset_init) { + dataset_init = true; + dataset::internal::Initialize(); + } + + switch (rel.rel_type_case()) { + case substrait::Rel::RelTypeCase::kRead: { + const auto& read = rel.read(); + RETURN_NOT_OK(CheckRelCommon(read)); + + ARROW_ASSIGN_OR_RAISE(auto base_schema, FromProto(read.base_schema(), ext_set)); + + auto scan_options = std::make_shared(); + + if (read.has_filter()) { + ARROW_ASSIGN_OR_RAISE(scan_options->filter, FromProto(read.filter(), ext_set)); + } + + if (read.has_projection()) { + // NOTE: scan_options->projection is not used by the scanner and thus can't be + // used for this + return Status::NotImplemented("substrait::ReadRel::projection"); + } + + if (!read.has_local_files()) { + return Status::NotImplemented( + "substrait::ReadRel with read_type other than LocalFiles"); + } + + if (read.local_files().has_advanced_extension()) { + return Status::NotImplemented( + "substrait::ReadRel::LocalFiles::advanced_extension"); + } + + auto format = std::make_shared(); + auto filesystem = std::make_shared(); + std::vector> fragments; + + for (const auto& item : read.local_files().items()) { + if (!item.has_uri_file()) { + return Status::NotImplemented( + "substrait::ReadRel::LocalFiles::FileOrFiles with " + "path_type other than uri_file"); + } + + if (item.format() != + substrait::ReadRel::LocalFiles::FileOrFiles::FILE_FORMAT_PARQUET) { + return Status::NotImplemented( + "substrait::ReadRel::LocalFiles::FileOrFiles::format " + "other than FILE_FORMAT_PARQUET"); + } + + if (!util::string_view{item.uri_file()}.starts_with("file:///")) { + return Status::NotImplemented( + "substrait::ReadRel::LocalFiles::FileOrFiles::uri_file " + "with other than local filesystem (file:///)"); + } + auto path = item.uri_file().substr(7); + + if (item.partition_index() != 0) { + return Status::NotImplemented( + "non-default substrait::ReadRel::LocalFiles::FileOrFiles::partition_index"); + } + + if (item.start() != 0) { + return Status::NotImplemented( + "non-default substrait::ReadRel::LocalFiles::FileOrFiles::start offset"); + } + + if (item.length() != 0) { + return Status::NotImplemented( + "non-default substrait::ReadRel::LocalFiles::FileOrFiles::length"); + } + + ARROW_ASSIGN_OR_RAISE(auto fragment, format->MakeFragment(dataset::FileSource{ + std::move(path), filesystem})); + fragments.push_back(std::move(fragment)); + } + + ARROW_ASSIGN_OR_RAISE( + auto ds, dataset::FileSystemDataset::Make( + std::move(base_schema), /*root_partition=*/compute::literal(true), + std::move(format), std::move(filesystem), std::move(fragments))); + + return compute::Declaration{ + "scan", dataset::ScanNodeOptions{std::move(ds), std::move(scan_options)}}; + } + + case substrait::Rel::RelTypeCase::kFilter: { + const auto& filter = rel.filter(); + RETURN_NOT_OK(CheckRelCommon(filter)); + + if (!filter.has_input()) { + return Status::Invalid("substrait::FilterRel with no input relation"); + } + ARROW_ASSIGN_OR_RAISE(auto input, FromProto(filter.input(), ext_set)); + + if (!filter.has_condition()) { + return Status::Invalid("substrait::FilterRel with no condition expression"); + } + ARROW_ASSIGN_OR_RAISE(auto condition, FromProto(filter.condition(), ext_set)); + + return compute::Declaration::Sequence({ + std::move(input), + {"filter", compute::FilterNodeOptions{std::move(condition)}}, + }); + } + + case substrait::Rel::RelTypeCase::kProject: { + const auto& project = rel.project(); + RETURN_NOT_OK(CheckRelCommon(project)); + + if (!project.has_input()) { + return Status::Invalid("substrait::ProjectRel with no input relation"); + } + ARROW_ASSIGN_OR_RAISE(auto input, FromProto(project.input(), ext_set)); + + std::vector expressions; + for (const auto& expr : project.expressions()) { + expressions.emplace_back(); + ARROW_ASSIGN_OR_RAISE(expressions.back(), FromProto(expr, ext_set)); + } + + return compute::Declaration::Sequence({ + std::move(input), + {"project", compute::ProjectNodeOptions{std::move(expressions)}}, + }); + } + + default: + break; + } + + return Status::NotImplemented( + "conversion to arrow::compute::Declaration from Substrait relation ", + rel.DebugString()); +} + +} // namespace engine +} // namespace arrow diff --git a/cpp/src/arrow/engine/substrait/relation_internal.h b/cpp/src/arrow/engine/substrait/relation_internal.h new file mode 100644 index 00000000000..d9b90f50779 --- /dev/null +++ b/cpp/src/arrow/engine/substrait/relation_internal.h @@ -0,0 +1,37 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +// This API is EXPERIMENTAL. + +#pragma once + +#include "arrow/compute/exec/exec_plan.h" +#include "arrow/engine/substrait/extension_types.h" +#include "arrow/engine/substrait/serde.h" +#include "arrow/engine/visibility.h" +#include "arrow/type_fwd.h" + +#include "substrait/relations.pb.h" // IWYU pragma: export + +namespace arrow { +namespace engine { + +ARROW_ENGINE_EXPORT +Result FromProto(const substrait::Rel&, const ExtensionSet&); + +} // namespace engine +} // namespace arrow diff --git a/cpp/src/arrow/engine/substrait/serde.cc b/cpp/src/arrow/engine/substrait/serde.cc new file mode 100644 index 00000000000..ea916d86757 --- /dev/null +++ b/cpp/src/arrow/engine/substrait/serde.cc @@ -0,0 +1,232 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#include "arrow/engine/substrait/serde.h" + +#include "arrow/engine/substrait/expression_internal.h" +#include "arrow/engine/substrait/plan_internal.h" +#include "arrow/engine/substrait/relation_internal.h" +#include "arrow/engine/substrait/type_internal.h" +#include "arrow/util/string_view.h" + +#include +#include +#include +#include +#include +#include + +namespace arrow { +namespace engine { + +Status ParseFromBufferImpl(const Buffer& buf, const std::string& full_name, + google::protobuf::Message* message) { + google::protobuf::io::ArrayInputStream buf_stream{buf.data(), + static_cast(buf.size())}; + + if (message->ParseFromZeroCopyStream(&buf_stream)) { + return Status::OK(); + } + return Status::IOError("ParseFromZeroCopyStream failed for ", full_name); +} + +template +Result ParseFromBuffer(const Buffer& buf) { + Message message; + ARROW_RETURN_NOT_OK( + ParseFromBufferImpl(buf, Message::descriptor()->full_name(), &message)); + return message; +} + +Result DeserializeRelation(const Buffer& buf, + const ExtensionSet& ext_set) { + ARROW_ASSIGN_OR_RAISE(auto rel, ParseFromBuffer(buf)); + return FromProto(rel, ext_set); +} + +Result> DeserializePlan( + const Buffer& buf, const ConsumerFactory& consumer_factory, + ExtensionSet* ext_set_out) { + ARROW_ASSIGN_OR_RAISE(auto plan, ParseFromBuffer(buf)); + + ARROW_ASSIGN_OR_RAISE(auto ext_set, GetExtensionSetFromPlan(plan)); + + std::vector sink_decls; + for (const substrait::PlanRel& plan_rel : plan.relations()) { + if (plan_rel.has_root()) { + return Status::NotImplemented("substrait::PlanRel with custom output field names"); + } + ARROW_ASSIGN_OR_RAISE(auto decl, FromProto(plan_rel.rel(), ext_set)); + + // pipe each relation into a consuming_sink node + auto sink_decl = compute::Declaration::Sequence({ + std::move(decl), + {"consuming_sink", compute::ConsumingSinkNodeOptions{consumer_factory()}}, + }); + sink_decls.push_back(std::move(sink_decl)); + } + + if (ext_set_out) { + *ext_set_out = std::move(ext_set); + } + return sink_decls; +} + +Result> DeserializeSchema(const Buffer& buf, + const ExtensionSet& ext_set) { + ARROW_ASSIGN_OR_RAISE(auto named_struct, ParseFromBuffer(buf)); + return FromProto(named_struct, ext_set); +} + +Result> SerializeSchema(const Schema& schema, + ExtensionSet* ext_set) { + ARROW_ASSIGN_OR_RAISE(auto named_struct, ToProto(schema, ext_set)); + std::string serialized = named_struct->SerializeAsString(); + return Buffer::FromString(std::move(serialized)); +} + +Result> DeserializeType(const Buffer& buf, + const ExtensionSet& ext_set) { + ARROW_ASSIGN_OR_RAISE(auto type, ParseFromBuffer(buf)); + ARROW_ASSIGN_OR_RAISE(auto type_nullable, FromProto(type, ext_set)); + return std::move(type_nullable.first); +} + +Result> SerializeType(const DataType& type, + ExtensionSet* ext_set) { + ARROW_ASSIGN_OR_RAISE(auto st_type, ToProto(type, /*nullable=*/true, ext_set)); + std::string serialized = st_type->SerializeAsString(); + return Buffer::FromString(std::move(serialized)); +} + +Result DeserializeExpression(const Buffer& buf, + const ExtensionSet& ext_set) { + ARROW_ASSIGN_OR_RAISE(auto expr, ParseFromBuffer(buf)); + return FromProto(expr, ext_set); +} + +Result> SerializeExpression(const compute::Expression& expr, + ExtensionSet* ext_set) { + ARROW_ASSIGN_OR_RAISE(auto st_expr, ToProto(expr, ext_set)); + std::string serialized = st_expr->SerializeAsString(); + return Buffer::FromString(std::move(serialized)); +} + +namespace internal { + +template +static Status CheckMessagesEquivalent(const Buffer& l_buf, const Buffer& r_buf) { + ARROW_ASSIGN_OR_RAISE(auto l, ParseFromBuffer(l_buf)); + ARROW_ASSIGN_OR_RAISE(auto r, ParseFromBuffer(r_buf)); + + using google::protobuf::util::MessageDifferencer; + + std::string out; + google::protobuf::io::StringOutputStream out_stream{&out}; + MessageDifferencer::StreamReporter reporter{&out_stream}; + + MessageDifferencer differencer; + differencer.set_message_field_comparison(MessageDifferencer::EQUIVALENT); + differencer.ReportDifferencesTo(&reporter); + + if (differencer.Compare(l, r)) { + return Status::OK(); + } + return Status::Invalid("Messages were not equivalent: ", out); +} + +Status CheckMessagesEquivalent(util::string_view message_name, const Buffer& l_buf, + const Buffer& r_buf) { + if (message_name == "Type") { + return CheckMessagesEquivalent(l_buf, r_buf); + } + + if (message_name == "NamedStruct") { + return CheckMessagesEquivalent(l_buf, r_buf); + } + + if (message_name == "Schema") { + return Status::Invalid( + "There is no substrait message named Schema. The substrait message type which " + "corresponds to Schemas is NamedStruct"); + } + + if (message_name == "Expression") { + return CheckMessagesEquivalent(l_buf, r_buf); + } + + if (message_name == "Rel") { + return CheckMessagesEquivalent(l_buf, r_buf); + } + + if (message_name == "Relation") { + return Status::Invalid( + "There is no substrait message named Relation. You probably meant \"Rel\""); + } + + return Status::Invalid("Unsupported message name ", message_name, + " for CheckMessagesEquivalent"); +} + +inline google::protobuf::util::TypeResolver* GetGeneratedTypeResolver() { + static std::unique_ptr type_resolver; + if (!type_resolver) { + type_resolver.reset(google::protobuf::util::NewTypeResolverForDescriptorPool( + /*url_prefix=*/"", google::protobuf::DescriptorPool::generated_pool())); + } + return type_resolver.get(); +} + +Result> SubstraitFromJSON(util::string_view type_name, + util::string_view json) { + std::string type_url = "/substrait." + type_name.to_string(); + + google::protobuf::io::ArrayInputStream json_stream{json.data(), + static_cast(json.size())}; + + std::string out; + google::protobuf::io::StringOutputStream out_stream{&out}; + + auto status = google::protobuf::util::JsonToBinaryStream( + GetGeneratedTypeResolver(), type_url, &json_stream, &out_stream); + + if (!status.ok()) { + return Status::Invalid("JsonToBinaryStream returned ", status); + } + return Buffer::FromString(std::move(out)); +} + +Result SubstraitToJSON(util::string_view type_name, const Buffer& buf) { + std::string type_url = "/substrait." + type_name.to_string(); + + google::protobuf::io::ArrayInputStream buf_stream{buf.data(), + static_cast(buf.size())}; + + std::string out; + google::protobuf::io::StringOutputStream out_stream{&out}; + + auto status = google::protobuf::util::BinaryToJsonStream( + GetGeneratedTypeResolver(), type_url, &buf_stream, &out_stream); + if (!status.ok()) { + return Status::Invalid("BinaryToJsonStream returned ", status); + } + return out; +} + +} // namespace internal +} // namespace engine +} // namespace arrow diff --git a/cpp/src/arrow/engine/substrait/serde.h b/cpp/src/arrow/engine/substrait/serde.h new file mode 100644 index 00000000000..9e63a1befb5 --- /dev/null +++ b/cpp/src/arrow/engine/substrait/serde.h @@ -0,0 +1,168 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +// This API is EXPERIMENTAL. + +#pragma once + +#include +#include +#include + +#include "arrow/buffer.h" +#include "arrow/compute/exec/exec_plan.h" +#include "arrow/compute/exec/options.h" +#include "arrow/engine/substrait/extension_set.h" +#include "arrow/engine/visibility.h" +#include "arrow/result.h" +#include "arrow/util/string_view.h" + +namespace arrow { +namespace engine { + +/// Factory function type for generating the node that consumes the batches produced by +/// each toplevel Substrait relation when deserializing a Substrait Plan. +using ConsumerFactory = std::function()>; + +/// \brief Deserializes a Substrait Plan message to a list of ExecNode declarations +/// +/// \param[in] buf a buffer containing the protobuf serialization of a Substrait Plan +/// message +/// \param[in] consumer_factory factory function for generating the node that consumes +/// the batches produced by each toplevel Substrait relation +/// \param[out] ext_set if non-null, the extension mapping used by the Substrait Plan is +/// returned here. +/// \return a vector of ExecNode declarations, one for each toplevel relation in the +/// Substrait Plan +ARROW_ENGINE_EXPORT Result> DeserializePlan( + const Buffer& buf, const ConsumerFactory& consumer_factory, + ExtensionSet* ext_set = NULLPTR); + +/// \brief Deserializes a Substrait Type message to the corresponding Arrow type +/// +/// \param[in] buf a buffer containing the protobuf serialization of a Substrait Type +/// message +/// \param[in] ext_set the extension mapping to use, normally provided by the +/// surrounding Plan message +/// \return the corresponding Arrow data type +ARROW_ENGINE_EXPORT +Result> DeserializeType(const Buffer& buf, + const ExtensionSet& ext_set); + +/// \brief Serializes an Arrow type to a Substrait Type message +/// +/// \param[in] type the Arrow data type to serialize +/// \param[in,out] ext_set the extension mapping to use; may be updated to add a +/// mapping for the given type +/// \return a buffer containing the protobuf serialization of the corresponding Substrait +/// Type message +ARROW_ENGINE_EXPORT +Result> SerializeType(const DataType& type, + ExtensionSet* ext_set); + +/// \brief Deserializes a Substrait NamedStruct message to an Arrow schema +/// +/// \param[in] buf a buffer containing the protobuf serialization of a Substrait +/// NamedStruct message +/// \param[in] ext_set the extension mapping to use, normally provided by the +/// surrounding Plan message +/// \return the corresponding Arrow schema +ARROW_ENGINE_EXPORT +Result> DeserializeSchema(const Buffer& buf, + const ExtensionSet& ext_set); + +/// \brief Serializes an Arrow schema to a Substrait NamedStruct message +/// +/// \param[in] schema the Arrow schema to serialize +/// \param[in,out] ext_set the extension mapping to use; may be updated to add +/// mappings for the types used in the schema +/// \return a buffer containing the protobuf serialization of the corresponding Substrait +/// NamedStruct message +ARROW_ENGINE_EXPORT +Result> SerializeSchema(const Schema& schema, + ExtensionSet* ext_set); + +/// \brief Deserializes a Substrait Expression message to a compute expression +/// +/// \param[in] buf a buffer containing the protobuf serialization of a Substrait +/// Expression message +/// \param[in] ext_set the extension mapping to use, normally provided by the +/// surrounding Plan message +/// \return the corresponding Arrow compute expression +ARROW_ENGINE_EXPORT +Result DeserializeExpression(const Buffer& buf, + const ExtensionSet& ext_set); + +/// \brief Serializes an Arrow compute expression to a Substrait Expression message +/// +/// \param[in] expr the Arrow compute expression to serialize +/// \param[in,out] ext_set the extension mapping to use; may be updated to add +/// mappings for the types used in the expression +/// \return a buffer containing the protobuf serialization of the corresponding Substrait +/// Expression message +ARROW_ENGINE_EXPORT +Result> SerializeExpression(const compute::Expression& expr, + ExtensionSet* ext_set); + +/// \brief Deserializes a Substrait Rel (relation) message to an ExecNode declaration +/// +/// \param[in] buf a buffer containing the protobuf serialization of a Substrait +/// Rel message +/// \param[in] ext_set the extension mapping to use, normally provided by the +/// surrounding Plan message +/// \return the corresponding ExecNode declaration +ARROW_ENGINE_EXPORT Result DeserializeRelation( + const Buffer& buf, const ExtensionSet& ext_set); + +namespace internal { + +/// \brief Checks whether two protobuf serializations of a particular Substrait message +/// type are equivalent +/// +/// Note that a binary comparison of the two buffers is insufficient. One reason for this +/// is that the fields of a message can be specified in any order in the serialization. +/// +/// \param[in] message_name the name of the Substrait message type to check +/// \param[in] l_buf buffer containing the first protobuf serialization to compare +/// \param[in] r_buf buffer containing the second protobuf serialization to compare +/// \return success if equivalent, failure if not +ARROW_ENGINE_EXPORT +Status CheckMessagesEquivalent(util::string_view message_name, const Buffer& l_buf, + const Buffer& r_buf); + +/// \brief Utility function to convert a JSON serialization of a Substrait message to +/// its binary serialization +/// +/// \param[in] type_name the name of the Substrait message type to convert +/// \param[in] json the JSON string to convert +/// \return a buffer filled with the binary protobuf serialization of message +ARROW_ENGINE_EXPORT +Result> SubstraitFromJSON(util::string_view type_name, + util::string_view json); + +/// \brief Utility function to convert a binary protobuf serialization of a Substrait +/// message to JSON +/// +/// \param[in] type_name the name of the Substrait message type to convert +/// \param[in] buf the buffer containing the binary protobuf serialization of the message +/// \return a JSON string representing the message +ARROW_ENGINE_EXPORT +Result SubstraitToJSON(util::string_view type_name, const Buffer& buf); + +} // namespace internal +} // namespace engine +} // namespace arrow diff --git a/cpp/src/arrow/engine/substrait/serde_test.cc b/cpp/src/arrow/engine/substrait/serde_test.cc new file mode 100644 index 00000000000..6af5d71521f --- /dev/null +++ b/cpp/src/arrow/engine/substrait/serde_test.cc @@ -0,0 +1,728 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#include "arrow/engine/substrait/serde.h" + +#include +#include +#include +#include + +#include "arrow/compute/exec/expression_internal.h" +#include "arrow/dataset/file_base.h" +#include "arrow/dataset/scanner.h" +#include "arrow/engine/substrait/extension_types.h" +#include "arrow/testing/gtest_util.h" +#include "arrow/testing/matchers.h" +#include "arrow/util/key_value_metadata.h" + +using testing::ElementsAre; +using testing::Eq; +using testing::HasSubstr; + +namespace arrow { + +using internal::checked_cast; + +namespace engine { + +const std::shared_ptr kBoringSchema = schema({ + field("bool", boolean()), + field("i8", int8()), + field("i32", int32()), + field("i32_req", int32(), /*nullable=*/false), + field("u32", uint32()), + field("i64", int64()), + field("f32", float32()), + field("f32_req", float32(), /*nullable=*/false), + field("f64", float64()), + field("date64", date64()), + field("str", utf8()), + field("list_i32", list(int32())), + field("struct", struct_({ + field("i32", int32()), + field("str", utf8()), + field("struct_i32_str", + struct_({field("i32", int32()), field("str", utf8())})), + })), + field("list_struct", list(struct_({ + field("i32", int32()), + field("str", utf8()), + field("struct_i32_str", struct_({field("i32", int32()), + field("str", utf8())})), + }))), + field("dict_str", dictionary(int32(), utf8())), + field("dict_i32", dictionary(int32(), int32())), + field("ts_ns", timestamp(TimeUnit::NANO)), +}); + +std::shared_ptr StripFieldNames(std::shared_ptr type) { + if (type->id() == Type::STRUCT) { + FieldVector fields(type->num_fields()); + for (int i = 0; i < type->num_fields(); ++i) { + fields[i] = type->field(i)->WithName(""); + } + return struct_(std::move(fields)); + } + + if (type->id() == Type::LIST) { + return list(type->field(0)->WithName("")); + } + + return type; +} + +inline compute::Expression UseBoringRefs(const compute::Expression& expr) { + if (expr.literal()) return expr; + + if (auto ref = expr.field_ref()) { + return compute::field_ref(*ref->FindOne(*kBoringSchema)); + } + + auto modified_call = *CallNotNull(expr); + for (auto& arg : modified_call.arguments) { + arg = UseBoringRefs(arg); + } + return compute::Expression{std::move(modified_call)}; +} + +TEST(Substrait, SupportedTypes) { + auto ExpectEq = [](util::string_view json, std::shared_ptr expected_type) { + ARROW_SCOPED_TRACE(json); + + ExtensionSet empty; + ASSERT_OK_AND_ASSIGN(auto buf, internal::SubstraitFromJSON("Type", json)); + ASSERT_OK_AND_ASSIGN(auto type, DeserializeType(*buf, empty)); + + EXPECT_EQ(*type, *expected_type); + + ASSERT_OK_AND_ASSIGN(auto serialized, SerializeType(*type, &empty)); + EXPECT_EQ(empty.num_types(), 0); + + // FIXME chokes on NULLABILITY_UNSPECIFIED + // EXPECT_THAT(internal::CheckMessagesEquivalent("Type", *buf, *serialized), Ok()); + + ASSERT_OK_AND_ASSIGN(auto roundtripped, DeserializeType(*serialized, empty)); + + EXPECT_EQ(*roundtripped, *expected_type); + }; + + ExpectEq(R"({"bool": {}})", boolean()); + + ExpectEq(R"({"i8": {}})", int8()); + ExpectEq(R"({"i16": {}})", int16()); + ExpectEq(R"({"i32": {}})", int32()); + ExpectEq(R"({"i64": {}})", int64()); + + ExpectEq(R"({"fp32": {}})", float32()); + ExpectEq(R"({"fp64": {}})", float64()); + + ExpectEq(R"({"string": {}})", utf8()); + ExpectEq(R"({"binary": {}})", binary()); + + ExpectEq(R"({"timestamp": {}})", timestamp(TimeUnit::MICRO)); + ExpectEq(R"({"date": {}})", date32()); + ExpectEq(R"({"time": {}})", time64(TimeUnit::MICRO)); + ExpectEq(R"({"timestamp_tz": {}})", timestamp(TimeUnit::MICRO, "UTC")); + ExpectEq(R"({"interval_year": {}})", interval_year()); + ExpectEq(R"({"interval_day": {}})", interval_day()); + + ExpectEq(R"({"uuid": {}})", uuid()); + + ExpectEq(R"({"fixed_char": {"length": 32}})", fixed_char(32)); + ExpectEq(R"({"varchar": {"length": 1024}})", varchar(1024)); + ExpectEq(R"({"fixed_binary": {"length": 32}})", fixed_size_binary(32)); + + ExpectEq(R"({"decimal": {"precision": 27, "scale": 5}})", decimal128(27, 5)); + + ExpectEq(R"({"struct": { + "types": [ + {"i64": {}}, + {"list": {"type": {"string":{}} }} + ] + }})", + struct_({ + field("", int64()), + field("", list(utf8())), + })); + + ExpectEq(R"({"map": { + "key": {"string":{"nullability": "NULLABILITY_REQUIRED"}}, + "value": {"string":{}} + }})", + map(utf8(), field("", utf8()), false)); +} + +TEST(Substrait, SupportedExtensionTypes) { + ExtensionSet ext_set; + + for (auto expected_type : { + null(), + uint8(), + uint16(), + uint32(), + uint64(), + }) { + auto anchor = ext_set.num_types(); + + EXPECT_THAT(ext_set.EncodeType(*expected_type), ResultWith(Eq(anchor))); + ASSERT_OK_AND_ASSIGN( + auto buf, + internal::SubstraitFromJSON( + "Type", "{\"user_defined_type_reference\": " + std::to_string(anchor) + "}")); + + ASSERT_OK_AND_ASSIGN(auto type, DeserializeType(*buf, ext_set)); + EXPECT_EQ(*type, *expected_type); + + auto size = ext_set.num_types(); + ASSERT_OK_AND_ASSIGN(auto serialized, SerializeType(*type, &ext_set)); + EXPECT_EQ(ext_set.num_types(), size) << "was already added to the set above"; + + ASSERT_OK_AND_ASSIGN(auto roundtripped, DeserializeType(*serialized, ext_set)); + EXPECT_EQ(*roundtripped, *expected_type); + } +} + +TEST(Substrait, NamedStruct) { + ExtensionSet ext_set; + + ASSERT_OK_AND_ASSIGN(auto buf, internal::SubstraitFromJSON("NamedStruct", R"({ + "struct": { + "types": [ + {"i64": {}}, + {"list": {"type": {"string":{}} }}, + {"struct": { + "types": [ + {"fp32": {"nullability": "NULLABILITY_REQUIRED"}}, + {"string": {}} + ] + }}, + {"list": {"type": {"string":{}} }}, + ] + }, + "names": ["a", "b", "c", "d", "e", "f"] + })")); + ASSERT_OK_AND_ASSIGN(auto schema, DeserializeSchema(*buf, ext_set)); + Schema expected_schema({ + field("a", int64()), + field("b", list(utf8())), + field("c", struct_({ + field("d", float32(), /*nullable=*/false), + field("e", utf8()), + })), + field("f", list(utf8())), + }); + EXPECT_EQ(*schema, expected_schema); + + ASSERT_OK_AND_ASSIGN(auto serialized, SerializeSchema(*schema, &ext_set)); + ASSERT_OK_AND_ASSIGN(auto roundtripped, DeserializeSchema(*serialized, ext_set)); + EXPECT_EQ(*roundtripped, expected_schema); + + // too few names + ASSERT_OK_AND_ASSIGN(buf, internal::SubstraitFromJSON("NamedStruct", R"({ + "struct": {"types": [{"i32": {}}, {"i32": {}}, {"i32": {}}]}, + "names": [] + })")); + EXPECT_THAT(DeserializeSchema(*buf, ext_set), Raises(StatusCode::Invalid)); + + // too many names + ASSERT_OK_AND_ASSIGN(buf, internal::SubstraitFromJSON("NamedStruct", R"({ + "struct": {"types": []}, + "names": ["a", "b", "c"] + })")); + EXPECT_THAT(DeserializeSchema(*buf, ext_set), Raises(StatusCode::Invalid)); + + // no schema metadata allowed + EXPECT_THAT(SerializeSchema(Schema({}, key_value_metadata({{"ext", "yes"}})), &ext_set), + Raises(StatusCode::Invalid)); + + // no schema metadata allowed + EXPECT_THAT( + SerializeSchema(Schema({field("a", int32(), key_value_metadata({{"ext", "yes"}}))}), + &ext_set), + Raises(StatusCode::Invalid)); +} + +TEST(Substrait, NoEquivalentArrowType) { + ASSERT_OK_AND_ASSIGN(auto buf, internal::SubstraitFromJSON( + "Type", R"({"user_defined_type_reference": 99})")); + ExtensionSet empty; + ASSERT_THAT( + DeserializeType(*buf, empty), + Raises(StatusCode::Invalid, HasSubstr("did not have a corresponding anchor"))); +} + +TEST(Substrait, NoEquivalentSubstraitType) { + for (auto type : { + date64(), + timestamp(TimeUnit::SECOND), + timestamp(TimeUnit::NANO), + timestamp(TimeUnit::MICRO, "New York"), + time32(TimeUnit::SECOND), + time32(TimeUnit::MILLI), + time64(TimeUnit::NANO), + + decimal256(76, 67), + + sparse_union({field("i8", int8()), field("f32", float32())}), + dense_union({field("i8", int8()), field("f32", float32())}), + dictionary(int32(), utf8()), + + fixed_size_list(float16(), 3), + + duration(TimeUnit::MICRO), + + large_utf8(), + large_binary(), + large_list(utf8()), + }) { + ARROW_SCOPED_TRACE(type->ToString()); + ExtensionSet set; + EXPECT_THAT(SerializeType(*type, &set), Raises(StatusCode::NotImplemented)); + } +} + +TEST(Substrait, SupportedLiterals) { + auto ExpectEq = [](util::string_view json, Datum expected_value) { + ARROW_SCOPED_TRACE(json); + + ASSERT_OK_AND_ASSIGN( + auto buf, internal::SubstraitFromJSON("Expression", + "{\"literal\":" + json.to_string() + "}")); + ExtensionSet ext_set; + ASSERT_OK_AND_ASSIGN(auto expr, DeserializeExpression(*buf, ext_set)); + + ASSERT_TRUE(expr.literal()); + ASSERT_THAT(*expr.literal(), DataEq(expected_value)); + + ASSERT_OK_AND_ASSIGN(auto serialized, SerializeExpression(expr, &ext_set)); + EXPECT_EQ(ext_set.num_functions(), 0); // shouldn't need extensions for core literals + + ASSERT_OK_AND_ASSIGN(auto roundtripped, DeserializeExpression(*serialized, ext_set)); + + ASSERT_TRUE(roundtripped.literal()); + ASSERT_THAT(*roundtripped.literal(), DataEq(expected_value)); + }; + + ExpectEq(R"({"boolean": true})", Datum(true)); + + ExpectEq(R"({"i8": 34})", Datum(int8_t(34))); + ExpectEq(R"({"i16": 34})", Datum(int16_t(34))); + ExpectEq(R"({"i32": 34})", Datum(int32_t(34))); + ExpectEq(R"({"i64": "34"})", Datum(int64_t(34))); + + ExpectEq(R"({"fp32": 3.5})", Datum(3.5F)); + ExpectEq(R"({"fp64": 7.125})", Datum(7.125)); + + ExpectEq(R"({"string": "hello world"})", Datum("hello world")); + + ExpectEq(R"({"binary": "enp6"})", BinaryScalar(Buffer::FromString("zzz"))); + + ExpectEq(R"({"timestamp": "579"})", TimestampScalar(579, TimeUnit::MICRO)); + + ExpectEq(R"({"date": "5"})", Date32Scalar(5)); + + ExpectEq(R"({"time": "64"})", Time64Scalar(64, TimeUnit::MICRO)); + + ExpectEq(R"({"interval_year_to_month": {"years": 34, "months": 3}})", + ExtensionScalar(FixedSizeListScalar(ArrayFromJSON(int32(), "[34, 3]")), + interval_year())); + + ExpectEq(R"({"interval_day_to_second": {"days": 34, "seconds": 3}})", + ExtensionScalar(FixedSizeListScalar(ArrayFromJSON(int32(), "[34, 3]")), + interval_day())); + + ExpectEq(R"({"fixed_char": "zzz"})", + ExtensionScalar( + FixedSizeBinaryScalar(Buffer::FromString("zzz"), fixed_size_binary(3)), + fixed_char(3))); + + ExpectEq(R"({"var_char": {"value": "zzz", "length": 1024}})", + ExtensionScalar(StringScalar("zzz"), varchar(1024))); + + ExpectEq(R"({"fixed_binary": "enp6"})", + FixedSizeBinaryScalar(Buffer::FromString("zzz"), fixed_size_binary(3))); + + ExpectEq( + R"({"decimal": {"value": "0gKWSQAAAAAAAAAAAAAAAA==", "precision": 27, "scale": 5}})", + Decimal128Scalar(Decimal128("123456789.0"), decimal128(27, 5))); + + ExpectEq(R"({"timestamp_tz": "579"})", TimestampScalar(579, TimeUnit::MICRO, "UTC")); + + // special case for empty lists + ExpectEq(R"({"empty_list": {"type": {"i32": {}}}})", + ScalarFromJSON(list(int32()), "[]")); + + ExpectEq(R"({"struct": { + "fields": [ + {"i64": "32"}, + {"list": {"values": [ + {"string": "hello"}, + {"string": "world"} + ]}} + ] + }})", + ScalarFromJSON(struct_({ + field("", int64()), + field("", list(utf8())), + }), + R"([32, ["hello", "world"]])")); + + // check null scalars: + for (auto type : { + boolean(), + + int8(), + int64(), + + timestamp(TimeUnit::MICRO), + interval_year(), + + struct_({ + field("", int64()), + field("", list(utf8())), + }), + }) { + ExtensionSet set; + ASSERT_OK_AND_ASSIGN(auto buf, SerializeType(*type, &set)); + ASSERT_OK_AND_ASSIGN(auto json, internal::SubstraitToJSON("Type", *buf)); + ExpectEq("{\"null\": " + json + "}", MakeNullScalar(type)); + } +} + +TEST(Substrait, CannotDeserializeLiteral) { + ExtensionSet ext_set; + + // Invalid: missing List.element_type + ASSERT_OK_AND_ASSIGN( + auto buf, internal::SubstraitFromJSON("Expression", + R"({"literal": {"list": {"values": []}}})")); + EXPECT_THAT(DeserializeExpression(*buf, ext_set), Raises(StatusCode::Invalid)); + + // Invalid: required null literal + ASSERT_OK_AND_ASSIGN( + buf, + internal::SubstraitFromJSON( + "Expression", + R"({"literal": {"null": {"bool": {"nullability": "NULLABILITY_REQUIRED"}}}})")); + EXPECT_THAT(DeserializeExpression(*buf, ext_set), Raises(StatusCode::Invalid)); + + // no equivalent arrow scalar + // FIXME no way to specify scalars of user_defined_type_reference +} + +TEST(Substrait, FieldRefRoundTrip) { + for (FieldRef ref : { + // by name + FieldRef("i32"), + FieldRef("ts_ns"), + FieldRef("struct"), + + // by index + FieldRef(0), + FieldRef(1), + FieldRef(kBoringSchema->num_fields() - 1), + FieldRef(kBoringSchema->GetFieldIndex("struct")), + + // nested + FieldRef("struct", "i32"), + FieldRef("struct", "struct_i32_str", "i32"), + FieldRef(kBoringSchema->GetFieldIndex("struct"), 1), + }) { + ARROW_SCOPED_TRACE(ref.ToString()); + ASSERT_OK_AND_ASSIGN(auto expr, compute::field_ref(ref).Bind(*kBoringSchema)); + + ExtensionSet ext_set; + ASSERT_OK_AND_ASSIGN(auto serialized, SerializeExpression(expr, &ext_set)); + EXPECT_EQ(ext_set.num_functions(), + 0); // shouldn't need extensions for core field references + ASSERT_OK_AND_ASSIGN(auto roundtripped, DeserializeExpression(*serialized, ext_set)); + ASSERT_TRUE(roundtripped.field_ref()); + + ASSERT_OK_AND_ASSIGN(auto expected, ref.FindOne(*kBoringSchema)); + ASSERT_OK_AND_ASSIGN(auto actual, roundtripped.field_ref()->FindOne(*kBoringSchema)); + EXPECT_EQ(actual.indices(), expected.indices()); + } +} + +TEST(Substrait, RecursiveFieldRef) { + FieldRef ref("struct", "str"); + + ARROW_SCOPED_TRACE(ref.ToString()); + ASSERT_OK_AND_ASSIGN(auto expr, compute::field_ref(ref).Bind(*kBoringSchema)); + ExtensionSet ext_set; + ASSERT_OK_AND_ASSIGN(auto serialized, SerializeExpression(expr, &ext_set)); + ASSERT_OK_AND_ASSIGN(auto expected, internal::SubstraitFromJSON("Expression", R"({ + "selection": { + "directReference": { + "structField": { + "field": 12, + "child": { + "structField": { + "field": 1 + } + } + } + }, + "rootReference": {} + } + })")); + ASSERT_OK(internal::CheckMessagesEquivalent("Expression", *serialized, *expected)); +} + +TEST(Substrait, FieldRefsInExpressions) { + ASSERT_OK_AND_ASSIGN(auto expr, + compute::call("struct_field", + {compute::call("if_else", + { + compute::literal(true), + compute::field_ref("struct"), + compute::field_ref("struct"), + })}, + compute::StructFieldOptions({0})) + .Bind(*kBoringSchema)); + + ExtensionSet ext_set; + ASSERT_OK_AND_ASSIGN(auto serialized, SerializeExpression(expr, &ext_set)); + ASSERT_OK_AND_ASSIGN(auto expected, internal::SubstraitFromJSON("Expression", R"({ + "selection": { + "directReference": { + "structField": { + "field": 0 + } + }, + "expression": { + "if_then": { + "ifs": [ + { + "if": {"literal": {"boolean": true}}, + "then": {"selection": {"directReference": {"structField": {"field": 12}}}} + } + ], + "else": {"selection": {"directReference": {"structField": {"field": 12}}}} + } + } + } + })")); + ASSERT_OK(internal::CheckMessagesEquivalent("Expression", *serialized, *expected)); +} + +TEST(Substrait, CallSpecialCaseRoundTrip) { + for (compute::Expression expr : { + compute::call("if_else", + { + compute::literal(true), + compute::field_ref({"struct", 1}), + compute::field_ref("str"), + }), + + compute::call( + "case_when", + { + compute::call("make_struct", + {compute::literal(false), compute::literal(true)}, + compute::MakeStructOptions({"cond1", "cond2"})), + compute::field_ref({"struct", "str"}), + compute::field_ref({"struct", "struct_i32_str", "str"}), + compute::field_ref("str"), + }), + + compute::call("list_element", + { + compute::field_ref("list_i32"), + compute::literal(3), + }), + + compute::call("struct_field", + {compute::call("list_element", + { + compute::field_ref("list_struct"), + compute::literal(42), + })}, + arrow::compute::StructFieldOptions({1})), + + compute::call("struct_field", + {compute::call("list_element", + { + compute::field_ref("list_struct"), + compute::literal(42), + })}, + arrow::compute::StructFieldOptions({2, 0})), + + compute::call("struct_field", + {compute::call("if_else", + { + compute::literal(true), + compute::field_ref("struct"), + compute::field_ref("struct"), + })}, + compute::StructFieldOptions({0})), + }) { + ARROW_SCOPED_TRACE(expr.ToString()); + ASSERT_OK_AND_ASSIGN(expr, expr.Bind(*kBoringSchema)); + + ExtensionSet ext_set; + ASSERT_OK_AND_ASSIGN(auto serialized, SerializeExpression(expr, &ext_set)); + + // These are special cased as core expressions in substrait; shouldn't require any + // extensions. + EXPECT_EQ(ext_set.num_functions(), 0); + + ASSERT_OK_AND_ASSIGN(auto roundtripped, DeserializeExpression(*serialized, ext_set)); + ASSERT_OK_AND_ASSIGN(roundtripped, roundtripped.Bind(*kBoringSchema)); + EXPECT_EQ(UseBoringRefs(roundtripped), UseBoringRefs(expr)); + } +} + +TEST(Substrait, CallExtensionFunction) { + for (compute::Expression expr : { + compute::call("add", {compute::literal(0), compute::literal(1)}), + }) { + ARROW_SCOPED_TRACE(expr.ToString()); + ASSERT_OK_AND_ASSIGN(expr, expr.Bind(*kBoringSchema)); + + ExtensionSet ext_set; + ASSERT_OK_AND_ASSIGN(auto serialized, SerializeExpression(expr, &ext_set)); + + // These require an extension, so we should have a single-element ext_set. + EXPECT_EQ(ext_set.num_functions(), 1); + + ASSERT_OK_AND_ASSIGN(auto roundtripped, DeserializeExpression(*serialized, ext_set)); + ASSERT_OK_AND_ASSIGN(roundtripped, roundtripped.Bind(*kBoringSchema)); + EXPECT_EQ(UseBoringRefs(roundtripped), UseBoringRefs(expr)); + } +} + +TEST(Substrait, ReadRel) { + ASSERT_OK_AND_ASSIGN(auto buf, internal::SubstraitFromJSON("Rel", R"({ + "read": { + "base_schema": { + "struct": { + "types": [ {"i64": {}}, {"bool": {}} ] + }, + "names": ["i", "b"] + }, + "filter": { + "selection": { + "directReference": { + "structField": { + "field": 1 + } + } + } + }, + "local_files": { + "items": [ + { + "uri_file": "file:///tmp/dat1.parquet", + "format": "FILE_FORMAT_PARQUET" + }, + { + "uri_file": "file:///tmp/dat2.parquet", + "format": "FILE_FORMAT_PARQUET" + } + ] + } + } + })")); + ExtensionSet ext_set; + ASSERT_OK_AND_ASSIGN(auto rel, DeserializeRelation(*buf, ext_set)); + + // converting a ReadRel produces a scan Declaration + ASSERT_EQ(rel.factory_name, "scan"); + const auto& scan_node_options = + checked_cast(*rel.options); + + // filter on the boolean field (#1) + EXPECT_EQ(scan_node_options.scan_options->filter, compute::field_ref(1)); + + // dataset is a FileSystemDataset in parquet format with the specified schema + ASSERT_EQ(scan_node_options.dataset->type_name(), "filesystem"); + const auto& dataset = + checked_cast(*scan_node_options.dataset); + EXPECT_THAT(dataset.files(), ElementsAre("/tmp/dat1.parquet", "/tmp/dat2.parquet")); + EXPECT_EQ(dataset.format()->type_name(), "parquet"); + EXPECT_EQ(*dataset.schema(), Schema({field("i", int64()), field("b", boolean())})); +} + +TEST(Substrait, ExtensionSetFromPlan) { + ASSERT_OK_AND_ASSIGN(auto buf, internal::SubstraitFromJSON("Plan", R"({ + "relations": [ + {"rel": { + "read": { + "base_schema": { + "struct": { + "types": [ {"i64": {}}, {"bool": {}} ] + }, + "names": ["i", "b"] + }, + "local_files": { "items": [] } + } + }} + ], + "extension_uris": [ + { + "extension_uri_anchor": 7, + "uri": "https://github.com/apache/arrow/blob/master/format/substrait/extension_types.yaml" + } + ], + "extensions": [ + {"extension_type": { + "extension_uri_reference": 7, + "type_anchor": 42, + "name": "null" + }}, + {"extension_type_variation": { + "extension_uri_reference": 7, + "type_variation_anchor": 23, + "name": "u8" + }}, + {"extension_function": { + "extension_uri_reference": 7, + "function_anchor": 42, + "name": "add" + }} + ] + })")); + + ExtensionSet ext_set; + ASSERT_OK_AND_ASSIGN( + auto sink_decls, + DeserializePlan( + *buf, [] { return std::shared_ptr{nullptr}; }, + &ext_set)); + + EXPECT_OK_AND_ASSIGN(auto decoded_null_type, ext_set.DecodeType(42)); + EXPECT_EQ(decoded_null_type.id.uri, kArrowExtTypesUri); + EXPECT_EQ(decoded_null_type.id.name, "null"); + EXPECT_EQ(*decoded_null_type.type, NullType()); + EXPECT_FALSE(decoded_null_type.is_variation); + + EXPECT_OK_AND_ASSIGN(auto decoded_uint8_type, ext_set.DecodeType(23)); + EXPECT_EQ(decoded_uint8_type.id.uri, kArrowExtTypesUri); + EXPECT_EQ(decoded_uint8_type.id.name, "u8"); + EXPECT_EQ(*decoded_uint8_type.type, UInt8Type()); + EXPECT_TRUE(decoded_uint8_type.is_variation); + + EXPECT_OK_AND_ASSIGN(auto decoded_add_func, ext_set.DecodeFunction(42)); + EXPECT_EQ(decoded_add_func.id.uri, kArrowExtTypesUri); + EXPECT_EQ(decoded_add_func.id.name, "add"); + EXPECT_EQ(decoded_add_func.name, "add"); +} + +} // namespace engine +} // namespace arrow diff --git a/cpp/src/arrow/engine/substrait/type_internal.cc b/cpp/src/arrow/engine/substrait/type_internal.cc new file mode 100644 index 00000000000..49ca1bbfabf --- /dev/null +++ b/cpp/src/arrow/engine/substrait/type_internal.cc @@ -0,0 +1,494 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#include "arrow/engine/substrait/type_internal.h" + +#include +#include + +#include "arrow/engine/substrait/extension_types.h" +#include "arrow/result.h" +#include "arrow/status.h" +#include "arrow/type.h" +#include "arrow/util/logging.h" +#include "arrow/util/make_unique.h" +#include "arrow/visit_type_inline.h" + +namespace arrow { +namespace engine { + +namespace internal { +using ::arrow::internal::make_unique; +} // namespace internal + +namespace { + +template +Status CheckVariation(const TypeMessage& type) { + if (type.type_variation_reference() == 0) return Status::OK(); + return Status::NotImplemented("Type variations for ", type.DebugString()); +} + +template +bool IsNullable(const TypeMessage& type) { + // FIXME what can we do with NULLABILITY_UNSPECIFIED + return type.nullability() != substrait::Type::NULLABILITY_REQUIRED; +} + +template +Result, bool>> FromProtoImpl(const TypeMessage& type, + A&&... args) { + RETURN_NOT_OK(CheckVariation(type)); + + return std::make_pair(std::static_pointer_cast( + std::make_shared(std::forward(args)...)), + IsNullable(type)); +} + +template +Result, bool>> FromProtoImpl( + const TypeMessage& type, std::shared_ptr type_factory(A...), A&&... args) { + RETURN_NOT_OK(CheckVariation(type)); + + return std::make_pair( + std::static_pointer_cast(type_factory(std::forward(args)...)), + IsNullable(type)); +} + +template +Result FieldsFromProto(int size, const Types& types, + const NextName& next_name, + const ExtensionSet& ext_set) { + FieldVector fields(size); + for (int i = 0; i < size; ++i) { + std::string name = next_name(); + std::shared_ptr type; + bool nullable; + + if (types[i].has_struct_()) { + const auto& struct_ = types[i].struct_(); + + ARROW_ASSIGN_OR_RAISE( + type, FieldsFromProto(struct_.types_size(), struct_.types(), next_name, ext_set) + .Map(arrow::struct_)); + + nullable = IsNullable(struct_); + } else { + ARROW_ASSIGN_OR_RAISE(std::tie(type, nullable), FromProto(types[i], ext_set)); + } + + fields[i] = field(std::move(name), std::move(type), nullable); + } + return fields; +} + +} // namespace + +Result, bool>> FromProto( + const substrait::Type& type, const ExtensionSet& ext_set) { + switch (type.kind_case()) { + case substrait::Type::kBool: + return FromProtoImpl(type.bool_()); + + case substrait::Type::kI8: + return FromProtoImpl(type.i8()); + case substrait::Type::kI16: + return FromProtoImpl(type.i16()); + case substrait::Type::kI32: + return FromProtoImpl(type.i32()); + case substrait::Type::kI64: + return FromProtoImpl(type.i64()); + + case substrait::Type::kFp32: + return FromProtoImpl(type.fp32()); + case substrait::Type::kFp64: + return FromProtoImpl(type.fp64()); + + case substrait::Type::kString: + return FromProtoImpl(type.string()); + case substrait::Type::kBinary: + return FromProtoImpl(type.binary()); + + case substrait::Type::kTimestamp: + return FromProtoImpl(type.timestamp(), TimeUnit::MICRO); + case substrait::Type::kTimestampTz: + return FromProtoImpl(type.timestamp_tz(), TimeUnit::MICRO, + TimestampTzTimezoneString()); + case substrait::Type::kDate: + return FromProtoImpl(type.date()); + + case substrait::Type::kTime: + return FromProtoImpl(type.time(), TimeUnit::MICRO); + + case substrait::Type::kIntervalYear: + return FromProtoImpl(type.interval_year(), interval_year); + + case substrait::Type::kIntervalDay: + return FromProtoImpl(type.interval_day(), interval_day); + + case substrait::Type::kUuid: + return FromProtoImpl(type.uuid(), uuid); + + case substrait::Type::kFixedChar: + return FromProtoImpl(type.fixed_char(), fixed_char, type.fixed_char().length()); + + case substrait::Type::kVarchar: + return FromProtoImpl(type.varchar(), varchar, type.varchar().length()); + + case substrait::Type::kFixedBinary: + return FromProtoImpl(type.fixed_binary(), + type.fixed_binary().length()); + + case substrait::Type::kDecimal: { + const auto& decimal = type.decimal(); + return FromProtoImpl(decimal, decimal.precision(), decimal.scale()); + } + + case substrait::Type::kStruct: { + const auto& struct_ = type.struct_(); + + ARROW_ASSIGN_OR_RAISE(auto fields, FieldsFromProto( + struct_.types_size(), struct_.types(), + /*next_name=*/[] { return ""; }, ext_set)); + + return FromProtoImpl(struct_, std::move(fields)); + } + + case substrait::Type::kList: { + const auto& list = type.list(); + + if (!list.has_type()) { + return Status::Invalid( + "While converting to ListType encountered a missing item type in ", + list.DebugString()); + } + + ARROW_ASSIGN_OR_RAISE(auto type_nullable, FromProto(list.type(), ext_set)); + return FromProtoImpl( + list, field("item", std::move(type_nullable.first), type_nullable.second)); + } + + case substrait::Type::kMap: { + const auto& map = type.map(); + + static const std::array kMissing = {"key and value", "value", "key", + nullptr}; + if (auto missing = kMissing[map.has_key() + map.has_value() * 2]) { + return Status::Invalid("While converting to MapType encountered missing ", + missing, " type in ", map.DebugString()); + } + + ARROW_ASSIGN_OR_RAISE(auto key_nullable, FromProto(map.key(), ext_set)); + ARROW_ASSIGN_OR_RAISE(auto value_nullable, FromProto(map.value(), ext_set)); + + if (key_nullable.second) { + return Status::Invalid( + "While converting to MapType encountered nullable key field in ", + map.DebugString()); + } + + return FromProtoImpl( + map, std::move(key_nullable.first), + field("value", std::move(value_nullable.first), value_nullable.second)); + } + + case substrait::Type::kUserDefinedTypeReference: { + uint32_t anchor = type.user_defined_type_reference(); + ARROW_ASSIGN_OR_RAISE(auto type_record, ext_set.DecodeType(anchor)); + return std::make_pair(std::move(type_record.type), true); + } + + default: + break; + } + + return Status::NotImplemented("conversion to arrow::DataType from Substrait type ", + type.DebugString()); +} + +namespace { + +struct DataTypeToProtoImpl { + Status Visit(const NullType& t) { return EncodeUserDefined(t); } + + Status Visit(const BooleanType& t) { + return SetWith(&substrait::Type::set_allocated_bool_); + } + + Status Visit(const Int8Type& t) { return SetWith(&substrait::Type::set_allocated_i8); } + Status Visit(const Int16Type& t) { + return SetWith(&substrait::Type::set_allocated_i16); + } + Status Visit(const Int32Type& t) { + return SetWith(&substrait::Type::set_allocated_i32); + } + Status Visit(const Int64Type& t) { + return SetWith(&substrait::Type::set_allocated_i64); + } + + Status Visit(const UInt8Type& t) { return EncodeUserDefined(t); } + Status Visit(const UInt16Type& t) { return EncodeUserDefined(t); } + Status Visit(const UInt32Type& t) { return EncodeUserDefined(t); } + Status Visit(const UInt64Type& t) { return EncodeUserDefined(t); } + + Status Visit(const HalfFloatType& t) { return EncodeUserDefined(t); } + Status Visit(const FloatType& t) { + return SetWith(&substrait::Type::set_allocated_fp32); + } + Status Visit(const DoubleType& t) { + return SetWith(&substrait::Type::set_allocated_fp64); + } + + Status Visit(const StringType& t) { + return SetWith(&substrait::Type::set_allocated_string); + } + Status Visit(const BinaryType& t) { + return SetWith(&substrait::Type::set_allocated_binary); + } + + Status Visit(const FixedSizeBinaryType& t) { + SetWithThen(&substrait::Type::set_allocated_fixed_binary)->set_length(t.byte_width()); + return Status::OK(); + } + + Status Visit(const Date32Type& t) { + return SetWith(&substrait::Type::set_allocated_date); + } + Status Visit(const Date64Type& t) { return NotImplemented(t); } + + Status Visit(const TimestampType& t) { + if (t.unit() != TimeUnit::MICRO) return NotImplemented(t); + + if (t.timezone() == "") { + return SetWith(&substrait::Type::set_allocated_timestamp); + } + if (t.timezone() == TimestampTzTimezoneString()) { + return SetWith(&substrait::Type::set_allocated_timestamp_tz); + } + + return NotImplemented(t); + } + + Status Visit(const Time32Type& t) { return NotImplemented(t); } + Status Visit(const Time64Type& t) { + if (t.unit() != TimeUnit::MICRO) return NotImplemented(t); + return SetWith(&substrait::Type::set_allocated_time); + } + + Status Visit(const MonthIntervalType& t) { return EncodeUserDefined(t); } + Status Visit(const DayTimeIntervalType& t) { return EncodeUserDefined(t); } + + Status Visit(const Decimal128Type& t) { + auto dec = SetWithThen(&substrait::Type::set_allocated_decimal); + dec->set_precision(t.precision()); + dec->set_scale(t.scale()); + return Status::OK(); + } + Status Visit(const Decimal256Type& t) { return NotImplemented(t); } + + Status Visit(const ListType& t) { + // FIXME assert default field name; custom ones won't roundtrip + ARROW_ASSIGN_OR_RAISE( + auto type, ToProto(*t.value_type(), t.value_field()->nullable(), ext_set_)); + SetWithThen(&substrait::Type::set_allocated_list)->set_allocated_type(type.release()); + return Status::OK(); + } + + Status Visit(const StructType& t) { + auto types = SetWithThen(&substrait::Type::set_allocated_struct_)->mutable_types(); + + types->Reserve(t.num_fields()); + + for (const auto& field : t.fields()) { + if (field->metadata() != nullptr) { + return Status::Invalid("substrait::Type::Struct does not support field metadata"); + } + ARROW_ASSIGN_OR_RAISE(auto type, + ToProto(*field->type(), field->nullable(), ext_set_)); + types->AddAllocated(type.release()); + } + return Status::OK(); + } + + Status Visit(const SparseUnionType& t) { return NotImplemented(t); } + Status Visit(const DenseUnionType& t) { return NotImplemented(t); } + Status Visit(const DictionaryType& t) { return NotImplemented(t); } + + Status Visit(const MapType& t) { + // FIXME assert default field names; custom ones won't roundtrip + auto map = SetWithThen(&substrait::Type::set_allocated_map); + + ARROW_ASSIGN_OR_RAISE(auto key, ToProto(*t.key_type(), /*nullable=*/false, ext_set_)); + map->set_allocated_key(key.release()); + + ARROW_ASSIGN_OR_RAISE(auto value, + ToProto(*t.item_type(), t.item_field()->nullable(), ext_set_)); + map->set_allocated_value(value.release()); + + return Status::OK(); + } + + Status Visit(const ExtensionType& t) { + if (UnwrapUuid(t)) { + return SetWith(&substrait::Type::set_allocated_uuid); + } + + if (auto length = UnwrapFixedChar(t)) { + SetWithThen(&substrait::Type::set_allocated_fixed_char)->set_length(*length); + return Status::OK(); + } + + if (auto length = UnwrapVarChar(t)) { + SetWithThen(&substrait::Type::set_allocated_varchar)->set_length(*length); + return Status::OK(); + } + + if (UnwrapIntervalYear(t)) { + return SetWith(&substrait::Type::set_allocated_interval_year); + } + + if (UnwrapIntervalDay(t)) { + return SetWith(&substrait::Type::set_allocated_interval_day); + } + + return NotImplemented(t); + } + + Status Visit(const FixedSizeListType& t) { return NotImplemented(t); } + Status Visit(const DurationType& t) { return NotImplemented(t); } + Status Visit(const LargeStringType& t) { return NotImplemented(t); } + Status Visit(const LargeBinaryType& t) { return NotImplemented(t); } + Status Visit(const LargeListType& t) { return NotImplemented(t); } + Status Visit(const MonthDayNanoIntervalType& t) { return EncodeUserDefined(t); } + + template + Sub* SetWithThen(void (substrait::Type::*set_allocated_sub)(Sub*)) { + auto sub = internal::make_unique(); + sub->set_nullability(nullable_ ? substrait::Type::NULLABILITY_NULLABLE + : substrait::Type::NULLABILITY_REQUIRED); + + auto out = sub.get(); + (type_->*set_allocated_sub)(sub.release()); + return out; + } + + template + Status SetWith(void (substrait::Type::*set_allocated_sub)(Sub*)) { + return SetWithThen(set_allocated_sub), Status::OK(); + } + + template + Status EncodeUserDefined(const T& t) { + ARROW_ASSIGN_OR_RAISE(auto anchor, ext_set_->EncodeType(t)); + type_->set_user_defined_type_reference(anchor); + return Status::OK(); + } + + Status NotImplemented(const DataType& t) { + return Status::NotImplemented("conversion to substrait::Type from ", t.ToString()); + } + + Status operator()(const DataType& type) { return VisitTypeInline(type, this); } + + substrait::Type* type_; + bool nullable_; + ExtensionSet* ext_set_; +}; +} // namespace + +Result> ToProto(const DataType& type, bool nullable, + ExtensionSet* ext_set) { + auto out = internal::make_unique(); + RETURN_NOT_OK((DataTypeToProtoImpl{out.get(), nullable, ext_set})(type)); + return std::move(out); +} + +Result> FromProto(const substrait::NamedStruct& named_struct, + const ExtensionSet& ext_set) { + if (!named_struct.has_struct_()) { + return Status::Invalid("While converting ", named_struct.DebugString(), + " no anonymous struct type was provided to which names " + "could be attached."); + } + const auto& struct_ = named_struct.struct_(); + RETURN_NOT_OK(CheckVariation(struct_)); + + int requested_names_count = 0; + ARROW_ASSIGN_OR_RAISE(auto fields, FieldsFromProto( + struct_.types_size(), struct_.types(), + /*next_name=*/ + [&] { + int i = requested_names_count++; + return i < named_struct.names_size() + ? named_struct.names().Get(i) + : ""; + }, + ext_set)); + + if (requested_names_count != named_struct.names_size()) { + return Status::Invalid("While converting ", named_struct.DebugString(), " received ", + named_struct.names_size(), " names but ", + requested_names_count, " struct fields"); + } + + return schema(std::move(fields)); +} + +namespace { +void ToProtoGetDepthFirstNames(const FieldVector& fields, + google::protobuf::RepeatedPtrField* names) { + for (const auto& field : fields) { + *names->Add() = field->name(); + + if (field->type()->id() == Type::STRUCT) { + ToProtoGetDepthFirstNames(field->type()->fields(), names); + } + } +} +} // namespace + +Result> ToProto(const Schema& schema, + ExtensionSet* ext_set) { + if (schema.metadata()) { + return Status::Invalid("substrait::NamedStruct does not support schema metadata"); + } + + auto named_struct = internal::make_unique(); + + auto names = named_struct->mutable_names(); + names->Reserve(schema.num_fields()); + ToProtoGetDepthFirstNames(schema.fields(), names); + + auto struct_ = internal::make_unique(); + auto types = struct_->mutable_types(); + types->Reserve(schema.num_fields()); + + for (const auto& field : schema.fields()) { + if (field->metadata() != nullptr) { + return Status::Invalid("substrait::NamedStruct does not support field metadata"); + } + + ARROW_ASSIGN_OR_RAISE(auto type, ToProto(*field->type(), field->nullable(), ext_set)); + types->AddAllocated(type.release()); + } + + named_struct->set_allocated_struct_(struct_.release()); + return std::move(named_struct); +} + +} // namespace engine +} // namespace arrow diff --git a/cpp/src/arrow/engine/substrait/type_internal.h b/cpp/src/arrow/engine/substrait/type_internal.h new file mode 100644 index 00000000000..058019c759f --- /dev/null +++ b/cpp/src/arrow/engine/substrait/type_internal.h @@ -0,0 +1,51 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +// This API is EXPERIMENTAL. + +#pragma once + +#include + +#include "arrow/engine/substrait/extension_set.h" +#include "arrow/engine/visibility.h" +#include "arrow/type_fwd.h" + +#include "substrait/type.pb.h" // IWYU pragma: export + +namespace arrow { +namespace engine { + +ARROW_ENGINE_EXPORT +Result, bool>> FromProto(const substrait::Type&, + const ExtensionSet&); + +ARROW_ENGINE_EXPORT +Result> ToProto(const DataType&, bool nullable, + ExtensionSet*); + +ARROW_ENGINE_EXPORT +Result> FromProto(const substrait::NamedStruct&, + const ExtensionSet&); + +ARROW_ENGINE_EXPORT +Result> ToProto(const Schema&, ExtensionSet*); + +inline std::string TimestampTzTimezoneString() { return "UTC"; } + +} // namespace engine +} // namespace arrow diff --git a/cpp/src/arrow/engine/visibility.h b/cpp/src/arrow/engine/visibility.h new file mode 100644 index 00000000000..5b1651f78ab --- /dev/null +++ b/cpp/src/arrow/engine/visibility.h @@ -0,0 +1,50 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +// This API is EXPERIMENTAL. + +#pragma once + +#if defined(_WIN32) || defined(__CYGWIN__) +#if defined(_MSC_VER) +#pragma warning(push) +#pragma warning(disable : 4251) +#else +#pragma GCC diagnostic ignored "-Wattributes" +#endif + +#ifdef ARROW_ENGINE_STATIC +#define ARROW_ENGINE_EXPORT +#elif defined(ARROW_ENGINE_EXPORTING) +#define ARROW_ENGINE_EXPORT __declspec(dllexport) +#else +#define ARROW_ENGINE_EXPORT __declspec(dllimport) +#endif + +#define ARROW_ENGINE_NO_EXPORT +#else // Not Windows +#ifndef ARROW_ENGINE_EXPORT +#define ARROW_ENGINE_EXPORT __attribute__((visibility("default"))) +#endif +#ifndef ARROW_ENGINE_NO_EXPORT +#define ARROW_ENGINE_NO_EXPORT __attribute__((visibility("hidden"))) +#endif +#endif // Non-Windows + +#if defined(_MSC_VER) +#pragma warning(pop) +#endif diff --git a/cpp/src/arrow/flight/CMakeLists.txt b/cpp/src/arrow/flight/CMakeLists.txt index 2cf8c9913e5..bcebe3e73aa 100644 --- a/cpp/src/arrow/flight/CMakeLists.txt +++ b/cpp/src/arrow/flight/CMakeLists.txt @@ -53,7 +53,7 @@ endif() # TODO(wesm): Protobuf shared vs static linking set(FLIGHT_PROTO_PATH "${ARROW_SOURCE_DIR}/../format") -set(FLIGHT_PROTO ${ARROW_SOURCE_DIR}/../format/Flight.proto) +set(FLIGHT_PROTO "${ARROW_SOURCE_DIR}/../format/Flight.proto") set(FLIGHT_GENERATED_PROTO_FILES "${CMAKE_CURRENT_BINARY_DIR}/Flight.pb.cc" "${CMAKE_CURRENT_BINARY_DIR}/Flight.pb.h" @@ -152,9 +152,9 @@ endif() set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS_BACKUP}") # Note, we do not compile the generated Protobuf sources directly, instead -# compiling then via protocol_internal.cc which contains some gRPC template +# compiling them via protocol_internal.cc which contains some gRPC template # overrides to enable Flight-specific optimizations. See comments in -# protobuf-internal.cc +# protocol_internal.cc set(ARROW_FLIGHT_SRCS client.cc client_cookie_middleware.cc diff --git a/cpp/src/arrow/scalar.cc b/cpp/src/arrow/scalar.cc index 23c07e6ade2..23c463b6523 100644 --- a/cpp/src/arrow/scalar.cc +++ b/cpp/src/arrow/scalar.cc @@ -475,9 +475,15 @@ Status Scalar::ValidateFull() const { return ScalarValidateImpl(/*full_validation=*/true).Validate(*this); } +BinaryScalar::BinaryScalar(std::string s) + : BinaryScalar(Buffer::FromString(std::move(s))) {} + StringScalar::StringScalar(std::string s) : StringScalar(Buffer::FromString(std::move(s))) {} +LargeBinaryScalar::LargeBinaryScalar(std::string s) + : LargeBinaryScalar(Buffer::FromString(std::move(s))) {} + LargeStringScalar::LargeStringScalar(std::string s) : LargeStringScalar(Buffer::FromString(std::move(s))) {} @@ -488,6 +494,12 @@ FixedSizeBinaryScalar::FixedSizeBinaryScalar(std::shared_ptr value, this->value->size()); } +FixedSizeBinaryScalar::FixedSizeBinaryScalar(const std::shared_ptr& value) + : BinaryScalar(value, fixed_size_binary(static_cast(value->size()))) {} + +FixedSizeBinaryScalar::FixedSizeBinaryScalar(std::string s) + : FixedSizeBinaryScalar(Buffer::FromString(std::move(s))) {} + BaseListScalar::BaseListScalar(std::shared_ptr value, std::shared_ptr type) : Scalar{std::move(type), true}, value(std::move(value)) { diff --git a/cpp/src/arrow/scalar.h b/cpp/src/arrow/scalar.h index 9df3e3c74e3..943a6420d80 100644 --- a/cpp/src/arrow/scalar.h +++ b/cpp/src/arrow/scalar.h @@ -250,6 +250,8 @@ struct ARROW_EXPORT BinaryScalar : public BaseBinaryScalar { explicit BinaryScalar(std::shared_ptr value) : BinaryScalar(std::move(value), binary()) {} + explicit BinaryScalar(std::string s); + BinaryScalar() : BinaryScalar(binary()) {} }; @@ -275,6 +277,8 @@ struct ARROW_EXPORT LargeBinaryScalar : public BaseBinaryScalar { explicit LargeBinaryScalar(std::shared_ptr value) : LargeBinaryScalar(std::move(value), large_binary()) {} + explicit LargeBinaryScalar(std::string s); + LargeBinaryScalar() : LargeBinaryScalar(large_binary()) {} }; @@ -295,7 +299,12 @@ struct ARROW_EXPORT FixedSizeBinaryScalar : public BinaryScalar { FixedSizeBinaryScalar(std::shared_ptr value, std::shared_ptr type); - explicit FixedSizeBinaryScalar(std::shared_ptr type) : BinaryScalar(type) {} + explicit FixedSizeBinaryScalar(const std::shared_ptr& value); + + explicit FixedSizeBinaryScalar(std::string s); + + explicit FixedSizeBinaryScalar(std::shared_ptr type) + : BinaryScalar(std::move(type)) {} }; template @@ -345,8 +354,8 @@ struct ARROW_EXPORT TimestampScalar : public TemporalScalar { using TemporalScalar::TemporalScalar; TimestampScalar(typename TemporalScalar::ValueType value, - TimeUnit::type unit) - : TimestampScalar(std::move(value), timestamp(unit)) {} + TimeUnit::type unit, std::string tz = "") + : TimestampScalar(std::move(value), timestamp(unit, std::move(tz))) {} }; template @@ -533,6 +542,11 @@ struct ARROW_EXPORT ExtensionScalar : public Scalar { ExtensionScalar(std::shared_ptr storage, std::shared_ptr type) : Scalar(std::move(type), true), value(std::move(storage)) {} + template ::value>> + ExtensionScalar(Storage&& storage, std::shared_ptr type) + : ExtensionScalar(std::make_shared(std::move(storage)), std::move(type)) {} + std::shared_ptr value; }; diff --git a/cpp/src/arrow/status_test.cc b/cpp/src/arrow/status_test.cc index 10a79d9b990..a8e1d1ca9a8 100644 --- a/cpp/src/arrow/status_test.cc +++ b/cpp/src/arrow/status_test.cc @@ -179,20 +179,19 @@ TEST(StatusTest, MatcherExplanations) { { testing::StringMatchResultListener listener; EXPECT_TRUE(matcher.MatchAndExplain(Status::Invalid("XXX"), &listener)); - EXPECT_THAT(listener.str(), testing::StrEq("whose value \"Invalid: XXX\" matches")); + EXPECT_THAT(listener.str(), testing::StrEq("whose error matches")); } { testing::StringMatchResultListener listener; EXPECT_FALSE(matcher.MatchAndExplain(Status::OK(), &listener)); - EXPECT_THAT(listener.str(), testing::StrEq("whose value \"OK\" doesn't match")); + EXPECT_THAT(listener.str(), testing::StrEq("whose non-error doesn't match")); } { testing::StringMatchResultListener listener; EXPECT_FALSE(matcher.MatchAndExplain(Status::TypeError("XXX"), &listener)); - EXPECT_THAT(listener.str(), - testing::StrEq("whose value \"Type error: XXX\" doesn't match")); + EXPECT_THAT(listener.str(), testing::StrEq("whose error doesn't match")); } } diff --git a/cpp/src/arrow/testing/matchers.h b/cpp/src/arrow/testing/matchers.h index ddfe60f1740..be88c3f93b4 100644 --- a/cpp/src/arrow/testing/matchers.h +++ b/cpp/src/arrow/testing/matchers.h @@ -24,9 +24,11 @@ #include "arrow/datum.h" #include "arrow/result.h" #include "arrow/status.h" +#include "arrow/stl_iterator.h" #include "arrow/testing/future_util.h" #include "arrow/testing/gtest_util.h" #include "arrow/util/future.h" +#include "arrow/util/unreachable.h" namespace arrow { @@ -196,8 +198,14 @@ class ErrorMatcher { message_matcher_->MatchAndExplain(status.message(), &value_listener); } - *listener << "whose value " << testing::PrintToString(status.ToString()) - << (match ? " matches" : " doesn't match"); + if (match) { + *listener << "whose error matches"; + } else if (status.ok()) { + *listener << "whose non-error doesn't match"; + } else { + *listener << "whose error doesn't match"; + } + testing::internal::PrintIfNotEmpty(value_listener.str(), listener->stream()); return match; } @@ -228,8 +236,7 @@ class OkMatcher { const Status& status = internal::GenericToStatus(maybe_value); const bool match = status.ok(); - *listener << "whose value " << testing::PrintToString(status.ToString()) - << (match ? " matches" : " doesn't match"); + *listener << "whose " << (match ? "non-error matches" : "error doesn't match"); return match; } }; @@ -268,6 +275,9 @@ ErrorMatcher Raises(StatusCode code, const MessageMatcher& message_matcher) { class DataEqMatcher { public: + // TODO(bkietz) support EqualOptions, ApproxEquals, etc + // Probably it's better to use something like config-through-key_value_metadata + // as with the random generators to decouple this from EqualOptions etc. explicit DataEqMatcher(Datum expected) : expected_(std::move(expected)) {} template @@ -295,17 +305,34 @@ class DataEqMatcher { return false; } - if (*boxed.type() != *expected_.type()) { - *listener << "whose DataType " << boxed.type()->ToString() << " doesn't match " - << expected_.type()->ToString(); - return false; + if (const auto& boxed_type = boxed.type()) { + if (*boxed_type != *expected_.type()) { + *listener << "whose DataType " << boxed_type->ToString() << " doesn't match " + << expected_.type()->ToString(); + return false; + } + } else if (const auto& boxed_schema = boxed.schema()) { + if (*boxed_schema != *expected_.schema()) { + *listener << "whose Schema " << boxed_schema->ToString() << " doesn't match " + << expected_.schema()->ToString(); + return false; + } + } else { + Unreachable(); } - const bool match = boxed == expected_; - *listener << "whose value "; - PrintTo(boxed, listener->stream()); - *listener << (match ? " matches" : " doesn't match"); - return match; + if (boxed == expected_) { + *listener << "whose value matches"; + return true; + } + + if (listener->IsInterested() && boxed.kind() == Datum::ARRAY) { + *listener << "whose value differs from the expected value by " + << boxed.make_array()->Diff(*expected_.make_array()); + } else { + *listener << "whose value doesn't match"; + } + return false; } Datum expected_; @@ -318,9 +345,66 @@ class DataEqMatcher { Datum expected_; }; +/// Constructs a datum against which arguments are matched template DataEqMatcher DataEq(Data&& dat) { return DataEqMatcher(Datum(std::forward(dat))); } +/// Constructs an array with ArrayFromJSON against which arguments are matched +inline DataEqMatcher DataEqArray(const std::shared_ptr& type, + util::string_view json) { + return DataEq(ArrayFromJSON(type, json)); +} + +/// Constructs an array from a vector of optionals against which arguments are matched +template ::ArrayType, + typename BuilderType = typename TypeTraits::BuilderType, + typename ValueType = + typename ::arrow::stl::detail::DefaultValueAccessor::ValueType> +DataEqMatcher DataEqArray(T type, const std::vector>& values) { + // FIXME(bkietz) broken until DataType is move constructible + BuilderType builder(std::make_shared(std::move(type)), default_memory_pool()); + DCHECK_OK(builder.Reserve(static_cast(values.size()))); + + // pseudo constexpr: + static const bool need_safe_append = !is_fixed_width(T::type_id); + + for (auto value : values) { + if (value) { + if (need_safe_append) { + builder.UnsafeAppend(*value); + } else { + DCHECK_OK(builder.Append(*value)); + } + } else { + builder.UnsafeAppendNull(); + } + } + + return DataEq(builder.Finish().ValueOrDie()); +} + +/// Constructs a scalar with ScalarFromJSON against which arguments are matched +inline DataEqMatcher DataEqScalar(const std::shared_ptr& type, + util::string_view json) { + return DataEq(ScalarFromJSON(type, json)); +} + +/// Constructs a scalar against which arguments are matched +template ::ScalarType, + typename ValueType = typename ScalarType::ValueType> +DataEqMatcher DataEqScalar(T type, util::optional value) { + ScalarType expected(std::make_shared(std::move(type))); + + if (value) { + expected.is_valid = true; + expected.value = std::move(*value); + } + + return DataEq(std::move(expected)); +} + +// HasType, HasSchema matchers + } // namespace arrow diff --git a/cpp/src/arrow/type.h b/cpp/src/arrow/type.h index 4c439841ba2..7381d94b43e 100644 --- a/cpp/src/arrow/type.h +++ b/cpp/src/arrow/type.h @@ -82,8 +82,8 @@ class ARROW_EXPORT Fingerprintable { virtual std::string ComputeFingerprint() const = 0; virtual std::string ComputeMetadataFingerprint() const = 0; - mutable std::atomic fingerprint_; - mutable std::atomic metadata_fingerprint_; + mutable std::atomic fingerprint_{NULLPTR}; + mutable std::atomic metadata_fingerprint_{NULLPTR}; }; } // namespace detail @@ -817,7 +817,7 @@ class ARROW_EXPORT Decimal256Type : public DecimalType { class ARROW_EXPORT BaseListType : public NestedType { public: using NestedType::NestedType; - std::shared_ptr value_field() const { return children_[0]; } + const std::shared_ptr& value_field() const { return children_[0]; } std::shared_ptr value_type() const { return children_[0]->type(); } }; diff --git a/cpp/src/arrow/util/hashing.h b/cpp/src/arrow/util/hashing.h index 328d7e7ca21..d2c0178b008 100644 --- a/cpp/src/arrow/util/hashing.h +++ b/cpp/src/arrow/util/hashing.h @@ -882,5 +882,14 @@ static inline Status ComputeNullBitmap(MemoryPool* pool, const MemoTableType& me return Status::OK(); } +struct StringViewHash { + // std::hash compatible hasher for use with std::unordered_* + // (the std::hash specialization provided by nonstd constructs std::string + // temporaries then invokes std::hash against those) + hash_t operator()(const util::string_view& value) const { + return ComputeStringHash<0>(value.data(), static_cast(value.size())); + } +}; + } // namespace internal } // namespace arrow diff --git a/dev/archery/archery/cli.py b/dev/archery/archery/cli.py index d8eeb7bab0e..dbe0b5c4bda 100644 --- a/dev/archery/archery/cli.py +++ b/dev/archery/archery/cli.py @@ -118,6 +118,12 @@ def _apply_options(cmd, options): @cpp_toolchain_options @click.option("--build-type", default=None, type=build_type, help="CMake's CMAKE_BUILD_TYPE") +@click.option("--build-static", default=True, type=BOOL, + help="Build static libraries") +@click.option("--build-shared", default=True, type=BOOL, + help="Build shared libraries") +@click.option("--build-unity", default=True, type=BOOL, + help="Use CMAKE_UNITY_BUILD") @click.option("--warn-level", default="production", type=warn_level_type, help="Controls compiler warnings -W(no-)error.") @click.option("--use-gold-linker", default=True, type=BOOL, diff --git a/dev/archery/archery/lang/cpp.py b/dev/archery/archery/lang/cpp.py index cf25ba871b5..4ece6ec829b 100644 --- a/dev/archery/archery/lang/cpp.py +++ b/dev/archery/archery/lang/cpp.py @@ -42,7 +42,7 @@ def __init__(self, cc=None, cxx=None, cxx_flags=None, build_type=None, warn_level=None, cpp_package_prefix=None, install_prefix=None, use_conda=None, - build_static=False, build_shared=True, build_unity=True, + build_static=True, build_shared=True, build_unity=True, # tests & examples with_tests=None, with_benchmarks=None, with_examples=None, with_integration=None, diff --git a/dev/release/rat_exclude_files.txt b/dev/release/rat_exclude_files.txt index 6229e0aec39..671b11cd362 100644 --- a/dev/release/rat_exclude_files.txt +++ b/dev/release/rat_exclude_files.txt @@ -31,6 +31,7 @@ cpp/src/generated/parquet_constants.cpp cpp/src/generated/parquet_constants.h cpp/src/generated/parquet_types.cpp cpp/src/generated/parquet_types.h +cpp/src/generated/substrait/* cpp/src/plasma/thirdparty/ae/ae.c cpp/src/plasma/thirdparty/ae/ae.h cpp/src/plasma/thirdparty/ae/ae_epoll.c diff --git a/format/substrait/extension_types.yaml b/format/substrait/extension_types.yaml new file mode 100644 index 00000000000..c905c8b04be --- /dev/null +++ b/format/substrait/extension_types.yaml @@ -0,0 +1,87 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +# substrait::{ExtensionTypeVariation, ExtensionType}s +# for wrapping types which appear in the arrow type system but +# are not first-class in substrait. These include: +# - null +# - unsigned integers +# - half-precision floating point numbers +# - 32-bit times and 64-bit dates +# - timestamps with units other than microseconds +# - timestamps with timezones other than UTC +# - 256-bit decimals +# - sparse and dense unions +# - dictionary encoded types +# - durations +# - string and binary with 64 bit offsets +# - list with 64-bit offsets +# - interval +# - interval +# - interval +# - arrow::ExtensionTypes +# +# Note that not all of these are currently implemented. In particular, these +# extension types are currently not parameterizable in Substrait, which means +# among other things that we can't declare dictionary type here at all since +# we'd have to declare a different dictionary type for all encoded types +# (but that is an infinite space). Similarly, we would have to declare a +# timestamp variation for all possible timezone strings. +# +# Ultimately these declarations are a promise which needs to be backed by +# equivalent serde in c++. This is handled by default_extension_id_registry(), +# defined in cpp/src/arrow/engine/substrait/extension_set.cc. These files +# currently need to be kept in sync manually; see ARROW-15535. + +type_variations: + - parent: i8 + name: u8 + description: an unsigned 8 bit integer + functions: SEPARATE + - parent: i16 + name: u16 + description: an unsigned 16 bit integer + functions: SEPARATE + - parent: i32 + name: u32 + description: an unsigned 32 bit integer + functions: SEPARATE + - parent: i64 + name: u64 + description: an unsigned 64 bit integer + functions: SEPARATE + + - parent: i16 + name: fp16 + description: a 16 bit floating point number + functions: SEPARATE + +types: + - name: "null" + structure: {} + - name: interval_month + structure: + months: i32 + - name: interval_day_milli + structure: + days: i32 + millis: i32 + - name: interval_month_day_nano + structure: + months: i32 + days: i32 + nanos: i64