diff --git a/ci/cpp-msvc-build-main.bat b/ci/cpp-msvc-build-main.bat index 735073c49cc..fee46e1843c 100644 --- a/ci/cpp-msvc-build-main.bat +++ b/ci/cpp-msvc-build-main.bat @@ -78,6 +78,7 @@ cmake -G "%GENERATOR%" %CMAKE_ARGS% ^ -DARROW_FLIGHT=%ARROW_BUILD_FLIGHT% ^ -DARROW_GANDIVA=%ARROW_BUILD_GANDIVA% ^ -DARROW_DATASET=ON ^ + -DARROW_ENGINE=ON ^ -DARROW_S3=%ARROW_S3% ^ -DARROW_MIMALLOC=ON ^ -DARROW_PARQUET=ON ^ diff --git a/ci/docker/conda-cpp.dockerfile b/ci/docker/conda-cpp.dockerfile index 0e35b6caf6d..190a5fcc0a2 100644 --- a/ci/docker/conda-cpp.dockerfile +++ b/ci/docker/conda-cpp.dockerfile @@ -65,6 +65,7 @@ ENTRYPOINT [ "/bin/bash", "-c", "-l" ] ENV ARROW_BUILD_TESTS=ON \ ARROW_DEPENDENCY_SOURCE=CONDA \ ARROW_DATASET=ON \ + ARROW_ENGINE=ON \ ARROW_FLIGHT=ON \ ARROW_GANDIVA=ON \ ARROW_HOME=$CONDA_PREFIX \ diff --git a/ci/docker/conda-integration.dockerfile b/ci/docker/conda-integration.dockerfile index 5672cb8def3..86c33a6cebe 100644 --- a/ci/docker/conda-integration.dockerfile +++ b/ci/docker/conda-integration.dockerfile @@ -45,6 +45,7 @@ ENV ARROW_BUILD_INTEGRATION=ON \ ARROW_FLIGHT=ON \ ARROW_ORC=OFF \ ARROW_DATASET=OFF \ + ARROW_ENGINE=OFF \ ARROW_GANDIVA=OFF \ ARROW_PLASMA=OFF \ ARROW_FILESYSTEM=OFF \ diff --git a/ci/docker/cuda-10.0-cpp.dockerfile b/ci/docker/cuda-10.0-cpp.dockerfile index 0697513d30d..cb077b8e92f 100644 --- a/ci/docker/cuda-10.0-cpp.dockerfile +++ b/ci/docker/cuda-10.0-cpp.dockerfile @@ -76,6 +76,7 @@ ENV ARROW_BUILD_STATIC=OFF \ ARROW_CSV=OFF \ ARROW_CUDA=ON \ ARROW_DATASET=OFF \ + ARROW_ENGINE=OFF \ ARROW_DEPENDENCY_SOURCE=SYSTEM \ ARROW_FILESYSTEM=OFF \ ARROW_FLIGHT=OFF \ diff --git a/ci/docker/cuda-10.1-cpp.dockerfile b/ci/docker/cuda-10.1-cpp.dockerfile index 6e86ca97fc5..73ef0603e82 100644 --- a/ci/docker/cuda-10.1-cpp.dockerfile +++ b/ci/docker/cuda-10.1-cpp.dockerfile @@ -76,6 +76,7 @@ ENV ARROW_BUILD_STATIC=OFF \ ARROW_CSV=OFF \ ARROW_CUDA=ON \ ARROW_DATASET=OFF \ + ARROW_ENGINE=OFF \ ARROW_DEPENDENCY_SOURCE=SYSTEM \ ARROW_FILESYSTEM=OFF \ ARROW_FLIGHT=OFF \ diff --git a/ci/docker/cuda-9.1-cpp.dockerfile b/ci/docker/cuda-9.1-cpp.dockerfile index bf3242ea180..b7c302d822c 100644 --- a/ci/docker/cuda-9.1-cpp.dockerfile +++ b/ci/docker/cuda-9.1-cpp.dockerfile @@ -76,6 +76,7 @@ ENV ARROW_BUILD_STATIC=OFF \ ARROW_CSV=OFF \ ARROW_CUDA=ON \ ARROW_DATASET=OFF \ + ARROW_ENGINE=OFF \ ARROW_DEPENDENCY_SOURCE=SYSTEM \ ARROW_FILESYSTEM=OFF \ ARROW_FLIGHT=OFF \ diff --git a/ci/docker/debian-10-cpp.dockerfile b/ci/docker/debian-10-cpp.dockerfile index e51f482d842..11d1b10c45d 100644 --- a/ci/docker/debian-10-cpp.dockerfile +++ b/ci/docker/debian-10-cpp.dockerfile @@ -61,6 +61,7 @@ RUN export DEBIAN_FRONTEND=noninteractive && \ ENV ARROW_BUILD_TESTS=ON \ ARROW_DEPENDENCY_SOURCE=SYSTEM \ ARROW_DATASET=ON \ + ARROW_ENGINE=ON \ ARROW_FLIGHT=ON \ ARROW_GANDIVA=ON \ ARROW_HOME=/usr/local \ diff --git a/ci/docker/fedora-29-cpp.dockerfile b/ci/docker/fedora-29-cpp.dockerfile index 94adb5e7762..0dc1b603fee 100644 --- a/ci/docker/fedora-29-cpp.dockerfile +++ b/ci/docker/fedora-29-cpp.dockerfile @@ -59,6 +59,7 @@ RUN dnf update -y && \ ENV ARROW_BUILD_TESTS=ON \ ARROW_DEPENDENCY_SOURCE=SYSTEM \ ARROW_DATASET=ON \ + ARROW_ENGINE=ON \ ARROW_FLIGHT=ON \ ARROW_GANDIVA_JAVA=ON \ ARROW_GANDIVA=OFF \ diff --git a/ci/docker/ubuntu-14.04-cpp.dockerfile b/ci/docker/ubuntu-14.04-cpp.dockerfile index 5f24f9a353b..4266f23ef10 100644 --- a/ci/docker/ubuntu-14.04-cpp.dockerfile +++ b/ci/docker/ubuntu-14.04-cpp.dockerfile @@ -58,6 +58,7 @@ RUN apt-get update -y -q && \ ENV ARROW_BUILD_TESTS=ON \ ARROW_DEPENDENCY_SOURCE=SYSTEM \ ARROW_DATASET=ON \ + ARROW_ENGINE=ON \ ARROW_FLIGHT=OFF \ ARROW_GANDIVA_JAVA=OFF \ ARROW_GANDIVA=OFF \ diff --git a/ci/docker/ubuntu-16.04-cpp.dockerfile b/ci/docker/ubuntu-16.04-cpp.dockerfile index c6773662717..eba39219059 100644 --- a/ci/docker/ubuntu-16.04-cpp.dockerfile +++ b/ci/docker/ubuntu-16.04-cpp.dockerfile @@ -66,6 +66,7 @@ RUN apt-get update -y -q && \ ENV ARROW_BUILD_BENCHMARKS=OFF \ ARROW_BUILD_TESTS=ON \ ARROW_DATASET=ON \ + ARROW_ENGINE=ON \ ARROW_DEPENDENCY_SOURCE=SYSTEM \ ARROW_GANDIVA_JAVA=OFF \ ARROW_GANDIVA=ON \ diff --git a/ci/docker/ubuntu-18.04-cpp.dockerfile b/ci/docker/ubuntu-18.04-cpp.dockerfile index 92c75445e62..64b106e33d2 100644 --- a/ci/docker/ubuntu-18.04-cpp.dockerfile +++ b/ci/docker/ubuntu-18.04-cpp.dockerfile @@ -95,6 +95,7 @@ RUN apt-get update -y -q && \ ENV ARROW_BUILD_TESTS=ON \ ARROW_DEPENDENCY_SOURCE=SYSTEM \ ARROW_DATASET=ON \ + ARROW_ENGINE=ON \ ARROW_FLIGHT=OFF \ ARROW_GANDIVA=ON \ ARROW_HDFS=ON \ diff --git a/ci/scripts/PKGBUILD b/ci/scripts/PKGBUILD index 133da043eb6..f95ba9fce6d 100644 --- a/ci/scripts/PKGBUILD +++ b/ci/scripts/PKGBUILD @@ -82,6 +82,7 @@ build() { -DARROW_COMPUTE=ON \ -DARROW_CSV=ON \ -DARROW_DATASET=ON \ + -DARROW_ENGINE=ON \ -DARROW_FILESYSTEM=ON \ -DARROW_HDFS=OFF \ -DARROW_JEMALLOC=OFF \ diff --git a/ci/scripts/cpp_build.sh b/ci/scripts/cpp_build.sh index 0286987caa1..0233f975914 100755 --- a/ci/scripts/cpp_build.sh +++ b/ci/scripts/cpp_build.sh @@ -58,6 +58,7 @@ cmake -G "${CMAKE_GENERATOR:-Ninja}" \ -DARROW_CUDA=${ARROW_CUDA:-OFF} \ -DARROW_CXXFLAGS=${ARROW_CXXFLAGS:-} \ -DARROW_DATASET=${ARROW_DATASET:-ON} \ + -DARROW_ENGINE=${ARROW_ENGINE:-ON} \ -DARROW_DEPENDENCY_SOURCE=${ARROW_DEPENDENCY_SOURCE:-AUTO} \ -DARROW_EXTRA_ERROR_CONTEXT=${ARROW_EXTRA_ERROR_CONTEXT:-OFF} \ -DARROW_ENABLE_TIMING_TESTS=${ARROW_ENABLE_TIMING_TESTS:-ON} \ diff --git a/cpp/CMakeLists.txt b/cpp/CMakeLists.txt index 553dee28244..b86acef1d8e 100644 --- a/cpp/CMakeLists.txt +++ b/cpp/CMakeLists.txt @@ -296,7 +296,12 @@ if(ARROW_CUDA OR ARROW_FLIGHT OR ARROW_PARQUET OR ARROW_BUILD_TESTS) set(ARROW_IPC ON) endif() +if(ARROW_ENGINE) + set(ARROW_DATASET ON) +endif() + if(ARROW_DATASET) + set(ARROW_PARQUET ON) set(ARROW_FILESYSTEM ON) endif() diff --git a/cpp/cmake_modules/DefineOptions.cmake b/cpp/cmake_modules/DefineOptions.cmake index f307a6d10da..eece40c859d 100644 --- a/cpp/cmake_modules/DefineOptions.cmake +++ b/cpp/cmake_modules/DefineOptions.cmake @@ -173,6 +173,8 @@ if("${CMAKE_SOURCE_DIR}" STREQUAL "${CMAKE_CURRENT_SOURCE_DIR}") define_option(ARROW_DATASET "Build the Arrow Dataset Modules" OFF) + define_option(ARROW_ENGINE "Build the Arrow Query Engine Modules" OFF) + define_option(ARROW_FILESYSTEM "Build the Arrow Filesystem Layer" OFF) define_option(ARROW_FLIGHT diff --git a/cpp/cmake_modules/FindArrowEngine.cmake b/cpp/cmake_modules/FindArrowEngine.cmake new file mode 100644 index 00000000000..200fcaa5427 --- /dev/null +++ b/cpp/cmake_modules/FindArrowEngine.cmake @@ -0,0 +1,98 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +# - Find Arrow Query Engine (arrow/engine/api.h, libarrow_engine.a, libarrow_engine.so) +# +# This module requires Arrow from which it uses +# arrow_find_package() +# +# This module defines +# ARROW_ENGINE_FOUND, whether Arrow Query Engine has been found +# ARROW_ENGINE_IMPORT_LIB, +# path to libarrow_engine's import library (Windows only) +# ARROW_ENGINE_INCLUDE_DIR, directory containing headers +# ARROW_ENGINE_LIB_DIR, directory containing Arrow Query Engine libraries +# ARROW_ENGINE_SHARED_LIB, path to libarrow_engine's shared library +# ARROW_ENGINE_STATIC_LIB, path to libarrow_engine.a + +if(DEFINED ARROW_ENGINE_FOUND) + return() +endif() + +set(find_package_arguments) +if(${CMAKE_FIND_PACKAGE_NAME}_FIND_VERSION) + list(APPEND find_package_arguments + "${${CMAKE_FIND_PACKAGE_NAME}_FIND_VERSION}") +endif() +if(${CMAKE_FIND_PACKAGE_NAME}_FIND_REQUIRED) + list(APPEND find_package_arguments REQUIRED) +endif() +if(${CMAKE_FIND_PACKAGE_NAME}_FIND_QUIETLY) + list(APPEND find_package_arguments QUIET) +endif() +find_package(Arrow ${find_package_arguments}) +find_package(ArrowEngine ${find_package_arguments}) + +if(ARROW_FOUND AND ARROW_DATASET_FOUND) + arrow_find_package(ARROW_ENGINE + "${ARROW_HOME}" + arrow_engine + arrow/engine/api.h + ArrowEngine + arrow-engine) + if(NOT ARROW_ENGINE_VERSION) + set(ARROW_ENGINE_VERSION "${ARROW_VERSION}") + endif() +endif() + +if("${ARROW_ENGINE_VERSION}" VERSION_EQUAL "${ARROW_VERSION}") + set(ARROW_ENGINE_VERSION_MATCH TRUE) +else() + set(ARROW_ENGINE_VERSION_MATCH FALSE) +endif() + +mark_as_advanced(ARROW_ENGINE_IMPORT_LIB + ARROW_ENGINE_INCLUDE_DIR + ARROW_ENGINE_LIBS + ARROW_ENGINE_LIB_DIR + ARROW_ENGINE_SHARED_IMP_LIB + ARROW_ENGINE_SHARED_LIB + ARROW_ENGINE_STATIC_LIB + ARROW_ENGINE_VERSION + ARROW_ENGINE_VERSION_MATCH) + +find_package_handle_standard_args(ArrowEngine + REQUIRED_VARS + ARROW_ENGINE_INCLUDE_DIR + ARROW_ENGINE_LIB_DIR + ARROW_ENGINE_VERSION_MATCH + VERSION_VAR + ARROW_ENGINE_VERSION) +set(ARROW_ENGINE_FOUND ${ArrowEngine_FOUND}) + +if(ArrowEngine_FOUND AND NOT ArrowEngine_FIND_QUIETLY) + message(STATUS "Found the Arrow Engine by ${ARROW_ENGINE_FIND_APPROACH}") + message( + STATUS "Found the Arrow Engine shared library: ${ARROW_ENGINE_SHARED_LIB}" + ) + message( + STATUS "Found the Arrow Engine import library: ${ARROW_ENGINE_IMPORT_LIB}" + ) + message( + STATUS "Found the Arrow Engine static library: ${ARROW_ENGINE_STATIC_LIB}" + ) +endif() diff --git a/cpp/src/arrow/CMakeLists.txt b/cpp/src/arrow/CMakeLists.txt index 3454cd0c87d..8d2ba524620 100644 --- a/cpp/src/arrow/CMakeLists.txt +++ b/cpp/src/arrow/CMakeLists.txt @@ -586,6 +586,10 @@ if(ARROW_DATASET) add_subdirectory(dataset) endif() +if(ARROW_ENGINE) + add_subdirectory(engine) +endif() + if(ARROW_FILESYSTEM) add_subdirectory(filesystem) endif() diff --git a/cpp/src/arrow/compute/kernels/sum_internal.h b/cpp/src/arrow/compute/kernels/sum_internal.h index 302d004e399..21f28da4db5 100644 --- a/cpp/src/arrow/compute/kernels/sum_internal.h +++ b/cpp/src/arrow/compute/kernels/sum_internal.h @@ -54,6 +54,32 @@ struct FindAccumulatorType> { using Type = DoubleType; }; +#define ACCUMULATOR_TYPE_CASE(ID, TYPE) \ + case Type::ID: \ + return TypeTraits::Type>::type_singleton(); + +static inline std::shared_ptr GetAccumulatorType( + const std::shared_ptr& type) { + switch (type->id()) { + ACCUMULATOR_TYPE_CASE(INT8, Int8Type) + ACCUMULATOR_TYPE_CASE(INT16, Int16Type) + ACCUMULATOR_TYPE_CASE(INT32, Int32Type) + ACCUMULATOR_TYPE_CASE(INT64, Int64Type) + ACCUMULATOR_TYPE_CASE(UINT8, UInt8Type) + ACCUMULATOR_TYPE_CASE(UINT16, UInt16Type) + ACCUMULATOR_TYPE_CASE(UINT32, UInt32Type) + ACCUMULATOR_TYPE_CASE(UINT64, UInt64Type) + ACCUMULATOR_TYPE_CASE(FLOAT, FloatType) + ACCUMULATOR_TYPE_CASE(DOUBLE, DoubleType) + default: + return nullptr; + } + + ARROW_UNREACHABLE; +} + +#undef ACCUMULATOR_TYPE_CASE + template class SumAggregateFunction final : public AggregateFunctionStaticState { using CType = typename TypeTraits::CType; diff --git a/cpp/src/arrow/engine/ArrowEngineConfig.cmake.in b/cpp/src/arrow/engine/ArrowEngineConfig.cmake.in new file mode 100644 index 00000000000..43cce1be535 --- /dev/null +++ b/cpp/src/arrow/engine/ArrowEngineConfig.cmake.in @@ -0,0 +1,37 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# +# This config sets the following variables in your project:: +# +# ArrowEngine_FOUND - true if Arrow Query engine is found on the system +# +# This config sets the following targets in your project:: +# +# arrow_engine_shared - for linked as shared library if shared library is built +# arrow_engine_static - for linked as static library if static library is built + +@PACKAGE_INIT@ + +include(CMakeFindDependencyMacro) +find_dependency(Arrow) +find_dependency(ArrowDataset) + +# Load targets only once. If we load targets multiple times, CMake reports +# already existent target error. +if(NOT (TARGET arrow_engine_shared OR TARGET arrow_engine_static)) + include("${CMAKE_CURRENT_LIST_DIR}/ArrowEngineTargets.cmake") +endif() diff --git a/cpp/src/arrow/engine/CMakeLists.txt b/cpp/src/arrow/engine/CMakeLists.txt new file mode 100644 index 00000000000..b4420320ad6 --- /dev/null +++ b/cpp/src/arrow/engine/CMakeLists.txt @@ -0,0 +1,94 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +add_custom_target(arrow_engine) + +arrow_install_all_headers("arrow/engine") + +set(ARROW_ENGINE_SRCS catalog.cc expression.cc logical_plan.cc) + +set(ARROW_ENGINE_LINK_STATIC arrow_static arrow_dataset_static) +set(ARROW_ENGINE_LINK_SHARED arrow_shared arrow_dataset_shared) + +add_arrow_lib(arrow_engine + CMAKE_PACKAGE_NAME + ArrowEngine + PKG_CONFIG_NAME + arrow-engine + OUTPUTS + ARROW_ENGINE_LIBRARIES + SOURCES + ${ARROW_ENGINE_SRCS} + PRECOMPILED_HEADERS + "$<$:arrow/engine/pch.h>" + PRIVATE_INCLUDES + ${ARROW_ENGINE_PRIVATE_INCLUDES} + SHARED_LINK_LIBS + ${ARROW_ENGINE_LINK_SHARED} + STATIC_LINK_LIBS + ${ARROW_ENGINE_LINK_STATIC}) + +if(ARROW_TEST_LINKAGE STREQUAL "static") + set(ARROW_ENGINE_TEST_LINK_LIBS arrow_engine_static ${ARROW_TEST_STATIC_LINK_LIBS}) +else() + set(ARROW_ENGINE_TEST_LINK_LIBS arrow_engine_shared ${ARROW_TEST_SHARED_LINK_LIBS}) +endif() + +foreach(LIB_TARGET ${ARROW_ENGINE_LIBRARIES}) + target_compile_definitions(${LIB_TARGET} PRIVATE ARROW_EN_EXPORTING) +endforeach() + +# Adding unit tests part of the "engine" portion of the test suite +function(ADD_ARROW_ENGINE_TEST REL_TEST_NAME) + set(options) + set(one_value_args PREFIX) + set(multi_value_args LABELS) + cmake_parse_arguments(ARG + "${options}" + "${one_value_args}" + "${multi_value_args}" + ${ARGN}) + + if(ARG_PREFIX) + set(PREFIX ${ARG_PREFIX}) + else() + set(PREFIX "arrow-engine") + endif() + + if(ARG_LABELS) + set(LABELS ${ARG_LABELS}) + else() + set(LABELS "arrow_engine") + endif() + + add_arrow_test(${REL_TEST_NAME} + EXTRA_LINK_LIBS + ${ARROW_ENGINE_TEST_LINK_LIBS} + PREFIX + ${PREFIX} + LABELS + ${LABELS} + ${ARG_UNPARSED_ARGUMENTS}) +endfunction() + +# +# Unit tests +# + +add_arrow_engine_test(catalog_test PREFIX arrow-engine) +add_arrow_engine_test(expression_test PREFIX arrow-engine) +add_arrow_engine_test(logical_plan_test PREFIX arrow-engine) diff --git a/cpp/src/arrow/engine/api.h b/cpp/src/arrow/engine/api.h new file mode 100644 index 00000000000..85ac052b217 --- /dev/null +++ b/cpp/src/arrow/engine/api.h @@ -0,0 +1,22 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#pragma once + +#include "arrow/engine/catalog.h" +#include "arrow/engine/expression.h" +#include "arrow/engine/logical_plan.h" diff --git a/cpp/src/arrow/engine/arrow-engine.pc.in b/cpp/src/arrow/engine/arrow-engine.pc.in new file mode 100644 index 00000000000..0ceabbedf68 --- /dev/null +++ b/cpp/src/arrow/engine/arrow-engine.pc.in @@ -0,0 +1,25 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +libdir=@CMAKE_INSTALL_FULL_LIBDIR@ +includedir=@CMAKE_INSTALL_FULL_INCLUDEDIR@ + +Name: Apache Arrow Query Engine +Description: Apache Arrow Query Engine provides an API to execute queries on Arrow table and datasets +Version: @ARROW_VERSION@ +Requires: arrow arrow-dataset +Libs: -L${libdir} -larrow_dataset diff --git a/cpp/src/arrow/engine/catalog.cc b/cpp/src/arrow/engine/catalog.cc new file mode 100644 index 00000000000..d5479c222d2 --- /dev/null +++ b/cpp/src/arrow/engine/catalog.cc @@ -0,0 +1,121 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#include "arrow/engine/catalog.h" + +#include "arrow/result.h" +#include "arrow/status.h" +#include "arrow/table.h" + +#include "arrow/dataset/dataset.h" + +namespace arrow { +namespace engine { + +// +// Catalog +// + +using Entry = Catalog::Entry; + +Catalog::Catalog(std::unordered_map datasets) + : datasets_(std::move(datasets)) {} + +Result Catalog::Get(const std::string& key) const { + auto value = datasets_.find(key); + if (value != datasets_.end()) return value->second; + return Status::KeyError("Table '", key, "' not found in catalog."); +} + +Result> Catalog::GetSchema(const std::string& key) const { + auto as_schema = [](const Entry& entry) -> Result> { + return entry.schema(); + }; + return Get(key).Map(as_schema); +} + +Result> Catalog::Make(const std::vector& datasets) { + CatalogBuilder builder; + + for (const auto& key_val : datasets) { + RETURN_NOT_OK(builder.Add(key_val)); + } + + return builder.Finish(); +} + +// +// Catalog::Entry +// + +Entry::Entry(std::shared_ptr dataset, std::string name) + : dataset_(std::move(dataset)), name_(std::move(name)) {} + +Entry::Entry(std::shared_ptr table, std::string name) + : dataset_(std::make_shared(std::move(table))), + name_(std::move(name)) {} + +const std::shared_ptr& Entry::dataset() const { return dataset_; } + +bool Entry::operator==(const Entry& other) const { + // Entries are unique by name in a catalog, but we can still protect with + // pointer equality. + return name_ == other.name_ && dataset_ == other.dataset_; +} + +const std::shared_ptr& Entry::schema() const { return dataset()->schema(); } + +// +// CatalogBuilder +// + +Status CatalogBuilder::Add(Entry entry) { + const auto& name = entry.name(); + if (name.empty()) { + return Status::Invalid("Key in catalog can't be empty"); + } + + if (entry.dataset() == nullptr) { + return Status::Invalid("Dataset entry can't be null."); + } + + auto inserted = datasets_.emplace(name, std::move(entry)); + if (!inserted.second) { + return Status::KeyError("Dataset '", name, "' already in catalog."); + } + + return Status::OK(); +} + +Status CatalogBuilder::Add(std::string name, std::shared_ptr dataset) { + return Add(Entry(std::move(dataset), std::move(name))); +} + +Status CatalogBuilder::Add(std::string name, std::shared_ptr
table) { + if (table == nullptr) { + return Status::Invalid("Table entry can't be null."); + } + + return Add(Entry(std::move(table), std::move(name))); +} + +Result> CatalogBuilder::Finish() { + return std::shared_ptr(new Catalog(std::move(datasets_))); +} + +} // namespace engine +} // namespace arrow diff --git a/cpp/src/arrow/engine/catalog.h b/cpp/src/arrow/engine/catalog.h new file mode 100644 index 00000000000..fc16f9c7c0e --- /dev/null +++ b/cpp/src/arrow/engine/catalog.h @@ -0,0 +1,86 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#pragma once + +#include +#include +#include +#include +#include + +#include "arrow/engine/visibility.h" +#include "arrow/type_fwd.h" +#include "arrow/util/variant.h" + +namespace arrow { + +namespace dataset { +class Dataset; +} + +namespace engine { + +/// Catalog is made of named Table/Dataset to be referenced in LogicalPlans. +class ARROW_EN_EXPORT Catalog { + public: + class Entry; + + static Result> Make(const std::vector& tables); + + Result Get(const std::string& name) const; + Result> GetSchema(const std::string& name) const; + + class ARROW_EN_EXPORT Entry { + public: + Entry(std::shared_ptr dataset, std::string name); + Entry(std::shared_ptr
table, std::string name); + + const std::string& name() const { return name_; } + + const std::shared_ptr& dataset() const; + + const std::shared_ptr& schema() const; + + bool operator==(const Entry& other) const; + + private: + std::shared_ptr dataset_; + std::string name_; + }; + + private: + friend class CatalogBuilder; + explicit Catalog(std::unordered_map datasets); + + std::unordered_map datasets_; +}; + +class ARROW_EN_EXPORT CatalogBuilder { + public: + Status Add(Catalog::Entry entry); + Status Add(std::string name, std::shared_ptr); + Status Add(std::string name, std::shared_ptr
); + + Result> Finish(); + + private: + std::unordered_map datasets_; +}; + +} // namespace engine +} // namespace arrow diff --git a/cpp/src/arrow/engine/catalog_test.cc b/cpp/src/arrow/engine/catalog_test.cc new file mode 100644 index 00000000000..cb65109c722 --- /dev/null +++ b/cpp/src/arrow/engine/catalog_test.cc @@ -0,0 +1,133 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#include +#include + +#include "arrow/testing/gtest_util.h" + +#include "arrow/engine/catalog.h" +#include "arrow/table.h" +#include "arrow/type.h" + +namespace arrow { +namespace engine { + +using Entry = Catalog::Entry; + +class TestCatalog : public testing::Test { + public: + std::shared_ptr schema_ = schema({field("f", int32())}); + std::shared_ptr
table(std::shared_ptr schema) const { + return MockTable(schema); + } + std::shared_ptr
table() const { return table(schema_); } +}; + +void AssertCatalogKeyIs(const std::shared_ptr& catalog, const std::string& key, + const std::shared_ptr
& expected) { + ASSERT_OK_AND_ASSIGN(auto t, catalog->Get(key)); + EXPECT_EQ(t.name(), key); + + ASSERT_OK_AND_ASSIGN(auto schema, catalog->GetSchema(key)); + AssertSchemaEqual(*schema, *expected->schema()); +} + +TEST_F(TestCatalog, EmptyCatalog) { + ASSERT_OK_AND_ASSIGN(auto empty_catalog, Catalog::Make({})); + ASSERT_RAISES(KeyError, empty_catalog->Get("")); + ASSERT_RAISES(KeyError, empty_catalog->Get("a_key")); +} + +TEST_F(TestCatalog, Make) { + auto key_1 = "a"; + auto table_1 = table(schema({field(key_1, int32())})); + auto key_2 = "b"; + auto table_2 = table(schema({field(key_2, int32())})); + auto key_3 = "c"; + auto table_3 = table(schema({field(key_3, int32())})); + + std::vector tables{Entry(table_1, key_1), Entry(table_2, key_2), + Entry(table_3, key_3)}; + + ASSERT_OK_AND_ASSIGN(auto catalog, Catalog::Make(std::move(tables))); + AssertCatalogKeyIs(catalog, key_1, table_1); + AssertCatalogKeyIs(catalog, key_2, table_2); + AssertCatalogKeyIs(catalog, key_3, table_3); +} + +class TestCatalogBuilder : public TestCatalog {}; + +TEST_F(TestCatalogBuilder, EmptyCatalog) { + CatalogBuilder builder; + ASSERT_OK_AND_ASSIGN(auto empty_catalog, builder.Finish()); + ASSERT_RAISES(KeyError, empty_catalog->Get("a_key")); +} + +TEST_F(TestCatalogBuilder, Basic) { + auto key_1 = "a"; + auto table_1 = table(schema({field(key_1, int32())})); + auto key_2 = "b"; + auto table_2 = table(schema({field(key_2, int32())})); + auto key_3 = "c"; + auto table_3 = table(schema({field(key_3, int32())})); + + CatalogBuilder builder; + ASSERT_OK(builder.Add(key_1, table_1)); + ASSERT_OK(builder.Add(key_2, table_2)); + ASSERT_OK(builder.Add(key_3, table_3)); + ASSERT_OK_AND_ASSIGN(auto catalog, builder.Finish()); + + AssertCatalogKeyIs(catalog, key_1, table_1); + AssertCatalogKeyIs(catalog, key_2, table_2); + AssertCatalogKeyIs(catalog, key_3, table_3); + + ASSERT_RAISES(KeyError, catalog->Get("invalid_key")); +} + +TEST_F(TestCatalogBuilder, NullOrEmptyKeys) { + CatalogBuilder builder; + + auto invalid_key = ""; + // Invalid empty key + ASSERT_RAISES(Invalid, builder.Add(invalid_key, table())); + + auto valid_key = "valid_key"; + // Invalid nullptr Table + ASSERT_RAISES(Invalid, builder.Add(valid_key, std::shared_ptr
{})); + // Invalid nullptr Dataset + ASSERT_RAISES(Invalid, builder.Add(valid_key, std::shared_ptr{})); +} + +TEST_F(TestCatalogBuilder, DuplicateKeys) { + CatalogBuilder builder; + + auto key = "a_key"; + + ASSERT_OK(builder.Add(key, table())); + // Key already in catalog + ASSERT_RAISES(KeyError, builder.Add(key, table())); + + // Should still yield a valid catalog if requested. + ASSERT_OK_AND_ASSIGN(auto catalog, builder.Finish()); + + ASSERT_OK_AND_ASSIGN(auto t, catalog->Get(key)); + EXPECT_EQ(t.name(), key); +} + +} // namespace engine +} // namespace arrow diff --git a/cpp/src/arrow/engine/expression.cc b/cpp/src/arrow/engine/expression.cc new file mode 100644 index 00000000000..121f86d84d3 --- /dev/null +++ b/cpp/src/arrow/engine/expression.cc @@ -0,0 +1,455 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#include "arrow/engine/expression.h" + +#include "arrow/compute/kernels/sum_internal.h" +#include "arrow/scalar.h" +#include "arrow/type.h" +#include "arrow/util/checked_cast.h" + +namespace arrow { +namespace engine { + +// +// ExprType +// + +std::string ShapeToString(ExprType::Shape shape) { + switch (shape) { + case ExprType::SCALAR: + return "scalar"; + case ExprType::ARRAY: + return "array"; + case ExprType::TABLE: + return "table"; + } + + return ""; +} + +ExprType ExprType::Scalar(std::shared_ptr type) { + return ExprType(std::move(type), SCALAR); +} + +ExprType ExprType::Array(std::shared_ptr type) { + return ExprType(std::move(type), ARRAY); +} + +ExprType ExprType::Table(std::shared_ptr schema) { + return ExprType(std::move(schema), TABLE); +} + +ExprType ExprType::Table(std::vector> fields) { + return ExprType(arrow::schema(std::move(fields)), TABLE); +} + +ExprType::ExprType(std::shared_ptr schema, Shape shape) + : schema_(std::move(schema)), shape_(shape) { + DCHECK_EQ(shape, TABLE); +} + +ExprType::ExprType(std::shared_ptr type, Shape shape) + : data_type_(std::move(type)), shape_(shape) { + DCHECK_NE(shape, TABLE); +} + +ExprType::ExprType(const ExprType& other) : shape_(other.shape()) { + switch (other.shape()) { + case SCALAR: + case ARRAY: + data_type_ = other.type(); + break; + case TABLE: + schema_ = other.schema(); + } +} + +ExprType::ExprType(ExprType&& other) : shape_(other.shape()) { + switch (other.shape()) { + case SCALAR: + case ARRAY: + data_type_ = std::move(other.type()); + break; + case TABLE: + schema_ = std::move(other.schema()); + } +} + +ExprType::~ExprType() { + switch (shape()) { + case SCALAR: + case ARRAY: + data_type_.reset(); + break; + case TABLE: + schema_.reset(); + } +} + +bool ExprType::Equals(const ExprType& other) const { + if (this == &other) { + return true; + } + + if (shape() != other.shape()) { + return false; + } + + switch (shape()) { + case SCALAR: + return type()->Equals(other.type()); + case ARRAY: + return type()->Equals(other.type()); + case TABLE: + return schema()->Equals(other.schema()); + default: + break; + } + + return false; +} + +Result ExprType::WithType(const std::shared_ptr& data_type) const { + switch (shape()) { + case SCALAR: + return ExprType::Scalar(data_type); + case ARRAY: + return ExprType::Array(data_type); + case TABLE: + return Status::Invalid("Cannot cast a TableType with a DataType"); + } + + return Status::UnknownError("unreachable"); +} + +Result ExprType::WithSchema(const std::shared_ptr& schema) const { + switch (shape()) { + case SCALAR: + return Status::Invalid("Cannot cast a ScalarType with a schema"); + case ARRAY: + return Status::Invalid("Cannot cast an ArrayType with a schema"); + case TABLE: + return ExprType::Table(schema); + } + + return Status::UnknownError("unreachable"); +} + +Result ExprType::Broadcast(const ExprType& lhs, const ExprType& rhs) { + if (lhs.IsTable() || rhs.IsTable()) { + return Status::Invalid("Broadcast operands must not be tables"); + } + + if (!lhs.type()->Equals(rhs.type())) { + return Status::Invalid("Broadcast operands must be of same type"); + } + + if (lhs.IsArray()) { + return lhs; + } + + if (rhs.IsArray()) { + return rhs; + } + + return lhs; +} + +#define ERROR_IF_TYPE(cond, ErrorType, ...) \ + do { \ + if (ARROW_PREDICT_FALSE(cond)) { \ + return Status::ErrorType(__VA_ARGS__); \ + } \ + } while (false) + +#define ERROR_IF(cond, ...) ERROR_IF_TYPE(cond, Invalid, __VA_ARGS__) + +// +// Expr +// + +std::string Expr::kind_name() const { + switch (kind_) { + case ExprKind::SCALAR_LITERAL: + return "scalar"; + case ExprKind::FIELD_REFERENCE: + return "field_ref"; + case ExprKind::COMPARE_OP: + return "compare_op"; + case ExprKind::AGGREGATE_FN_OP: + return "aggregate_fn_op"; + case ExprKind::EMPTY_REL: + return "empty_rel"; + case ExprKind::SCAN_REL: + return "scan_rel"; + case ExprKind::PROJECTION_REL: + return "projection_rel"; + case ExprKind::FILTER_REL: + return "filter_rel"; + } + + return "unknown expr"; +} + +struct ExprEqualityVisitor { + bool operator()(const ScalarExpr& rhs) const { + auto lhs_scalar = internal::checked_cast(lhs); + return lhs_scalar.scalar()->Equals(*rhs.scalar()); + } + + bool operator()(const FieldRefExpr& rhs) const { + auto lhs_field = internal::checked_cast(lhs); + return lhs_field.index() == rhs.index() && + lhs_field.operand()->Equals(*rhs.operand()); + } + + template + enable_if_compare_expr operator()(const E& rhs) const { + auto lhs_cmp = internal::checked_cast(lhs); + return (lhs_cmp.left_operand()->Equals(rhs.left_operand()) && + lhs_cmp.right_operand()->Equals(rhs.right_operand())) || + (lhs_cmp.left_operand()->Equals(rhs.right_operand()) && + lhs_cmp.left_operand()->Equals(rhs.right_operand())); + } + + bool operator()(const EmptyRelExpr& rhs) const { + auto lhs_empty = internal::checked_cast(lhs); + return lhs_empty.schema()->Equals(rhs.schema()); + } + + bool operator()(const ScanRelExpr& rhs) const { + auto lhs_scan = internal::checked_cast(lhs); + // Performs a pointer equality on Table/Dataset + return lhs_scan.input() == rhs.input(); + } + + bool operator()(const ProjectionRelExpr& rhs) const { + auto lhs_proj = internal::checked_cast(lhs); + + const auto& lhs_exprs = lhs_proj.expressions(); + const auto& rhs_exprs = rhs.expressions(); + if (lhs_exprs.size() != rhs_exprs.size()) { + return false; + } + + for (size_t i = 0; i < lhs_exprs.size(); i++) { + if (!lhs_exprs[i]->Equals(rhs_exprs[i])) { + return false; + } + } + + return true; + } + + bool operator()(const Expr&) const { return false; } + + static bool Visit(const Expr& lhs, const Expr& rhs) { + return VisitExpr(rhs, ExprEqualityVisitor{lhs}); + } + + const Expr& lhs; +}; + +bool Expr::Equals(const Expr& other) const { + if (this == &other) { + return true; + } + + if (kind() != other.kind() || type() != other.type()) { + return false; + } + + return ExprEqualityVisitor::Visit(*this, other); +} + +std::string Expr::ToString() const { return ""; } + +// +// ScalarExpr +// + +ScalarExpr::ScalarExpr(std::shared_ptr scalar) + : Expr(SCALAR_LITERAL, ExprType::Scalar(scalar->type)), scalar_(std::move(scalar)) {} + +Result> ScalarExpr::Make(std::shared_ptr scalar) { + ERROR_IF(scalar == nullptr, "ScalarExpr's scalar must be non-null"); + return std::shared_ptr(new ScalarExpr(std::move(scalar))); +} + +// +// FieldRefExpr +// + +FieldRefExpr::FieldRefExpr(std::shared_ptr input, int index) + : UnaryOpMixin(std::move(input)), + Expr(FIELD_REFERENCE, + ExprType::Array(operand()->type().schema()->field(index)->type())), + index_(index) {} + +Result> FieldRefExpr::Make(std::shared_ptr input, + int index) { + ERROR_IF(input == nullptr, "FieldRefExpr's input must be non-null"); + + auto expr_type = input->type(); + ERROR_IF(!expr_type.IsTable(), "FieldRefExpr's input must have a table shape, got '", + ShapeToString(expr_type.shape()), "'"); + + auto schema = expr_type.schema(); + ERROR_IF_TYPE(index < 0 || index >= schema->num_fields(), KeyError, + "FieldRefExpr's index is out of bound, '", index, "' not in range [0, ", + schema->num_fields(), ")"); + + return std::shared_ptr(new FieldRefExpr(std::move(input), index)); +} + +Result> FieldRefExpr::Make(std::shared_ptr input, + std::string field_name) { + ERROR_IF(input == nullptr, "FieldRefExpr's input must be non-null"); + + auto expr_type = input->type(); + ERROR_IF(!expr_type.IsTable(), "FieldRefExpr's input must have a table shape, got '", + ShapeToString(expr_type.shape()), "'"); + + auto schema = expr_type.schema(); + auto field = schema->GetFieldByName(field_name); + ERROR_IF_TYPE(field == nullptr, KeyError, + "FieldRefExpr's can't reference with field name '", field_name, "'"); + + auto index = schema->GetFieldIndex(field_name); + ERROR_IF(index == -1, "FieldRefExpr's index by name is invalid."); + + return std::shared_ptr(new FieldRefExpr(std::move(input), index)); +} + +// +// CountExpr +// + +CountExpr::CountExpr(std::shared_ptr input) + : UnaryOpMixin(std::move(input)), + AggregateFnExpr(ExprType::Scalar(int64()), AggregateFnKind::COUNT) {} + +Result> CountExpr::Make(std::shared_ptr input) { + ERROR_IF(input == nullptr, "CountExpr's input must be non-null"); + return std::shared_ptr(new CountExpr(std::move(input))); +} + +// +// SumExpr +// + +SumExpr::SumExpr(std::shared_ptr input) + : UnaryOpMixin(std::move(input)), + AggregateFnExpr( + ExprType::Scalar(arrow::compute::GetAccumulatorType(operand()->type().type())), + AggregateFnKind::SUM) {} + +Result> SumExpr::Make(std::shared_ptr input) { + ERROR_IF(input == nullptr, "SumExpr's input must be non-null"); + + auto expr_type = input->type(); + ERROR_IF(!expr_type.HasType(), "SumExpr's input must be a Scalar or an Array"); + + auto type = expr_type.type(); + ERROR_IF(!is_numeric(type->id()), "SumExpr's require an input with numeric type"); + + return std::shared_ptr(new SumExpr(std::move(input))); +} + +// +// EmptyRelExpr +// + +EmptyRelExpr::EmptyRelExpr(std::shared_ptr schema) + : RelExpr(ExprKind::EMPTY_REL, std::move(schema)) {} + +Result> EmptyRelExpr::Make(std::shared_ptr schema) { + ERROR_IF(schema == nullptr, "EmptyRelExpr schema must be non-null"); + return std::shared_ptr(new EmptyRelExpr(std::move(schema))); +} + +// +// ScanRelExpr +// + +ScanRelExpr::ScanRelExpr(Catalog::Entry input) + : RelExpr(ExprKind::SCAN_REL, input.schema()), input_(std::move(input)) {} + +Result> ScanRelExpr::Make(Catalog::Entry input) { + return std::shared_ptr(new ScanRelExpr(std::move(input))); +} + +// +// ProjectionRelExpr +// + +ProjectionRelExpr::ProjectionRelExpr(std::shared_ptr input, + std::shared_ptr schema, + std::vector> expressions) + : UnaryOpMixin(std::move(input)), + RelExpr(ExprKind::PROJECTION_REL, std::move(schema)), + expressions_(std::move(expressions)) {} + +Result> ProjectionRelExpr::Make( + std::shared_ptr input, std::vector> expressions) { + ERROR_IF(input == nullptr, "ProjectionRelExpr's input must be non-null."); + ERROR_IF(expressions.empty(), "Must project at least one expression."); + + auto n_fields = expressions.size(); + std::vector> fields; + + for (size_t i = 0; i < n_fields; i++) { + const auto& expr = expressions[i]; + const auto& expr_type = expr->type(); + ERROR_IF(!expr_type.HasType(), "Expression at position ", i, + " should not be have a table shape"); + // TODO(fsaintjacques): better name handling. Callers should be able to + // pass a vector of names. + fields.push_back(field("expr", expr_type.type())); + } + + return std::shared_ptr(new ProjectionRelExpr( + std::move(input), arrow::schema(std::move(fields)), std::move(expressions))); +} + +// +// FilterRelExpr +// + +Result> FilterRelExpr::Make( + std::shared_ptr input, std::shared_ptr predicate) { + ERROR_IF(input == nullptr, "FilterRelExpr's input must be non-null."); + ERROR_IF(!input->type().IsTable(), "FilterRelExpr's input must be a table."); + ERROR_IF(predicate == nullptr, "FilterRelExpr's predicate must be non-null."); + ERROR_IF(!predicate->type().IsPredicate(), + "FilterRelExpr's predicate must be a predicate"); + + return std::shared_ptr( + new FilterRelExpr(std::move(input), std::move(predicate))); +} + +FilterRelExpr::FilterRelExpr(std::shared_ptr input, std::shared_ptr predicate) + : UnaryOpMixin(std::move(input)), + RelExpr(ExprKind::FILTER_REL, operand()->type().schema()), + predicate_(std::move(predicate)) {} + +#undef ERROR_IF +#undef ERROR_IF_TYPE + +} // namespace engine +} // namespace arrow diff --git a/cpp/src/arrow/engine/expression.h b/cpp/src/arrow/engine/expression.h new file mode 100644 index 00000000000..a007c70d98f --- /dev/null +++ b/cpp/src/arrow/engine/expression.h @@ -0,0 +1,575 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#pragma once + +#include +#include +#include +#include + +#include "arrow/engine/catalog.h" +#include "arrow/engine/type_fwd.h" +#include "arrow/engine/type_traits.h" +#include "arrow/engine/visibility.h" +#include "arrow/result.h" +#include "arrow/type.h" +#include "arrow/type_fwd.h" +#include "arrow/util/compare.h" +#include "arrow/util/macros.h" + +namespace arrow { +namespace engine { + +/// ExprType is a class representing the type of an Expression. The type is +/// composed of a shape and a DataType or a Schema depending on the shape. +/// +/// ExprType is mainly used to validate arguments for operator expressions, e.g. +/// relational operator expressions expect inputs of Table shape. +/// +/// The sum-type representation would be: +/// +/// enum ExprType { +/// ScalarType(DataType), +/// ArrayType(DataType), +/// TableType(Schema), +/// } +class ARROW_EN_EXPORT ExprType : public util::EqualityComparable { + public: + enum Shape : uint8_t { + // The expression yields a Scalar, e.g. "1". + SCALAR, + // The expression yields an Array, e.g. "[1, 2, 3]". + ARRAY, + // The expression yields a Table, e.g. "{'a': [1, 2], 'b': [true, false]}" + TABLE, + }; + + /// Construct a Scalar type. + static ExprType Scalar(std::shared_ptr type); + /// Construct an Array type. + static ExprType Array(std::shared_ptr type); + /// Construct a Table type. + static ExprType Table(std::shared_ptr schema); + static ExprType Table(std::vector> fields); + + /// \brief Shape of the expression. + Shape shape() const { return shape_; } + + /// \brief DataType of the expression if a scalar or an array. + /// WARNING: You must ensure the proper shape before calling this accessor. + const std::shared_ptr& type() const { return data_type_; } + /// \brief Schema of the expression if of table shape. + /// WARNING: You must ensure the proper shape before calling this accessor. + const std::shared_ptr& schema() const { return schema_; } + + /// \brief Indicate if the type is a Scalar. + bool IsScalar() const { return shape_ == SCALAR; } + /// \brief Indicate if the type is an Array. + bool IsArray() const { return shape_ == ARRAY; } + /// \brief Indicate if the type is a Table. + bool IsTable() const { return shape_ == TABLE; } + + bool HasType() const { return IsScalar() || IsArray(); } + bool HasSchema() const { return IsTable(); } + + bool IsTypedLike(Type::type type_id) const { + return HasType() && data_type_->id() == type_id; + } + + /// \brief Static version of IsTypedLike + template + bool IsTypedLike() const { + return HasType() && data_type_->id() == TYPE_ID; + } + + /// \brief Indicate if the type is a predicate, i.e. a boolean scalar. + bool IsPredicate() const { return IsTypedLike(); } + + /// \brief Cast the inner DataType/Schema while preserving the shape. + Result WithType(const std::shared_ptr& data_type) const; + Result WithSchema(const std::shared_ptr& schema) const; + + /// \brief Expand the smallest shape to the bigger one if possible. + /// + /// \param[in] lhs first type to broadcast + /// \param[in] rhs second type to broadcast + /// \return broadcasted type or an error why it can't be broadcasted. + /// + /// Broadcasting promotes the shape of the smallest type to the bigger one if + /// they share the same DataType. In functional pattern matching it would look + /// like: + /// + /// ``` + /// Broadcast(rhs, lhs) = match(lhs, rhs) { + /// case: ScalarType(t1), ScalarType(t2) if t1 == t2 => ScalarType(t) + /// case: ScalarType(t1), ArrayType(t2) if t1 == t2 => ArrayType(t) + /// case: ArrayType(t1), ScalarType(t2) if t1 == t2 => ArrayType(t) + /// case: ArrayType(t1), ArrayType(t2) if t1 == t2 => ArrayType(t) + /// case: _ => Error("Types not compatible for broadcasting") + /// } + /// ``` + static Result Broadcast(const ExprType& lhs, const ExprType& rhs); + + bool Equals(const ExprType& type) const; + + std::string ToString() const; + + ExprType(const ExprType& copy); + ExprType(ExprType&& copy); + ~ExprType(); + + private: + /// Table constructor + ExprType(std::shared_ptr schema, Shape shape); + /// Scalar or Array constructor + ExprType(std::shared_ptr type, Shape shape); + + union { + /// Zero initialize the pointer or Copy/Assign constructors will fail. + std::shared_ptr data_type_{}; + std::shared_ptr schema_; + }; + Shape shape_; +}; + +/// Represents an expression tree +class ARROW_EN_EXPORT Expr : public util::EqualityComparable { + public: + /// \brief Return the kind of the expression. + ExprKind kind() const { return kind_; } + /// \brief Return a string representation of the kind. + std::string kind_name() const; + + /// \brief Return the type and shape of the resulting expression. + const ExprType& type() const { return type_; } + + /// \brief Indicate if the expressions are equal. + bool Equals(const Expr& other) const; + using util::EqualityComparable::Equals; + + /// \brief Return a string representing the expression + std::string ToString() const; + + virtual ~Expr() = default; + + protected: + explicit Expr(ExprKind kind, ExprType type) : type_(std::move(type)), kind_(kind) {} + + ExprType type_; + ExprKind kind_; +}; + +/// +/// Operator expressions mixin. +/// + +class ARROW_EN_EXPORT UnaryOpMixin { + public: + const std::shared_ptr& operand() const { return operand_; } + + protected: + explicit UnaryOpMixin(std::shared_ptr operand) : operand_(std::move(operand)) {} + + std::shared_ptr operand_; +}; + +class ARROW_EN_EXPORT BinaryOpMixin { + public: + const std::shared_ptr& left_operand() const { return left_operand_; } + const std::shared_ptr& right_operand() const { return right_operand_; } + + protected: + BinaryOpMixin(std::shared_ptr left, std::shared_ptr right) + : left_operand_(std::move(left)), right_operand_(std::move(right)) {} + + std::shared_ptr left_operand_; + std::shared_ptr right_operand_; +}; + +class ARROW_EN_EXPORT MultiAryOpMixin { + public: + const std::vector>& operands() const { return operands_; } + + protected: + explicit MultiAryOpMixin(std::vector> operands) + : operands_(std::move(operands)) {} + + std::vector> operands_; +}; + +/// +/// Value Expressions +/// + +/// An unnamed scalar literal expression. +class ARROW_EN_EXPORT ScalarExpr : public Expr { + public: + static Result> Make(std::shared_ptr scalar); + + const std::shared_ptr& scalar() const { return scalar_; } + + private: + explicit ScalarExpr(std::shared_ptr scalar); + + std::shared_ptr scalar_; +}; + +/// References a column in a table/dataset +class ARROW_EN_EXPORT FieldRefExpr : public UnaryOpMixin, public Expr { + public: + static Result> Make(std::shared_ptr input, + int index); + static Result> Make(std::shared_ptr input, + std::string field_name); + + int index() const { return index_; } + + private: + FieldRefExpr(std::shared_ptr input, int index); + + int index_; +}; + +/// +/// Comparison expressions +/// + +class ARROW_EN_EXPORT CompareOpExpr : public BinaryOpMixin, public Expr { + public: + CompareKind compare_kind() const { return compare_kind_; } + + /// This inner-class is required because `using` statements can't use derived + /// methods. + template + struct MakeMixin { + static Result> Make(std::shared_ptr left, + std::shared_ptr right) { + if (left == NULLPTR || right == NULLPTR) { + return Status::Invalid("Compare operands must be non-nulls"); + } + + // Broadcast the comparison to the biggest shape. + ARROW_ASSIGN_OR_RAISE(auto broadcast, + ExprType::Broadcast(left->type(), right->type())); + // And change this shape's type to boolean. + ARROW_ASSIGN_OR_RAISE(auto type, broadcast.WithType(boolean())); + + return std::shared_ptr(new Derived(std::move(type), + expr_traits::compare_kind_id, + std::move(left), std::move(right))); + } + }; + + protected: + CompareOpExpr(ExprType type, CompareKind op, std::shared_ptr left, + std::shared_ptr right) + : BinaryOpMixin(std::move(left), std::move(right)), + Expr(COMPARE_OP, std::move(type)), + compare_kind_(op) {} + + CompareKind compare_kind_; +}; + +template +class BaseCompareExpr : public CompareOpExpr, + protected CompareOpExpr::MakeMixin { + public: + using CompareOpExpr::MakeMixin::Make; + + protected: + using CompareOpExpr::CompareOpExpr; +}; + +class ARROW_EN_EXPORT EqualExpr : public BaseCompareExpr { + protected: + using BaseCompareExpr::BaseCompareExpr; + friend CompareOpExpr::MakeMixin; +}; + +class ARROW_EN_EXPORT NotEqualExpr : public BaseCompareExpr { + protected: + using BaseCompareExpr::BaseCompareExpr; + friend CompareOpExpr::MakeMixin; +}; + +class ARROW_EN_EXPORT GreaterThanExpr : public BaseCompareExpr { + protected: + using BaseCompareExpr::BaseCompareExpr; + friend CompareOpExpr::MakeMixin; +}; + +class ARROW_EN_EXPORT GreaterThanEqualExpr + : public BaseCompareExpr { + protected: + using BaseCompareExpr::BaseCompareExpr; + friend CompareOpExpr::MakeMixin; +}; + +class ARROW_EN_EXPORT LessThanExpr : public BaseCompareExpr { + protected: + using BaseCompareExpr::BaseCompareExpr; + friend CompareOpExpr::MakeMixin; +}; + +class ARROW_EN_EXPORT LessThanEqualExpr : public BaseCompareExpr { + protected: + using BaseCompareExpr::BaseCompareExpr; + friend CompareOpExpr::MakeMixin; +}; + +/// +/// Aggregate Functions +/// + +/// \brief Aggregate function operators collapse arrays and scalars to scalar. +class ARROW_EN_EXPORT AggregateFnExpr : public Expr { + public: + AggregateFnKind aggregate_kind() const { return aggregate_kind_; } + + protected: + AggregateFnExpr(ExprType type, AggregateFnKind kind) + : Expr(AGGREGATE_FN_OP, std::move(type)), aggregate_kind_(kind) {} + + AggregateFnKind aggregate_kind_; +}; + +/// \brief Count the number of values in the input expression. +class ARROW_EN_EXPORT CountExpr : public UnaryOpMixin, public AggregateFnExpr { + public: + static Result> Make(std::shared_ptr input); + + protected: + explicit CountExpr(std::shared_ptr input); +}; + +/// \brief Sum the input values. +class ARROW_EN_EXPORT SumExpr : public UnaryOpMixin, public AggregateFnExpr { + public: + static Result> Make(std::shared_ptr input); + + protected: + explicit SumExpr(std::shared_ptr input); +}; + +/// +/// Relational Expressions +/// + +/// \brief Relational Expressions that acts on tables. +class ARROW_EN_EXPORT RelExpr : public Expr { + public: + const std::shared_ptr& schema() const { return schema_; } + + protected: + explicit RelExpr(ExprKind kind, std::shared_ptr schema) + : Expr(kind, ExprType::Table(schema)), schema_(std::move(schema)) {} + + std::shared_ptr schema_; +}; + +/// \brief An empty relation that returns/contains no rows. +/// +/// An EmptyRelExpr is usually not found in user constructed logical plan but +/// can appear due to optimization passes, e.g. replacing a FilterRelExpr with +/// an always false predicate. It is also subsequently used in constant +/// propagation-like optimizations, e.g Filter(EmptyRel) => EmptyRel, or +/// InnerJoin(_, EmptyRel) => EmptyRel. +/// +/// \input schema, the schema of the empty relation +/// \ouput relation with no rows of the given input schema +class ARROW_EN_EXPORT EmptyRelExpr : public RelExpr { + public: + static Result> Make(std::shared_ptr schema); + + protected: + explicit EmptyRelExpr(std::shared_ptr schema); +}; + +/// \brief Materialize a relation from a dataset. +/// +/// The ScanRelExpr are found in the leaves of the Expr tree. A Scan materialize +/// the relation from a datasets. In essence, it is a relational operator that +/// has no relation input (except some auxiliary information like a catalog +/// entry), and output a relation. +/// +/// \input table, a catalog entry pointing to a dataset +/// \ouput relation from the materialized dataset +/// +/// ``` +/// SELECT * FROM table; +/// ``` +class ARROW_EN_EXPORT ScanRelExpr : public RelExpr { + public: + static Result> Make(Catalog::Entry input); + + const Catalog::Entry& input() const { return input_; } + + private: + explicit ScanRelExpr(Catalog::Entry input); + + Catalog::Entry input_; +}; + +/// \brief Project columns based on expressions. +/// +/// A projection creates a relation with new columns based on expressions of +/// the input's columns. It could be a simple permutation or selection of +/// column via FieldRefExpr or more complex expressions like the sum of two +/// columns. The projection operator will usually change the output schema of +/// the input relation due to the expressions without changing the number of +/// rows. +/// +/// \input relation, the input relation to compute the expressions from +/// \input expressions, the expressions to compute +/// \output relation where the columns are the expressions computed +/// +/// ``` +/// SELECT a, b, a + b, 1, mean(a) > b FROM relation; +/// ``` +class ARROW_EN_EXPORT ProjectionRelExpr : public UnaryOpMixin, public RelExpr { + public: + static Result> Make( + std::shared_ptr input, std::vector> expressions); + + const std::vector> expressions() const { return expressions_; } + + private: + ProjectionRelExpr(std::shared_ptr input, std::shared_ptr schema, + std::vector> expressions); + + std::vector> expressions_; +}; + +/// \brief Filter the rows of a relation according to a predicate. +/// +/// A filter removes rows that don't match a predicate or a mask column. +/// +/// \input relation, the input relation to filter the rows from +/// \input predicate, a predicate to evaluate for each filter +/// \output relation where the rows are filtered according to the predicate +/// +/// ``` +/// SELECT * FROM relation WHERE predicate +/// ``` +class ARROW_EN_EXPORT FilterRelExpr : public UnaryOpMixin, public RelExpr { + public: + static Result> Make(std::shared_ptr input, + std::shared_ptr predicate); + + const std::shared_ptr& predicate() const { return predicate_; } + + private: + FilterRelExpr(std::shared_ptr input, std::shared_ptr predicate); + + std::shared_ptr predicate_; +}; + +template +auto VisitExpr(const Expr& expr, Visitor&& visitor) -> decltype(visitor(expr)) { + switch (expr.kind()) { + case ExprKind::SCALAR_LITERAL: + return visitor(internal::checked_cast(expr)); + case ExprKind::FIELD_REFERENCE: + return visitor(internal::checked_cast(expr)); + + case ExprKind::COMPARE_OP: { + const auto& cmp_expr = static_cast(expr); + switch (cmp_expr.compare_kind()) { + case CompareKind::EQUAL: + return visitor(internal::checked_cast(expr)); + case CompareKind::NOT_EQUAL: + return visitor(internal::checked_cast(expr)); + case CompareKind::GREATER_THAN: + return visitor(internal::checked_cast(expr)); + case CompareKind::GREATER_THAN_EQUAL: + return visitor(internal::checked_cast(expr)); + case CompareKind::LESS_THAN: + return visitor(internal::checked_cast(expr)); + case CompareKind::LESS_THAN_EQUAL: + return visitor(internal::checked_cast(expr)); + } + + ARROW_UNREACHABLE; + } + + case ExprKind::AGGREGATE_FN_OP: { + const auto& agg_expr = static_cast(expr); + switch (agg_expr.aggregate_kind()) { + case AggregateFnKind::COUNT: + return visitor(internal::checked_cast(expr)); + case AggregateFnKind::SUM: + return visitor(internal::checked_cast(expr)); + } + + ARROW_UNREACHABLE; + } + + case ExprKind::EMPTY_REL: + return visitor(internal::checked_cast(expr)); + case ExprKind::SCAN_REL: + return visitor(internal::checked_cast(expr)); + case ExprKind::PROJECTION_REL: + return visitor(internal::checked_cast(expr)); + case ExprKind::FILTER_REL: + return visitor(internal::checked_cast(expr)); + } + + ARROW_UNREACHABLE; +} + +/// +/// RTTI utilities +/// + +/// \defgroup isa-expr Family of functions to introspect if an expression of a +/// given expression class. +/// @{ + +template +enable_if_simple_expr IsA(const Expr& expr) { + return expr.kind() == expr_traits::kind_id; +} + +template +enable_if_compare_expr IsA(const Expr& expr) { + if (expr.kind() != ExprKind::COMPARE_OP) { + return false; + } + const auto& cmp = internal::checked_cast(expr); + return cmp.compare_kind() == expr_traits::compare_kind_id; +} + +template +enable_if_aggregate_fn_expr IsA(const Expr& expr) { + if (expr.kind() != ExprKind::AGGREGATE_FN_OP) { + return false; + } + const auto& agg = internal::checked_cast(expr); + return agg.aggregate_kind() == expr_traits::aggregate_kind_id; +} + +template +bool IsA(const std::shared_ptr& expr) { + if (!expr) return false; + return IsA(*expr); +} + +/// @} + +} // namespace engine +} // namespace arrow diff --git a/cpp/src/arrow/engine/expression_test.cc b/cpp/src/arrow/engine/expression_test.cc new file mode 100644 index 00000000000..037327e12f6 --- /dev/null +++ b/cpp/src/arrow/engine/expression_test.cc @@ -0,0 +1,322 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#include "arrow/engine/expression.h" +#include "arrow/scalar.h" +#include "arrow/testing/gmock.h" +#include "arrow/testing/gtest_common.h" +#include "arrow/type.h" + +using testing::HasSubstr; +using testing::Not; +using testing::Pointee; + +namespace arrow { +namespace engine { + +class ExprTypeTest : public testing::Test {}; + +TEST_F(ExprTypeTest, Basic) { + auto i32 = int32(); + auto s = schema({field("i32", i32)}); + + auto scalar = ExprType::Scalar(i32); + EXPECT_EQ(scalar.shape(), ExprType::Shape::SCALAR); + EXPECT_TRUE(scalar.type()->Equals(i32)); + EXPECT_TRUE(scalar.IsScalar()); + EXPECT_FALSE(scalar.IsArray()); + EXPECT_FALSE(scalar.IsTable()); + + auto array = ExprType::Array(i32); + EXPECT_EQ(array.shape(), ExprType::Shape::ARRAY); + EXPECT_TRUE(array.type()->Equals(i32)); + EXPECT_FALSE(array.IsScalar()); + EXPECT_TRUE(array.IsArray()); + EXPECT_FALSE(array.IsTable()); + + auto table = ExprType::Table(s); + EXPECT_EQ(table.shape(), ExprType::Shape::TABLE); + EXPECT_TRUE(table.schema()->Equals(s)); + EXPECT_FALSE(table.IsScalar()); + EXPECT_FALSE(table.IsArray()); + EXPECT_TRUE(table.IsTable()); +} + +TEST_F(ExprTypeTest, IsPredicate) { + auto bool_scalar = ExprType::Scalar(boolean()); + EXPECT_TRUE(bool_scalar.IsPredicate()); + + auto bool_array = ExprType::Array(boolean()); + EXPECT_TRUE(bool_array.IsPredicate()); + + auto bool_table = ExprType::Table(schema({field("b", boolean())})); + EXPECT_FALSE(bool_table.IsPredicate()); + + auto i32_scalar = ExprType::Scalar(int32()); + EXPECT_FALSE(i32_scalar.IsPredicate()); +} + +TEST_F(ExprTypeTest, Broadcast) { + auto bool_scalar = ExprType::Scalar(boolean()); + auto bool_array = ExprType::Array(boolean()); + auto bool_table = ExprType::Table(schema({field("b", boolean())})); + auto i32_scalar = ExprType::Scalar(int32()); + + EXPECT_RAISES_WITH_MESSAGE_THAT(Invalid, HasSubstr("operands must be of same type"), + ExprType::Broadcast(bool_scalar, i32_scalar)); + EXPECT_RAISES_WITH_MESSAGE_THAT(Invalid, HasSubstr("operands must not be tables"), + ExprType::Broadcast(bool_scalar, bool_table)); + + EXPECT_THAT(ExprType::Broadcast(bool_scalar, bool_scalar), OkAndEq(bool_scalar)); + EXPECT_THAT(ExprType::Broadcast(bool_scalar, bool_array), OkAndEq(bool_array)); + EXPECT_THAT(ExprType::Broadcast(bool_array, bool_scalar), OkAndEq(bool_array)); + EXPECT_THAT(ExprType::Broadcast(bool_array, bool_array), OkAndEq(bool_array)); +} + +TEST_F(ExprTypeTest, WithTypeOrSchema) { + auto bool_scalar = ExprType::Scalar(boolean()); + auto bool_array = ExprType::Array(boolean()); + auto bool_table = ExprType::Table(schema({field("b", boolean())})); + + auto i32 = int32(); + auto other = schema({field("a", i32)}); + + EXPECT_RAISES_WITH_MESSAGE_THAT(Invalid, HasSubstr("Cannot cast a ScalarType with"), + bool_scalar.WithSchema(other)); + EXPECT_RAISES_WITH_MESSAGE_THAT(Invalid, HasSubstr("Cannot cast an ArrayType with"), + bool_array.WithSchema(other)); + EXPECT_RAISES_WITH_MESSAGE_THAT(Invalid, HasSubstr("Cannot cast a TableType with"), + bool_table.WithType(i32)); + + EXPECT_EQ(bool_scalar.WithType(i32), ExprType::Scalar(i32)); + EXPECT_EQ(bool_array.WithType(i32), ExprType::Array(i32)); + EXPECT_EQ(bool_table.WithSchema(other), ExprType::Table(other)); +} + +class ExprTest : public testing::Test {}; + +TEST_F(ExprTest, ScalarExpr) { + ASSERT_RAISES(Invalid, ScalarExpr::Make(nullptr)); + + auto i32 = int32(); + ASSERT_OK_AND_ASSIGN(auto value, MakeScalar(i32, 10)); + ASSERT_OK_AND_ASSIGN(auto expr, ScalarExpr::Make(value)); + EXPECT_EQ(expr->kind(), ExprKind::SCALAR_LITERAL); + EXPECT_EQ(expr->type(), ExprType::Scalar(i32)); + EXPECT_EQ(*expr->scalar(), *value); +} + +TEST_F(ExprTest, FieldRefExpr) { + auto i32 = int32(); + auto f_i32 = field("i32", i32); + auto schema = arrow::schema({f_i32}); + ASSERT_OK_AND_ASSIGN(auto input, EmptyRelExpr::Make(schema)); + + ASSERT_RAISES(Invalid, FieldRefExpr::Make(nullptr, 0)); + ASSERT_RAISES(KeyError, FieldRefExpr::Make(input, -1)); + ASSERT_RAISES(KeyError, FieldRefExpr::Make(input, 1)); + ASSERT_RAISES(KeyError, FieldRefExpr::Make(input, "not_present")); + + ASSERT_OK_AND_ASSIGN(auto expr, FieldRefExpr::Make(input, 0)); + EXPECT_EQ(expr->kind(), ExprKind::FIELD_REFERENCE); + EXPECT_EQ(expr->type(), ExprType::Array(i32)); + EXPECT_THAT(expr->index(), 0); + + ASSERT_OK_AND_ASSIGN(expr, FieldRefExpr::Make(input, "i32")); + EXPECT_EQ(expr->kind(), ExprKind::FIELD_REFERENCE); + EXPECT_EQ(expr->type(), ExprType::Array(i32)); + EXPECT_THAT(expr->index(), 0); +} + +template +class CompareExprTest : public ExprTest { + public: + ExprKind kind() { return expr_traits::kind_id; } + CompareKind compare_kind() { return expr_traits::compare_kind_id; } + + Result> Make(std::shared_ptr left, + std::shared_ptr right) { + return CmpClass::Make(std::move(left), std::move(right)); + } +}; + +using CompareExprs = + ::testing::Types; + +TYPED_TEST_CASE(CompareExprTest, CompareExprs); +TYPED_TEST(CompareExprTest, BasicCompareExpr) { + auto i32 = int32(); + auto f_i32 = field("i32", i32); + auto schema = arrow::schema({f_i32}); + ASSERT_OK_AND_ASSIGN(auto input, EmptyRelExpr::Make(schema)); + + ASSERT_OK_AND_ASSIGN(auto f_expr, FieldRefExpr::Make(input, "i32")); + ASSERT_OK_AND_ASSIGN(auto s_i32, MakeScalar(i32, 42)); + ASSERT_OK_AND_ASSIGN(auto s_expr, ScalarExpr::Make(s_i32)); + + // Required fields + ASSERT_RAISES(Invalid, this->Make(nullptr, nullptr)); + ASSERT_RAISES(Invalid, this->Make(s_expr, nullptr)); + ASSERT_RAISES(Invalid, this->Make(nullptr, f_expr)); + + // Not type compatible + ASSERT_OK_AND_ASSIGN(auto s_i64, MakeScalar(int64(), static_cast(42))); + ASSERT_OK_AND_ASSIGN(auto s_expr_i64, ScalarExpr::Make(s_i64)); + EXPECT_RAISES_WITH_MESSAGE_THAT(Invalid, HasSubstr("operands must be of same type"), + this->Make(s_expr_i64, f_expr)); + + ASSERT_OK_AND_ASSIGN(auto expr, this->Make(f_expr, s_expr)); + EXPECT_EQ(expr->kind(), this->kind()); + EXPECT_EQ(expr->compare_kind(), this->compare_kind()); + // Ensure type is broadcasted + EXPECT_EQ(expr->type(), ExprType::Array(boolean())); + EXPECT_TRUE(expr->type().IsPredicate()); + EXPECT_THAT(expr, PtrEquals(expr)); + EXPECT_THAT(expr->left_operand(), PtrEquals(f_expr)); + EXPECT_THAT(expr->right_operand(), PtrEquals(s_expr)); + + ASSERT_OK_AND_ASSIGN(auto other, this->Make(f_expr, s_expr)); + EXPECT_THAT(expr, PtrEquals(other)); + // Compare operators supports commutativity + // TODO(fsaintjacques): what about floating point types? + ASSERT_OK_AND_ASSIGN(auto swapped, this->Make(s_expr, f_expr)); + EXPECT_THAT(expr, PtrEquals(swapped)); +} + +TEST_F(ExprTest, CountExpr) { + ASSERT_RAISES(Invalid, CountExpr::Make(nullptr)); + + // Counting scalar is permitted. + ASSERT_OK_AND_ASSIGN(auto i32_lit, ScalarExpr::Make(MakeScalar(42))); + EXPECT_THAT(CountExpr::Make(i32_lit), Ok()); + + // Counting a string scalar is permitted + ASSERT_OK_AND_ASSIGN(auto str_lit, ScalarExpr::Make(MakeScalar("hi"))); + EXPECT_THAT(CountExpr::Make(str_lit), Ok()); + + auto schema = arrow::schema({field("i32", int32()), field("str", utf8())}); + ASSERT_OK_AND_ASSIGN(auto input, EmptyRelExpr::Make(schema)); + + // Counting an int column should be supported. + ASSERT_OK_AND_ASSIGN(auto i32_column, FieldRefExpr::Make(input, 0)); + EXPECT_THAT(CountExpr::Make(i32_column), Ok()); + + // Counting a string column should be supported. + ASSERT_OK_AND_ASSIGN(auto str_column, FieldRefExpr::Make(input, 1)); + EXPECT_THAT(CountExpr::Make(str_column), Ok()); + + // Counting a table should be supported + EXPECT_THAT(CountExpr::Make(input), Ok()); +} + +TEST_F(ExprTest, SumExpr) { + ASSERT_RAISES(Invalid, SumExpr::Make(nullptr)); + + // Summing a scalar is permitted. + ASSERT_OK_AND_ASSIGN(auto i32_lit, ScalarExpr::Make(MakeScalar(42))); + EXPECT_THAT(SumExpr::Make(i32_lit), Ok()); + + // Summing a string is not permitted. + ASSERT_OK_AND_ASSIGN(auto str_lit, ScalarExpr::Make(MakeScalar("hi"))); + ASSERT_RAISES(Invalid, SumExpr::Make(str_lit)); + + auto schema = arrow::schema( + {field("i32", int32()), field("str", utf8()), field("list_i32", list(int32()))}); + ASSERT_OK_AND_ASSIGN(auto input, EmptyRelExpr::Make(schema)); + + // Summing an integer column should be supported. + ASSERT_OK_AND_ASSIGN(auto i32_column, FieldRefExpr::Make(input, 0)); + EXPECT_THAT(SumExpr::Make(i32_column), Ok()); + + // Summing a string column should not be supported. + ASSERT_OK_AND_ASSIGN(auto str_column, FieldRefExpr::Make(input, 1)); + ASSERT_RAISES(Invalid, SumExpr::Make(str_column)); + + // Summing a list column should not be supported (yet). + ASSERT_OK_AND_ASSIGN(auto list_i32_column, FieldRefExpr::Make(input, 2)); + ASSERT_RAISES(Invalid, SumExpr::Make(list_i32_column)); + + // Summing a table should not be supported + ASSERT_RAISES(Invalid, SumExpr::Make(input)); +} + +class RelExprTest : public ExprTest { + protected: + void SetUp() override { + CatalogBuilder builder; + ASSERT_OK(builder.Add(table_1, MockTable(schema_1))); + ASSERT_OK_AND_ASSIGN(catalog, builder.Finish()); + } + + std::string table_1 = "table_1"; + std::shared_ptr schema_1 = schema({field("i32", int32())}); + + std::shared_ptr catalog; +}; + +TEST_F(RelExprTest, EmptyRelExpr) { + ASSERT_RAISES(Invalid, EmptyRelExpr::Make(nullptr)); + + ASSERT_OK_AND_ASSIGN(auto empty, EmptyRelExpr::Make(schema_1)); + EXPECT_THAT(empty->type(), ExprType::Table(schema_1)); + EXPECT_THAT(empty->schema(), PtrEquals(schema_1)); + EXPECT_THAT(empty, PtrEquals(empty)); + + ASSERT_OK_AND_ASSIGN(auto other, EmptyRelExpr::Make(schema_1)); + EXPECT_THAT(other, PtrEquals(empty)); +} + +TEST_F(RelExprTest, ScanRelExpr) { + ASSERT_OK_AND_ASSIGN(auto table, catalog->Get(table_1)); + + ASSERT_OK_AND_ASSIGN(auto scan, ScanRelExpr::Make(table)); + EXPECT_THAT(scan, PtrEquals(scan)); + EXPECT_THAT(scan->type(), ExprType::Table(schema_1)); + EXPECT_THAT(scan->schema(), PtrEquals(schema_1)); + + ASSERT_OK_AND_ASSIGN(auto other, ScanRelExpr::Make(table)); + EXPECT_THAT(other, PtrEquals(scan)); +} + +TEST_F(RelExprTest, ProjectionRelExpr) { + // TODO(fsaintjacques): FILLME +} + +TEST_F(RelExprTest, FilterRelExpr) { + ASSERT_OK_AND_ASSIGN(auto empty, EmptyRelExpr::Make(schema_1)); + ASSERT_OK_AND_ASSIGN(auto pred, ScalarExpr::Make(MakeScalar(true))); + + ASSERT_RAISES(Invalid, FilterRelExpr::Make(nullptr, nullptr)); + ASSERT_RAISES(Invalid, FilterRelExpr::Make(empty, nullptr)); + ASSERT_RAISES(Invalid, FilterRelExpr::Make(nullptr, pred)); + + EXPECT_RAISES_WITH_MESSAGE_THAT(Invalid, HasSubstr("input must be a table"), + FilterRelExpr::Make(pred, pred)); + EXPECT_RAISES_WITH_MESSAGE_THAT(Invalid, HasSubstr("predicate must be a predicate"), + FilterRelExpr::Make(empty, empty)); + + ASSERT_OK_AND_ASSIGN(auto filter, FilterRelExpr::Make(empty, pred)); + EXPECT_THAT(filter, PtrEquals(filter)); + EXPECT_THAT(filter->type(), ExprType::Table(schema_1)); + EXPECT_THAT(filter->schema(), PtrEquals(schema_1)); + EXPECT_THAT(filter->operand(), PtrEquals(empty)); + EXPECT_THAT(filter->predicate(), PtrEquals(pred)); +} + +} // namespace engine +} // namespace arrow diff --git a/cpp/src/arrow/engine/logical_plan.cc b/cpp/src/arrow/engine/logical_plan.cc new file mode 100644 index 00000000000..bf3e745fd50 --- /dev/null +++ b/cpp/src/arrow/engine/logical_plan.cc @@ -0,0 +1,202 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#include "arrow/engine/logical_plan.h" + +#include + +#include "arrow/engine/expression.h" +#include "arrow/result.h" +#include "arrow/type.h" +#include "arrow/util/logging.h" + +namespace arrow { +namespace engine { + +// +// LogicalPlan +// + +LogicalPlan::LogicalPlan(std::shared_ptr root) : root_(std::move(root)) { + DCHECK_NE(root_, nullptr); +} + +const ExprType& LogicalPlan::type() const { return root()->type(); } + +bool LogicalPlan::Equals(const LogicalPlan& other) const { + if (this == &other) { + return true; + } + + return root()->Equals(other.root()); +} + +std::string LogicalPlan::ToString() const { return root_->ToString(); } + +// +// LogicalPlanBuilder +// + +LogicalPlanBuilder::LogicalPlanBuilder(LogicalPlanBuilderOptions options) + : catalog_(options.catalog) {} + +using ResultExpr = LogicalPlanBuilder::ResultExpr; + +#define ERROR_IF_TYPE(cond, ErrorType, ...) \ + do { \ + if (ARROW_PREDICT_FALSE(cond)) { \ + return Status::ErrorType(__VA_ARGS__); \ + } \ + } while (false) + +#define ERROR_IF(cond, ...) ERROR_IF_TYPE(cond, Invalid, __VA_ARGS__) + +// +// Leaf builder. +// + +ResultExpr LogicalPlanBuilder::Scalar(const std::shared_ptr& scalar) { + return ScalarExpr::Make(scalar); +} + +ResultExpr LogicalPlanBuilder::Field(const std::shared_ptr& input, + const std::string& field_name) { + return FieldRefExpr::Make(input, field_name); +} + +ResultExpr LogicalPlanBuilder::Field(const std::shared_ptr& input, + int field_index) { + return FieldRefExpr::Make(input, field_index); +} + +ResultExpr LogicalPlanBuilder::Compare(CompareKind compare_kind, + const std::shared_ptr& lhs, + const std::shared_ptr& rhs) { + switch (compare_kind) { + case (CompareKind::EQUAL): + return EqualExpr::Make(lhs, rhs); + case (CompareKind::NOT_EQUAL): + return NotEqualExpr::Make(lhs, rhs); + case (CompareKind::GREATER_THAN): + return GreaterThanExpr::Make(lhs, rhs); + case (CompareKind::GREATER_THAN_EQUAL): + return GreaterThanEqualExpr::Make(lhs, rhs); + case (CompareKind::LESS_THAN): + return LessThanExpr::Make(lhs, rhs); + case (CompareKind::LESS_THAN_EQUAL): + return LessThanEqualExpr::Make(lhs, rhs); + } + + ARROW_UNREACHABLE; +} + +ResultExpr LogicalPlanBuilder::Equal(const std::shared_ptr& lhs, + const std::shared_ptr& rhs) { + return Compare(CompareKind::EQUAL, lhs, rhs); +} + +ResultExpr LogicalPlanBuilder::NotEqual(const std::shared_ptr& lhs, + const std::shared_ptr& rhs) { + return Compare(CompareKind::NOT_EQUAL, lhs, rhs); +} + +ResultExpr LogicalPlanBuilder::GreaterThan(const std::shared_ptr& lhs, + const std::shared_ptr& rhs) { + return Compare(CompareKind::GREATER_THAN, lhs, rhs); +} + +ResultExpr LogicalPlanBuilder::GreaterThanEqual(const std::shared_ptr& lhs, + const std::shared_ptr& rhs) { + return Compare(CompareKind::GREATER_THAN_EQUAL, lhs, rhs); +} + +ResultExpr LogicalPlanBuilder::LessThan(const std::shared_ptr& lhs, + const std::shared_ptr& rhs) { + return Compare(CompareKind::LESS_THAN, lhs, rhs); +} + +ResultExpr LogicalPlanBuilder::LessThanEqual(const std::shared_ptr& lhs, + const std::shared_ptr& rhs) { + return Compare(CompareKind::LESS_THAN_EQUAL, lhs, rhs); +} + +// +// Count +// + +ResultExpr LogicalPlanBuilder::Count(const std::shared_ptr& input) { + return CountExpr::Make(input); +} + +ResultExpr LogicalPlanBuilder::Sum(const std::shared_ptr& input) { + return SumExpr::Make(input); +} + +// +// Relational +// + +ResultExpr LogicalPlanBuilder::Scan(const std::string& table_name) { + ERROR_IF(catalog_ == nullptr, "Cannot scan from an empty catalog"); + ARROW_ASSIGN_OR_RAISE(auto table, catalog_->Get(table_name)); + return ScanRelExpr::Make(table); +} + +ResultExpr LogicalPlanBuilder::Filter(const std::shared_ptr& input, + const std::shared_ptr& predicate) { + return FilterRelExpr::Make(input, predicate); +} + +ResultExpr LogicalPlanBuilder::Project( + const std::shared_ptr& input, + const std::vector>& expressions) { + return ProjectionRelExpr::Make(input, expressions); +} + +ResultExpr LogicalPlanBuilder::Project(const std::shared_ptr& input, + const std::vector& column_names) { + ERROR_IF(input == nullptr, "Input expression can't be null."); + ERROR_IF(column_names.empty(), "Must have at least one column name."); + + std::vector> expressions{column_names.size()}; + for (size_t i = 0; i < column_names.size(); i++) { + ARROW_ASSIGN_OR_RAISE(expressions[i], Field(input, column_names[i])); + } + + // TODO(fsaintjacques): preserve field names. + return Project(input, expressions); +} + +ResultExpr LogicalPlanBuilder::Project(const std::shared_ptr& input, + const std::vector& column_indices) { + ERROR_IF(input == nullptr, "Input expression can't be null."); + ERROR_IF(column_indices.empty(), "Must have at least one column index."); + + std::vector> expressions{column_indices.size()}; + for (size_t i = 0; i < column_indices.size(); i++) { + ARROW_ASSIGN_OR_RAISE(expressions[i], Field(input, column_indices[i])); + } + + // TODO(fsaintjacques): preserve field names. + return Project(input, expressions); +} + +#undef ERROR_IF +#undef ERROR_IF_TYPE + +} // namespace engine +} // namespace arrow diff --git a/cpp/src/arrow/engine/logical_plan.h b/cpp/src/arrow/engine/logical_plan.h new file mode 100644 index 00000000000..f5ef46809f5 --- /dev/null +++ b/cpp/src/arrow/engine/logical_plan.h @@ -0,0 +1,150 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#pragma once + +#include +#include +#include + +#include "arrow/engine/type_fwd.h" +#include "arrow/engine/visibility.h" +#include "arrow/type_fwd.h" +#include "arrow/util/compare.h" + +namespace arrow { + +namespace dataset { +class Dataset; +} + +namespace engine { + +class ARROW_EN_EXPORT LogicalPlan : public util::EqualityComparable { + public: + explicit LogicalPlan(std::shared_ptr root); + + const std::shared_ptr& root() const { return root_; } + const ExprType& type() const; + + bool Equals(const LogicalPlan& other) const; + std::string ToString() const; + + private: + std::shared_ptr root_; +}; + +struct LogicalPlanBuilderOptions { + /// Catalog containing named tables. + std::shared_ptr catalog; +}; + +class ARROW_EN_EXPORT LogicalPlanBuilder { + public: + using ResultExpr = Result>; + + explicit LogicalPlanBuilder(LogicalPlanBuilderOptions options = {}); + + /// \defgroup leaf-nodes Leaf nodes in the logical plan + /// @{ + + /// \brief Construct a Scalar literal. + ResultExpr Scalar(const std::shared_ptr& scalar); + + /// \brief References a field by index. + ResultExpr Field(const std::shared_ptr& input, int field_index); + /// \brief References a field by name. + ResultExpr Field(const std::shared_ptr& input, const std::string& field_name); + + /// \brief Scan a Table/Dataset from the Catalog. + ResultExpr Scan(const std::string& table_name); + + /// @} + + /// \defgroup comparator-nodes Comparison operators + /// @{ + + /// \brief Compare inputs. + ResultExpr Compare(CompareKind compare_kind, const std::shared_ptr& lhs, + const std::shared_ptr& rhs); + + /// \brief Compare if inputs are equal. + ResultExpr Equal(const std::shared_ptr& lhs, const std::shared_ptr& rhs); + + /// \brief Compare if inputs are not equal. + ResultExpr NotEqual(const std::shared_ptr& lhs, const std::shared_ptr& rhs); + + /// \brief Compare if lhs is greater than rhs. + ResultExpr GreaterThan(const std::shared_ptr& lhs, + const std::shared_ptr& rhs); + + /// \brief Compare if lhs is greater than equal rhs. + ResultExpr GreaterThanEqual(const std::shared_ptr& lhs, + const std::shared_ptr& rhs); + + /// \brief Compare if lhs is less than rhs. + ResultExpr LessThan(const std::shared_ptr& lhs, const std::shared_ptr& rhs); + + /// \brief Compare if lhs is less than equal rhs. + ResultExpr LessThanEqual(const std::shared_ptr& lhs, + const std::shared_ptr& rhs); + + /// @} + + /// \defgroup Aggregate function operators + /// @{ + + /// \brief Count the number of elements in the input. + ResultExpr Count(const std::shared_ptr& input); + + /// \brief Sum the elements of the input. + ResultExpr Sum(const std::shared_ptr& input); + + /// @} + + /// \defgroup rel-nodes Relational operator nodes in the logical plan + + /// \brief Filter rows of a relation with the given predicate. + ResultExpr Filter(const std::shared_ptr& input, + const std::shared_ptr& predicate); + + /// \brief Project (mutate) columns with given expressions. + ResultExpr Project(const std::shared_ptr& input, + const std::vector>& expressions); + + /// \brief Project (select) columns by names. + /// + /// This is a simplified version of Project where columns are selected by + /// names. Duplicates and ordering are preserved. + ResultExpr Project(const std::shared_ptr& input, + const std::vector& column_names); + + /// \brief Project (select) columns by indices. + /// + /// This is a simplified version of Project where columns are selected by + /// indices. Duplicates and ordering are preserved. + ResultExpr Project(const std::shared_ptr& input, + const std::vector& column_indices); + + /// @} + + private: + std::shared_ptr catalog_; +}; + +} // namespace engine +} // namespace arrow diff --git a/cpp/src/arrow/engine/logical_plan_test.cc b/cpp/src/arrow/engine/logical_plan_test.cc new file mode 100644 index 00000000000..28d4607ed41 --- /dev/null +++ b/cpp/src/arrow/engine/logical_plan_test.cc @@ -0,0 +1,188 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#include + +#include "arrow/engine/catalog.h" +#include "arrow/engine/expression.h" +#include "arrow/engine/logical_plan.h" +#include "arrow/testing/gtest_common.h" + +using testing::HasSubstr; + +namespace arrow { +namespace engine { + +using ResultExpr = LogicalPlanBuilder::ResultExpr; + +class LogicalPlanBuilderTest : public testing::Test { + protected: + void SetUp() override { + CatalogBuilder catalog_builder; + ASSERT_OK(catalog_builder.Add(table_1, MockTable(schema_1))); + ASSERT_OK_AND_ASSIGN(options.catalog, catalog_builder.Finish()); + builder = LogicalPlanBuilder{options}; + } + + ResultExpr scalar_expr() { + auto forthy_two = MakeScalar(42); + return builder.Scalar(forthy_two); + } + + ResultExpr scan_expr() { return builder.Scan(table_1); } + + template + ResultExpr field_expr(T key, std::shared_ptr input = nullptr) { + if (input == nullptr) { + ARROW_ASSIGN_OR_RAISE(input, scan_expr()); + } + return builder.Field(input, key); + } + + ResultExpr predicate_expr() { return nullptr; } + + std::string table_1 = "table_1"; + std::shared_ptr schema_1 = schema({ + field("bool", boolean()), + field("i32", int32()), + field("u64", uint64()), + field("f32", float32()), + field("utf8", utf8()), + }); + LogicalPlanBuilderOptions options{}; + LogicalPlanBuilder builder{}; +}; + +TEST_F(LogicalPlanBuilderTest, Scalar) { + auto forthy_two = MakeScalar(42); + EXPECT_OK_AND_ASSIGN(auto scalar, builder.Scalar(forthy_two)); + ASSERT_TRUE(IsA(scalar)); +} + +TEST_F(LogicalPlanBuilderTest, FieldReferences) { + ASSERT_RAISES(Invalid, builder.Field(nullptr, "i32")); + ASSERT_RAISES(Invalid, builder.Field(nullptr, 0)); + + // Can't lookup a scalar + EXPECT_OK_AND_ASSIGN(auto scalar, scalar_expr()); + ASSERT_RAISES(Invalid, builder.Field(scalar, "i32")); + + EXPECT_OK_AND_ASSIGN(auto table, scan_expr()); + ASSERT_RAISES(KeyError, builder.Field(table, "")); + ASSERT_RAISES(KeyError, builder.Field(table, -1)); + ASSERT_RAISES(KeyError, builder.Field(table, 9000)); + + EXPECT_OK_AND_ASSIGN(auto field_name_ref, builder.Field(table, "i32")); + ASSERT_TRUE(IsA(field_name_ref)); + + EXPECT_OK_AND_ASSIGN(auto field_idx_ref, builder.Field(table, 0)); + ASSERT_TRUE(IsA(field_idx_ref)); +} + +TEST_F(LogicalPlanBuilderTest, BasicScan) { + ASSERT_RAISES(KeyError, builder.Scan("")); + ASSERT_RAISES(KeyError, builder.Scan("not_found")); + + EXPECT_OK_AND_ASSIGN(auto scan, builder.Scan(table_1)); + ASSERT_TRUE(IsA(scan)); +} + +TEST_F(LogicalPlanBuilderTest, Comparisons) { + EXPECT_OK_AND_ASSIGN(auto table, scan_expr()); + EXPECT_OK_AND_ASSIGN(auto field, field_expr("i32", table)); + EXPECT_OK_AND_ASSIGN(auto scalar, scalar_expr()); + + EXPECT_OK_AND_ASSIGN(auto eq, builder.Equal(field, scalar)); + ASSERT_TRUE(IsA(eq)); + + EXPECT_OK_AND_ASSIGN(auto ne, builder.NotEqual(field, scalar)); + ASSERT_TRUE(IsA(ne)); + + EXPECT_OK_AND_ASSIGN(auto gt, builder.GreaterThan(field, scalar)); + ASSERT_TRUE(IsA(gt)); + + EXPECT_OK_AND_ASSIGN(auto ge, builder.GreaterThanEqual(field, scalar)); + ASSERT_TRUE(IsA(ge)); + + EXPECT_OK_AND_ASSIGN(auto lt, builder.LessThan(field, scalar)); + ASSERT_TRUE(IsA(lt)); + + EXPECT_OK_AND_ASSIGN(auto le, builder.LessThanEqual(field, scalar)); + ASSERT_TRUE(IsA(le)); +} + +TEST_F(LogicalPlanBuilderTest, Count) { + EXPECT_OK_AND_ASSIGN(auto table, scan_expr()); + EXPECT_OK_AND_ASSIGN(auto field, field_expr("i32", table)); + + EXPECT_OK_AND_ASSIGN(auto scalar, scalar_expr()); + EXPECT_OK_AND_ASSIGN(auto s_count, builder.Count(scalar)); + ASSERT_TRUE(IsA(s_count)); + + EXPECT_OK_AND_ASSIGN(auto f_count, builder.Count(field)); + ASSERT_TRUE(IsA(f_count)); + + EXPECT_OK_AND_ASSIGN(auto t_count, builder.Count(table)); + ASSERT_TRUE(IsA(t_count)); +} + +TEST_F(LogicalPlanBuilderTest, Sum) { + EXPECT_OK_AND_ASSIGN(auto table, scan_expr()); + + EXPECT_OK_AND_ASSIGN(auto scalar, scalar_expr()); + EXPECT_OK_AND_ASSIGN(auto s_sum, builder.Sum(scalar)); + ASSERT_TRUE(IsA(s_sum)); + + EXPECT_OK_AND_ASSIGN(auto i32_field, field_expr("i32", table)); + EXPECT_OK_AND_ASSIGN(auto f_sum, builder.Sum(i32_field)); + ASSERT_TRUE(IsA(s_sum)); + + EXPECT_OK_AND_ASSIGN(auto str_field, field_expr("utf8", table)); + ASSERT_RAISES(Invalid, builder.Sum(str_field)); + ASSERT_RAISES(Invalid, builder.Sum(table)); +} + +TEST_F(LogicalPlanBuilderTest, Filter) { + EXPECT_OK_AND_ASSIGN(auto table, scan_expr()); + + EXPECT_OK_AND_ASSIGN(auto field, field_expr("i32", table)); + EXPECT_OK_AND_ASSIGN(auto scalar, scalar_expr()); + EXPECT_OK_AND_ASSIGN(auto predicate, EqualExpr::Make(field, scalar)); + + EXPECT_OK_AND_ASSIGN(auto filter, builder.Filter(table, predicate)); + ASSERT_TRUE(IsA(filter)); +} + +TEST_F(LogicalPlanBuilderTest, ProjectionByNamesAndIndices) { + EXPECT_OK_AND_ASSIGN(auto table, scan_expr()); + + std::vector no_names{}; + ASSERT_RAISES(Invalid, builder.Project(table, no_names)); + std::vector invalid_names{"u64", "nope"}; + ASSERT_RAISES(KeyError, builder.Project(table, invalid_names)); + std::vector invalid_idx{42, 0}; + ASSERT_RAISES(KeyError, builder.Project(table, invalid_idx)); + + std::vector valid_names{"u64", "f32"}; + ASSERT_OK(builder.Project(table, valid_names)); + std::vector valid_idx{3, 1, 1}; + EXPECT_OK_AND_ASSIGN(auto project, builder.Project(table, valid_idx)); + ASSERT_TRUE(IsA(project)); +} + +} // namespace engine +} // namespace arrow diff --git a/cpp/src/arrow/engine/pch.h b/cpp/src/arrow/engine/pch.h new file mode 100644 index 00000000000..2014fc9a1f2 --- /dev/null +++ b/cpp/src/arrow/engine/pch.h @@ -0,0 +1,24 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +// Often-used headers, for precompiling. +// If updating this header, please make sure you check compilation speed +// before checking in. Adding headers which are not used extremely often +// may incur a slowdown, since it makes the precompiled header heavier to load. + +#include "arrow/engine/expression.h" +#include "arrow/pch.h" diff --git a/cpp/src/arrow/engine/type_fwd.h b/cpp/src/arrow/engine/type_fwd.h new file mode 100644 index 00000000000..55d10720e68 --- /dev/null +++ b/cpp/src/arrow/engine/type_fwd.h @@ -0,0 +1,99 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#pragma once + +#include + +#include "arrow/engine/visibility.h" + +namespace arrow { +namespace engine { + +class ExprType; + +/// Tag identifier for the expression type. +enum ExprKind : uint8_t { + /// A Scalar literal, i.e. a constant. + SCALAR_LITERAL, + /// A Field reference in a schema. + FIELD_REFERENCE, + + // Comparison operators, see CompareKind. + COMPARE_OP, + + // Aggregate function operators, see AggregateFnKind. + AGGREGATE_FN_OP, + + /// Empty relation with a known schema. + EMPTY_REL, + /// Scan relational operator. + SCAN_REL, + /// Projection relational operator. + PROJECTION_REL, + /// Filter relational operator. + FILTER_REL, +}; + +class Expr; +class ScalarExpr; +class FieldRefExpr; + +/// Tag identifier for comparison operators +enum CompareKind : uint8_t { + EQUAL, + NOT_EQUAL, + GREATER_THAN, + GREATER_THAN_EQUAL, + LESS_THAN, + LESS_THAN_EQUAL, +}; + +class CompareOpExpr; +class EqualExpr; +class NotEqualExpr; +class GreaterThanExpr; +class GreaterThanEqualExpr; +class LessThanExpr; +class LessThanEqualExpr; + +/// Tag identifier for aggregate function operators +enum AggregateFnKind : uint8_t { + // Count the number of elements in the input array + COUNT, + // Sum the elements of the input array. + SUM, +}; + +class AggregateFnExpr; +class CountExpr; +class SumExpr; + +class RelExpr; + +class EmptyRelExpr; +class ScanRelExpr; +class ProjectionRelExpr; +class FilterRelExpr; + +class Catalog; + +class LogicalPlan; +class LogicalPlanBuilder; + +} // namespace engine +} // namespace arrow diff --git a/cpp/src/arrow/engine/type_traits.h b/cpp/src/arrow/engine/type_traits.h new file mode 100644 index 00000000000..6e44cb6d826 --- /dev/null +++ b/cpp/src/arrow/engine/type_traits.h @@ -0,0 +1,141 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#pragma once + +#include + +#include "arrow/engine/type_fwd.h" + +namespace arrow { +namespace engine { + +template +using enable_if_t = typename std::enable_if::type; + +template +struct expr_traits; + +template <> +struct expr_traits { + static constexpr auto kind_id = ExprKind::SCALAR_LITERAL; +}; + +template <> +struct expr_traits { + static constexpr auto kind_id = ExprKind::FIELD_REFERENCE; +}; + +template <> +struct expr_traits { + static constexpr auto kind_id = ExprKind::COMPARE_OP; + static constexpr auto compare_kind_id = CompareKind::EQUAL; +}; + +template <> +struct expr_traits { + static constexpr auto kind_id = ExprKind::COMPARE_OP; + static constexpr auto compare_kind_id = CompareKind::NOT_EQUAL; +}; + +template <> +struct expr_traits { + static constexpr auto kind_id = ExprKind::COMPARE_OP; + static constexpr auto compare_kind_id = CompareKind::GREATER_THAN; +}; + +template <> +struct expr_traits { + static constexpr auto kind_id = ExprKind::COMPARE_OP; + static constexpr auto compare_kind_id = CompareKind::GREATER_THAN_EQUAL; +}; + +template <> +struct expr_traits { + static constexpr auto kind_id = ExprKind::COMPARE_OP; + static constexpr auto compare_kind_id = CompareKind::LESS_THAN; +}; + +template <> +struct expr_traits { + static constexpr auto kind_id = ExprKind::COMPARE_OP; + static constexpr auto compare_kind_id = CompareKind::LESS_THAN_EQUAL; +}; + +template <> +struct expr_traits { + static constexpr auto kind_id = ExprKind::AGGREGATE_FN_OP; + static constexpr auto aggregate_kind_id = AggregateFnKind::COUNT; +}; + +template <> +struct expr_traits { + static constexpr auto kind_id = ExprKind::AGGREGATE_FN_OP; + static constexpr auto aggregate_kind_id = AggregateFnKind::SUM; +}; + +template <> +struct expr_traits { + static constexpr auto kind_id = ExprKind::EMPTY_REL; +}; + +template <> +struct expr_traits { + static constexpr auto kind_id = ExprKind::SCAN_REL; +}; + +template <> +struct expr_traits { + static constexpr auto kind_id = ExprKind::PROJECTION_REL; +}; + +template <> +struct expr_traits { + static constexpr auto kind_id = ExprKind::FILTER_REL; +}; + +template +using is_expr = std::is_base_of; + +template +using enable_if_expr = enable_if_t::value, Ret>; + +template +using is_compare_expr = std::is_base_of; + +template +using enable_if_compare_expr = enable_if_t::value, Ret>; + +template +using is_aggregate_fn_expr = std::is_base_of; + +template +using enable_if_aggregate_fn_expr = enable_if_t::value, Ret>; + +template +using is_relational_expr = std::is_base_of; + +template +using enable_if_relational_expr = enable_if_t::value, Ret>; + +// Catch-all used by `IsA` pattern matcher. +template +using enable_if_simple_expr = + enable_if_t::value && !is_aggregate_fn_expr::value, Ret>; + +} // namespace engine +} // namespace arrow diff --git a/cpp/src/arrow/engine/visibility.h b/cpp/src/arrow/engine/visibility.h new file mode 100644 index 00000000000..0598aee3802 --- /dev/null +++ b/cpp/src/arrow/engine/visibility.h @@ -0,0 +1,48 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#pragma once + +#if defined(_WIN32) || defined(__CYGWIN__) +#if defined(_MSC_VER) +#pragma warning(push) +#pragma warning(disable : 4251) +#else +#pragma GCC diagnostic ignored "-Wattributes" +#endif + +#ifdef ARROW_EN_STATIC +#define ARROW_EN_EXPORT +#elif defined(ARROW_EN_EXPORTING) +#define ARROW_EN_EXPORT __declspec(dllexport) +#else +#define ARROW_EN_EXPORT __declspec(dllimport) +#endif + +#define ARROW_EN_NO_EXPORT +#else // Not Windows +#ifndef ARROW_EN_EXPORT +#define ARROW_EN_EXPORT __attribute__((visibility("default"))) +#endif +#ifndef ARROW_EN_NO_EXPORT +#define ARROW_EN_NO_EXPORT __attribute__((visibility("hidden"))) +#endif +#endif // Non-Windows + +#if defined(_MSC_VER) +#pragma warning(pop) +#endif diff --git a/cpp/src/arrow/testing/gmock.h b/cpp/src/arrow/testing/gmock.h new file mode 100644 index 00000000000..4b801679ec1 --- /dev/null +++ b/cpp/src/arrow/testing/gmock.h @@ -0,0 +1,32 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#pragma once + +#include + +namespace arrow { + +using testing::Eq; +using testing::HasSubstr; + +MATCHER_P(Equals, other, "") { return arg.Equals(other); } +MATCHER_P(PtrEquals, other, "") { return arg->Equals(*other); } +MATCHER(Ok, "") { return arg.ok(); } +MATCHER_P(OkAndEq, other, "") { return arg.ok() && arg.ValueOrDie() == other; } + +} // namespace arrow diff --git a/cpp/src/arrow/testing/gtest_util.cc b/cpp/src/arrow/testing/gtest_util.cc index 8caf3f1cec9..566377a8fdf 100644 --- a/cpp/src/arrow/testing/gtest_util.cc +++ b/cpp/src/arrow/testing/gtest_util.cc @@ -321,6 +321,61 @@ void CompareBatch(const RecordBatch& left, const RecordBatch& right, } } +namespace detail { + +class MockTable : public Table { + public: + explicit MockTable(std::shared_ptr schema, int num_rows = 0) { + schema_ = std::move(schema); + num_rows_ = num_rows; + } + + static std::shared_ptr
Make(std::shared_ptr schema) { + return std::make_shared(std::move(schema)); + } + + std::shared_ptr column(int i) const override { + return std::make_shared(ArrayVector{}, schema_->field(i)->type()); + } + std::shared_ptr
Slice(int64_t offset, int64_t length) const override { + return nullptr; + } + + Status RemoveColumn(int i, std::shared_ptr
* out) const override { + return Status::NotImplemented("MockTable does not implement ", __FUNCTION__); + } + + Status AddColumn(int i, std::shared_ptr field_arg, + std::shared_ptr column, + std::shared_ptr
* out) const override { + return Status::NotImplemented("MockTable does not implement ", __FUNCTION__); + } + + Status SetColumn(int i, std::shared_ptr field_arg, + std::shared_ptr column, + std::shared_ptr
* out) const override { + return Status::NotImplemented("MockTable does not implement ", __FUNCTION__); + } + + std::shared_ptr
ReplaceSchemaMetadata( + const std::shared_ptr& metadata) const override { + return nullptr; + } + + Status Flatten(MemoryPool* pool, std::shared_ptr
* out) const override { + return Status::NotImplemented("MockTable does not implement ", __FUNCTION__); + } + + Status Validate() const override { return Status::OK(); } + Status ValidateFull() const override { return Status::OK(); } +}; + +} // namespace detail + +std::shared_ptr
MockTable(std::shared_ptr schema) { + return detail::MockTable::Make(schema); +} + class LocaleGuard::Impl { public: explicit Impl(const char* new_locale) : global_locale_(std::locale()) { diff --git a/cpp/src/arrow/testing/gtest_util.h b/cpp/src/arrow/testing/gtest_util.h index 93ea12ddcf8..e2be60c5908 100644 --- a/cpp/src/arrow/testing/gtest_util.h +++ b/cpp/src/arrow/testing/gtest_util.h @@ -397,6 +397,9 @@ inline void BitmapFromVector(const std::vector& is_valid, ASSERT_OK(GetBitmapFromVector(is_valid, out)); } +// Returns a table with 0 rows of a given schema. +ARROW_EXPORT std::shared_ptr
MockTable(std::shared_ptr schema); + template void AssertSortedEquals(std::vector u, std::vector v) { std::sort(u.begin(), u.end()); diff --git a/cpp/src/arrow/type_traits.h b/cpp/src/arrow/type_traits.h index 021c9985c1e..06746fdf4cf 100644 --- a/cpp/src/arrow/type_traits.h +++ b/cpp/src/arrow/type_traits.h @@ -684,6 +684,10 @@ static inline bool is_floating(Type::type type_id) { return false; } +static inline bool is_numeric(Type::type type_id) { + return is_integer(type_id) || is_floating(type_id); +} + static inline bool is_primitive(Type::type type_id) { switch (type_id) { case Type::NA: diff --git a/cpp/src/arrow/util/compare.h b/cpp/src/arrow/util/compare.h index 287a30d03b2..3be798f69b8 100644 --- a/cpp/src/arrow/util/compare.h +++ b/cpp/src/arrow/util/compare.h @@ -22,6 +22,7 @@ #include #include "arrow/util/macros.h" +#include "arrow/util/visibility.h" namespace arrow { namespace util { diff --git a/cpp/src/arrow/util/macros.h b/cpp/src/arrow/util/macros.h index 7d04a80e802..90760a4c23d 100644 --- a/cpp/src/arrow/util/macros.h +++ b/cpp/src/arrow/util/macros.h @@ -61,6 +61,12 @@ #define ARROW_PREFETCH(addr) #endif +#if defined(__GNUC__) +#define ARROW_UNREACHABLE __builtin_unreachable() +#elif defined(_MSC_VER) +#define ARROW_UNREACHABLE __assume(0) +#endif + #if (defined(__GNUC__) || defined(__APPLE__)) #define ARROW_MUST_USE_RESULT __attribute__((warn_unused_result)) #elif defined(_MSC_VER) diff --git a/dev/archery/archery/cli.py b/dev/archery/archery/cli.py index 744e92405de..9366631e547 100644 --- a/dev/archery/archery/cli.py +++ b/dev/archery/archery/cli.py @@ -155,6 +155,8 @@ def _apply_options(cmd, options): help="Build with compute kernels support.") @click.option("--with-dataset", default=False, type=BOOL, help="Build with dataset support.") +@click.option("--with-engine", default=False, type=BOOL, + help="Build with query engine support.") @click.option("--use-sanitizers", default=False, type=BOOL, help="Toggles ARROW_USE_*SAN sanitizers.") @click.option("--with-fuzzing", default=False, type=BOOL, diff --git a/dev/archery/archery/lang/cpp.py b/dev/archery/archery/lang/cpp.py index 607581b3c71..5eb0507bd88 100644 --- a/dev/archery/archery/lang/cpp.py +++ b/dev/archery/archery/lang/cpp.py @@ -46,7 +46,7 @@ def __init__(self, with_parquet=False, # Components with_gandiva=False, with_compute=False, with_dataset=False, - with_plasma=False, with_flight=False, + with_engine=False, with_plasma=False, with_flight=False, # extras with_lint_only=False, with_fuzzing=False, use_gold_linker=True, use_sanitizers=True, @@ -65,12 +65,13 @@ def __init__(self, self.with_benchmarks = with_benchmarks self.with_examples = with_examples self.with_python = with_python - self.with_parquet = with_parquet or with_dataset self.with_gandiva = with_gandiva self.with_plasma = with_plasma self.with_flight = with_flight self.with_compute = with_compute - self.with_dataset = with_dataset + self.with_engine = with_engine + self.with_dataset = with_dataset or self.with_engine + self.with_parquet = with_parquet or self.with_dataset self.with_lint_only = with_lint_only self.with_fuzzing = with_fuzzing @@ -147,6 +148,7 @@ def _gen_defs(self): yield ("ARROW_FLIGHT", truthifier(self.with_flight)) yield ("ARROW_COMPUTE", truthifier(self.with_compute)) yield ("ARROW_DATASET", truthifier(self.with_dataset)) + yield ("ARROW_ENGINE", truthifier(self.with_engine)) if self.use_sanitizers or self.with_fuzzing: yield ("ARROW_USE_ASAN", "ON")