From 8ccf0e9fb530b7c2f93616f7c794167e8748a4af Mon Sep 17 00:00:00 2001 From: Wes McKinney Date: Wed, 6 May 2020 18:42:47 -0500 Subject: [PATCH 01/41] New C++ compute kernels development framework project --- cpp/CMakeLists.txt | 4 + cpp/src/arrow/CMakeLists.txt | 40 +- cpp/src/arrow/array/diff_test.cc | 13 +- cpp/src/arrow/compute/CMakeLists.txt | 9 +- cpp/src/arrow/compute/README.md | 58 + cpp/src/arrow/compute/api.h | 23 +- cpp/src/arrow/compute/api_eager.cc | 209 +++ cpp/src/arrow/compute/api_eager.h | 341 ++++ cpp/src/arrow/compute/cast.h | 68 + cpp/src/arrow/compute/compute_test.cc | 95 - cpp/src/arrow/compute/context.h | 79 - cpp/src/arrow/compute/exec.cc | 859 +++++++++ cpp/src/arrow/compute/exec.h | 175 ++ cpp/src/arrow/compute/exec_internal.h | 128 ++ cpp/src/arrow/compute/exec_test.cc | 840 +++++++++ cpp/src/arrow/compute/filter.h | 54 + cpp/src/arrow/compute/function.cc | 150 ++ cpp/src/arrow/compute/function.h | 197 +++ cpp/src/arrow/compute/function_test.cc | 239 +++ cpp/src/arrow/compute/kernel.cc | 296 ++++ cpp/src/arrow/compute/kernel.h | 655 ++++--- cpp/src/arrow/compute/kernel_test.cc | 430 +++++ cpp/src/arrow/compute/kernels/CMakeLists.txt | 66 +- cpp/src/arrow/compute/kernels/add.cc | 131 -- cpp/src/arrow/compute/kernels/add.h | 77 - cpp/src/arrow/compute/kernels/aggregate.cc | 88 - cpp/src/arrow/compute/kernels/aggregate.h | 115 -- .../arrow/compute/kernels/aggregate_basic.cc | 366 ++++ .../compute/kernels/aggregate_benchmark.cc | 9 +- .../kernels/{isin.h => aggregate_internal.h} | 43 +- .../arrow/compute/kernels/aggregate_test.cc | 292 +--- cpp/src/arrow/compute/kernels/boolean.cc | 269 --- cpp/src/arrow/compute/kernels/boolean.h | 105 -- cpp/src/arrow/compute/kernels/cast.cc | 1549 ----------------- cpp/src/arrow/compute/kernels/cast.h | 101 -- .../arrow/compute/kernels/codegen_internal.cc | 145 ++ .../arrow/compute/kernels/codegen_internal.h | 429 +++++ cpp/src/arrow/compute/kernels/common.h | 50 + cpp/src/arrow/compute/kernels/compare.cc | 332 ---- cpp/src/arrow/compute/kernels/compare.h | 72 - cpp/src/arrow/compute/kernels/count.cc | 115 -- cpp/src/arrow/compute/kernels/count.h | 88 - cpp/src/arrow/compute/kernels/filter.h | 105 -- cpp/src/arrow/compute/kernels/hash.h | 102 -- cpp/src/arrow/compute/kernels/isin.cc | 306 ---- cpp/src/arrow/compute/kernels/isin_test.cc | 415 ----- cpp/src/arrow/compute/kernels/match.cc | 281 --- cpp/src/arrow/compute/kernels/match.h | 57 - cpp/src/arrow/compute/kernels/match_test.cc | 389 ----- cpp/src/arrow/compute/kernels/mean.cc | 116 -- cpp/src/arrow/compute/kernels/mean.h | 66 - cpp/src/arrow/compute/kernels/minmax.cc | 142 +- cpp/src/arrow/compute/kernels/minmax.h | 98 -- .../arrow/compute/kernels/nth_to_indices.cc | 140 -- .../arrow/compute/kernels/nth_to_indices.h | 53 - .../kernels/{sort_to_indices.h => registry.h} | 41 +- .../scalar_arithmetic.cc} | 43 +- ...{add_test.cc => scalar_arithmetic_test.cc} | 21 +- .../arrow/compute/kernels/scalar_boolean.cc | 183 ++ ...boolean_test.cc => scalar_boolean_test.cc} | 46 +- cpp/src/arrow/compute/kernels/scalar_cast.cc | 449 +++++ .../compute/kernels/scalar_cast_boolean.cc | 82 + .../compute/kernels/scalar_cast_decimal.cc | 88 + .../compute/kernels/scalar_cast_internal.h | 222 +++ .../compute/kernels/scalar_cast_nested.cc | 64 + .../compute/kernels/scalar_cast_numeric.cc | 425 +++++ .../compute/kernels/scalar_cast_string.cc | 110 ++ .../compute/kernels/scalar_cast_temporal.cc | 276 +++ .../{cast_test.cc => scalar_cast_test.cc} | 88 +- .../arrow/compute/kernels/scalar_compare.cc | 117 ++ ...nchmark.cc => scalar_compare_benchmark.cc} | 12 +- ...compare_test.cc => scalar_compare_test.cc} | 330 ++-- .../compute/kernels/scalar_set_lookup.cc | 317 ++++ .../compute/kernels/scalar_set_lookup_test.cc | 677 +++++++ cpp/src/arrow/compute/kernels/sum.cc | 114 -- cpp/src/arrow/compute/kernels/sum.h | 71 - cpp/src/arrow/compute/kernels/sum_internal.h | 207 --- .../arrow/compute/kernels/util_internal.cc | 337 ---- cpp/src/arrow/compute/kernels/util_internal.h | 154 -- .../compute/kernels/util_internal_test.cc | 247 --- .../kernels/{filter.cc => vector_filter.cc} | 154 +- ...enchmark.cc => vector_filter_benchmark.cc} | 20 +- .../{filter_test.cc => vector_filter_test.cc} | 116 +- .../kernels/{hash.cc => vector_hash.cc} | 76 +- .../vector_hash_benchmark.cc} | 21 +- .../{hash_test.cc => vector_hash_test.cc} | 241 ++- .../arrow/compute/kernels/vector_partition.cc | 107 ++ ...hmark.cc => vector_partition_benchmark.cc} | 7 +- ...dices_test.cc => vector_partition_test.cc} | 78 +- .../{sort_to_indices.cc => vector_sort.cc} | 4 +- ..._benchmark.cc => vector_sort_benchmark.cc} | 8 +- ...to_indices_test.cc => vector_sort_test.cc} | 23 +- .../kernels/{take.cc => vector_take.cc} | 30 +- ..._benchmark.cc => vector_take_benchmark.cc} | 8 +- ...take_internal.h => vector_take_internal.h} | 2 +- .../{take_test.cc => vector_take_test.cc} | 49 +- cpp/src/arrow/compute/options.h | 155 ++ cpp/src/arrow/compute/registry.cc | 124 ++ cpp/src/arrow/compute/registry.h | 74 + cpp/src/arrow/compute/registry_test.cc | 89 + cpp/src/arrow/compute/{kernels => }/take.h | 118 +- cpp/src/arrow/compute/test_util.h | 24 +- cpp/src/arrow/dataset/filter.cc | 86 +- cpp/src/arrow/dataset/filter.h | 81 +- cpp/src/arrow/dataset/filter_test.cc | 21 +- cpp/src/arrow/dataset/scanner.h | 1 - cpp/src/arrow/dataset/scanner_internal.h | 2 +- cpp/src/arrow/dataset/scanner_test.cc | 1 - cpp/src/arrow/dataset/type_fwd.h | 2 +- cpp/src/arrow/datum.cc | 188 ++ cpp/src/arrow/datum.h | 270 +++ cpp/src/arrow/datum_test.cc | 161 ++ cpp/src/arrow/python/arrow_to_pandas.cc | 26 +- cpp/src/arrow/python/numpy_to_arrow.cc | 22 +- cpp/src/arrow/python/numpy_to_arrow.h | 2 +- cpp/src/arrow/stl.h | 14 +- cpp/src/arrow/stl_test.cc | 6 +- cpp/src/arrow/testing/gtest_util.cc | 16 +- cpp/src/arrow/testing/gtest_util.h | 8 +- cpp/src/arrow/type.cc | 86 + cpp/src/arrow/type.h | 10 + cpp/src/arrow/type_fwd.h | 1 + cpp/src/gandiva/arrow.h | 27 +- .../parquet/arrow/arrow_reader_writer_test.cc | 16 +- cpp/src/parquet/arrow/reader_internal.cc | 4 +- cpp/src/parquet/column_writer.cc | 12 +- cpp/src/parquet/encoding_test.cc | 1 - testing | 2 +- 128 files changed, 10719 insertions(+), 8270 deletions(-) create mode 100644 cpp/src/arrow/compute/README.md create mode 100644 cpp/src/arrow/compute/api_eager.cc create mode 100644 cpp/src/arrow/compute/api_eager.h create mode 100644 cpp/src/arrow/compute/cast.h delete mode 100644 cpp/src/arrow/compute/compute_test.cc delete mode 100644 cpp/src/arrow/compute/context.h create mode 100644 cpp/src/arrow/compute/exec.cc create mode 100644 cpp/src/arrow/compute/exec.h create mode 100644 cpp/src/arrow/compute/exec_internal.h create mode 100644 cpp/src/arrow/compute/exec_test.cc create mode 100644 cpp/src/arrow/compute/filter.h create mode 100644 cpp/src/arrow/compute/function.cc create mode 100644 cpp/src/arrow/compute/function.h create mode 100644 cpp/src/arrow/compute/function_test.cc create mode 100644 cpp/src/arrow/compute/kernel.cc create mode 100644 cpp/src/arrow/compute/kernel_test.cc delete mode 100644 cpp/src/arrow/compute/kernels/add.cc delete mode 100644 cpp/src/arrow/compute/kernels/add.h delete mode 100644 cpp/src/arrow/compute/kernels/aggregate.cc delete mode 100644 cpp/src/arrow/compute/kernels/aggregate.h create mode 100644 cpp/src/arrow/compute/kernels/aggregate_basic.cc rename cpp/src/arrow/compute/kernels/{isin.h => aggregate_internal.h} (53%) delete mode 100644 cpp/src/arrow/compute/kernels/boolean.cc delete mode 100644 cpp/src/arrow/compute/kernels/boolean.h delete mode 100644 cpp/src/arrow/compute/kernels/cast.cc delete mode 100644 cpp/src/arrow/compute/kernels/cast.h create mode 100644 cpp/src/arrow/compute/kernels/codegen_internal.cc create mode 100644 cpp/src/arrow/compute/kernels/codegen_internal.h create mode 100644 cpp/src/arrow/compute/kernels/common.h delete mode 100644 cpp/src/arrow/compute/kernels/compare.cc delete mode 100644 cpp/src/arrow/compute/kernels/compare.h delete mode 100644 cpp/src/arrow/compute/kernels/count.cc delete mode 100644 cpp/src/arrow/compute/kernels/count.h delete mode 100644 cpp/src/arrow/compute/kernels/filter.h delete mode 100644 cpp/src/arrow/compute/kernels/hash.h delete mode 100644 cpp/src/arrow/compute/kernels/isin.cc delete mode 100644 cpp/src/arrow/compute/kernels/isin_test.cc delete mode 100644 cpp/src/arrow/compute/kernels/match.cc delete mode 100644 cpp/src/arrow/compute/kernels/match.h delete mode 100644 cpp/src/arrow/compute/kernels/match_test.cc delete mode 100644 cpp/src/arrow/compute/kernels/mean.cc delete mode 100644 cpp/src/arrow/compute/kernels/mean.h delete mode 100644 cpp/src/arrow/compute/kernels/minmax.h delete mode 100644 cpp/src/arrow/compute/kernels/nth_to_indices.cc delete mode 100644 cpp/src/arrow/compute/kernels/nth_to_indices.h rename cpp/src/arrow/compute/kernels/{sort_to_indices.h => registry.h} (53%) rename cpp/src/arrow/compute/{context.cc => kernels/scalar_arithmetic.cc} (51%) rename cpp/src/arrow/compute/kernels/{add_test.cc => scalar_arithmetic_test.cc} (82%) create mode 100644 cpp/src/arrow/compute/kernels/scalar_boolean.cc rename cpp/src/arrow/compute/kernels/{boolean_test.cc => scalar_boolean_test.cc} (89%) create mode 100644 cpp/src/arrow/compute/kernels/scalar_cast.cc create mode 100644 cpp/src/arrow/compute/kernels/scalar_cast_boolean.cc create mode 100644 cpp/src/arrow/compute/kernels/scalar_cast_decimal.cc create mode 100644 cpp/src/arrow/compute/kernels/scalar_cast_internal.h create mode 100644 cpp/src/arrow/compute/kernels/scalar_cast_nested.cc create mode 100644 cpp/src/arrow/compute/kernels/scalar_cast_numeric.cc create mode 100644 cpp/src/arrow/compute/kernels/scalar_cast_string.cc create mode 100644 cpp/src/arrow/compute/kernels/scalar_cast_temporal.cc rename cpp/src/arrow/compute/kernels/{cast_test.cc => scalar_cast_test.cc} (96%) create mode 100644 cpp/src/arrow/compute/kernels/scalar_compare.cc rename cpp/src/arrow/compute/kernels/{compare_benchmark.cc => scalar_compare_benchmark.cc} (89%) rename cpp/src/arrow/compute/kernels/{compare_test.cc => scalar_compare_test.cc} (52%) create mode 100644 cpp/src/arrow/compute/kernels/scalar_set_lookup.cc create mode 100644 cpp/src/arrow/compute/kernels/scalar_set_lookup_test.cc delete mode 100644 cpp/src/arrow/compute/kernels/sum.cc delete mode 100644 cpp/src/arrow/compute/kernels/sum.h delete mode 100644 cpp/src/arrow/compute/kernels/sum_internal.h delete mode 100644 cpp/src/arrow/compute/kernels/util_internal.cc delete mode 100644 cpp/src/arrow/compute/kernels/util_internal.h delete mode 100644 cpp/src/arrow/compute/kernels/util_internal_test.cc rename cpp/src/arrow/compute/kernels/{filter.cc => vector_filter.cc} (64%) rename cpp/src/arrow/compute/kernels/{filter_benchmark.cc => vector_filter_benchmark.cc} (86%) rename cpp/src/arrow/compute/kernels/{filter_test.cc => vector_filter_test.cc} (87%) rename cpp/src/arrow/compute/kernels/{hash.cc => vector_hash.cc} (87%) rename cpp/src/arrow/compute/{compute_benchmark.cc => kernels/vector_hash_benchmark.cc} (93%) rename cpp/src/arrow/compute/kernels/{hash_test.cc => vector_hash_test.cc} (68%) create mode 100644 cpp/src/arrow/compute/kernels/vector_partition.cc rename cpp/src/arrow/compute/kernels/{nth_to_indices_benchmark.cc => vector_partition_benchmark.cc} (90%) rename cpp/src/arrow/compute/kernels/{nth_to_indices_test.cc => vector_partition_test.cc} (59%) rename cpp/src/arrow/compute/kernels/{sort_to_indices.cc => vector_sort.cc} (99%) rename cpp/src/arrow/compute/kernels/{sort_to_indices_benchmark.cc => vector_sort_benchmark.cc} (92%) rename cpp/src/arrow/compute/kernels/{sort_to_indices_test.cc => vector_sort_test.cc} (90%) rename cpp/src/arrow/compute/kernels/{take.cc => vector_take.cc} (87%) rename cpp/src/arrow/compute/kernels/{take_benchmark.cc => vector_take_benchmark.cc} (95%) rename cpp/src/arrow/compute/kernels/{take_internal.h => vector_take_internal.h} (99%) rename cpp/src/arrow/compute/kernels/{take_test.cc => vector_take_test.cc} (93%) create mode 100644 cpp/src/arrow/compute/options.h create mode 100644 cpp/src/arrow/compute/registry.cc create mode 100644 cpp/src/arrow/compute/registry.h create mode 100644 cpp/src/arrow/compute/registry_test.cc rename cpp/src/arrow/compute/{kernels => }/take.h (60%) create mode 100644 cpp/src/arrow/datum.cc create mode 100644 cpp/src/arrow/datum.h create mode 100644 cpp/src/arrow/datum_test.cc diff --git a/cpp/CMakeLists.txt b/cpp/CMakeLists.txt index c5f65835499..662461ec89f 100644 --- a/cpp/CMakeLists.txt +++ b/cpp/CMakeLists.txt @@ -309,6 +309,10 @@ if(ARROW_DATASET) set(ARROW_FILESYSTEM ON) endif() +if(ARROW_GANDIVA) + set(ARROW_COMPUTE ON) +endif() + if(ARROW_PARQUET) set(ARROW_COMPUTE ON) endif() diff --git a/cpp/src/arrow/CMakeLists.txt b/cpp/src/arrow/CMakeLists.txt index b06147f2247..2e62391c442 100644 --- a/cpp/src/arrow/CMakeLists.txt +++ b/cpp/src/arrow/CMakeLists.txt @@ -121,7 +121,6 @@ endfunction() set(ARROW_SRCS array.cc - builder.cc array/builder_adaptive.cc array/builder_base.cc array/builder_binary.cc @@ -134,8 +133,10 @@ set(ARROW_SRCS array/dict_internal.cc array/diff.cc array/validate.cc + builder.cc buffer.cc compare.cc + datum.cc device.cc extension_type.cc memory_pool.cc @@ -319,24 +320,24 @@ endif() if(ARROW_COMPUTE) list(APPEND ARROW_SRCS - compute/context.cc - compute/kernels/aggregate.cc - compute/kernels/boolean.cc - compute/kernels/cast.cc - compute/kernels/compare.cc - compute/kernels/count.cc - compute/kernels/hash.cc - compute/kernels/filter.cc - compute/kernels/mean.cc - compute/kernels/minmax.cc - compute/kernels/sort_to_indices.cc - compute/kernels/nth_to_indices.cc - compute/kernels/sum.cc - compute/kernels/add.cc - compute/kernels/take.cc - compute/kernels/isin.cc - compute/kernels/match.cc - compute/kernels/util_internal.cc) + compute/api_eager.cc + compute/exec.cc + compute/function.cc + compute/kernel.cc + compute/registry.cc + compute/kernels/codegen_internal.cc + compute/kernels/aggregate_basic.cc + compute/kernels/scalar_arithmetic.cc + compute/kernels/scalar_boolean.cc + compute/kernels/scalar_compare.cc + compute/kernels/scalar_set_lookup.cc + compute/kernels/vector_partition.cc + # compute/kernels/scalar_cast.cc + # compute/kernels/filter.cc + # compute/kernels/take.cc + # compute/kernels/hash.cc + # compute/kernels/sort_to_indices.cc + ) endif() if(ARROW_FILESYSTEM) @@ -524,6 +525,7 @@ endif() add_arrow_test(misc_test SOURCES + datum_test.cc memory_pool_test.cc result_test.cc pretty_print_test.cc diff --git a/cpp/src/arrow/array/diff_test.cc b/cpp/src/arrow/array/diff_test.cc index 0e9ccc40504..4917d4524d1 100644 --- a/cpp/src/arrow/array/diff_test.cc +++ b/cpp/src/arrow/array/diff_test.cc @@ -33,8 +33,7 @@ #include "arrow/array/diff.h" #include "arrow/buffer.h" #include "arrow/builder.h" -#include "arrow/compute/context.h" -#include "arrow/compute/kernels/filter.h" +#include "arrow/compute/api.h" #include "arrow/status.h" #include "arrow/testing/gtest_common.h" #include "arrow/testing/random.h" @@ -119,20 +118,19 @@ class DiffTest : public ::testing::Test { void BaseAndTargetFromRandomFilter(std::shared_ptr values, double filter_probability) { - compute::Datum out_datum, base_filter, target_filter; + std::shared_ptr base_filter, target_filter; do { base_filter = this->rng_.Boolean(values->length(), filter_probability, 0.0); target_filter = this->rng_.Boolean(values->length(), filter_probability, 0.0); - } while (base_filter.Equals(target_filter)); + } while (base_filter->Equals(target_filter)); - ASSERT_OK(compute::Filter(&ctx_, values, base_filter, {}, &out_datum)); + ASSERT_OK_AND_ASSIGN(Datum out_datum, compute::Filter(values, base_filter)); base_ = out_datum.make_array(); - ASSERT_OK(compute::Filter(&ctx_, values, target_filter, {}, &out_datum)); + ASSERT_OK_AND_ASSIGN(out_datum, compute::Filter(values, target_filter)); target_ = out_datum.make_array(); } - compute::FunctionContext ctx_; random::RandomArrayGenerator rng_; std::shared_ptr edits_; std::shared_ptr base_, target_; @@ -616,7 +614,6 @@ void MakeSameLength(std::shared_ptr* a, std::shared_ptr* b) { } TEST_F(DiffTest, CompareRandomStruct) { - compute::FunctionContext ctx; for (auto null_probability : {0.0, 0.25}) { constexpr auto length = 1 << 10; auto int32_values = this->rng_.Int32(length, 0, 127, null_probability); diff --git a/cpp/src/arrow/compute/CMakeLists.txt b/cpp/src/arrow/compute/CMakeLists.txt index 495a4a3f944..8ee87047a3d 100644 --- a/cpp/src/arrow/compute/CMakeLists.txt +++ b/cpp/src/arrow/compute/CMakeLists.txt @@ -58,7 +58,12 @@ function(ADD_ARROW_COMPUTE_TEST REL_TEST_NAME) ${ARG_UNPARSED_ARGUMENTS}) endfunction() -add_arrow_compute_test(compute_test) -add_arrow_benchmark(compute_benchmark) +add_arrow_compute_test(internals_test + SOURCES + function_test.cc + exec_test.cc + kernel_test.cc + registry_test.cc) +add_arrow_compute_test(exec_test) add_subdirectory(kernels) diff --git a/cpp/src/arrow/compute/README.md b/cpp/src/arrow/compute/README.md new file mode 100644 index 00000000000..80d8918e3d9 --- /dev/null +++ b/cpp/src/arrow/compute/README.md @@ -0,0 +1,58 @@ + + +## Apache Arrow C++ Compute Functions + +This submodule contains analytical functions that process primarily Arrow +columnar data; some functions can process scalar or Arrow-based array +inputs. These are intended for use inside query engines, data frame libraries, +etc. + +Many functions have SQL-like semantics in that they perform elementwise or +scalar operations on whole arrays at a time. Other functions are not SQL-like +and compute results that may be a different length or whose results depend on +the order of the values. + +Some basic terminology: + +* We use the term "function" to refer to particular general operation that may + have many different implementations corresponding to different combinations + of types or function behavior options. +* We call a specific implementation of a function a "kernel". When executing a + function on inputs, we must first select a suitable kernel (kernel selection + is called "dispatching") corresponding to the value types of the inputs +* Functions along with their kernel implementations are collected in a + "function registry". Given a function name and argument types, we can look up + that function and dispatch to a compatible kernel. + +Types of functions + +* Scalar functions: elementwise functions that perform scalar operations in a + vectorized manner. These functions are generally valid for SQL-like + context. These are called "scalar" in that the functions executed consider + each value in an array independently, and the output array or arrays have the + same length as the input arrays. The result for each array cell is generally + independent of its position in the array. +* Vector functions, which produce a result whose output is generally dependent + on the entire contents of the input arrays. These functions **are generally + not valid** for SQL-like processing because the output size may be different + than the input size, and the result may change based on the order of the + values in the array. This includes things like array subselection, sorting, + hashing, and more. +* Scalar aggregate functions of which can be used in a SQL-like context \ No newline at end of file diff --git a/cpp/src/arrow/compute/api.h b/cpp/src/arrow/compute/api.h index 8e60a39a0fd..8c3a2ac08ba 100644 --- a/cpp/src/arrow/compute/api.h +++ b/cpp/src/arrow/compute/api.h @@ -15,20 +15,13 @@ // specific language governing permissions and limitations // under the License. -#pragma once +// NOTE: API is EXPERIMENTAL and will change without going through a +// deprecation cycle -#include "arrow/compute/context.h" // IWYU pragma: export -#include "arrow/compute/kernel.h" // IWYU pragma: export +#pragma once -#include "arrow/compute/kernels/boolean.h" // IWYU pragma: export -#include "arrow/compute/kernels/cast.h" // IWYU pragma: export -#include "arrow/compute/kernels/compare.h" // IWYU pragma: export -#include "arrow/compute/kernels/count.h" // IWYU pragma: export -#include "arrow/compute/kernels/filter.h" // IWYU pragma: export -#include "arrow/compute/kernels/hash.h" // IWYU pragma: export -#include "arrow/compute/kernels/isin.h" // IWYU pragma: export -#include "arrow/compute/kernels/mean.h" // IWYU pragma: export -#include "arrow/compute/kernels/nth_to_indices.h" // IWYU pragma: export -#include "arrow/compute/kernels/sort_to_indices.h" // IWYU pragma: export -#include "arrow/compute/kernels/sum.h" // IWYU pragma: export -#include "arrow/compute/kernels/take.h" // IWYU pragma: export +#include "arrow/compute/api_eager.h" // IWYU pragma: export +#include "arrow/compute/exec.h" // IWYU pragma: export +#include "arrow/compute/function.h" // IWYU pragma: export +#include "arrow/compute/kernel.h" // IWYU pragma: export +#include "arrow/compute/registry.h" // IWYU pragma: export diff --git a/cpp/src/arrow/compute/api_eager.cc b/cpp/src/arrow/compute/api_eager.cc new file mode 100644 index 00000000000..129a40f69f8 --- /dev/null +++ b/cpp/src/arrow/compute/api_eager.cc @@ -0,0 +1,209 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#include "arrow/compute/api_eager.h" + +#include +#include +#include + +#include "arrow/compute/exec.h" + +namespace arrow { +namespace compute { + +#define SCALAR_EAGER_UNARY(NAME, REGISTRY_NAME) \ + Result NAME(const Datum& value, ExecContext* ctx) { \ + return ExecScalarFunction(ctx, REGISTRY_NAME, {value}); \ + } + +#define SCALAR_EAGER_BINARY(NAME, REGISTRY_NAME) \ + Result NAME(const Datum& left, const Datum& right, ExecContext* ctx) { \ + return ExecScalarFunction(ctx, REGISTRY_NAME, {left, right}); \ + } + +// ---------------------------------------------------------------------- +// Arithmetic + +SCALAR_EAGER_BINARY(Add, "add") + +// ---------------------------------------------------------------------- +// Set-related operations + +static Result ExecSetLookup(const std::string& func_name, const Datum& data, + std::shared_ptr value_set, + bool add_nulls_to_hash_table, ExecContext* ctx) { + if (value_set->length() > 0 && !data.type()->Equals(value_set->type())) { + std::stringstream ss; + ss << "Array type didn't match type of values set: " << data.type()->ToString() + << " vs " << value_set->type()->ToString(); + return Status::Invalid(ss.str()); + } + SetLookupOptions options(std::move(value_set), !add_nulls_to_hash_table); + return ExecScalarFunction(ctx, func_name, {data}, &options); +} + +Result IsIn(const Datum& values, std::shared_ptr value_set, + ExecContext* ctx) { + return ExecSetLookup("isin", values, std::move(value_set), + /*add_nulls_to_hash_table=*/false, ctx); +} + +Result Match(const Datum& values, std::shared_ptr value_set, + ExecContext* ctx) { + return ExecSetLookup("match", values, std::move(value_set), + /*add_nulls_to_hash_table=*/true, ctx); +} + +// ---------------------------------------------------------------------- +// Boolean functions + +SCALAR_EAGER_UNARY(Invert, "invert") +SCALAR_EAGER_BINARY(And, "and") +SCALAR_EAGER_BINARY(KleeneAnd, "and_kleene") +SCALAR_EAGER_BINARY(Or, "or") +SCALAR_EAGER_BINARY(KleeneOr, "or_kleene") +SCALAR_EAGER_BINARY(Xor, "xor") + +// ---------------------------------------------------------------------- + +Result Compare(const Datum& left, const Datum& right, CompareOptions options, + ExecContext* ctx) { + std::string func_name; + switch (options.op) { + case CompareOperator::EQUAL: + func_name = "=="; + break; + case CompareOperator::NOT_EQUAL: + func_name = "!="; + break; + case CompareOperator::GREATER: + func_name = ">"; + break; + case CompareOperator::GREATER_EQUAL: + func_name = ">="; + break; + case CompareOperator::LESS: + func_name = "<"; + break; + case CompareOperator::LESS_EQUAL: + func_name = "<="; + break; + default: + DCHECK(false); + break; + } + return ExecScalarFunction(ctx, func_name, {left, right}, &options); +} + +// ---------------------------------------------------------------------- +// Scalar aggregates + +Result Count(const Datum& value, CountOptions options, ExecContext* ctx) { + return ExecScalarAggregateFunction(ctx, "count", {value}, &options); +} + +Result Mean(const Datum& value, ExecContext* ctx) { + return ExecScalarAggregateFunction(ctx, "mean", {value}); +} + +Result Sum(const Datum& value, ExecContext* ctx) { + return ExecScalarAggregateFunction(ctx, "sum", {value}); +} + +// Result MinMax(const Datum& value, const MinMaxOptions& options, +// ExecContext* ctx) { +// return ExecScalarAggregateFunction(ctx, "minmax", {value}); +// } + +// ---------------------------------------------------------------------- +// Vector functions + +namespace { + +// Status InvokeHash(FunctionContext* ctx, HashKernel* func, const Datum& value, +// std::vector* kernel_outputs, +// std::shared_ptr* dictionary) { +// RETURN_NOT_OK(detail::InvokeUnaryArrayKernel(ctx, func, value, kernel_outputs)); +// std::shared_ptr dict_data; +// RETURN_NOT_OK(func->GetDictionary(&dict_data)); +// *dictionary = MakeArray(dict_data); +// return Status::OK(); +// } + +} // namespace + +Result> Unique(const Datum& value, ExecContext* ctx) { + // std::unique_ptr func; + // RETURN_NOT_OK(GetUniqueKernel(ctx, value.type(), &func)); + // std::vector dummy_outputs; + // return InvokeHash(ctx, func.get(), value, &dummy_outputs, out); + return Status::NotImplemented("NYI"); +} + +Result DictionaryEncode(const Datum& value, ExecContext* ctx) { + // std::unique_ptr func; + // RETURN_NOT_OK(GetDictionaryEncodeKernel(ctx, value.type(), &func)); + // std::shared_ptr dict; + // std::vector indices_outputs; + // RETURN_NOT_OK(InvokeHash(ctx, func.get(), value, &indices_outputs, &dict)); + // auto dict_type = dictionary(func->out_type(), dict->type()); + // // Wrap indices in dictionary arrays for result + // std::vector> dict_chunks; + // for (const Datum& datum : indices_outputs) { + // dict_chunks.emplace_back( + // std::make_shared(dict_type, datum.make_array(), dict)); + // } + // *out = detail::WrapArraysLike(value, dict_type, dict_chunks); + // return Status::OK(); + return Status::NotImplemented("NYI"); +} + +const char kValuesFieldName[] = "values"; +const char kCountsFieldName[] = "counts"; +const int32_t kValuesFieldIndex = 0; +const int32_t kCountsFieldIndex = 1; + +Result> ValueCounts(const Datum& value, ExecContext* ctx) { + // std::unique_ptr func; + // RETURN_NOT_OK(GetValueCountsKernel(ctx, value.type(), &func)); + // // Calls return nothing for counts. + // std::vector unused_output; + // std::shared_ptr uniques; + // RETURN_NOT_OK(InvokeHash(ctx, func.get(), value, &unused_output, &uniques)); + // Datum value_counts; + // RETURN_NOT_OK(func->FlushFinal(&value_counts)); + // auto data_type = std::make_shared(std::vector>{ + // std::make_shared(kValuesFieldName, uniques->type()), + // std::make_shared(kCountsFieldName, int64())}); + // *counts = std::make_shared( + // data_type, uniques->length(), + // std::vector>{uniques, MakeArray(value_counts.array())}); + // return Status::OK(); + return Status::NotImplemented("NYI"); +} + +Result> PartitionIndices(const Array& values, int64_t n, + ExecContext* ctx) { + PartitionOptions options(/*pivot=*/n); + ARROW_ASSIGN_OR_RAISE(Datum result, ExecVectorFunction(ctx, "partition_indices", + {Datum(values)}, &options)); + return result.make_array(); +} + +} // namespace compute +} // namespace arrow diff --git a/cpp/src/arrow/compute/api_eager.h b/cpp/src/arrow/compute/api_eager.h new file mode 100644 index 00000000000..d41210c9594 --- /dev/null +++ b/cpp/src/arrow/compute/api_eager.h @@ -0,0 +1,341 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +// Eager evaluation convenience APIs for invoking common functions, including +// necessary memory allocations + +#pragma once + +#include + +#include "arrow/compute/cast.h" +#include "arrow/compute/exec.h" +#include "arrow/compute/filter.h" +#include "arrow/compute/options.h" +#include "arrow/compute/take.h" +#include "arrow/datum.h" +#include "arrow/result.h" + +namespace arrow { +namespace compute { + +class ExecContext; + +// ---------------------------------------------------------------------- + +/// \brief Add two values together. Array values must be the same length. If a +/// value is null in either addend, the result is null +/// +/// \param[in] left the first value +/// \param[in] right the second value +/// \param[in] ctx the function execution context, optional +/// \return the elementwise addition of the values +ARROW_EXPORT +Result Add(const Datum& left, const Datum& right, ExecContext* ctx = NULLPTR); + +/// \brief Compare a numeric array with a scalar. +/// +/// \param[in] left datum to compare, must be an Array +/// \param[in] right datum to compare, must be a Scalar of the same type than +/// left Datum. +/// \param[in] options compare options +/// \param[in] ctx the function execution context, optional +/// \return resulting datum +/// +/// Note on floating point arrays, this uses ieee-754 compare semantics. +/// +/// \since 1.0.0 +/// \note API not yet finalized +ARROW_EXPORT +Result Compare(const Datum& left, const Datum& right, + struct CompareOptions options, ExecContext* ctx = NULLPTR); + +/// \brief Invert the values of a boolean datum +/// \param[in] value datum to invert +/// \param[in] ctx the function execution context, optional +/// \return the resulting datum +/// +/// \since 1.0.0 +/// \note API not yet finalized +ARROW_EXPORT +Result Invert(const Datum& value, ExecContext* ctx = NULLPTR); + +/// \brief Element-wise AND of two boolean datums which always propagates nulls +/// (null and false is null). +/// +/// \param[in] left left operand (array) +/// \param[in] right right operand (array) +/// \param[in] ctx the function execution context, optional +/// \return the resulting datum +/// +/// \since 1.0.0 +/// \note API not yet finalized +ARROW_EXPORT +Result And(const Datum& left, const Datum& right, ExecContext* ctx = NULLPTR); + +/// \brief Element-wise AND of two boolean datums with a Kleene truth table +/// (null and false is false). +/// +/// \param[in] left left operand (array) +/// \param[in] right right operand (array) +/// \param[in] ctx the function execution context, optional +/// \return the resulting datum +/// +/// \since 1.0.0 +/// \note API not yet finalized +ARROW_EXPORT +Result KleeneAnd(const Datum& left, const Datum& right, + ExecContext* ctx = NULLPTR); + +/// \brief Element-wise OR of two boolean datums which always propagates nulls +/// (null and true is null). +/// +/// \param[in] left left operand (array) +/// \param[in] right right operand (array) +/// \param[in] ctx the function execution context, optional +/// \return the resulting datum +/// +/// \since 1.0.0 +/// \note API not yet finalized +ARROW_EXPORT +Result Or(const Datum& left, const Datum& right, ExecContext* ctx = NULLPTR); + +/// \brief Element-wise OR of two boolean datums with a Kleene truth table +/// (null or true is true). +/// +/// \param[in] left left operand (array) +/// \param[in] right right operand (array) +/// \param[in] ctx the function execution context, optional +/// \return the resulting datum +/// +/// \since 1.0.0 +/// \note API not yet finalized +ARROW_EXPORT +Result KleeneOr(const Datum& left, const Datum& right, ExecContext* ctx = NULLPTR); + +/// \brief Element-wise XOR of two boolean datums +/// \param[in] left left operand (array) +/// \param[in] right right operand (array) +/// \param[in] ctx the function execution context, optional +/// \return the resulting datum +/// +/// \since 1.0.0 +/// \note API not yet finalized +ARROW_EXPORT +Result Xor(const Datum& left, const Datum& right, ExecContext* ctx = NULLPTR); + +/// \brief IsIn returns true for each element of `values` that is contained in +/// `value_set` +/// +/// If null occurs in left, if null count in right is not 0, +/// it returns true, else returns null. +/// +/// \param[in] values array-like input to look up in value_set +/// \param[in] value_set Array input +/// \param[in] ctx the function execution context, optional +/// \return the resulting datum +/// +/// \since 1.0.0 +/// \note API not yet finalized +ARROW_EXPORT +Result IsIn(const Datum& values, std::shared_ptr value_set, + ExecContext* ctx = NULLPTR); + +/// \brief Match examines each slot in the haystack against a needles array. +/// If the value is not found in needles, null will be output. +/// If found, the index of occurrence within needles (ignoring duplicates) +/// will be output. +/// +/// For example given haystack = [99, 42, 3, null] and +/// needles = [3, 3, 99], the output will be = [1, null, 0, null] +/// +/// Note: Null in the haystack is considered to match +/// a null in the needles array. For example given +/// haystack = [99, 42, 3, null] and needles = [3, 99, null], +/// the output will be = [1, null, 0, 2] +/// +/// \param[in] haystack array-like input +/// \param[in] needles Array input +/// \param[in] ctx the function execution context, optional +/// \return the resulting datum +/// +/// \since 1.0.0 +/// \note API not yet finalized +ARROW_EXPORT +Result Match(const Datum& haystack, std::shared_ptr needles, + ExecContext* ctx = NULLPTR); + +/// \brief Returns indices that partition an array around n-th +/// sorted element. +/// +/// Find index of n-th(0 based) smallest value and perform indirect +/// partition of an array around that element. Output indices[0 ~ n-1] +/// holds values no greater than n-th element, and indices[n+1 ~ end] +/// holds values no less than n-th element. Elements in each partition +/// is not sorted. Nulls will be partitioned to the end of the output. +/// Output is not guaranteed to be stable. +/// +/// \param[in] values array to be partitioned +/// \param[in] n pivot array around sorted n-th element +/// \param[in] ctx the function execution context, optional +/// \return offsets indices that would partition an array +ARROW_EXPORT +Result> PartitionIndices(const Array& values, int64_t n, + ExecContext* ctx = NULLPTR); + +ARROW_DEPRECATED("Deprecated in 1.0.0. Use PartitionIndices") +Result> NthToIndices(const Array& values, int64_t n, + ExecContext* ctx = NULLPTR) { + return PartitionIndices(values, n, ctx); +} + +/// \brief Returns the indices that would sort an array. +/// +/// Perform an indirect sort of array. The output array will contain +/// indices that would sort an array, which would be the same length +/// as input. Nulls will be stably partitioned to the end of the output. +/// +/// For example given values = [null, 1, 3.3, null, 2, 5.3], the output +/// will be [1, 4, 2, 5, 0, 3] +/// +/// \param[in] values array to sort +/// \param[in] ctx the function execution context, optional +/// \return offsets indices that would sort an array +ARROW_EXPORT +Result> SortToIndices(const Array& values, + ExecContext* ctx = NULLPTR); + +/// \brief Compute unique elements from an array-like object +/// +/// Note if a null occurs in the input it will NOT be included in the output. +/// +/// \param[in] datum array-like input +/// \param[in] ctx the function execution context, optional +/// \return result as Array +/// +/// \since 1.0.0 +/// \note API not yet finalized +ARROW_EXPORT +Result> Unique(const Datum& datum, ExecContext* ctx = NULLPTR); + +// Constants for accessing the output of ValueCounts +ARROW_EXPORT extern const char kValuesFieldName[]; +ARROW_EXPORT extern const char kCountsFieldName[]; +ARROW_EXPORT extern const int32_t kValuesFieldIndex; +ARROW_EXPORT extern const int32_t kCountsFieldIndex; +/// \brief Return counts of unique elements from an array-like object. +/// +/// Note that the counts do not include counts for nulls in the array. These can be +/// obtained separately from metadata. +/// +/// For floating point arrays there is no attempt to normalize -0.0, 0.0 and NaN values +/// which can lead to unexpected results if the input Array has these values. +/// +/// \param[in] value array-like input +/// \param[in] ctx the function execution context, optional +/// \return counts An array of structs. +/// +/// \since 1.0.0 +/// \note API not yet finalized +ARROW_EXPORT +Result> ValueCounts(const Datum& value, + ExecContext* ctx = NULLPTR); + +/// \brief Dictionary-encode values in an array-like object +/// \param[in] data array-like input +/// \param[in] ctx the function execution context, optional +/// \return result with same shape and type as input +/// +/// \since 1.0.0 +/// \note API not yet finalized +ARROW_EXPORT +Result DictionaryEncode(const Datum& data, ExecContext* ctx = NULLPTR); + +// ---------------------------------------------------------------------- +// Aggregate functions + +/// \brief Count non-null (or null) values in an array. +/// +/// \param[in] options counting options, see CountOptions for more information +/// \param[in] datum to count +/// \param[in] ctx the function execution context, optional +/// \return out resulting datum +/// +/// \since 1.0.0 +/// \note API not yet finalized +ARROW_EXPORT +Result Count(const Datum& datum, CountOptions options = CountOptions::Defaults(), + ExecContext* ctx = NULLPTR); + +/// \brief Compute the mean of a numeric array. +/// +/// \param[in] value datum to compute the mean, expecting Array +/// \param[in] ctx the function execution context, optional +/// \return datum of the computed mean as a DoubleScalar +/// +/// \since 1.0.0 +/// \note API not yet finalized +ARROW_EXPORT +Result Mean(const Datum& value, ExecContext* ctx = NULLPTR); + +/// \brief Sum values of a numeric array. +/// +/// \param[in] value datum to sum, expecting Array or ChunkedArray +/// \param[in] ctx the function execution context, optional +/// \return datum of the computed sum as a Scalar +/// +/// \since 1.0.0 +/// \note API not yet finalized +ARROW_EXPORT +Result Sum(const Datum& value, ExecContext* ctx = NULLPTR); + +/// \brief Calculate the min / max of a numeric array +/// +/// This function returns both the min and max as a collection. The resulting +/// datum thus consists of two scalar datums: {Datum(min), Datum(max)} +/// +/// \param[in] value input datum, expecting Array or ChunkedArray +/// \param[in] options see MinMaxOptions for more information +/// \param[in] ctx the function execution context, optional +/// \return resulting datum containing a {min, max} collection +/// +/// \since 1.0.0 +/// \note API not yet finalized +ARROW_EXPORT +Result MinMax(const Datum& value, + const MinMaxOptions& options = MinMaxOptions::Defaults(), + ExecContext* ctx = NULLPTR); + +/// \brief Calculate the min / max of a numeric array. +/// +/// This function returns both the min and max as a collection. The resulting +/// datum thus consists of two scalar datums: {Datum(min), Datum(max)} +/// +/// \param[in] array input array +/// \param[in] options see MinMaxOptions for more information +/// \param[in] ctx the function execution context, optional +/// \return resulting datum containing a {min, max} collection +/// +/// \since 1.0.0 +/// \note API not yet finalized +ARROW_EXPORT +Result MinMax(const Array& array, + const MinMaxOptions& options = MinMaxOptions::Defaults(), + ExecContext* ctx = NULLPTR); + +} // namespace compute +} // namespace arrow diff --git a/cpp/src/arrow/compute/cast.h b/cpp/src/arrow/compute/cast.h new file mode 100644 index 00000000000..1c8c1f511d9 --- /dev/null +++ b/cpp/src/arrow/compute/cast.h @@ -0,0 +1,68 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#pragma once + +#include + +#include "arrow/compute/exec.h" +#include "arrow/compute/options.h" +#include "arrow/datum.h" +#include "arrow/result.h" + +namespace arrow { +namespace compute { + +class ExecContext; + +// ---------------------------------------------------------------------- +// Convenience invocation APIs for a number of kernels + +/// \brief Cast from one array type to another +/// \param[in] value array to cast +/// \param[in] to_type type to cast to +/// \param[in] options casting options +/// \param[in] context the function execution context, optional +/// \return the resulting array +/// +/// \since 1.0.0 +/// \note API not yet finalized +ARROW_EXPORT +Result> Cast(const Array& value, std::shared_ptr to_type, + const CastOptions& options = CastOptions::Safe(), + ExecContext* context = NULLPTR); + +/// \brief Cast from one value to another +/// \param[in] value datum to cast +/// \param[in] to_type type to cast to +/// \param[in] options casting options +/// \param[in] context the function execution context, optional +/// \return the resulting datum +/// +/// \since 1.0.0 +/// \note API not yet finalized +ARROW_EXPORT +Result Cast(const Datum& value, std::shared_ptr to_type, + const CastOptions& options = CastOptions::Safe(), + ExecContext* context = NULLPTR); + +/// \brief Return true if a cast function is defined +ARROW_EXPORT +bool CanCast(const DataType& from_type, const DataType& to_type); + +} // namespace compute +} // namespace arrow diff --git a/cpp/src/arrow/compute/compute_test.cc b/cpp/src/arrow/compute/compute_test.cc deleted file mode 100644 index cd33466a67a..00000000000 --- a/cpp/src/arrow/compute/compute_test.cc +++ /dev/null @@ -1,95 +0,0 @@ -// Licensed to the Apache Software Foundation (ASF) under one -// or more contributor license agreements. See the NOTICE file -// distributed with this work for additional information -// regarding copyright ownership. The ASF licenses this file -// to you under the Apache License, Version 2.0 (the -// "License"); you may not use this file except in compliance -// with the License. You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, -// software distributed under the License is distributed on an -// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -// KIND, either express or implied. See the License for the -// specific language governing permissions and limitations -// under the License. - -#include -#include -#include -#include -#include -#include - -#include - -#include "arrow/array.h" -#include "arrow/buffer.h" -#include "arrow/memory_pool.h" -#include "arrow/status.h" -#include "arrow/table.h" -#include "arrow/testing/gtest_common.h" -#include "arrow/testing/gtest_util.h" -#include "arrow/type.h" -#include "arrow/type_traits.h" -#include "arrow/util/decimal.h" - -#include "arrow/compute/context.h" -#include "arrow/compute/kernel.h" -#include "arrow/compute/kernels/util_internal.h" -#include "arrow/compute/test_util.h" - -namespace arrow { -namespace compute { - -// ---------------------------------------------------------------------- -// Datum - -template -void CheckImplicitConstructor(enum Datum::type expected_kind) { - std::shared_ptr value; - Datum datum = value; - ASSERT_EQ(expected_kind, datum.kind()); -} - -TEST(TestDatum, ImplicitConstructors) { - CheckImplicitConstructor(Datum::SCALAR); - - CheckImplicitConstructor(Datum::ARRAY); - - // Instantiate from array subclass - CheckImplicitConstructor(Datum::ARRAY); - - CheckImplicitConstructor(Datum::CHUNKED_ARRAY); - CheckImplicitConstructor(Datum::RECORD_BATCH); - - CheckImplicitConstructor(Datum::TABLE); -} - -class TestInvokeBinaryKernel : public ComputeFixture, public TestBase {}; - -TEST_F(TestInvokeBinaryKernel, Exceptions) { - MockBinaryKernel kernel; - std::vector outputs; - std::shared_ptr
table; - std::vector values1 = {true, false, true}; - std::vector values2 = {false, true, false}; - - auto type = boolean(); - auto a1 = _MakeArray(type, values1, {}); - auto a2 = _MakeArray(type, values2, {}); - - // Left is not an array-like - ASSERT_RAISES(Invalid, detail::InvokeBinaryArrayKernel(&this->ctx_, &kernel, table, a2, - &outputs)); - // Right is not an array-like - ASSERT_RAISES(Invalid, detail::InvokeBinaryArrayKernel(&this->ctx_, &kernel, a1, table, - &outputs)); - // Different sized inputs - ASSERT_RAISES(Invalid, detail::InvokeBinaryArrayKernel(&this->ctx_, &kernel, a1, - a1->Slice(1), &outputs)); -} - -} // namespace compute -} // namespace arrow diff --git a/cpp/src/arrow/compute/context.h b/cpp/src/arrow/compute/context.h deleted file mode 100644 index dde8b686fc3..00000000000 --- a/cpp/src/arrow/compute/context.h +++ /dev/null @@ -1,79 +0,0 @@ -// Licensed to the Apache Software Foundation (ASF) under one -// or more contributor license agreements. See the NOTICE file -// distributed with this work for additional information -// regarding copyright ownership. The ASF licenses this file -// to you under the Apache License, Version 2.0 (the -// "License"); you may not use this file except in compliance -// with the License. You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, -// software distributed under the License is distributed on an -// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -// KIND, either express or implied. See the License for the -// specific language governing permissions and limitations -// under the License. - -#pragma once - -#include -#include - -#include "arrow/memory_pool.h" -#include "arrow/status.h" -#include "arrow/util/macros.h" -#include "arrow/util/visibility.h" - -namespace arrow { - -class Buffer; - -namespace internal { -class CpuInfo; -} // namespace internal - -namespace compute { - -#define ARROW_RETURN_IF_ERROR(ctx) \ - if (ARROW_PREDICT_FALSE(ctx->HasError())) { \ - Status s = ctx->status(); \ - ctx->ResetStatus(); \ - return s; \ - } - -/// \brief Container for variables and options used by function evaluation -class ARROW_EXPORT FunctionContext { - public: - explicit FunctionContext(MemoryPool* pool = default_memory_pool()); - MemoryPool* memory_pool() const; - - /// \brief Allocate buffer from the context's memory pool - Status Allocate(const int64_t nbytes, std::shared_ptr* out); - - /// \brief Indicate that an error has occurred, to be checked by a parent caller - /// \param[in] status a Status instance - /// - /// \note Will not overwrite a prior set Status, so we will have the first - /// error that occurred until FunctionContext::ResetStatus is called - void SetStatus(const Status& status); - - /// \brief Clear any error status - void ResetStatus(); - - /// \brief Return true if an error has occurred - bool HasError() const { return !status_.ok(); } - - /// \brief Return the current status of the context - const Status& status() const { return status_; } - - internal::CpuInfo* cpu_info() const { return cpu_info_; } - - private: - Status status_; - MemoryPool* pool_; - internal::CpuInfo* cpu_info_; -}; - -} // namespace compute -} // namespace arrow diff --git a/cpp/src/arrow/compute/exec.cc b/cpp/src/arrow/compute/exec.cc new file mode 100644 index 00000000000..7683ef23403 --- /dev/null +++ b/cpp/src/arrow/compute/exec.cc @@ -0,0 +1,859 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#include "arrow/compute/exec.h" + +#include +#include +#include +#include +#include + +#include "arrow/array.h" +#include "arrow/compute/exec_internal.h" +#include "arrow/compute/function.h" +#include "arrow/compute/kernel.h" +#include "arrow/compute/registry.h" +#include "arrow/datum.h" +#include "arrow/status.h" +#include "arrow/table.h" +#include "arrow/util/bit_util.h" +#include "arrow/util/checked_cast.h" +#include "arrow/util/cpu_info.h" +#include "arrow/util/logging.h" + +namespace arrow { + +using internal::checked_cast; + +namespace compute { + +#define CTX_RETURN_IF_ERROR(CTX) \ + do { \ + if (ARROW_PREDICT_FALSE((CTX)->HasError())) { \ + Status s = (CTX)->status(); \ + (CTX)->ResetStatus(); \ + return s; \ + } \ + } while (0) + +namespace { + +Result> AllocateDataBuffer(KernelContext* ctx, int64_t length, + int bit_width) { + if (bit_width == 1) { + return ctx->AllocateBitmap(length); + } else { + ARROW_CHECK_EQ(bit_width % 8, 0) + << "Only bit widths with multiple of 8 are currently supported"; + int64_t buffer_size = length * bit_width / 8; + return ctx->Allocate(buffer_size); + } + return Status::OK(); +} + +bool CanPreallocate(const DataType& type) { return is_fixed_width(type.id()); } + +Status GetValueDescriptors(const std::vector& args, + std::vector* descrs) { + for (const auto& arg : args) { + descrs->emplace_back(arg.descr()); + } + return Status::OK(); +} + +} // namespace + +namespace detail { + +ExecBatchIterator::ExecBatchIterator(std::vector args, int64_t length, + int64_t max_chunksize) + : args_(std::move(args)), + position_(0), + length_(length), + max_chunksize_(max_chunksize), + finished_(false) { + chunk_indexes_.resize(args_.size(), 0); + chunk_positions_.resize(args_.size(), 0); +} + +Result> ExecBatchIterator::Make( + std::vector args, int64_t max_chunksize) { + for (const auto& arg : args) { + if (!(arg.is_arraylike() || arg.is_scalar())) { + return Status::Invalid( + "ExecBatchIterator only works with Scalar, Array, and " + "ChunkedArray arguments"); + } + } + + // If the arguments are all scalars, then the length is 1 + int64_t length = 1; + + bool length_set = false; + for (size_t i = 0; i < args.size(); ++i) { + if (args[i].is_scalar()) { + continue; + } + if (!length_set) { + length = args[i].length(); + length_set = true; + } else { + if (args[i].length() != length) { + return Status::Invalid("Array arguments must all be the same length"); + } + } + } + + // No maximum was indicated + if (max_chunksize < 1) { + max_chunksize = length; + } + + return std::unique_ptr( + new ExecBatchIterator(std::move(args), length, max_chunksize)); +} + +bool ExecBatchIterator::Next(ExecBatch* batch) { + if (finished_) return false; + + // Determine how large the common contiguous "slice" of all the arguments is + int64_t iteration_size = std::min(length_ - position_, max_chunksize_); + + // If length_ is 0, then this loop will never execute + for (size_t i = 0; i < args_.size() && iteration_size > 0; ++i) { + // If the argument is not a chunked array, it's either a Scalar or Array, + // in which case it doesn't influence the size of this batch. Note that if + // the args are all scalars the batch length is 1 + if (args_[i].kind() != Datum::CHUNKED_ARRAY) { + continue; + } + const ChunkedArray& arg = *args_[i].chunked_array(); + std::shared_ptr current_chunk; + while (true) { + current_chunk = arg.chunk(chunk_indexes_[i]); + if (chunk_positions_[i] == current_chunk->length()) { + // Chunk is zero-length, or was exhausted in the previous iteration + chunk_positions_[i] = 0; + ++chunk_indexes_[i]; + continue; + } + break; + } + iteration_size = + std::min(current_chunk->length() - chunk_positions_[i], iteration_size); + } + + // Now, fill the batch + batch->values.resize(args_.size()); + batch->length = iteration_size; + for (size_t i = 0; i < args_.size(); ++i) { + if (args_[i].is_scalar()) { + batch->values[i] = args_[i].scalar(); + } else if (args_[i].is_array()) { + batch->values[i] = args_[i].array()->Slice(position_, iteration_size); + } else { + const ChunkedArray& carr = *args_[i].chunked_array(); + if (carr.num_chunks() > 0) { + const auto& chunk = carr.chunk(chunk_indexes_[i]); + batch->values[i] = chunk->data()->Slice(chunk_positions_[i], iteration_size); + } else { + // Degenerate case of a ChunkedArray with zero chunks + DCHECK_EQ(0, length_); + batch->values[i] = ArrayData::Make(carr.type(), 0); + } + chunk_positions_[i] += iteration_size; + } + } + position_ += iteration_size; + DCHECK_LE(position_, length_); + if (position_ == length_) { + finished_ = true; + } + return true; +} + +// Null propagation implementation that deals both with preallocated bitmaps +// and maybe-to-be allocated bitmaps +// +// If the bitmap is preallocated, it MUST be populated (since it might be a +// view of a much larger bitmap). If it isn't preallocated, then we have +// more flexibility. +// +// * If the batch has no nulls, then we do nothing +// * If only a single array has nulls, and its offset is a multiple of 8, +// then we can zero-copy the bitmap into the output +// * Otherwise, we allocate the bitmap and populate it +class NullPropagator { + public: + NullPropagator(KernelContext* ctx, const ExecBatch& batch, ArrayData* output) + : ctx_(ctx), batch_(batch), output_(output) { + // At this point, the values in batch_.values must have been validated to + // all be value-like + for (const Datum& val : batch_.values) { + if (val.kind() == Datum::ARRAY) { + // Do not count the bits if they haven't been counted already + const int64_t known_null_count = val.array()->null_count.load(); + if (known_null_count == kUnknownNullCount || known_null_count > 0) { + values_with_nulls_.push_back(&val); + } + } else if (!val.scalar()->is_valid) { + values_with_nulls_.push_back(&val); + } + } + + if (output->buffers[0] != nullptr) { + bitmap_preallocated_ = true; + SetBitmap(output_->buffers[0].get()); + } + } + + void SetBitmap(Buffer* bitmap) { bitmap_ = bitmap->mutable_data(); } + + Status EnsureAllocated() { + if (bitmap_preallocated_) { + return Status::OK(); + } + ARROW_ASSIGN_OR_RAISE(output_->buffers[0], ctx_->AllocateBitmap(output_->length)); + SetBitmap(output_->buffers[0].get()); + return Status::OK(); + } + + Result ShortCircuitIfAllNull() { + // An all-null value (scalar null or all-null array) gives us a short + // circuit opportunity + bool is_all_null = false; + std::shared_ptr all_null_bitmap; + + // Walk all the values with nulls instead of breaking on the first in case + // we find a bitmap that can be reused in the non-preallocated case + for (const Datum* value : values_with_nulls_) { + if (value->type()->id() == Type::NA) { + // No bitmap + is_all_null = true; + } else if (value->kind() == Datum::ARRAY) { + const ArrayData& arr = *value->array(); + if (arr.null_count.load() == arr.length) { + // Pluck the all null bitmap so we can set it in the output if it was + // not pre-allocated + all_null_bitmap = arr.buffers[0]; + is_all_null = true; + } + } else { + // Scalar + is_all_null = true; + } + } + if (!is_all_null) { + return false; + } + + // OK, the output should be all null + output_->null_count = output_->length; + + if (!bitmap_preallocated_ && all_null_bitmap) { + // If we did not pre-allocate memory, and we observed an all-null bitmap, + // then we can zero-copy it into the output + output_->buffers[0] = std::move(all_null_bitmap); + } else { + RETURN_NOT_OK(EnsureAllocated()); + BitUtil::SetBitsTo(bitmap_, output_->offset, output_->length, false); + } + return true; + } + + Status PropagateSingle() { + // One array + const ArrayData& arr = *values_with_nulls_[0]->array(); + const std::shared_ptr& arr_bitmap = arr.buffers[0]; + + // Reuse the null count if it's known + output_->null_count = arr.null_count.load(); + + if (bitmap_preallocated_) { + internal::CopyBitmap(arr_bitmap->data(), arr.offset, arr.length, bitmap_, + output_->offset); + } else { + // Two cases when memory was not pre-allocated: + // + // * Offset is zero: we reuse the bitmap as is + // * Offset is nonzero but a multiple of 8: we can slice the bitmap + // * Offset is not a multiple of 8: we must allocate and use CopyBitmap + // + // Keep in mind that output_->offset is not permitted to be nonzero when + // the bitmap is not preallocated, and that precondition is asserted + // higher in the call stack. + if (arr.offset == 0) { + output_->buffers[0] = arr_bitmap; + } else if (arr.offset % 8 == 0) { + output_->buffers[0] = + SliceBuffer(arr_bitmap, arr.offset / 8, BitUtil::BytesForBits(arr.length)); + } else { + RETURN_NOT_OK(EnsureAllocated()); + internal::CopyBitmap(arr_bitmap->data(), arr.offset, arr.length, bitmap_, + /*dst_offset=*/0); + } + } + return Status::OK(); + } + + Status PropagateMultiple() { + // More than one array. We use BitmapAnd to intersect their bitmaps + + // Do not compute the intersection null count until it's needed + RETURN_NOT_OK(EnsureAllocated()); + + auto Accumulate = [&](const ArrayData& left, const ArrayData& right) { + internal::BitmapAnd(left.buffers[0]->data(), left.offset, right.buffers[0]->data(), + right.offset, output_->length, output_->offset, + output_->buffers[0]->mutable_data()); + }; + + DCHECK_GT(values_with_nulls_.size(), 1); + + // Seed the output bitmap with the & of the first two bitmaps + Accumulate(*values_with_nulls_[0]->array(), *values_with_nulls_[1]->array()); + + // Accumulate the rest + for (size_t i = 2; i < values_with_nulls_.size(); ++i) { + Accumulate(*output_, *values_with_nulls_[i]->array()); + } + return Status::OK(); + } + + Status Execute() { + bool finished = false; + ARROW_ASSIGN_OR_RAISE(finished, ShortCircuitIfAllNull()); + if (finished) { + return Status::OK(); + } + + // At this point, by construction we know that all of the values in + // values_with_nulls_ are arrays that are not all null. So there are a + // few cases: + // + // * No arrays. This is a no-op w/o preallocation but when the bitmap is + // pre-allocated we have to fill it with 1's + // * One array, whose bitmap can be zero-copied (w/o preallocation, and + // when no byte is split) or copied (split byte or w/ preallocation) + // * More than one array, we must compute the intersection of all the + // bitmaps + // + // BUT, if the output offset is nonzero for some reason, we copy into the + // output unconditionally + + output_->null_count = kUnknownNullCount; + + if (values_with_nulls_.size() == 0) { + // No arrays with nulls case + output_->null_count = 0; + if (bitmap_preallocated_) { + BitUtil::SetBitsTo(bitmap_, output_->offset, output_->length, true); + } + return Status::OK(); + } else if (values_with_nulls_.size() == 1) { + return PropagateSingle(); + } else { + return PropagateMultiple(); + } + } + + private: + KernelContext* ctx_; + const ExecBatch& batch_; + std::vector values_with_nulls_; + ArrayData* output_; + uint8_t* bitmap_; + bool bitmap_preallocated_ = false; +}; + +Status PropagateNulls(KernelContext* ctx, const ExecBatch& batch, ArrayData* output) { + DCHECK_NE(nullptr, output); + DCHECK_GT(output->buffers.size(), 0); + + if (output->type->id() == Type::NA) { + // Null output type is a no-op (rare when this would happen but we at least + // will test for it) + return Status::OK(); + } + + // This function is ONLY able to write into output with non-zero offset + // when the bitmap is preallocated. This could be a DCHECK but returning + // error Status for now for emphasis + if (output->offset != 0 && output->buffers[0] == nullptr) { + return Status::Invalid( + "Can only propagate nulls into pre-allocated memory " + "when the output offset is non-zero"); + } + NullPropagator propagator(ctx, batch, output); + return propagator.Execute(); +} + +Status ExecListener::OnResult(Datum) { return Status::NotImplemented("OnResult"); } + +class DatumAccumulator : public ExecListener { + public: + DatumAccumulator() {} + + Status OnResult(Datum value) override { + values_.emplace_back(value); + return Status::OK(); + } + + std::vector values() const { return values_; } + + private: + std::vector values_; +}; + +template +class FunctionExecutorImpl : public FunctionExecutor { + public: + FunctionExecutorImpl(ExecContext* exec_ctx, const FunctionType* func, + const FunctionOptions* options) + : exec_ctx_(exec_ctx), kernel_ctx_(exec_ctx), func_(func), options_(options) {} + + protected: + using KernelType = typename FunctionType::KernelType; + + void Reset() {} + + Status InitState() { + // Some kernels require initialization of an opaque state object + if (kernel_->init) { + state_ = kernel_->init(&kernel_ctx_, *kernel_, options_); + CTX_RETURN_IF_ERROR(&kernel_ctx_); + kernel_ctx_.SetState(state_.get()); + } + return Status::OK(); + } + + Status BindArgs(const std::vector& args) { + std::vector arg_descrs; + RETURN_NOT_OK(GetValueDescriptors(args, &arg_descrs)); + ARROW_ASSIGN_OR_RAISE(kernel_, func_->DispatchExact(arg_descrs)); + + // Resolve the output descriptor for this kernel + ARROW_ASSIGN_OR_RAISE(output_descr_, + kernel_->signature->out_type().Resolve(arg_descrs)); + + ARROW_ASSIGN_OR_RAISE(batch_iterator_, + ExecBatchIterator::Make(args, exec_ctx_->exec_chunksize())); + + return Status::OK(); + } + + ValueDescr output_descr() const override { return output_descr_; } + + ExecContext* exec_ctx_; + KernelContext kernel_ctx_; + const FunctionType* func_; + const KernelType* kernel_; + std::unique_ptr batch_iterator_; + std::unique_ptr state_; + ValueDescr output_descr_; + const FunctionOptions* options_; +}; + +// Executor for SCALAR and VECTOR functions +template +class ArrayExecutor : public FunctionExecutorImpl { + public: + using BASE = FunctionExecutorImpl; + using BASE::BASE; + + Status ExecuteBatch(const ExecBatch& batch, ExecListener* listener) { + Datum out; + RETURN_NOT_OK(PrepareNextOutput(batch, &out)); + + if (kernel_->null_handling == NullHandling::INTERSECTION && + output_descr_.shape == ValueDescr::ARRAY) { + RETURN_NOT_OK(PropagateNulls(&kernel_ctx_, batch, out.mutable_array())); + } + + kernel_->exec(&kernel_ctx_, batch, &out); + CTX_RETURN_IF_ERROR(&kernel_ctx_); + if (!preallocate_contiguous_) { + // If we are producing chunked output rather than one big array, then + // emit each chunk as soon as it's available + RETURN_NOT_OK(listener->OnResult(std::move(out))); + } + return Status::OK(); + } + + Status PrepareExecute(const std::vector& args) { + this->Reset(); + RETURN_NOT_OK(this->BindArgs(args)); + RETURN_NOT_OK(this->InitState()); + output_num_buffers_ = static_cast(output_descr_.type->layout().buffers.size()); + + // If the executor is configured to produce a single large Array output for + // kernels supporting preallocation, then we do so up front and then + // iterate over slices of that large array. Otherwise, we preallocate prior + // to processing each batch emitted from the ExecBatchIterator + if (output_descr_.shape == ValueDescr::ARRAY) { + RETURN_NOT_OK(SetupPreallocation(batch_iterator_->length())); + } + return Status::OK(); + } + + Status Execute(const std::vector& args, ExecListener* listener) override { + RETURN_NOT_OK(PrepareExecute(args)); + ExecBatch batch; + while (batch_iterator_->Next(&batch)) { + RETURN_NOT_OK(ExecuteBatch(batch, listener)); + } + if (preallocate_contiguous_) { + // If we preallocated one big chunk, since the kernel execution is + // completed, we can now emit it + RETURN_NOT_OK(listener->OnResult(std::move(preallocated_))); + } + return Status::OK(); + } + + protected: + // We must accommodate two different modes of execution for preallocated + // execution + // + // * A single large ("contiguous") allocation that we populate with results + // on a chunkwise basis according to the ExecBatchIterator. This permits + // parallelization even if the objective is to obtain a single Array or + // ChunkedArray at the end + // * A standalone buffer preallocation for each chunk emitted from the + // ExecBatchIterator + // + // When data buffer preallocation is not possible (e.g. with BINARY / STRING + // outputs), then contiguous results are only possible if the input is + // contiguous. + + Status PrepareNextOutput(const ExecBatch& batch, Datum* out) { + if (output_descr_.shape == ValueDescr::ARRAY) { + if (preallocate_contiguous_) { + // The output is already fully preallocated + const int64_t batch_start_position = batch_iterator_->position() - batch.length; + + if (batch.length < batch_iterator_->length()) { + // If this is a partial execution, then we write into a slice of + // preallocated_ + // + // XXX: ArrayData::Slice not returning std::shared_ptr is + // a nuisance + out->value = std::make_shared( + preallocated_->Slice(batch_start_position, batch.length)); + } else { + // Otherwise write directly into preallocated_. The main difference + // computationally (versus the Slice approach) is that the null_count + // may not need to be recomputed in the result + out->value = preallocated_; + } + } else { + // We preallocate (maybe) only for the output of processing the current + // batch + ARROW_ASSIGN_OR_RAISE(out->value, PrepareOutput(batch.length)); + } + } + // XXX: Scalar outputs are the responsibility of the kernel? + return Status::OK(); + } + + Result> PrepareOutput(int64_t length) { + auto out = std::make_shared(output_descr_.type, length); + out->buffers.resize(output_num_buffers_); + + const auto& fw_type = checked_cast(*out->type); + if (validity_preallocated_) { + ARROW_ASSIGN_OR_RAISE(out->buffers[0], kernel_ctx_.AllocateBitmap(length)); + } + if (data_preallocated_) { + ARROW_ASSIGN_OR_RAISE( + out->buffers[1], AllocateDataBuffer(&kernel_ctx_, length, fw_type.bit_width())); + } + return out; + } + + Status SetupPreallocation(int64_t total_length) { + // Decide if we need to preallocate memory for this kernel + data_preallocated_ = ((kernel_->mem_allocation == MemAllocation::PREALLOCATE) && + CanPreallocate(*output_descr_.type)); + + validity_preallocated_ = + (kernel_->null_handling != NullHandling::COMPUTED_NO_PREALLOCATE && + kernel_->null_handling != NullHandling::OUTPUT_NOT_NULL); + + // Contiguous preallocation only possible if both the VALIDITY and DATA can + // be preallocated. Otherwise, we must go chunk-by-chunk. Note that when + // the DATA cannot be preallocated, the VALIDITY may still be preallocated + // depending on the NullHandling of the kernel + // + // Some kernels are unable to write into sliced outputs, so we respect the + // kernel's attributes + preallocate_contiguous_ = + (exec_ctx_->preallocate_contiguous() && kernel_->can_write_into_slices && + data_preallocated_ && validity_preallocated_); + if (preallocate_contiguous_) { + // TODO: Are there contiguous preallocation scenarios that are NOT + // primitive (2-buffer)? + DCHECK_EQ(2, output_num_buffers_); + ARROW_ASSIGN_OR_RAISE(preallocated_, PrepareOutput(total_length)); + } + return Status::OK(); + } + + // Lift protected members so we don't have to use this-> + using BASE::batch_iterator_; + using BASE::exec_ctx_; + using BASE::func_; + using BASE::kernel_; + using BASE::kernel_ctx_; + using BASE::options_; + using BASE::output_descr_; + using BASE::state_; + + int output_num_buffers_; + + // If true, then the kernel writes into a preallocated data buffer + bool data_preallocated_ = false; + + // If true, then memory is preallocated for the validity bitmap with the same + // strategy as the data buffer(s). + bool validity_preallocated_ = false; + + // If true, and the kernel and output type supports preallocation (for both + // the validity and data buffers), then we allocate one big array and then + // iterate through it while executing the kernel in chunks + bool preallocate_contiguous_ = false; + + // For storing a contiguous preallocation per above. Unused otherwise + std::shared_ptr preallocated_; +}; + +class ScalarExecutor : public ArrayExecutor { + public: + using FunctionType = ScalarFunction; + static constexpr Function::Kind function_kind = Function::SCALAR; + using BASE = ArrayExecutor; + using BASE::BASE; +}; + +class VectorExecutor : public ArrayExecutor { + public: + using FunctionType = VectorFunction; + static constexpr Function::Kind function_kind = Function::VECTOR; + using BASE = ArrayExecutor; + using BASE::BASE; +}; + +class ScalarAggExecutor : public FunctionExecutorImpl { + public: + using FunctionType = ScalarAggregateFunction; + static constexpr Function::Kind function_kind = Function::SCALAR_AGGREGATE; + using BASE = FunctionExecutorImpl; + + Status Consume(const ExecBatch& batch) { + auto batch_state = kernel_->init(&kernel_ctx_, *kernel_, options_); + KernelContext batch_ctx(exec_ctx_); + batch_ctx.SetState(batch_state.get()); + + kernel_->consume(&batch_ctx, batch); + CTX_RETURN_IF_ERROR(&batch_ctx); + + kernel_->merge(&kernel_ctx_, *batch_state, state_.get()); + CTX_RETURN_IF_ERROR(&kernel_ctx_); + return Status::OK(); + } + + Status Execute(const std::vector& args, ExecListener* listener) override { + RETURN_NOT_OK(BindArgs(args)); + + // This is the global/total state for the aggregation. Batches are + // aggregated independently and then merged into the state + RETURN_NOT_OK(InitState()); + + ExecBatch batch; + while (batch_iterator_->Next(&batch)) { + // TODO: implement parallelism + if (batch.length > 0) { + RETURN_NOT_OK(Consume(batch)); + } + } + + Datum out; + kernel_->finalize(&kernel_ctx_, &out); + CTX_RETURN_IF_ERROR(&kernel_ctx_); + RETURN_NOT_OK(listener->OnResult(std::move(out))); + return Status::OK(); + } + + private: + using BASE::BASE; +}; + +template +Result> MakeExecutor(ExecContext* ctx, + const Function* func, + const FunctionOptions* options) { + DCHECK_EQ(ExecutorType::function_kind, func->kind()); + auto typed_func = checked_cast(func); + return std::unique_ptr(new ExecutorType(ctx, typed_func, options)); +} + +Result> FunctionExecutor::Make( + ExecContext* ctx, const Function* func, const FunctionOptions* options) { + switch (func->kind()) { + case Function::SCALAR: + return MakeExecutor(ctx, func, options); + case Function::VECTOR: + return MakeExecutor(ctx, func, options); + case Function::SCALAR_AGGREGATE: + return MakeExecutor(ctx, func, options); + } +} + +Status CheckAllValues(const std::vector& values) { + for (const auto& value : values) { + if (!value.is_value()) { + return Status::Invalid("Datum contained non-scalar/array type"); + } + } + return Status::OK(); +} + +Status ExecuteFunction(ExecContext* ctx, const std::string& func_name, + const std::vector& args, const FunctionOptions* options, + ValueDescr* out_descr, ExecListener* listener) { + if (ctx == nullptr) { + ExecContext default_ctx; + return ExecuteFunction(&default_ctx, func_name, args, options, out_descr, listener); + } + + // type-check Datum arguments here. Really we'd like to avoid this as much as + // possible + RETURN_NOT_OK(CheckAllValues(args)); + + ARROW_ASSIGN_OR_RAISE(std::shared_ptr func, + ctx->func_registry()->GetFunction(func_name)); + ARROW_ASSIGN_OR_RAISE(auto executor, FunctionExecutor::Make(ctx, func.get(), options)); + RETURN_NOT_OK(executor->Execute(args, listener)); + *out_descr = executor->output_descr(); + return Status::OK(); +} + +} // namespace detail + +ExecContext::ExecContext(MemoryPool* pool, FunctionRegistry* func_registry) + : pool_(pool) { + this->func_registry_ = func_registry == nullptr ? GetFunctionRegistry() : func_registry; +} + +internal::CpuInfo* ExecContext::cpu_info() const { + return internal::CpuInfo::GetInstance(); +} + +// ---------------------------------------------------------------------- +// SelectionVector + +SelectionVector::SelectionVector(std::shared_ptr data) + : data_(std::move(data)) { + DCHECK_EQ(Type::INT32, data_->type->id()); + DCHECK_EQ(0, data_->GetNullCount()); + indices_ = data_->GetValues(1); +} + +Result> SelectionVector::FromMask(const Array& arr) { + return Status::NotImplemented("FromMask"); +} + +namespace { + +std::shared_ptr ToChunkedArray(const std::vector& values, + const std::shared_ptr& type) { + std::vector> arrays; + for (const auto& val : values) { + arrays.emplace_back(val.make_array()); + } + return std::make_shared(arrays, type); +} + +bool HaveChunkedArray(const std::vector& values) { + for (const auto& value : values) { + if (value.kind() == Datum::CHUNKED_ARRAY) { + return true; + } + } + return false; +} + +Datum WrapArrayResults(const std::vector& input_args, + const std::vector& results, + const ValueDescr& output_descr) { + DCHECK_GT(results.size(), 0); + if (output_descr.shape == ValueDescr::SCALAR) { + if (results.size() == 1) { + // Return as SCALAR + return results[0]; + } else { + // Return as COLLECTION + return results; + } + } else { + // If execution yielded multiple chunks (because large arrays were split + // based on the ExecContext parameters, then the result is a ChunkedArray + if (HaveChunkedArray(input_args) || results.size() > 1) { + return ToChunkedArray(results, output_descr.type); + } else { + // Results have just one element + return results[0]; + } + } +} + +} // namespace + +Result ExecScalarFunction(ExecContext* ctx, const std::string& func_name, + const std::vector& args, + const FunctionOptions* options) { + auto listener = std::make_shared(); + ValueDescr out_descr; + RETURN_NOT_OK( + detail::ExecuteFunction(ctx, func_name, args, options, &out_descr, listener.get())); + return WrapArrayResults(args, listener->values(), out_descr); +} + +Result ExecVectorFunction(ExecContext* ctx, const std::string& func_name, + const std::vector& args, + const FunctionOptions* options) { + auto listener = std::make_shared(); + ValueDescr out_descr; + RETURN_NOT_OK( + detail::ExecuteFunction(ctx, func_name, args, options, &out_descr, listener.get())); + return WrapArrayResults(args, listener->values(), out_descr); +} + +Result ExecScalarAggregateFunction(ExecContext* ctx, const std::string& func_name, + const std::vector& args, + const FunctionOptions* options) { + auto listener = std::make_shared(); + ValueDescr unused; + RETURN_NOT_OK( + detail::ExecuteFunction(ctx, func_name, args, options, &unused, listener.get())); + DCHECK_EQ(1, listener->values().size()); + return listener->values()[0]; +} + +} // namespace compute +} // namespace arrow diff --git a/cpp/src/arrow/compute/exec.h b/cpp/src/arrow/compute/exec.h new file mode 100644 index 00000000000..b473838e281 --- /dev/null +++ b/cpp/src/arrow/compute/exec.h @@ -0,0 +1,175 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +// NOTE: API is EXPERIMENTAL and will change without going through a +// deprecation cycle + +#pragma once + +#include +#include +#include +#include +#include + +#include "arrow/datum.h" +#include "arrow/memory_pool.h" +#include "arrow/status.h" +#include "arrow/util/macros.h" +#include "arrow/util/visibility.h" + +namespace arrow { + +namespace internal { + +class CpuInfo; + +} // namespace internal + +namespace compute { + +struct FunctionOptions; +class FunctionRegistry; + +// It seems like 64K might be a good default chunksize to use for execution +// based on the experience of other query processing systems, so using this for +// now. +static constexpr int64_t kDefaultExecChunksize = UINT16_MAX; + +/// \brief Context for expression-global variables and options used by +/// function evaluation +class ARROW_EXPORT ExecContext { + public: + // If no function registry passed, the default is used + explicit ExecContext(MemoryPool* pool = default_memory_pool(), + FunctionRegistry* func_registry = NULLPTR); + + MemoryPool* memory_pool() const { return pool_; } + + internal::CpuInfo* cpu_info() const; + + FunctionRegistry* func_registry() const { return func_registry_; } + + // \brief Set maximum length unit of work for kernel execution. Larger inputs + // will be split into smaller chunks, and, if desired, processed in + // parallel. Set to -1 for no limit + void set_exec_chunksize(int64_t chunksize) { exec_chunksize_ = chunksize; } + + // \brief Maximum length unit of work for kernel execution. + int64_t exec_chunksize() const { return exec_chunksize_; } + + /// \brief Set whether to use multiple threads for function execution + void set_use_threads(bool use_threads = true) { use_threads_ = use_threads; } + + /// \brief If true, then utilize multiple threads where relevant for function + /// execution + bool use_threads() const { return use_threads_; } + + // Set the preallocation strategy for kernel execution as it relates to + // chunked execution. For chunked execution, whether via ChunkedArray inputs + // or splitting larger Array arguments into smaller pieces, contiguous + // allocation (if permitted by the kernel) will allocate one large array to + // write output into yielding it to the caller at the end. If this option is + // set to off, then preallocations will be performed independently for each + // chunk of execution + // + // TODO: At some point we might want the limit the size of contiguous + // preallocations (for example, merging small ChunkedArray chunks until + // reaching some desired size) + void set_preallocate_contiguous(bool preallocate = true) { + preallocate_contiguous_ = preallocate; + } + + bool preallocate_contiguous() const { return preallocate_contiguous_; } + + private: + MemoryPool* pool_; + FunctionRegistry* func_registry_; + int64_t exec_chunksize_ = -1; + bool preallocate_contiguous_ = true; + bool use_threads_ = true; +}; + +// TODO: Consider standardizing on uint16 selection vectors and only use them +// when we can ensure that each value is 64K length or smaller + +/// \brief Container for a int32 selection +class ARROW_EXPORT SelectionVector { + public: + explicit SelectionVector(std::shared_ptr data); + + explicit SelectionVector(const Array& arr) : SelectionVector(arr.data()) {} + + /// \brief Create SelectionVector from boolean mask + static Result> FromMask(const Array& arr); + + int32_t index(int i) const { return indices_[i]; } + const int32_t* indices() const { return indices_; } + int32_t length() const { return static_cast(data_->length); } + + private: + std::shared_ptr data_; + const int32_t* indices_; +}; + +struct ExecBatch { + ExecBatch() {} + ExecBatch(std::vector values, int64_t length) + : values(std::move(values)), length(length) {} + + std::vector values; + std::shared_ptr selection_vector; + int64_t length; + const Datum& operator[](int i) const { return values[i]; } + + int num_values() const { return static_cast(values.size()); } + + std::vector GetDescriptors() const { + std::vector result; + for (const auto& value : this->values) { + result.emplace_back(value.descr()); + } + return result; + } +}; + +/// \brief Convenience method for invoking a scalar (elementwise) array +/// function, including handling iteration on ChunkedArray inputs +ARROW_EXPORT +Result ExecScalarFunction(ExecContext* ctx, const std::string& func_name, + const std::vector& args, + const FunctionOptions* options = NULLPTR); + +/// \brief Convenience method for invoking a vector array function, including +/// handling iteration on ChunkedArray inputs. Compared with a scalar function, +/// vector functions may require post-processing of chunked outputs if the +/// results are dependent on the whole data passed (e.g. with hash table +/// functions) +ARROW_EXPORT +Result ExecVectorFunction(ExecContext* ctx, const std::string& func_name, + const std::vector& args, + const FunctionOptions* options = NULLPTR); + +/// \brief Convenience method for invoking a scalar aggregate function, +/// including handling iteration on ChunkedArray inputs +ARROW_EXPORT +Result ExecScalarAggregateFunction(ExecContext* ctx, const std::string& func_name, + const std::vector& args, + const FunctionOptions* options = NULLPTR); + +} // namespace compute +} // namespace arrow diff --git a/cpp/src/arrow/compute/exec_internal.h b/cpp/src/arrow/compute/exec_internal.h new file mode 100644 index 00000000000..1c61541d557 --- /dev/null +++ b/cpp/src/arrow/compute/exec_internal.h @@ -0,0 +1,128 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#pragma once + +#include +#include +#include +#include + +#include "arrow/array.h" +#include "arrow/buffer.h" +#include "arrow/compute/exec.h" +#include "arrow/compute/kernel.h" +#include "arrow/status.h" +#include "arrow/util/visibility.h" + +namespace arrow { +namespace compute { + +class Function; + +// \brief Make a copy of the buffers into a destination array without carrying +// the type. +static inline void ZeroCopyData(const ArrayData& input, ArrayData* output) { + output->length = input.length; + output->SetNullCount(input.null_count); + output->buffers = input.buffers; + output->offset = input.offset; + output->child_data = input.child_data; +} + +namespace detail { + +/// \brief Break std::vector into a sequence of ExecBatch for kernel +/// execution +class ARROW_EXPORT ExecBatchIterator { + public: + /// \brief Construct iterator and do basic argument validation + /// + /// \param[in] args the Datum argument, must be all array-like or scalar + /// \param[in] max_chunksize the maximum length of each ExecBatch. Depending + /// on the chunk layout of ChunkedArray. Default of -1 means no maximum, so + /// as greedy as possible + static Result> Make(std::vector args, + int64_t max_chunksize = -1); + + /// \brief Compute the next batch. Always returns at least one batch. Return + /// false if the iterator is exhausted + bool Next(ExecBatch* batch); + + int64_t length() const { return length_; } + + int64_t position() const { return position_; } + + int64_t max_chunksize() const { return max_chunksize_; } + + private: + ExecBatchIterator(std::vector args, int64_t length, int64_t max_chunksize); + + std::vector args_; + std::vector chunk_indexes_; + std::vector chunk_positions_; + int64_t position_; + int64_t length_; + int64_t max_chunksize_; + bool finished_; +}; + +// "Push" / listener API like IPC reader so that consumers can receive +// processed chunks as soon as they're available. + +class ARROW_EXPORT ExecListener { + public: + virtual ~ExecListener() = default; + + virtual Status OnResult(Datum value); +}; + +class ARROW_EXPORT FunctionExecutor { + public: + virtual ~FunctionExecutor() = default; + + /// XXX: Better configurability for listener + /// Not thread-safe + virtual Status Execute(const std::vector& args, ExecListener* listener) = 0; + + virtual ValueDescr output_descr() const = 0; + + static Result> Make(ExecContext* ctx, + const Function* func, + const FunctionOptions* options); +}; + +ARROW_EXPORT +Status ExecuteFunction(ExecContext* ctx, const std::string& func_name, + const std::vector& args, const FunctionOptions* options, + ValueDescr* out_descr, ExecListener* listener); + +/// \brief Populate validity bitmap with the intersection of the nullity of the +/// arguments. If a preallocated bitmap is not provided, then one will be +/// allocated if needed (in some cases a bitmap can be zero-copied from the +/// arguments). If any Scalar value is null, then the entire validity bitmap +/// will be set to null. +/// +/// \param[in] ctx kernel execution context, for memory allocation etc. +/// \param[in] batch the data batch +/// \param[in] out the output ArrayData, must not be null +ARROW_EXPORT +Status PropagateNulls(KernelContext* ctx, const ExecBatch& batch, ArrayData* out); + +} // namespace detail +} // namespace compute +} // namespace arrow diff --git a/cpp/src/arrow/compute/exec_test.cc b/cpp/src/arrow/compute/exec_test.cc new file mode 100644 index 00000000000..ceee12785fd --- /dev/null +++ b/cpp/src/arrow/compute/exec_test.cc @@ -0,0 +1,840 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#include +#include + +#include + +#include "arrow/testing/gtest_common.h" +#include "arrow/testing/gtest_util.h" +#include "arrow/testing/random.h" + +#include "arrow/array.h" +#include "arrow/buffer.h" +#include "arrow/compute/exec.h" +#include "arrow/compute/exec_internal.h" +#include "arrow/compute/function.h" +#include "arrow/compute/kernel.h" +#include "arrow/compute/options.h" +#include "arrow/compute/registry.h" +#include "arrow/compute/test_util.h" +#include "arrow/memory_pool.h" +#include "arrow/pretty_print.h" +#include "arrow/util/bit_util.h" +#include "arrow/util/checked_cast.h" +#include "arrow/util/cpu_info.h" + +namespace arrow { + +using internal::checked_cast; + +namespace compute { +namespace detail { + +TEST(ExecContext, BasicWorkings) { + { + ExecContext ctx; + ASSERT_EQ(GetFunctionRegistry(), ctx.func_registry()); + ASSERT_EQ(default_memory_pool(), ctx.memory_pool()); + + // No default chunksize right now + ASSERT_EQ(-1, ctx.exec_chunksize()); + + ASSERT_TRUE(ctx.use_threads()); + ASSERT_EQ(internal::CpuInfo::GetInstance(), ctx.cpu_info()); + } + + // Now, let's customize all the things + LoggingMemoryPool my_pool(default_memory_pool()); + std::unique_ptr custom_reg = FunctionRegistry::Make(); + ExecContext ctx(&my_pool, custom_reg.get()); + + ASSERT_EQ(custom_reg.get(), ctx.func_registry()); + ASSERT_EQ(&my_pool, ctx.memory_pool()); + + ctx.set_exec_chunksize(1 << 20); + ASSERT_EQ(1 << 20, ctx.exec_chunksize()); + + ctx.set_use_threads(false); + ASSERT_FALSE(ctx.use_threads()); +} + +TEST(SelectionVector, Basics) { + auto indices = ArrayFromJSON(int32(), "[0, 3]"); + auto sel_vector = std::make_shared(*indices); + + ASSERT_EQ(indices->length(), sel_vector->length()); + ASSERT_EQ(3, sel_vector->index(1)); + ASSERT_EQ(3, sel_vector->indices()[1]); +} + +void AssertValidityZeroExtraBits(const ArrayData& arr) { + const Buffer& buf = *arr.buffers[0]; + + const int64_t bit_extent = ((arr.offset + arr.length + 7) / 8) * 8; + for (int64_t i = arr.offset + arr.length; i < bit_extent; ++i) { + EXPECT_FALSE(BitUtil::GetBit(buf.data(), i)) << i; + } +} + +class TestComputeInternals : public ::testing::Test { + public: + void SetUp() { + registry_ = FunctionRegistry::Make(); + rng_.reset(new random::RandomArrayGenerator(/*seed=*/0)); + ResetContexts(); + } + + void ResetContexts() { + exec_ctx_.reset(new ExecContext(default_memory_pool(), registry_.get())); + ctx_.reset(new KernelContext(exec_ctx_.get())); + } + + std::shared_ptr GetUInt8Array(int64_t size, double null_probability = 0.1) { + return rng_->UInt8(size, /*min=*/0, /*max=*/100, null_probability); + } + + std::shared_ptr GetInt32Array(int64_t size, double null_probability = 0.1) { + return rng_->Int32(size, /*min=*/0, /*max=*/1000, null_probability); + } + + std::shared_ptr GetFloat64Array(int64_t size, double null_probability = 0.1) { + return rng_->Float64(size, /*min=*/0, /*max=*/1000, null_probability); + } + + std::shared_ptr GetInt32Chunked(const std::vector& sizes) { + std::vector> chunks; + for (auto size : sizes) { + chunks.push_back(GetInt32Array(size)); + } + return std::make_shared(std::move(chunks)); + } + + protected: + std::unique_ptr exec_ctx_; + std::unique_ptr ctx_; + std::unique_ptr registry_; + std::unique_ptr rng_; +}; + +class TestPropagateNulls : public TestComputeInternals {}; + +TEST_F(TestPropagateNulls, UnknownNullCountWithNullsZeroCopies) { + const int64_t length = 16; + + constexpr uint8_t validity_bitmap[8] = {254, 0, 0, 0, 0, 0, 0, 0}; + auto nulls = std::make_shared(validity_bitmap, 8); + + ArrayData output(boolean(), length, {nullptr, nullptr}); + ArrayData input(boolean(), length, {nulls, nullptr}, kUnknownNullCount); + + ExecBatch batch({input}, length); + ASSERT_OK(PropagateNulls(ctx_.get(), batch, &output)); + ASSERT_EQ(nulls.get(), output.buffers[0].get()); + ASSERT_EQ(kUnknownNullCount, output.null_count); + ASSERT_EQ(9, output.GetNullCount()); +} + +TEST_F(TestPropagateNulls, UnknownNullCountWithoutNulls) { + const int64_t length = 16; + constexpr uint8_t validity_bitmap[8] = {255, 255, 0, 0, 0, 0, 0, 0}; + auto nulls = std::make_shared(validity_bitmap, 8); + + ArrayData output(boolean(), length, {nullptr, nullptr}); + ArrayData input(boolean(), length, {nulls, nullptr}, kUnknownNullCount); + + ExecBatch batch({input}, length); + ASSERT_OK(PropagateNulls(ctx_.get(), batch, &output)); + EXPECT_EQ(-1, output.null_count); + EXPECT_EQ(nulls.get(), output.buffers[0].get()); +} + +TEST_F(TestPropagateNulls, SetAllNulls) { + const int64_t length = 16; + + auto CheckSetAllNull = [&](std::vector values, bool preallocate) { + // Make fresh bitmap with all 1's + uint8_t bitmap_data[2] = {255, 255}; + auto preallocated_mem = std::make_shared(bitmap_data, 2); + + std::vector> buffers(2); + if (preallocate) { + buffers[0] = preallocated_mem; + } + + ArrayData output(boolean(), length, buffers); + + ExecBatch batch(values, length); + ASSERT_OK(PropagateNulls(ctx_.get(), batch, &output)); + + if (preallocate) { + // Ensure that buffer object the same when we pass in preallocated memory + ASSERT_EQ(preallocated_mem.get(), output.buffers[0].get()); + } + ASSERT_NE(nullptr, output.buffers[0]); + uint8_t expected[2] = {0, 0}; + const Buffer& out_buf = *output.buffers[0]; + ASSERT_EQ(0, std::memcmp(out_buf.data(), expected, out_buf.size())); + }; + + // There is a null scalar + std::shared_ptr i32_val = std::make_shared(3); + std::vector vals = {i32_val, MakeNullScalar(boolean())}; + CheckSetAllNull(vals, true); + CheckSetAllNull(vals, false); + + const double true_prob = 0.5; + + vals[0] = rng_->Boolean(length, true_prob); + CheckSetAllNull(vals, true); + CheckSetAllNull(vals, false); + + auto arr_all_nulls = rng_->Boolean(length, true_prob, /*null_probability=*/1); + + // One value is all null + vals = {rng_->Boolean(length, true_prob, /*null_probability=*/0.5), arr_all_nulls}; + CheckSetAllNull(vals, true); + CheckSetAllNull(vals, false); + + // A value is NullType + std::shared_ptr null_arr = std::make_shared(length); + vals = {rng_->Boolean(length, true_prob), null_arr}; + CheckSetAllNull(vals, true); + CheckSetAllNull(vals, false); + + // Other nitty-gritty scenarios + { + // An all-null bitmap is zero-copied over, even though there is a + // null-scalar earlier in the batch + ArrayData output(boolean(), length, {nullptr, nullptr}); + ExecBatch batch({MakeNullScalar(boolean()), arr_all_nulls}, length); + ASSERT_OK(PropagateNulls(ctx_.get(), batch, &output)); + ASSERT_EQ(arr_all_nulls->data()->buffers[0].get(), output.buffers[0].get()); + } +} + +TEST_F(TestPropagateNulls, SingleValueWithNulls) { + // Input offset is non-zero (0 mod 8 and nonzero mod 8 cases) + const int64_t length = 100; + auto arr = rng_->Boolean(length, 0.5, /*null_probability=*/0.5); + + auto CheckSliced = [&](int64_t offset, bool preallocate = false, + int64_t out_offset = 0) { + // Unaligned bitmap, zero copy not possible + auto sliced = arr->Slice(offset); + std::vector vals = {sliced}; + + ArrayData output(boolean(), vals[0].length(), {nullptr, nullptr}); + output.offset = out_offset; + + ExecBatch batch(vals, vals[0].length()); + + std::shared_ptr preallocated_bitmap; + if (preallocate) { + ASSERT_OK_AND_ASSIGN( + preallocated_bitmap, + AllocateBuffer(BitUtil::BytesForBits(sliced->length() + out_offset))); + std::memset(preallocated_bitmap->mutable_data(), 0, preallocated_bitmap->size()); + output.buffers[0] = preallocated_bitmap; + } else { + ASSERT_EQ(0, output.offset); + } + + ASSERT_OK(PropagateNulls(ctx_.get(), batch, &output)); + + if (!preallocate) { + const Buffer* parent_buf = arr->data()->buffers[0].get(); + if (offset == 0) { + // Validity bitmap same, no slice + ASSERT_EQ(parent_buf, output.buffers[0].get()); + } else if (offset % 8 == 0) { + // Validity bitmap sliced + ASSERT_NE(parent_buf, output.buffers[0].get()); + ASSERT_EQ(parent_buf, output.buffers[0]->parent().get()); + } else { + // New memory for offset not 0 mod 8 + ASSERT_NE(parent_buf, output.buffers[0].get()); + ASSERT_EQ(nullptr, output.buffers[0]->parent()); + } + } else { + // preallocated, so check that the validity bitmap is unbothered + ASSERT_EQ(preallocated_bitmap.get(), output.buffers[0].get()); + } + + ASSERT_EQ(arr->Slice(offset)->null_count(), output.GetNullCount()); + + ASSERT_TRUE(internal::BitmapEquals(output.buffers[0]->data(), output.offset, + sliced->null_bitmap_data(), sliced->offset(), + output.length)); + AssertValidityZeroExtraBits(output); + }; + + CheckSliced(8); + CheckSliced(7); + CheckSliced(8, /*preallocated=*/true); + CheckSliced(7, true); + CheckSliced(8, true, /*offset=*/4); + CheckSliced(7, true, 4); +} + +TEST_F(TestPropagateNulls, ZeroCopyWhenZeroNullsOnOneInput) { + const int64_t length = 16; + + constexpr uint8_t validity_bitmap[8] = {254, 0, 0, 0, 0, 0, 0, 0}; + auto nulls = std::make_shared(validity_bitmap, 8); + + ArrayData some_nulls(boolean(), 16, {nulls, nullptr}, /*null_count=*/9); + ArrayData no_nulls(boolean(), length, {nullptr, nullptr}, /*null_count=*/0); + + ArrayData output(boolean(), length, {nullptr, nullptr}); + ExecBatch batch({some_nulls, no_nulls}, length); + ASSERT_OK(PropagateNulls(ctx_.get(), batch, &output)); + ASSERT_EQ(nulls.get(), output.buffers[0].get()); + ASSERT_EQ(9, output.null_count); + + // Flip order of args + output = ArrayData(boolean(), length, {nullptr, nullptr}); + batch.values = {no_nulls, no_nulls, some_nulls}; + ASSERT_OK(PropagateNulls(ctx_.get(), batch, &output)); + ASSERT_EQ(nulls.get(), output.buffers[0].get()); + ASSERT_EQ(9, output.null_count); + + // Check that preallocated memory is not clobbered + uint8_t bitmap_data[2] = {0, 0}; + auto preallocated_mem = std::make_shared(bitmap_data, 2); + output.null_count = kUnknownNullCount; + output.buffers[0] = preallocated_mem; + ASSERT_OK(PropagateNulls(ctx_.get(), batch, &output)); + + ASSERT_EQ(preallocated_mem.get(), output.buffers[0].get()); + ASSERT_EQ(9, output.null_count); + ASSERT_EQ(254, bitmap_data[0]); + ASSERT_EQ(0, bitmap_data[1]); +} + +TEST_F(TestPropagateNulls, IntersectsNulls) { + const int64_t length = 16; + + // 0b01111111 0b11001111 + constexpr uint8_t bitmap1[8] = {127, 207, 0, 0, 0, 0, 0, 0}; + + // 0b11111110 0b01111111 + constexpr uint8_t bitmap2[8] = {254, 127, 0, 0, 0, 0, 0, 0}; + + // 0b11101111 0b11111110 + constexpr uint8_t bitmap3[8] = {239, 254, 0, 0, 0, 0, 0, 0}; + + ArrayData arr1(boolean(), length, {std::make_shared(bitmap1, 8), nullptr}); + ArrayData arr2(boolean(), length, {std::make_shared(bitmap2, 8), nullptr}); + ArrayData arr3(boolean(), length, {std::make_shared(bitmap3, 8), nullptr}); + + auto CheckCase = [&](std::vector values, int64_t ex_null_count, + const uint8_t* ex_bitmap, bool preallocate = false, + int64_t output_offset = 0) { + ExecBatch batch(values, length); + + std::shared_ptr nulls; + if (preallocate) { + // Make the buffer one byte bigger so we can have non-zero offsets + ASSERT_OK_AND_ASSIGN(nulls, AllocateBuffer(3)); + std::memset(nulls->mutable_data(), 0, nulls->size()); + } else { + // non-zero output offset not permitted unless the output memory is + // preallocated + ASSERT_EQ(0, output_offset); + } + ArrayData output(boolean(), length, {nulls, nullptr}); + output.offset = output_offset; + + ASSERT_OK(PropagateNulls(ctx_.get(), batch, &output)); + + // Preallocated memory used + if (preallocate) { + ASSERT_EQ(nulls.get(), output.buffers[0].get()); + } + + EXPECT_EQ(kUnknownNullCount, output.null_count); + EXPECT_EQ(ex_null_count, output.GetNullCount()); + + const auto& out_buffer = *output.buffers[0]; + + ASSERT_TRUE(internal::BitmapEquals(out_buffer.data(), output_offset, ex_bitmap, + /*ex_offset=*/0, length)); + + // Now check that the rest of the bits in out_buffer are still 0 + AssertValidityZeroExtraBits(output); + }; + + // 0b01101110 0b01001110 + uint8_t expected1[2] = {110, 78}; + CheckCase({arr1, arr2, arr3}, 7, expected1); + CheckCase({arr1, arr2, arr3}, 7, expected1, /*preallocate=*/true); + CheckCase({arr1, arr2, arr3}, 7, expected1, /*preallocate=*/true, + /*output_offset=*/4); + + // 0b01111110 0b01001111 + uint8_t expected2[2] = {126, 79}; + CheckCase({arr1, arr2}, 5, expected2); + CheckCase({arr1, arr2}, 5, expected2, /*preallocate=*/true, + /*output_offset=*/4); +} + +TEST_F(TestPropagateNulls, NullOutputTypeNoop) { + // Ensure we leave the buffers alone when the output type is null() + const int64_t length = 100; + ExecBatch batch({rng_->Boolean(100, 0.5, 0.5)}, length); + + ArrayData output(null(), length, {nullptr}); + ASSERT_OK(PropagateNulls(ctx_.get(), batch, &output)); + ASSERT_EQ(nullptr, output.buffers[0]); +} + +// ---------------------------------------------------------------------- +// ExecBatchIterator + +class TestExecBatchIterator : public TestComputeInternals { + public: + void SetupIterator(std::vector args, int64_t max_chunksize = -1) { + ASSERT_OK_AND_ASSIGN(iterator_, + ExecBatchIterator::Make(std::move(args), max_chunksize)); + } + void CheckIteration(const std::vector& args, int chunksize, + const std::vector& ex_batch_sizes) { + SetupIterator(args, chunksize); + ExecBatch batch; + int64_t position = 0; + for (size_t i = 0; i < ex_batch_sizes.size(); ++i) { + ASSERT_EQ(position, iterator_->position()); + ASSERT_TRUE(iterator_->Next(&batch)); + ASSERT_EQ(ex_batch_sizes[i], batch.length); + + for (size_t j = 0; j < args.size(); ++j) { + switch (args[j].kind()) { + case Datum::SCALAR: + ASSERT_TRUE(args[j].scalar()->Equals(batch[j].scalar())); + break; + case Datum::ARRAY: + AssertArraysEqual(*args[j].make_array()->Slice(position, batch.length), + *batch[j].make_array()); + break; + case Datum::CHUNKED_ARRAY: { + const ChunkedArray& carr = *args[j].chunked_array(); + if (batch.length == 0) { + ASSERT_EQ(0, carr.length()); + } else { + auto arg_slice = carr.Slice(position, batch.length); + // The sliced ChunkedArrays should only ever be 1 chunk + ASSERT_EQ(1, arg_slice->num_chunks()); + AssertArraysEqual(*arg_slice->chunk(0), *batch[j].make_array()); + } + } break; + default: + break; + } + } + position += ex_batch_sizes[i]; + } + // Ensure that the iterator is exhausted + ASSERT_FALSE(iterator_->Next(&batch)); + + ASSERT_EQ(iterator_->length(), iterator_->position()); + } + + protected: + std::unique_ptr iterator_; +}; + +TEST_F(TestExecBatchIterator, Basics) { + const int64_t length = 100; + + // Simple case with a single chunk + std::vector args = {Datum(GetInt32Array(length)), Datum(GetFloat64Array(length)), + Datum(std::make_shared(3))}; + SetupIterator(args); + + ExecBatch batch; + ASSERT_TRUE(iterator_->Next(&batch)); + ASSERT_EQ(3, batch.values.size()); + ASSERT_EQ(3, batch.num_values()); + ASSERT_EQ(length, batch.length); + + std::vector descrs = batch.GetDescriptors(); + ASSERT_EQ(ValueDescr::Array(int32()), descrs[0]); + ASSERT_EQ(ValueDescr::Array(float64()), descrs[1]); + ASSERT_EQ(ValueDescr::Scalar(int32()), descrs[2]); + + AssertArraysEqual(*args[0].make_array(), *batch[0].make_array()); + AssertArraysEqual(*args[1].make_array(), *batch[1].make_array()); + ASSERT_TRUE(args[2].scalar()->Equals(batch[2].scalar())); + + ASSERT_EQ(length, iterator_->position()); + ASSERT_FALSE(iterator_->Next(&batch)); + + // Split into chunks of size 16 + CheckIteration(args, /*chunksize=*/16, {16, 16, 16, 16, 16, 16, 4}); +} + +TEST_F(TestExecBatchIterator, InputValidation) { + std::vector args = {Datum(GetInt32Array(10)), Datum(GetInt32Array(9))}; + ASSERT_RAISES(Invalid, ExecBatchIterator::Make(args)); + + args = {Datum(GetInt32Array(9)), Datum(GetInt32Array(10))}; + ASSERT_RAISES(Invalid, ExecBatchIterator::Make(args)); + + args = {Datum(GetInt32Array(10))}; + ASSERT_OK_AND_ASSIGN(auto iterator, ExecBatchIterator::Make(args)); + ASSERT_EQ(10, iterator->max_chunksize()); +} + +TEST_F(TestExecBatchIterator, ChunkedArrays) { + std::vector args = {Datum(GetInt32Chunked({0, 20, 10})), + Datum(GetInt32Chunked({15, 15})), Datum(GetInt32Array(30)), + Datum(std::make_shared(5)), + Datum(MakeNullScalar(boolean()))}; + + CheckIteration(args, /*chunksize=*/10, {10, 5, 5, 10}); + CheckIteration(args, /*chunksize=*/20, {15, 5, 10}); + CheckIteration(args, /*chunksize=*/30, {15, 5, 10}); +} + +TEST_F(TestExecBatchIterator, ZeroLengthCases) { + auto carr = std::shared_ptr(new ChunkedArray({}, int32())); + + // Zero-length ChunkedArray with zero chunks + std::vector args = {Datum(carr)}; + CheckIteration(args, /*chunksize=*/10, {0}); + + // Zero-length array + args = {Datum(GetInt32Array(0))}; + CheckIteration(args, /*chunksize=*/10, {0}); + + // ChunkedArray with single empty chunk + args = {Datum(GetInt32Chunked({0}))}; + CheckIteration(args, /*chunksize=*/10, {0}); +} + +// ---------------------------------------------------------------------- +// Scalar function execution + +void ExecCopy(KernelContext*, const ExecBatch& batch, Datum* out) { + DCHECK_EQ(1, batch.num_values()); + const auto& type = checked_cast(*batch[0].type()); + int value_size = type.bit_width() / 8; + + const ArrayData& arg0 = *batch[0].array(); + ArrayData* out_arr = out->mutable_array(); + uint8_t* dst = out_arr->buffers[1]->mutable_data() + out_arr->offset * value_size; + const uint8_t* src = arg0.buffers[1]->data() + arg0.offset * value_size; + std::memcpy(dst, src, batch.length * value_size); +} + +void ExecComputedBitmap(KernelContext* ctx, const ExecBatch& batch, Datum* out) { + // Propagate nulls not used. Check that the out bitmap isn't the same already + // as the input bitmap + const ArrayData& arg0 = *batch[0].array(); + ArrayData* out_arr = out->mutable_array(); + + DCHECK(!internal::BitmapEquals(arg0.buffers[0]->data(), arg0.offset, + out_arr->buffers[0]->data(), out_arr->offset, + batch.length)); + internal::CopyBitmap(arg0.buffers[0]->data(), arg0.offset, batch.length, + out_arr->buffers[0]->mutable_data(), out_arr->offset); + ExecCopy(ctx, batch, out); +} + +void ExecNoPreallocatedData(KernelContext* ctx, const ExecBatch& batch, Datum* out) { + // Validity preallocated, but not the data + ArrayData* out_arr = out->mutable_array(); + DCHECK_EQ(0, out_arr->offset); + const auto& type = checked_cast(*batch[0].type()); + int value_size = type.bit_width() / 8; + Status s = (ctx->Allocate(out_arr->length * value_size).Value(&out_arr->buffers[1])); + DCHECK_OK(s); + ExecCopy(ctx, batch, out); +} + +void ExecNoPreallocatedAnything(KernelContext* ctx, const ExecBatch& batch, Datum* out) { + // Neither validity nor data preallocated + ArrayData* out_arr = out->mutable_array(); + DCHECK_EQ(0, out_arr->offset); + Status s = (ctx->AllocateBitmap(out_arr->length).Value(&out_arr->buffers[0])); + DCHECK_OK(s); + const ArrayData& arg0 = *batch[0].array(); + internal::CopyBitmap(arg0.buffers[0]->data(), arg0.offset, batch.length, + out_arr->buffers[0]->mutable_data(), /*offset=*/0); + + // Reuse the kernel that allocates the data + ExecNoPreallocatedData(ctx, batch, out); +} + +struct ExampleOptions : public FunctionOptions { + std::shared_ptr value; + explicit ExampleOptions(std::shared_ptr value) : value(std::move(value)) {} +}; + +struct ExampleState : public KernelState { + std::shared_ptr value; + explicit ExampleState(std::shared_ptr value) : value(std::move(value)) {} +}; + +std::unique_ptr InitStateful(KernelContext*, const Kernel&, + const FunctionOptions* options) { + auto func_options = static_cast(options); + return std::unique_ptr(new ExampleState{func_options->value}); +} + +void ExecStateful(KernelContext* ctx, const ExecBatch& batch, Datum* out) { + // We take the value from the state and multiply the data in batch[0] with it + ExampleState* state = static_cast(ctx->state()); + int32_t multiplier = checked_cast(*state->value).value; + + const ArrayData& arg0 = *batch[0].array(); + ArrayData* out_arr = out->mutable_array(); + const int32_t* arg0_data = arg0.GetValues(1); + int32_t* dst = out_arr->GetMutableValues(1); + for (int64_t i = 0; i < arg0.length; ++i) { + dst[i] = arg0_data[i] * multiplier; + } +} + +void ExecAddInt32(KernelContext* ctx, const ExecBatch& batch, Datum* out) { + const Int32Scalar& arg0 = batch[0].scalar_as(); + const Int32Scalar& arg1 = batch[1].scalar_as(); + out->value = std::make_shared(arg0.value + arg1.value); +} + +class TestExecScalarFunction : public TestComputeInternals { + public: + void SetUp() { + TestComputeInternals::SetUp(); + + AddCopyFunctions(); + AddNoPreallocateFunctions(); + AddStatefulFunction(); + AddScalarFunction(); + } + + void AddCopyFunctions() { + // This function simply copies memory from the input argument into the + // (preallocated) output + auto func = std::make_shared("copy", 1); + + // Add a few kernels. Our implementation only accepts arrays + ASSERT_OK(func->AddKernel({InputType::Array(uint8())}, uint8(), ExecCopy)); + ASSERT_OK(func->AddKernel({InputType::Array(int32())}, int32(), ExecCopy)); + ASSERT_OK(func->AddKernel({InputType::Array(float64())}, float64(), ExecCopy)); + ASSERT_OK(registry_->AddFunction(func)); + + // A version which doesn't want the executor to call PropagateNulls + auto func2 = std::make_shared("copy_computed_bitmap", 1); + ScalarKernel kernel({InputType::Array(uint8())}, uint8(), ExecComputedBitmap); + kernel.null_handling = NullHandling::COMPUTED_PREALLOCATE; + ASSERT_OK(func2->AddKernel(kernel)); + ASSERT_OK(registry_->AddFunction(func2)); + } + + void AddNoPreallocateFunctions() { + // A function that allocates its own output memory. We have cases for both + // non-preallocated data and non-preallocated validity bitmap + auto f1 = std::make_shared("nopre_data", 1); + auto f2 = std::make_shared("nopre_validity_or_data", 1); + + ScalarKernel kernel({InputType::Array(uint8())}, uint8(), ExecNoPreallocatedData); + kernel.mem_allocation = MemAllocation::NO_PREALLOCATE; + ASSERT_OK(f1->AddKernel(kernel)); + + kernel.exec = ExecNoPreallocatedAnything; + kernel.null_handling = NullHandling::COMPUTED_NO_PREALLOCATE; + ASSERT_OK(f2->AddKernel(kernel)); + + ASSERT_OK(registry_->AddFunction(f1)); + ASSERT_OK(registry_->AddFunction(f2)); + } + + void AddStatefulFunction() { + // This function's behavior depends on a static parameter that is made + // available to the kernel's execution function through its Options object + auto func = std::make_shared("stateful", 1); + + ScalarKernel kernel({InputType::Array(int32())}, int32(), ExecStateful, InitStateful); + ASSERT_OK(func->AddKernel(kernel)); + ASSERT_OK(registry_->AddFunction(func)); + } + + void AddScalarFunction() { + auto func = std::make_shared("scalar_add_int32", 2); + ASSERT_OK(func->AddKernel({InputType::Scalar(int32()), InputType::Scalar(int32())}, + int32(), ExecAddInt32)); + ASSERT_OK(registry_->AddFunction(func)); + } +}; + +TEST_F(TestExecScalarFunction, ArgumentValidation) { + // Copy accepts only a single array argument + Datum d1(GetInt32Array(10)); + + // Too many args + std::vector args = {d1, d1}; + ASSERT_RAISES(Invalid, ExecScalarFunction(exec_ctx_.get(), "copy", args)); + + // Too few + args = {}; + ASSERT_RAISES(Invalid, ExecScalarFunction(exec_ctx_.get(), "copy", args)); + + // Cannot do scalar + args = {Datum(std::make_shared(5))}; + ASSERT_RAISES(KeyError, ExecScalarFunction(exec_ctx_.get(), "copy", args)); +} + +TEST_F(TestExecScalarFunction, PreallocationCases) { + double null_prob = 0.2; + + auto arr = GetUInt8Array(50, null_prob); + + auto CheckFunction = [&](std::string func_name) { + ResetContexts(); + + // The default should be a single array output + { + std::vector args = {Datum(arr)}; + ASSERT_OK_AND_ASSIGN(Datum result, + ExecScalarFunction(exec_ctx_.get(), func_name, args)); + ASSERT_EQ(Datum::ARRAY, result.kind()); + AssertArraysEqual(*arr, *result.make_array()); + } + + // Set the exec_chunksize to be smaller, so now we have several invocations + // of the kernel, but still the output is onee array + { + std::vector args = {Datum(arr)}; + exec_ctx_->set_exec_chunksize(8); + ASSERT_OK_AND_ASSIGN(Datum result, + ExecScalarFunction(exec_ctx_.get(), func_name, args)); + AssertArraysEqual(*arr, *result.make_array()); + } + + exec_ctx_->set_exec_chunksize(12); + + // Chunksize not multiple of 8 + { + std::vector args = {Datum(arr)}; + exec_ctx_->set_exec_chunksize(12); + ASSERT_OK_AND_ASSIGN(Datum result, + ExecScalarFunction(exec_ctx_.get(), func_name, args)); + AssertArraysEqual(*arr, *result.make_array()); + } + + // Input is chunked, output has one big chunk + { + auto carr = std::shared_ptr( + new ChunkedArray({arr->Slice(0, 15), arr->Slice(15)})); + std::vector args = {Datum(carr)}; + ASSERT_OK_AND_ASSIGN(Datum result, + ExecScalarFunction(exec_ctx_.get(), func_name, args)); + std::shared_ptr actual = result.chunked_array(); + ASSERT_EQ(1, actual->num_chunks()); + AssertChunkedEquivalent(*carr, *actual); + } + + // Preallocate independently for each batch + { + std::vector args = {Datum(arr)}; + exec_ctx_->set_preallocate_contiguous(false); + exec_ctx_->set_exec_chunksize(20); + ASSERT_OK_AND_ASSIGN(Datum result, + ExecScalarFunction(exec_ctx_.get(), func_name, args)); + ASSERT_EQ(Datum::CHUNKED_ARRAY, result.kind()); + const ChunkedArray& carr = *result.chunked_array(); + ASSERT_EQ(3, carr.num_chunks()); + AssertArraysEqual(*arr->Slice(0, 20), *carr.chunk(0)); + AssertArraysEqual(*arr->Slice(20, 20), *carr.chunk(1)); + AssertArraysEqual(*arr->Slice(40), *carr.chunk(2)); + } + }; + + CheckFunction("copy"); + CheckFunction("copy_computed_bitmap"); +} + +TEST_F(TestExecScalarFunction, BasicNonStandardCases) { + // Test a handful of cases + // + // * Validity bitmap computed by kernel rather than using PropagateNulls + // * Data not pre-allocated + // * Validity bitmap not pre-allocated + + double null_prob = 0.2; + + auto arr = GetUInt8Array(100, null_prob); + std::vector args = {Datum(arr)}; + + auto CheckFunction = [&](std::string func_name) { + ResetContexts(); + + // The default should be a single array output + { + exec_ctx_->set_exec_chunksize(-1); + ASSERT_OK_AND_ASSIGN(Datum result, + ExecScalarFunction(exec_ctx_.get(), func_name, args)); + AssertArraysEqual(*arr, *result.make_array(), true); + } + + // Split execution into 3 chunks + { + exec_ctx_->set_exec_chunksize(40); + ASSERT_OK_AND_ASSIGN(Datum result, + ExecScalarFunction(exec_ctx_.get(), func_name, args)); + ASSERT_EQ(Datum::CHUNKED_ARRAY, result.kind()); + const ChunkedArray& carr = *result.chunked_array(); + ASSERT_EQ(3, carr.num_chunks()); + AssertArraysEqual(*arr->Slice(0, 40), *carr.chunk(0)); + AssertArraysEqual(*arr->Slice(40, 40), *carr.chunk(1)); + AssertArraysEqual(*arr->Slice(80), *carr.chunk(2)); + } + }; + + CheckFunction("nopre_data"); + CheckFunction("nopre_validity_or_data"); +} + +TEST_F(TestExecScalarFunction, StatefulKernel) { + auto input = ArrayFromJSON(int32(), "[1, 2, 3, null, 5]"); + auto multiplier = std::make_shared(2); + auto expected = ArrayFromJSON(int32(), "[2, 4, 6, null, 10]"); + + ExampleOptions options(multiplier); + std::vector args = {Datum(input)}; + ASSERT_OK_AND_ASSIGN(Datum result, + ExecScalarFunction(exec_ctx_.get(), "stateful", args, &options)); + AssertArraysEqual(*expected, *result.make_array()); +} + +TEST_F(TestExecScalarFunction, ScalarFunction) { + std::vector args = {Datum(std::make_shared(5)), + Datum(std::make_shared(7))}; + ASSERT_OK_AND_ASSIGN(Datum result, + ExecScalarFunction(exec_ctx_.get(), "scalar_add_int32", args)); + ASSERT_EQ(Datum::SCALAR, result.kind()); + + auto expected = std::make_shared(12); + ASSERT_TRUE(expected->Equals(*result.scalar())); +} + +} // namespace detail +} // namespace compute +} // namespace arrow diff --git a/cpp/src/arrow/compute/filter.h b/cpp/src/arrow/compute/filter.h new file mode 100644 index 00000000000..260e9909b00 --- /dev/null +++ b/cpp/src/arrow/compute/filter.h @@ -0,0 +1,54 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#pragma once + +#include + +#include "arrow/compute/exec.h" +#include "arrow/compute/options.h" +#include "arrow/datum.h" +#include "arrow/result.h" + +namespace arrow { +namespace compute { + +class ExecContext; + +/// \brief Filter with a boolean selection filter +/// +/// The output will be populated with values from the input at positions +/// where the selection filter is not 0. Nulls in the filter will be handled +/// based on options.null_selection_behavior. +/// +/// For example given values = ["a", "b", "c", null, "e", "f"] and +/// filter = [0, 1, 1, 0, null, 1], the output will be +/// (null_selection_behavior == DROP) = ["b", "c", "f"] +/// (null_selection_behavior == EMIT_NULL) = ["b", "c", null, "f"] +/// +/// \param[in] values array to filter +/// \param[in] filter indicates which values should be filtered out +/// \param[in] options configures null_selection_behavior +/// \param[in] context the function execution context, optional +/// \return the resulting datum +ARROW_EXPORT +Result Filter(const Datum& values, const Datum& filter, + FilterOptions options = FilterOptions::Defaults(), + ExecContext* context = NULLPTR); + +} // namespace compute +} // namespace arrow diff --git a/cpp/src/arrow/compute/function.cc b/cpp/src/arrow/compute/function.cc new file mode 100644 index 00000000000..1c29ab7ed3b --- /dev/null +++ b/cpp/src/arrow/compute/function.cc @@ -0,0 +1,150 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#include "arrow/compute/function.h" + +#include +#include +#include + +namespace arrow { + +struct ValueDescr; + +namespace compute { + +static Status CheckArity(const std::vector& args, const FunctionArity& arity) { + const int passed_num_args = static_cast(args.size()); + if (arity.is_varargs && passed_num_args < arity.num_args) { + return Status::Invalid("Varargs function needs at least ", arity.num_args, + " arguments but kernel accepts only ", passed_num_args); + } else if (!arity.is_varargs && passed_num_args != arity.num_args) { + return Status::Invalid("Function accepts ", arity.num_args, + " arguments but kernel accepts ", passed_num_args); + } + return Status::OK(); +} + +template +std::string FormatArgTypes(const std::vector& descrs) { + std::stringstream ss; + ss << "("; + for (size_t i = 0; i < descrs.size(); ++i) { + if (i > 0) { + ss << ", "; + } + ss << descrs[i].ToString(); + } + ss << ")"; + return ss.str(); +} + +template +Result DispatchExactImpl(const Function& func, + const std::vector& kernels, + const std::vector& values) { + const int passed_num_args = static_cast(values.size()); + + // Validate arity + const FunctionArity arity = func.arity(); + if (arity.is_varargs && passed_num_args < arity.num_args) { + return Status::Invalid("Varargs function needs at least ", arity.num_args, + " arguments but passed only ", passed_num_args); + } else if (!arity.is_varargs && passed_num_args != arity.num_args) { + return Status::Invalid("Function accepts ", arity.num_args, " arguments but passed ", + passed_num_args); + } + for (const auto& kernel : kernels) { + if (kernel.signature->MatchesInputs(values)) { + return &kernel; + } + } + return Status::KeyError("Function ", func.name(), + " has no kernel exactly matching input types ", + FormatArgTypes(values)); +} + +Status ScalarFunction::AddKernel(std::vector in_types, OutputType out_type, + ArrayKernelExec exec, KernelInit init) { + RETURN_NOT_OK(CheckArity(in_types, arity_)); + + if (arity_.is_varargs && in_types.size() != 1) { + return Status::Invalid("Varargs signatures must have exactly one input type"); + } + auto sig = + KernelSignature::Make(std::move(in_types), std::move(out_type), arity_.is_varargs); + kernels_.emplace_back(std::move(sig), exec, init); + return Status::OK(); +} + +Status ScalarFunction::AddKernel(ScalarKernel kernel) { + RETURN_NOT_OK(CheckArity(kernel.signature->in_types(), arity_)); + if (arity_.is_varargs && !kernel.signature->is_varargs()) { + return Status::Invalid("Function accepts varargs but kernel signature does not"); + } + kernels_.emplace_back(std::move(kernel)); + return Status::OK(); +} + +Result ScalarFunction::DispatchExact( + const std::vector& values) const { + return DispatchExactImpl(*this, kernels_, values); +} + +Status VectorFunction::AddKernel(std::vector in_types, OutputType out_type, + ArrayKernelExec exec, KernelInit init) { + RETURN_NOT_OK(CheckArity(in_types, arity_)); + + if (arity_.is_varargs && in_types.size() != 1) { + return Status::Invalid("Varargs signatures must have exactly one input type"); + } + auto sig = + KernelSignature::Make(std::move(in_types), std::move(out_type), arity_.is_varargs); + kernels_.emplace_back(std::move(sig), exec, init); + return Status::OK(); +} + +Status VectorFunction::AddKernel(VectorKernel kernel) { + RETURN_NOT_OK(CheckArity(kernel.signature->in_types(), arity_)); + if (arity_.is_varargs && !kernel.signature->is_varargs()) { + return Status::Invalid("Function accepts varargs but kernel signature does not"); + } + kernels_.emplace_back(std::move(kernel)); + return Status::OK(); +} + +Result VectorFunction::DispatchExact( + const std::vector& values) const { + return DispatchExactImpl(*this, kernels_, values); +} + +Status ScalarAggregateFunction::AddKernel(ScalarAggregateKernel kernel) { + RETURN_NOT_OK(CheckArity(kernel.signature->in_types(), arity_)); + if (arity_.is_varargs && !kernel.signature->is_varargs()) { + return Status::Invalid("Function accepts varargs but kernel signature does not"); + } + kernels_.emplace_back(std::move(kernel)); + return Status::OK(); +} + +Result ScalarAggregateFunction::DispatchExact( + const std::vector& values) const { + return DispatchExactImpl(*this, kernels_, values); +} + +} // namespace compute +} // namespace arrow diff --git a/cpp/src/arrow/compute/function.h b/cpp/src/arrow/compute/function.h new file mode 100644 index 00000000000..3fa9ab1ae24 --- /dev/null +++ b/cpp/src/arrow/compute/function.h @@ -0,0 +1,197 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +// NOTE: API is EXPERIMENTAL and will change without going through a +// deprecation cycle + +#pragma once + +#include +#include +#include + +#include "arrow/compute/kernel.h" +#include "arrow/compute/options.h" // IWYU pragma: keep +#include "arrow/result.h" +#include "arrow/status.h" +#include "arrow/util/visibility.h" + +namespace arrow { + +struct ValueDescr; + +namespace compute { + +/// \brief Contains the number of required arguments for the function +struct ARROW_EXPORT FunctionArity { + static FunctionArity Nullary() { return FunctionArity(0, false); } + static FunctionArity Unary() { return FunctionArity(1, false); } + static FunctionArity Binary() { return FunctionArity(2, false); } + static FunctionArity Ternary() { return FunctionArity(3, false); } + static FunctionArity Varargs(int min_args = 1) { return FunctionArity(min_args, true); } + + FunctionArity(int num_args, bool is_varargs = false) // NOLINT implicit conversion + : num_args(num_args), is_varargs(is_varargs) {} + + /// The number of required arguments (or the minimum number for varargs + /// functions) + int num_args; + + /// If true, then the num_args is the minimum number of required arguments + bool is_varargs = false; +}; + +/// \brief Base class for function containers that are capable of dispatch to +/// kernel implementations +class ARROW_EXPORT Function { + public: + /// \brief The kind of function, which indicates in what contexts it is + /// valid for use + enum Kind { + /// A function that performs scalar data operations on whole arrays of + /// data. Can generally process Array or Scalar values. The size of the + /// output will be the same as the size (or broadcasted size, in the case + /// of mixing Array and Scalar inputs) of the input. + SCALAR, + + /// A function with array input and output whose behavior depends on the + /// values of the entire arrays passed, rather than the value of each scalar + /// value. + VECTOR, + + /// A function that computes scalar summary statistics from array input. + SCALAR_AGGREGATE + }; + + virtual ~Function() = default; + + /// \brief The name of the kernel. The registry enforces uniqueness of names + const std::string& name() const { return name_; } + + /// \brief The kind of kernel, which indicates in what contexts it is valid + /// for use + Function::Kind kind() const { return kind_; } + + /// \brief Contains the number of arguments the function requires + const FunctionArity& arity() const { return arity_; } + + /// \brief Returns the number of registered kernels for this function + virtual int num_kernels() const = 0; + + protected: + Function(std::string name, Function::Kind kind, const FunctionArity& arity) + : name_(std::move(name)), kind_(kind), arity_(arity) {} + std::string name_; + Function::Kind kind_; + FunctionArity arity_; +}; + +namespace detail { + +template +class FunctionImpl : public Function { + public: + /// \brief Return vector of all available kernels for this function + const std::vector& kernels() const { return kernels_; } + + int num_kernels() const override { return static_cast(kernels_.size()); } + + protected: + FunctionImpl(std::string name, Function::Kind kind, const FunctionArity& arity) + : Function(std::move(name), kind, arity) {} + + std::vector kernels_; +}; + +} // namespace detail + +/// \brief A function that executes elementwise operations on arrays or +/// scalars, and therefore whose results generally do not depend on the order +/// of the values in the arguments. Accepts and returns arrays that are all of +/// the same size. These functions roughly correspond to the functions used in +/// SQL expressions. +class ARROW_EXPORT ScalarFunction : public detail::FunctionImpl { + public: + using KernelType = ScalarKernel; + + ScalarFunction(std::string name, const FunctionArity& arity) + : detail::FunctionImpl(std::move(name), Function::SCALAR, arity) {} + + /// \brief Add a simple kernel (function implementation) with given + /// input/output types, no required state initialization, preallocation for + /// fixed-width types, and default null handling (intersect validity bitmaps + /// of inputs) + Status AddKernel(std::vector in_types, OutputType out_type, + ArrayKernelExec func, KernelInit init = NULLPTR); + + /// \brief Add a kernel (function implementation). Returns error if fails + /// to match the other parameters of the function + Status AddKernel(ScalarKernel kernel); + + /// \brief Return the first kernel that can execute the function given the + /// exact argument types (without implicit type casts or scalar->array + /// promotions) + Result DispatchExact(const std::vector& values) const; +}; + +/// \brief A function that executes general array operations that may yield +/// outputs of different sizes or have results that depend on the whole array +/// contents. These functions roughly correspond to the functions found in +/// non-SQL array languages like APL and its derivatives +class ARROW_EXPORT VectorFunction : public detail::FunctionImpl { + public: + using KernelType = VectorKernel; + + VectorFunction(std::string name, const FunctionArity& arity) + : detail::FunctionImpl(std::move(name), Function::VECTOR, arity) {} + + /// \brief Add a simple kernel (function implementation) with given + /// input/output types, no required state initialization, preallocation for + /// fixed-width types, and default null handling (intersect validity bitmaps + /// of inputs) + Status AddKernel(std::vector in_types, OutputType out_type, + ArrayKernelExec func, KernelInit init = NULLPTR); + + /// \brief Add a kernel (function implementation). Returns error if fails + /// to match the other parameters of the function + Status AddKernel(VectorKernel kernel); + + /// \brief Return the first kernel that can execute the function given the + /// exact argument types (without implicit type casts or scalar->array + /// promotions) + Result DispatchExact(const std::vector& values) const; +}; + +class ARROW_EXPORT ScalarAggregateFunction + : public detail::FunctionImpl { + public: + using KernelType = ScalarAggregateKernel; + + ScalarAggregateFunction(std::string name, const FunctionArity& arity) + : detail::FunctionImpl(std::move(name), + Function::SCALAR_AGGREGATE, arity) {} + + /// \brief Add a kernel (function implementation). Returns error if fails + /// to match the other parameters of the function + Status AddKernel(ScalarAggregateKernel kernel); + + Result DispatchExact( + const std::vector& values) const; +}; + +} // namespace compute +} // namespace arrow diff --git a/cpp/src/arrow/compute/function_test.cc b/cpp/src/arrow/compute/function_test.cc new file mode 100644 index 00000000000..89c3ed00352 --- /dev/null +++ b/cpp/src/arrow/compute/function_test.cc @@ -0,0 +1,239 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#include +#include +#include + +#include + +#include "arrow/compute/function.h" +#include "arrow/compute/kernel.h" +#include "arrow/status.h" +#include "arrow/testing/gtest_util.h" +#include "arrow/type.h" + +namespace arrow { + +struct Datum; + +namespace compute { + +class ExecContext; +struct FunctionOptions; + +TEST(FunctionArity, Basics) { + auto nullary = FunctionArity::Nullary(); + ASSERT_EQ(0, nullary.num_args); + ASSERT_FALSE(nullary.is_varargs); + + auto unary = FunctionArity::Unary(); + ASSERT_EQ(1, unary.num_args); + + auto binary = FunctionArity::Binary(); + ASSERT_EQ(2, binary.num_args); + + auto ternary = FunctionArity::Ternary(); + ASSERT_EQ(3, ternary.num_args); + + auto varargs = FunctionArity::Varargs(); + ASSERT_EQ(1, varargs.num_args); + ASSERT_TRUE(varargs.is_varargs); + + auto varargs2 = FunctionArity::Varargs(2); + ASSERT_EQ(2, varargs2.num_args); + ASSERT_TRUE(varargs2.is_varargs); +} + +TEST(ScalarFunction, Basics) { + ScalarFunction func("scalar_test", 2); + ScalarFunction varargs_func("varargs_test", FunctionArity::Varargs(1)); + + ASSERT_EQ("scalar_test", func.name()); + ASSERT_EQ(2, func.arity().num_args); + ASSERT_FALSE(func.arity().is_varargs); + ASSERT_EQ(Function::SCALAR, func.kind()); + + ASSERT_EQ("varargs_test", varargs_func.name()); + ASSERT_EQ(1, varargs_func.arity().num_args); + ASSERT_TRUE(varargs_func.arity().is_varargs); + ASSERT_EQ(Function::SCALAR, varargs_func.kind()); +} + +TEST(VectorFunction, Basics) { + VectorFunction func("vector_test", 2); + VectorFunction varargs_func("varargs_test", FunctionArity::Varargs(1)); + + ASSERT_EQ("vector_test", func.name()); + ASSERT_EQ(2, func.arity().num_args); + ASSERT_FALSE(func.arity().is_varargs); + ASSERT_EQ(Function::VECTOR, func.kind()); + + ASSERT_EQ("varargs_test", varargs_func.name()); + ASSERT_EQ(1, varargs_func.arity().num_args); + ASSERT_TRUE(varargs_func.arity().is_varargs); + ASSERT_EQ(Function::VECTOR, varargs_func.kind()); +} + +auto ExecNYI = [](KernelContext* ctx, const ExecBatch& args, Datum* out) { + ctx->SetStatus(Status::NotImplemented("NYI")); + return; +}; + +template +void CheckAddDispatch(FunctionType* func) { + using KernelType = typename FunctionType::KernelType; + + ASSERT_EQ(0, func->num_kernels()); + ASSERT_EQ(0, func->kernels().size()); + + std::vector in_types1 = {int32(), int32()}; + OutputType out_type1 = int32(); + + ASSERT_OK(func->AddKernel(in_types1, out_type1, ExecNYI)); + ASSERT_OK(func->AddKernel({int32(), int8()}, int32(), ExecNYI)); + + // Duplicate sig is okay + ASSERT_OK(func->AddKernel(in_types1, out_type1, ExecNYI)); + + // Add given a descr + KernelType descr({float64(), float64()}, float64(), ExecNYI); + ASSERT_OK(func->AddKernel(descr)); + + ASSERT_EQ(4, func->num_kernels()); + ASSERT_EQ(4, func->kernels().size()); + + // Try adding some invalid kernels + ASSERT_RAISES(Invalid, func->AddKernel({}, int32(), ExecNYI)); + ASSERT_RAISES(Invalid, func->AddKernel({int32()}, int32(), ExecNYI)); + ASSERT_RAISES(Invalid, func->AddKernel({int8(), int8(), int8()}, int32(), ExecNYI)); + + // Add valid and invalid kernel using kernel struct directly + KernelType valid_kernel({boolean(), boolean()}, boolean(), ExecNYI); + ASSERT_OK(func->AddKernel(valid_kernel)); + + KernelType invalid_kernel({boolean()}, boolean(), ExecNYI); + ASSERT_RAISES(Invalid, func->AddKernel(invalid_kernel)); + + ASSERT_OK_AND_ASSIGN(const KernelType* kernel, func->DispatchExact({int32(), int32()})); + KernelSignature expected_sig(in_types1, out_type1); + ASSERT_TRUE(kernel->signature->Equals(expected_sig)); + + // No kernel available + ASSERT_RAISES(KeyError, func->DispatchExact({utf8(), utf8()})); + + // Wrong arity + ASSERT_RAISES(Invalid, func->DispatchExact({})); + ASSERT_RAISES(Invalid, func->DispatchExact({int32(), int32(), int32()})); +} + +TEST(ScalarVectorFunction, DispatchExact) { + ScalarFunction func1("scalar_test", 2); + VectorFunction func2("vector_test", 2); + + CheckAddDispatch(&func1); + CheckAddDispatch(&func2); +} + +TEST(ArrayFunction, Varargs) { + ScalarFunction va_func("va_test", FunctionArity::Varargs(1)); + + std::vector va_args = {int8()}; + + ASSERT_OK(va_func.AddKernel(va_args, int8(), ExecNYI)); + + // No input type passed + ASSERT_RAISES(Invalid, va_func.AddKernel({}, int8(), ExecNYI)); + + // Varargs function expect a single input type + ASSERT_RAISES(Invalid, va_func.AddKernel({int8(), int8()}, int8(), ExecNYI)); + + // Invalid sig + ScalarKernel non_va_kernel(std::make_shared(va_args, int8()), ExecNYI); + ASSERT_RAISES(Invalid, va_func.AddKernel(non_va_kernel)); + + std::vector args = {ValueDescr::Scalar(int8()), int8(), int8()}; + ASSERT_OK_AND_ASSIGN(const ScalarKernel* kernel, va_func.DispatchExact(args)); + ASSERT_TRUE(kernel->signature->MatchesInputs(args)); + + // No dispatch possible because args incompatible + args[2] = int32(); + ASSERT_RAISES(KeyError, va_func.DispatchExact(args)); +} + +TEST(ScalarAggregateFunction, Basics) { + ScalarAggregateFunction func("agg_test", 1); + + ASSERT_EQ("agg_test", func.name()); + ASSERT_EQ(1, func.arity().num_args); + ASSERT_FALSE(func.arity().is_varargs); + ASSERT_EQ(Function::SCALAR_AGGREGATE, func.kind()); +} + +std::unique_ptr NoopInit(KernelContext*, const Kernel&, + const FunctionOptions*) { + return nullptr; +} + +void NoopConsume(KernelContext*, const ExecBatch&) {} +void NoopMerge(KernelContext*, const KernelState&, KernelState*) {} +void NoopFinalize(KernelContext*, Datum*) {} + +TEST(ScalarAggregateFunction, DispatchExact) { + ScalarAggregateFunction func("agg_test", 1); + + std::vector in_args = {ValueDescr::Array(int8())}; + ScalarAggregateKernel kernel(std::move(in_args), int64(), NoopInit, NoopConsume, + NoopMerge, NoopFinalize); + ASSERT_OK(func.AddKernel(kernel)); + + in_args = {float64()}; + kernel.signature = std::make_shared(in_args, float64()); + ASSERT_OK(func.AddKernel(kernel)); + + ASSERT_EQ(2, func.num_kernels()); + ASSERT_EQ(2, func.kernels().size()); + ASSERT_TRUE(func.kernels()[1].signature->Equals(*kernel.signature)); + + // Invalid arity + in_args = {}; + kernel.signature = std::make_shared(in_args, float64()); + ASSERT_RAISES(Invalid, func.AddKernel(kernel)); + + in_args = {float32(), float64()}; + kernel.signature = std::make_shared(in_args, float64()); + ASSERT_RAISES(Invalid, func.AddKernel(kernel)); + + std::vector dispatch_args = {ValueDescr::Array(int8())}; + ASSERT_OK_AND_ASSIGN(const ScalarAggregateKernel* selected_kernel, + func.DispatchExact(dispatch_args)); + ASSERT_EQ(&func.kernels()[0], selected_kernel); + ASSERT_TRUE(selected_kernel->signature->MatchesInputs(dispatch_args)); + + // We declared that only arrays are accepted + dispatch_args[0] = {ValueDescr::Scalar(int8())}; + ASSERT_RAISES(KeyError, func.DispatchExact(dispatch_args)); + + // Didn't qualify the float64() kernel so this actually dispatches (even + // though that may not be what you want) + dispatch_args[0] = {ValueDescr::Scalar(float64())}; + ASSERT_OK_AND_ASSIGN(selected_kernel, func.DispatchExact(dispatch_args)); + ASSERT_TRUE(selected_kernel->signature->MatchesInputs(dispatch_args)); +} + +} // namespace compute +} // namespace arrow diff --git a/cpp/src/arrow/compute/kernel.cc b/cpp/src/arrow/compute/kernel.cc new file mode 100644 index 00000000000..b03523c8be0 --- /dev/null +++ b/cpp/src/arrow/compute/kernel.cc @@ -0,0 +1,296 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#include "arrow/compute/kernel.h" + +#include +#include +#include +#include + +#include "arrow/buffer.h" +#include "arrow/compute/exec.h" +#include "arrow/result.h" +#include "arrow/util/bit_util.h" +#include "arrow/util/hashing.h" +#include "arrow/util/logging.h" +#include "arrow/util/macros.h" + +namespace arrow { + +using internal::hash_combine; + +static constexpr size_t kHashSeed = 0; + +namespace compute { + +// ---------------------------------------------------------------------- +// KernelContext + +inline void ZeroLastByte(Buffer* buffer) { + *(buffer->mutable_data() + (buffer->size() - 1)) = 0; +} + +Result> KernelContext::Allocate(int64_t nbytes) { + ARROW_ASSIGN_OR_RAISE(std::shared_ptr result, + AllocateBuffer(nbytes, exec_ctx_->memory_pool())); + result->ZeroPadding(); + return result; +} + +Result> KernelContext::AllocateBitmap(int64_t num_bits) { + const int64_t nbytes = BitUtil::BytesForBits(num_bits); + ARROW_ASSIGN_OR_RAISE(std::shared_ptr result, + AllocateBuffer(nbytes, exec_ctx_->memory_pool())); + // Some utility methods access the last byte before it might be + // initialized this makes valgrind/asan unhappy, so we proactively + // zero it. + ZeroLastByte(result.get()); + result->ZeroPadding(); + return result; +} + +void KernelContext::SetStatus(const Status& status) { + if (ARROW_PREDICT_FALSE(!status_.ok())) { + return; + } + status_ = status; +} + +/// \brief Clear any error status +void KernelContext::ResetStatus() { status_ = Status::OK(); } + +// ---------------------------------------------------------------------- +// InputType + +size_t InputType::Hash() const { + size_t result = kHashSeed; + hash_combine(result, static_cast(shape_)); + switch (kind_) { + case InputType::EXACT_TYPE: + hash_combine(result, type_->Hash()); + break; + case InputType::SAME_TYPE_ID: + hash_combine(result, static_cast(type_id_)); + break; + default: + break; + } + return result; +} + +std::string InputType::ToString() const { + std::stringstream ss; + switch (shape_) { + case ValueDescr::ANY: + ss << "any"; + break; + case ValueDescr::ARRAY: + ss << "array"; + break; + case ValueDescr::SCALAR: + ss << "scalar"; + break; + default: + DCHECK(false); + break; + } + ss << "["; + switch (kind_) { + case InputType::EXACT_TYPE: + ss << type_->ToString(); + break; + case InputType::SAME_TYPE_ID: { + // Indicate that the parameters for the type are unspecified. TODO: don't + // show this for types without parameters, like Type::INT32 + ss << internal::ToString(type_id_) << "*"; + } break; + default: + DCHECK(false); + break; + } + ss << "]"; + return ss.str(); +} + +bool InputType::Equals(const InputType& other) const { + if (this == &other) { + return true; + } + if (kind_ != other.kind_ || shape_ != other.shape_) { + return false; + } + switch (kind_) { + case InputType::EXACT_TYPE: + return type_->Equals(*other.type_); + case InputType::SAME_TYPE_ID: + return type_id_ == other.type_id_; + default: + return false; + } +} + +bool InputType::Matches(const ValueDescr& descr) const { + if (shape_ != ValueDescr::ANY && descr.shape != shape_) { + return false; + } + switch (kind_) { + case InputType::EXACT_TYPE: + return type_->Equals(*descr.type); + case InputType::SAME_TYPE_ID: + return type_id_ == descr.type->id(); + default: + // ANY_TYPE + return true; + } +} + +bool InputType::Matches(const Datum& value) const { return Matches(value.descr()); } + +const std::shared_ptr& InputType::type() const { + DCHECK_EQ(InputType::EXACT_TYPE, kind_); + return type_; +} + +Type::type InputType::type_id() const { + DCHECK_EQ(InputType::SAME_TYPE_ID, kind_); + return type_id_; +} + +// ---------------------------------------------------------------------- +// OutputType + +OutputType::Resolver ResolveAs(ValueDescr descr) { + return [descr](const std::vector&) { return descr; }; +} + +OutputType::OutputType(ValueDescr descr) : resolver_(ResolveAs(descr)) {} + +Result OutputType::Resolve(const std::vector& args) const { + if (kind_ == OutputType::FIXED) { + return ValueDescr(type_, GetBroadcastShape(args)); + } else { + return resolver_(args); + } +} + +const std::shared_ptr& OutputType::type() const { + DCHECK_EQ(FIXED, kind_); + return type_; +} + +const OutputType::Resolver& OutputType::resolver() const { + DCHECK_EQ(COMPUTED, kind_); + return resolver_; +} + +std::string OutputType::ToString() const { + if (kind_ == OutputType::FIXED) { + return type_->ToString(); + } else { + return "computed"; + } +} + +// ---------------------------------------------------------------------- +// KernelSignature + +KernelSignature::KernelSignature(std::vector in_types, OutputType out_type, + bool is_varargs) + : in_types_(std::move(in_types)), + out_type_(std::move(out_type)), + is_varargs_(is_varargs), + hash_code_(0) { + // Varargs sigs must have only a single input type to use for argument validation + DCHECK(!is_varargs || (is_varargs && (in_types_.size() == 1))); +} + +std::shared_ptr KernelSignature::Make(std::vector in_types, + OutputType out_type, + bool is_varargs) { + return std::make_shared(std::move(in_types), std::move(out_type), + is_varargs); +} + +bool KernelSignature::Equals(const KernelSignature& other) const { + if (is_varargs_ != other.is_varargs_) { + return false; + } + if (in_types_.size() != other.in_types_.size()) { + return false; + } + for (size_t i = 0; i < in_types_.size(); ++i) { + if (!in_types_[i].Equals(other.in_types_[i])) { + return false; + } + } + return true; +} + +bool KernelSignature::MatchesInputs(const std::vector& args) const { + if (is_varargs_) { + for (const auto& arg : args) { + if (!in_types_[0].Matches(arg)) { + return false; + } + } + } else { + if (args.size() != in_types_.size()) { + return false; + } + for (size_t i = 0; i < in_types_.size(); ++i) { + if (!in_types_[i].Matches(args[i])) { + return false; + } + } + } + return true; +} + +int64_t KernelSignature::Hash() const { + if (hash_code_ != 0) { + return hash_code_; + } + size_t result = kHashSeed; + for (const auto& in_type : in_types_) { + hash_combine(result, in_type.Hash()); + } + hash_code_ = result; + return result; +} + +std::string KernelSignature::ToString() const { + std::stringstream ss; + + if (is_varargs_) { + ss << "varargs[" << in_types_[0].ToString() << "]"; + } else { + ss << "("; + for (size_t i = 0; i < in_types_.size(); ++i) { + if (i > 0) { + ss << ", "; + } + ss << in_types_[i].ToString(); + } + ss << ")"; + } + ss << " -> " << out_type_.ToString(); + return ss.str(); +} + +} // namespace compute +} // namespace arrow diff --git a/cpp/src/arrow/compute/kernel.h b/cpp/src/arrow/compute/kernel.h index 16dca696567..30eb097f5ef 100644 --- a/cpp/src/arrow/compute/kernel.h +++ b/cpp/src/arrow/compute/kernel.h @@ -15,295 +15,472 @@ // specific language governing permissions and limitations // under the License. +// NOTE: API is EXPERIMENTAL and will change without going through a +// deprecation cycle + #pragma once +#include +#include #include +#include #include #include -#include "arrow/array.h" -#include "arrow/record_batch.h" -#include "arrow/scalar.h" -#include "arrow/table.h" -#include "arrow/util/macros.h" -#include "arrow/util/memory.h" -#include "arrow/util/variant.h" // IWYU pragma: export +#include "arrow/compute/exec.h" +#include "arrow/datum.h" +#include "arrow/result.h" +#include "arrow/status.h" +#include "arrow/type.h" #include "arrow/util/visibility.h" namespace arrow { + +class Buffer; +struct Datum; + namespace compute { -class FunctionContext; +struct FunctionOptions; -/// \class OpKernel -/// \brief Base class for operator kernels -/// -/// Note to implementors: -/// Operator kernels are intended to be the lowest level of an analytics/compute -/// engine. They will generally not be exposed directly to end-users. Instead -/// they will be wrapped by higher level constructs (e.g. top-level functions -/// or physical execution plan nodes). These higher level constructs are -/// responsible for user input validation and returning the appropriate -/// error Status. -/// -/// Due to this design, implementations of Call (the execution -/// method on subclasses) should use assertions (i.e. DCHECK) to double-check -/// parameter arguments when in higher level components returning an -/// InvalidArgument error might be more appropriate. -/// -class ARROW_EXPORT OpKernel { +/// \brief Base class for opaque kernel-specific state. For example, if there +/// is some kind of initialization required +struct KernelState { + virtual ~KernelState() = default; +}; + +/// \brief Context/state for the execution of a particular kernel +class ARROW_EXPORT KernelContext { public: - virtual ~OpKernel() = default; - /// \brief EXPERIMENTAL The output data type of the kernel - /// \return the output type - virtual std::shared_ptr out_type() const = 0; + explicit KernelContext(ExecContext* exec_ctx) : exec_ctx_(exec_ctx) {} + + /// \brief Allocate buffer from the context's memory pool + Result> Allocate(int64_t nbytes); + + /// \brief Allocate buffer for bitmap from the context's memory pool + Result> AllocateBitmap(int64_t num_bits); + + /// \brief Indicate that an error has occurred, to be checked by a exec caller + /// \param[in] status a Status instance + /// + /// \note Will not overwrite a prior set Status, so we will have the first + /// error that occurred until ExecContext::ResetStatus is called + void SetStatus(const Status& status); + + /// \brief Clear any error status + void ResetStatus(); + + /// \brief Return true if an error has occurred + bool HasError() const { return !status_.ok(); } + + /// \brief Return the current status of the context + const Status& status() const { return status_; } + + // For passing kernel state to + void SetState(KernelState* state) { state_ = state; } + + KernelState* state() { return state_; } + + /// \brief Common state related to function execution + ExecContext* exec_context() { return exec_ctx_; } + + private: + ExecContext* exec_ctx_; + Status status_; + KernelState* state_; }; -struct Datum; -static inline bool CollectionEquals(const std::vector& left, - const std::vector& right); - -// Datums variants may have a length. This special value indicate that the -// current variant does not have a length. -constexpr int64_t kUnknownLength = -1; - -/// \class Datum -/// \brief Variant type for various Arrow C++ data structures -struct ARROW_EXPORT Datum { - enum type { NONE, SCALAR, ARRAY, CHUNKED_ARRAY, RECORD_BATCH, TABLE, COLLECTION }; - - util::variant, std::shared_ptr, - std::shared_ptr, std::shared_ptr, - std::shared_ptr
, std::vector> - value; - - /// \brief Empty datum, to be populated elsewhere - Datum() : value(NULLPTR) {} - - Datum(const std::shared_ptr& value) // NOLINT implicit conversion - : value(value) {} - Datum(const std::shared_ptr& value) // NOLINT implicit conversion - : value(value) {} - - Datum(const std::shared_ptr& value) // NOLINT implicit conversion - : Datum(value ? value->data() : NULLPTR) {} - - Datum(const std::shared_ptr& value) // NOLINT implicit conversion - : value(value) {} - Datum(const std::shared_ptr& value) // NOLINT implicit conversion - : value(value) {} - Datum(const std::shared_ptr
& value) // NOLINT implicit conversion - : value(value) {} - Datum(const std::vector& value) // NOLINT implicit conversion - : value(value) {} - - // Cast from subtypes of Array to Datum - template ::value>> - Datum(const std::shared_ptr& value) // NOLINT implicit conversion - : Datum(std::shared_ptr(value)) {} - - // Convenience constructors - explicit Datum(bool value) : value(std::make_shared(value)) {} - explicit Datum(int8_t value) : value(std::make_shared(value)) {} - explicit Datum(uint8_t value) : value(std::make_shared(value)) {} - explicit Datum(int16_t value) : value(std::make_shared(value)) {} - explicit Datum(uint16_t value) : value(std::make_shared(value)) {} - explicit Datum(int32_t value) : value(std::make_shared(value)) {} - explicit Datum(uint32_t value) : value(std::make_shared(value)) {} - explicit Datum(int64_t value) : value(std::make_shared(value)) {} - explicit Datum(uint64_t value) : value(std::make_shared(value)) {} - explicit Datum(float value) : value(std::make_shared(value)) {} - explicit Datum(double value) : value(std::make_shared(value)) {} - - ~Datum() {} - - Datum(const Datum& other) noexcept { this->value = other.value; } - - Datum& operator=(const Datum& other) noexcept { - value = other.value; - return *this; - } +/// A standard function taking zero or more Array/Scalar values and returning +/// Array/Scalar output. May be used for SCALAR and VECTOR kernel kinds. Should +/// write into pre-allocated memory except in cases when a builder +/// (e.g. StringBuilder) must be employed +using ArrayKernelExec = std::function; - // Define move constructor and move assignment, for better performance - Datum(Datum&& other) noexcept : value(std::move(other.value)) {} +/// \brief A container to express what kernel argument input types are accepted +class ARROW_EXPORT InputType { + public: + enum Kind { + /// Accept any value type + ANY_TYPE, - Datum& operator=(Datum&& other) noexcept { - value = std::move(other.value); - return *this; - } + /// A fixed arrow::DataType and will only exact match having this exact + /// type (e.g. same TimestampType unit, same decimal scale and precision, + /// or same nested child types + EXACT_TYPE, - Datum::type kind() const { - switch (this->value.index()) { - case 0: - return Datum::NONE; - case 1: - return Datum::SCALAR; - case 2: - return Datum::ARRAY; - case 3: - return Datum::CHUNKED_ARRAY; - case 4: - return Datum::RECORD_BATCH; - case 5: - return Datum::TABLE; - case 6: - return Datum::COLLECTION; - default: - return Datum::NONE; - } - } + /// Any type having the indicated Type::type id. For example, accept + /// any Type::LIST or any Type::TIMESTAMP + SAME_TYPE_ID, + }; - std::shared_ptr array() const { - return util::get>(this->value); - } + InputType(ValueDescr::Shape shape = ValueDescr::ANY) // NOLINT implicit construction + : kind_(ANY_TYPE), shape_(shape) {} - std::shared_ptr make_array() const { - return MakeArray(util::get>(this->value)); - } + InputType(std::shared_ptr type, + ValueDescr::Shape shape = ValueDescr::ANY) // NOLINT implicit construction + : kind_(EXACT_TYPE), shape_(shape), type_(std::move(type)) {} - std::shared_ptr chunked_array() const { - return util::get>(this->value); - } + InputType(const ValueDescr& descr) // NOLINT implicit construction + : InputType(descr.type, descr.shape) {} - std::shared_ptr record_batch() const { - return util::get>(this->value); - } + InputType(Type::type type_id, + ValueDescr::Shape shape = ValueDescr::ANY) // NOLINT implicit construction + : kind_(SAME_TYPE_ID), shape_(shape), type_id_(type_id) {} - std::shared_ptr
table() const { - return util::get>(this->value); - } + InputType(const InputType& other) { CopyInto(other); } - const std::vector collection() const { - return util::get>(this->value); + // Convenience ctors + static InputType Array(std::shared_ptr type) { + return InputType(std::move(type), ValueDescr::ARRAY); } - std::shared_ptr scalar() const { - return util::get>(this->value); + static InputType Scalar(std::shared_ptr type) { + return InputType(std::move(type), ValueDescr::SCALAR); } - bool is_array() const { return this->kind() == Datum::ARRAY; } + static InputType Array(Type::type id) { return InputType(id, ValueDescr::ARRAY); } - bool is_arraylike() const { - return this->kind() == Datum::ARRAY || this->kind() == Datum::CHUNKED_ARRAY; - } + static InputType Scalar(Type::type id) { return InputType(id, ValueDescr::SCALAR); } - bool is_scalar() const { return this->kind() == Datum::SCALAR; } + void operator=(const InputType& other) { CopyInto(other); } - bool is_collection() const { return this->kind() == Datum::COLLECTION; } + InputType(InputType&& other) { MoveInto(std::forward(other)); } - /// \brief The value type of the variant, if any - /// - /// \return nullptr if no type - std::shared_ptr type() const { - if (this->kind() == Datum::ARRAY) { - return util::get>(this->value)->type; - } else if (this->kind() == Datum::CHUNKED_ARRAY) { - return util::get>(this->value)->type(); - } else if (this->kind() == Datum::SCALAR) { - return util::get>(this->value)->type; - } - return NULLPTR; + void operator=(InputType&& other) { MoveInto(std::forward(other)); } + + /// \brief Return true if this type exactly matches another + bool Equals(const InputType& other) const; + + bool operator==(const InputType& other) const { return this->Equals(other); } + + bool operator!=(const InputType& other) const { return !(*this == other); } + + /// \brief Return hash code + uint64_t Hash() const; + + /// \brief Render a human-readable string representation + std::string ToString() const; + + /// \brief Return true if the value matches this argument kind in type + /// and shape + bool Matches(const Datum& value) const; + + /// \brief Return true if the value descriptor matches this argument kind in + /// type and shape + bool Matches(const ValueDescr& value) const; + + /// \brief The type matching rule that this InputType uses + Kind kind() const { return kind_; } + + ValueDescr::Shape shape() const { return shape_; } + + /// \brief For ArgKind::EXACT_TYPE, the exact type that this InputType must + /// match. Otherwise this function should not be used + const std::shared_ptr& type() const; + + /// \brief For ArgKind::SAME_TYPE_ID, the Type::type that this InputType must + /// match, Otherwise this function should not be used + Type::type type_id() const; + + private: + void CopyInto(const InputType& other) { + this->kind_ = other.kind_; + this->shape_ = other.shape_; + this->type_ = other.type_; + this->type_id_ = other.type_id_; } - /// \brief The value length of the variant, if any - /// - /// \return kUnknownLength if no type - int64_t length() const { - if (this->kind() == Datum::ARRAY) { - return util::get>(this->value)->length; - } else if (this->kind() == Datum::CHUNKED_ARRAY) { - return util::get>(this->value)->length(); - } else if (this->kind() == Datum::SCALAR) { - return 1; - } - return kUnknownLength; + void MoveInto(InputType&& other) { + this->kind_ = other.kind_; + this->shape_ = other.shape_; + this->type_ = std::move(other.type_); + this->type_id_ = other.type_id_; } - /// \brief The array chunks of the variant, if any - /// - /// \return empty if not arraylike - ArrayVector chunks() const { - if (!this->is_arraylike()) { - return {}; - } - if (this->is_array()) { - return {this->make_array()}; - } - return this->chunked_array()->chunks(); + Kind kind_; + + ValueDescr::Shape shape_; + + // For EXACT_TYPE ArgKind + std::shared_ptr type_; + + // For SAME_TYPE_ID ArgKind + Type::type type_id_; +}; + +/// \brief Container to capture both exact and input-dependent output types +/// +/// The value shape returned by Resolve will be determined by broadcasting the +/// shapes of the input arguments, otherwise this is handled by the +/// user-defined resolver function +/// +/// * Any ARRAY shape -> output shape is ARRAY +/// * All SCALAR shapes -> output shape is SCALAR +class ARROW_EXPORT OutputType { + public: + /// \brief An enum indicating whether the value type is an invariant fixed + /// value or one that's computed by a kernel-defined resolver function + enum ResolveKind { FIXED, COMPUTED }; + + /// Type resolution function. Given input types and shapes, return output + /// type and shape. This function SHOULD _not_ be used to check for arity, + /// that SHOULD be performed one or more layers above. + using Resolver = std::function(const std::vector&)>; + + OutputType(std::shared_ptr type) // NOLINT implicit construction + : kind_(FIXED), type_(std::move(type)) {} + + /// For outputting a particular type and shape + OutputType(ValueDescr descr); // NOLINT implicit construction + + explicit OutputType(Resolver resolver) : kind_(COMPUTED), resolver_(resolver) {} + + OutputType(const OutputType& other) { + this->kind_ = other.kind_; + this->type_ = other.type_; + this->resolver_ = other.resolver_; } - bool Equals(const Datum& other) const { - if (this->kind() != other.kind()) return false; - - switch (this->kind()) { - case Datum::NONE: - return true; - case Datum::SCALAR: - return internal::SharedPtrEquals(this->scalar(), other.scalar()); - case Datum::ARRAY: - return internal::SharedPtrEquals(this->make_array(), other.make_array()); - case Datum::CHUNKED_ARRAY: - return internal::SharedPtrEquals(this->chunked_array(), other.chunked_array()); - case Datum::RECORD_BATCH: - return internal::SharedPtrEquals(this->record_batch(), other.record_batch()); - case Datum::TABLE: - return internal::SharedPtrEquals(this->table(), other.table()); - case Datum::COLLECTION: - return CollectionEquals(this->collection(), other.collection()); - default: - return false; - } + OutputType(OutputType&& other) { + this->kind_ = other.kind_; + this->type_ = std::move(other.type_); + this->resolver_ = other.resolver_; } + + /// \brief Return the shape and type of the expected output value of the + /// kernel given the value descriptors (shapes and types) + Result Resolve(const std::vector& args) const; + + /// \brief The value type for the FIXED kind rule + const std::shared_ptr& type() const; + + /// \brief For use with COMPUTED resolution strategy, the output type depends + /// on the input type. It may be more convenient to invoke this with + /// OutputType::Resolve returned from this method + const Resolver& resolver() const; + + /// \brief Render a human-readable string representation + std::string ToString() const; + + /// \brief Return the kind of type resolution of this output type, whether + /// fixed/invariant or computed by a "user"-defined resolver + ResolveKind kind() const { return kind_; } + + private: + ResolveKind kind_; + + // For FIXED resolution + std::shared_ptr type_; + + // For COMPUTED resolution + Resolver resolver_; }; -/// \class UnaryKernel -/// \brief An array-valued function of a single input argument. +/// \brief Holds the input types and output type of the kernel /// -/// Note to implementors: Try to avoid making kernels that allocate memory if -/// the output size is a deterministic function of the Input Datum's metadata. -/// Instead separate the logic of the kernel and allocations necessary into -/// two different kernels. Some reusable kernels that allocate buffers -/// and delegate computation to another kernel are available in util-internal.h. -class ARROW_EXPORT UnaryKernel : public OpKernel { +/// Varargs functions should pass a single input type to be used to validate +/// the the input types of a function invocation +class ARROW_EXPORT KernelSignature { public: - /// \brief Executes the kernel. - /// - /// \param[in] ctx The function context for the kernel - /// \param[in] input The kernel input data - /// \param[out] out The output of the function. Each implementation of this - /// function might assume different things about the existing contents of out - /// (e.g. which buffers are preallocated). In the future it is expected that - /// there will be a more generic mechanism for understanding the necessary - /// contracts. - virtual Status Call(FunctionContext* ctx, const Datum& input, Datum* out) = 0; + KernelSignature(std::vector in_types, OutputType out_type, + bool is_varargs = false); + + /// \brief Convenience ctor since make_shared can be awkward + static std::shared_ptr Make(std::vector in_types, + OutputType out_type, + bool is_varargs = false); + + /// \brief Return true if the signature if compatible with the list of input + /// value descriptors + bool MatchesInputs(const std::vector& descriptors) const; + + /// \brief Returns true if the input types of each signature are + /// equal. Well-formed functions should have a deterministic output type + /// given input types, but currently it is the responsibility of the + /// developer to ensure this + bool Equals(const KernelSignature& other) const; + + bool operator==(const KernelSignature& other) const { return this->Equals(other); } + + bool operator!=(const KernelSignature& other) const { return !(*this == other); } + + /// \brief Compute a hash code for the signature + int64_t Hash() const; + + const std::vector& in_types() const { return in_types_; } + + const OutputType& out_type() const { return out_type_; } + + /// \brief Render a human-readable string representation + std::string ToString() const; + + bool is_varargs() const { return is_varargs_; } + + private: + std::vector in_types_; + OutputType out_type_; + bool is_varargs_; + + // For caching the hash code after it's computed the first time + mutable int64_t hash_code_; }; -/// \class BinaryKernel -/// \brief An array-valued function of a two input arguments -class ARROW_EXPORT BinaryKernel : public OpKernel { - public: - virtual Status Call(FunctionContext* ctx, const Datum& left, const Datum& right, - Datum* out) = 0; +struct SimdLevel { + enum type { NONE, SSE4_2, AVX, AVX2, AVX512, NEON }; }; -// TODO doxygen 1.8.16 does not like the following code -///@cond INTERNAL +struct NullHandling { + enum type { + /// Compute the output validity bitmap by intersecting the validity bitmaps + /// of the arguments. Kernel does not do anything with the bitmap + INTERSECTION, -static inline bool CollectionEquals(const std::vector& left, - const std::vector& right) { - if (left.size() != right.size()) { - return false; - } + /// Kernel expects a pre-allocated buffer to write the result bitmap into + COMPUTED_PREALLOCATE, - for (size_t i = 0; i < left.size(); i++) { - if (!left[i].Equals(right[i])) { - return false; - } - } - return true; -} + /// Kernel allocates and populates the validity bitmap of the output + COMPUTED_NO_PREALLOCATE, + + /// Output is never null + OUTPUT_NOT_NULL + }; +}; + +struct MemAllocation { + enum type { + // For data types that support pre-allocation (fixed-type), the kernel + // expects to be provided pre-allocated memory to write + // into. Non-fixed-width must always allocate their own memory but perhaps + // not their validity bitmaps. The allocation made for the same length as + // the execution batch, so vector kernels yielding differently sized output + // should not use this + PREALLOCATE, + + // The kernel does its own memory allocation + NO_PREALLOCATE + }; +}; + +struct Kernel; + +using KernelInit = std::function( + KernelContext*, const Kernel&, const FunctionOptions*)>; + +/// \brief Base type for kernels. Contains the function signature and +/// optionally the state initialization function, along with some common +/// attributes +struct Kernel { + Kernel() {} + + Kernel(std::shared_ptr sig, KernelInit init) + : signature(std::move(sig)), init(init) {} -///@endcond + Kernel(std::vector in_types, OutputType out_type, KernelInit init) + : Kernel(KernelSignature::Make(std::move(in_types), out_type), init) {} + + std::shared_ptr signature; + + /// \brief Create a new KernelState for invocations of this kernel, e.g. to + /// set up any options or state relevant for execution. May be nullptr + KernelInit init; + + // Does execution benefit from parallelization (splitting large chunks into + // smaller chunks and using multiple threads). Some vector kernels may + // require single-threaded execution. + bool parallelizable = true; + + SimdLevel::type simd_level = SimdLevel::NONE; +}; + +/// \brief Descriptor to hold signature and execution function implementations +/// for a particular kernel +struct ArrayKernel : public Kernel { + ArrayKernel() {} + + ArrayKernel(std::shared_ptr sig, ArrayKernelExec exec, + KernelInit init = NULLPTR) + : Kernel(std::move(sig), init), exec(exec) {} + + ArrayKernel(std::vector in_types, OutputType out_type, ArrayKernelExec exec, + KernelInit init = NULLPTR) + : Kernel(std::move(in_types), std::move(out_type), init), exec(exec) {} + + /// \brief Perform a single invocation of this kernel. In general, this + /// function must + ArrayKernelExec exec; + + /// \brief Writing execution results into larger contiguous allocations + /// requires that the kernel be able to write into sliced output + /// ArrayData*. Some kernel implementations may not be able to do this, so + /// setting this to false disables this functionality + bool can_write_into_slices = true; +}; + +struct ScalarKernel : public ArrayKernel { + using ArrayKernel::ArrayKernel; + + // For scalar functions preallocated data and intersecting arg validity + // bitmaps is a reasonable default + NullHandling::type null_handling = NullHandling::INTERSECTION; + MemAllocation::type mem_allocation = MemAllocation::PREALLOCATE; +}; + +// Finalize returns Datum to permit multiple return values +using VectorFinalize = std::function*)>; + +struct VectorKernel : public ArrayKernel { + VectorKernel() {} + + VectorKernel(std::shared_ptr sig, ArrayKernelExec exec) + : ArrayKernel(std::move(sig), exec) {} + + VectorKernel(std::vector in_types, OutputType out_type, ArrayKernelExec exec, + KernelInit init = NULLPTR, VectorFinalize finalize = NULLPTR) + : ArrayKernel(std::move(in_types), out_type, exec, init), finalize(finalize) {} + + VectorKernel(std::shared_ptr sig, ArrayKernelExec exec, + KernelInit init = NULLPTR, VectorFinalize finalize = NULLPTR) + : ArrayKernel(std::move(sig), exec, init), finalize(finalize) {} + + VectorFinalize finalize; + + // Since vector kernels generally are implemented rather differently from + // scalar/elementwise kernels (and they may not even yield arrays of the same + // size), so we make the developer opt-in to any memory preallocation rather + // than having to turn it off. + NullHandling::type null_handling = NullHandling::COMPUTED_NO_PREALLOCATE; + MemAllocation::type mem_allocation = MemAllocation::NO_PREALLOCATE; +}; + +using ScalarAggregateConsume = std::function; + +using ScalarAggregateMerge = + std::function; + +// Finalize returns Datum to permit multiple return values +using ScalarAggregateFinalize = std::function; + +struct ScalarAggregateKernel : public Kernel { + ScalarAggregateKernel() {} + + ScalarAggregateKernel(std::shared_ptr sig, KernelInit init, + ScalarAggregateConsume consume, ScalarAggregateMerge merge, + ScalarAggregateFinalize finalize) + : Kernel(std::move(sig), init), + consume(consume), + merge(merge), + finalize(finalize) {} + + ScalarAggregateKernel(std::vector in_types, OutputType out_type, + KernelInit init, ScalarAggregateConsume consume, + ScalarAggregateMerge merge, ScalarAggregateFinalize finalize) + : ScalarAggregateKernel(KernelSignature::Make(std::move(in_types), out_type), init, + consume, merge, finalize) {} + + ScalarAggregateConsume consume; + ScalarAggregateMerge merge; + ScalarAggregateFinalize finalize; +}; } // namespace compute } // namespace arrow diff --git a/cpp/src/arrow/compute/kernel_test.cc b/cpp/src/arrow/compute/kernel_test.cc new file mode 100644 index 00000000000..b562da95815 --- /dev/null +++ b/cpp/src/arrow/compute/kernel_test.cc @@ -0,0 +1,430 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#include +#include +#include + +#include + +#include "arrow/compute/kernel.h" +#include "arrow/status.h" +#include "arrow/testing/gtest_util.h" +#include "arrow/type.h" +#include "arrow/util/key_value_metadata.h" + +namespace arrow { +namespace compute { + +// ---------------------------------------------------------------------- +// InputType + +TEST(InputType, AnyTypeConstructor) { + // Check the ANY_TYPE ctors + InputType ty; + ASSERT_EQ(InputType::ANY_TYPE, ty.kind()); + ASSERT_EQ(ValueDescr::ANY, ty.shape()); + + ty = InputType(ValueDescr::SCALAR); + ASSERT_EQ(ValueDescr::SCALAR, ty.shape()); + + ty = InputType(ValueDescr::ARRAY); + ASSERT_EQ(ValueDescr::ARRAY, ty.shape()); +} + +TEST(InputType, Constructors) { + // Exact type constructor + InputType ty1(int8()); + ASSERT_EQ(InputType::EXACT_TYPE, ty1.kind()); + ASSERT_EQ(ValueDescr::ANY, ty1.shape()); + AssertTypeEqual(*int8(), *ty1.type()); + + InputType ty1_implicit = int8(); + ASSERT_TRUE(ty1.Equals(ty1_implicit)); + + InputType ty1_array(int8(), ValueDescr::ARRAY); + ASSERT_EQ(ValueDescr::ARRAY, ty1_array.shape()); + + InputType ty1_scalar(int8(), ValueDescr::SCALAR); + ASSERT_EQ(ValueDescr::SCALAR, ty1_scalar.shape()); + + // Same type id constructor + InputType ty2 = Type::DECIMAL; + ASSERT_EQ(InputType::SAME_TYPE_ID, ty2.kind()); + + InputType ty2_array(Type::DECIMAL, ValueDescr::ARRAY); + ASSERT_EQ(ValueDescr::ARRAY, ty2_array.shape()); + + InputType ty2_scalar(Type::DECIMAL, ValueDescr::SCALAR); + ASSERT_EQ(ValueDescr::SCALAR, ty2_scalar.shape()); + + // Implicit construction in a vector + std::vector types = {int8(), Type::DECIMAL}; + ASSERT_TRUE(types[0].Equals(ty1)); + ASSERT_TRUE(types[1].Equals(ty2)); + + // Copy constructor + InputType ty3 = ty1; + InputType ty4 = ty2; + ASSERT_TRUE(ty3.Equals(ty1)); + ASSERT_TRUE(ty4.Equals(ty2)); + + // Move constructor + InputType ty5 = std::move(ty3); + InputType ty6 = std::move(ty4); + ASSERT_TRUE(ty5.Equals(ty1)); + ASSERT_TRUE(ty6.Equals(ty2)); + + // ToString + ASSERT_EQ("any[int8]", ty1.ToString()); + ASSERT_EQ("array[int8]", ty1_array.ToString()); + ASSERT_EQ("scalar[int8]", ty1_scalar.ToString()); + + ASSERT_EQ("any[decimal*]", ty2.ToString()); + ASSERT_EQ("array[decimal*]", ty2_array.ToString()); + ASSERT_EQ("scalar[decimal*]", ty2_scalar.ToString()); +} + +TEST(InputType, Equals) { + InputType t1 = int8(); + InputType t2 = int8(); + InputType t3(int8(), ValueDescr::ARRAY); + InputType t3_i32(int32(), ValueDescr::ARRAY); + InputType t3_scalar(int8(), ValueDescr::SCALAR); + InputType t4(int8(), ValueDescr::ARRAY); + InputType t4_i32(int32(), ValueDescr::ARRAY); + + InputType t5 = Type::DECIMAL; + InputType t6 = Type::DECIMAL; + InputType t7(Type::DECIMAL, ValueDescr::SCALAR); + InputType t7_i32(Type::INT32, ValueDescr::SCALAR); + InputType t8(Type::DECIMAL, ValueDescr::SCALAR); + InputType t8_i32(Type::INT32, ValueDescr::SCALAR); + + ASSERT_TRUE(t1.Equals(t2)); + ASSERT_EQ(t1, t2); + + // ANY vs SCALAR + ASSERT_NE(t1, t3); + + ASSERT_EQ(t3, t4); + + // both ARRAY, but different type + ASSERT_NE(t3, t3_i32); + + // ARRAY vs SCALAR + ASSERT_NE(t3, t3_scalar); + + ASSERT_EQ(t3_i32, t4_i32); + + ASSERT_FALSE(t1.Equals(t5)); + ASSERT_NE(t1, t5); + + ASSERT_EQ(t5, t5); + ASSERT_EQ(t5, t6); + ASSERT_NE(t5, t7); + ASSERT_EQ(t7, t8); + ASSERT_EQ(t7, t8); + ASSERT_NE(t7, t7_i32); + ASSERT_EQ(t7_i32, t8_i32); + + // NOTE: For the time being, we treat int32() and Type::INT32 as being + // different. This could obviously be fixed later to make these equivalent + ASSERT_NE(InputType(int8()), InputType(Type::INT32)); + + // Check that field metadata excluded from equality checks + InputType t9 = list( + field("item", utf8(), /*nullable=*/true, key_value_metadata({"foo"}, {"bar"}))); + InputType t10 = list(field("item", utf8())); + ASSERT_TRUE(t9.Equals(t10)); +} + +TEST(InputType, Hash) { + InputType t0; + InputType t0_scalar(ValueDescr::SCALAR); + InputType t0_array(ValueDescr::ARRAY); + + InputType t1 = int8(); + InputType t2 = Type::DECIMAL; + + // These checks try to determine first of all whether Hash always returns the + // same value, and whether the elements of the type are all incorporated into + // the Hash + ASSERT_EQ(t0.Hash(), t0.Hash()); + ASSERT_NE(t0.Hash(), t0_scalar.Hash()); + ASSERT_NE(t0.Hash(), t0_array.Hash()); + ASSERT_NE(t0_scalar.Hash(), t0_array.Hash()); + + ASSERT_EQ(t1.Hash(), t1.Hash()); + ASSERT_EQ(t2.Hash(), t2.Hash()); + + ASSERT_NE(t0.Hash(), t1.Hash()); + ASSERT_NE(t0.Hash(), t2.Hash()); + ASSERT_NE(t1.Hash(), t2.Hash()); +} + +TEST(InputType, Matches) { + InputType ty1 = int8(); + + ASSERT_TRUE(ty1.Matches(ValueDescr::Scalar(int8()))); + ASSERT_TRUE(ty1.Matches(ValueDescr::Array(int8()))); + ASSERT_TRUE(ty1.Matches(ValueDescr::Any(int8()))); + ASSERT_FALSE(ty1.Matches(ValueDescr::Any(int16()))); + + InputType ty2 = Type::DECIMAL; + ASSERT_TRUE(ty2.Matches(ValueDescr::Scalar(decimal(12, 2)))); + ASSERT_TRUE(ty2.Matches(ValueDescr::Array(decimal(12, 2)))); + ASSERT_FALSE(ty2.Matches(ValueDescr::Any(float64()))); + + InputType ty3(int64(), ValueDescr::SCALAR); + ASSERT_FALSE(ty3.Matches(ValueDescr::Array(int64()))); + ASSERT_TRUE(ty3.Matches(ValueDescr::Scalar(int64()))); + ASSERT_FALSE(ty3.Matches(ValueDescr::Scalar(int32()))); + ASSERT_FALSE(ty3.Matches(ValueDescr::Any(int64()))); +} + +// ---------------------------------------------------------------------- +// OutputType + +TEST(OutputType, Constructors) { + OutputType ty1 = int8(); + ASSERT_EQ(OutputType::FIXED, ty1.kind()); + AssertTypeEqual(*int8(), *ty1.type()); + + auto DummyResolver = [](const std::vector& args) { + return ValueDescr(int32(), GetBroadcastShape(args)); + }; + OutputType ty2(DummyResolver); + ASSERT_EQ(OutputType::COMPUTED, ty2.kind()); + + ASSERT_OK_AND_ASSIGN(ValueDescr out_descr2, ty2.Resolve({})); + ASSERT_EQ(ValueDescr::Scalar(int32()), out_descr2); + + // Copy constructor + OutputType ty3 = ty1; + ASSERT_EQ(OutputType::FIXED, ty3.kind()); + AssertTypeEqual(*ty1.type(), *ty3.type()); + + OutputType ty4 = ty2; + ASSERT_EQ(OutputType::COMPUTED, ty4.kind()); + ASSERT_OK_AND_ASSIGN(ValueDescr out_descr4, ty4.Resolve({})); + ASSERT_EQ(ValueDescr::Scalar(int32()), out_descr4); + + // Move constructor + OutputType ty5 = std::move(ty1); + ASSERT_EQ(OutputType::FIXED, ty5.kind()); + AssertTypeEqual(*int8(), *ty5.type()); + + OutputType ty6 = std::move(ty4); + ASSERT_EQ(OutputType::COMPUTED, ty6.kind()); + ASSERT_OK_AND_ASSIGN(ValueDescr out_descr6, ty6.Resolve({})); + ASSERT_EQ(ValueDescr::Scalar(int32()), out_descr6); + + // ToString + + // ty1 was copied to ty3 + ASSERT_EQ("int8", ty3.ToString()); + ASSERT_EQ("computed", ty2.ToString()); +} + +TEST(OutputType, Resolve) { + // Check shape promotion rules for FIXED kind + OutputType ty1(int32()); + + ASSERT_OK_AND_ASSIGN(ValueDescr descr, ty1.Resolve({})); + ASSERT_EQ(ValueDescr::Scalar(int32()), descr); + + ASSERT_OK_AND_ASSIGN(descr, ty1.Resolve({ValueDescr(int8(), ValueDescr::SCALAR)})); + ASSERT_EQ(ValueDescr::Scalar(int32()), descr); + + ASSERT_OK_AND_ASSIGN(descr, ty1.Resolve({ValueDescr(int8(), ValueDescr::SCALAR), + ValueDescr(int8(), ValueDescr::ARRAY)})); + ASSERT_EQ(ValueDescr::Array(int32()), descr); + + OutputType ty2([](const std::vector& args) -> Result { + return ValueDescr(args[0].type, GetBroadcastShape(args)); + }); + + ASSERT_OK_AND_ASSIGN(descr, ty2.Resolve({ValueDescr::Array(utf8())})); + ASSERT_EQ(ValueDescr::Array(utf8()), descr); + + // Type resolver that returns an error + OutputType ty3([](const std::vector& args) -> Result { + // NB: checking the value types versus the function arity should be + // validated elsewhere, so this is just for illustration purposes + if (args.size() == 0) { + return Status::Invalid("Need at least one argument"); + } + return ValueDescr(args[0]); + }); + ASSERT_RAISES(Invalid, ty3.Resolve({})); +} + +TEST(OutputType, ResolveDescr) { + ValueDescr d1 = ValueDescr::Scalar(int32()); + ValueDescr d2 = ValueDescr::Array(int32()); + + OutputType ty1(d1); + OutputType ty2(d2); + + { + ASSERT_OK_AND_ASSIGN(ValueDescr descr, ty1.Resolve({})); + ASSERT_EQ(d1, descr); + } + + { + ASSERT_OK_AND_ASSIGN(ValueDescr descr, ty2.Resolve({})); + ASSERT_EQ(d2, descr); + } +} + +// ---------------------------------------------------------------------- +// KernelSignature + +TEST(KernelSignature, Basics) { + // (any[int8], scalar[decimal]) -> utf8 + std::vector in_types({int8(), InputType(Type::DECIMAL, ValueDescr::SCALAR)}); + OutputType out_type(utf8()); + + KernelSignature sig(in_types, out_type); + ASSERT_EQ(2, sig.in_types().size()); + ASSERT_TRUE(sig.in_types()[0].type()->Equals(*int8())); + ASSERT_TRUE(sig.in_types()[0].Matches(ValueDescr::Scalar(int8()))); + ASSERT_TRUE(sig.in_types()[0].Matches(ValueDescr::Array(int8()))); + + ASSERT_TRUE(sig.in_types()[1].Matches(ValueDescr::Scalar(decimal(12, 2)))); + ASSERT_FALSE(sig.in_types()[1].Matches(ValueDescr::Array(decimal(12, 2)))); +} + +TEST(KernelSignature, Equals) { + KernelSignature sig1({}, utf8()); + KernelSignature sig1_copy({}, utf8()); + KernelSignature sig2({int8()}, utf8()); + + // Output type doesn't matter (for now) + KernelSignature sig3({int8()}, int32()); + + KernelSignature sig4({int8(), int16()}, utf8()); + KernelSignature sig4_copy({int8(), int16()}, utf8()); + KernelSignature sig5({int8(), int16(), int32()}, utf8()); + + // Differ in shape + KernelSignature sig6({ValueDescr::Scalar(int8())}, utf8()); + KernelSignature sig7({ValueDescr::Array(int8())}, utf8()); + + ASSERT_EQ(sig1, sig1); + + ASSERT_EQ(sig2, sig3); + ASSERT_NE(sig3, sig4); + + // Different sig objects, but same sig + ASSERT_EQ(sig1, sig1_copy); + ASSERT_EQ(sig4, sig4_copy); + + // Match first 2 args, but not third + ASSERT_NE(sig4, sig5); + + ASSERT_NE(sig6, sig7); +} + +TEST(KernelSignature, VarargsEquals) { + KernelSignature sig1({int8()}, utf8(), /*is_varargs=*/true); + KernelSignature sig2({int8()}, utf8(), /*is_varargs=*/true); + KernelSignature sig3({int8()}, utf8()); + + ASSERT_EQ(sig1, sig2); + ASSERT_NE(sig2, sig3); +} + +TEST(KernelSignature, Hash) { + // Some basic tests to ensure that the hashes are deterministic and that all + // input arguments are incorporated + KernelSignature sig1({}, utf8()); + KernelSignature sig2({int8()}, utf8()); + KernelSignature sig3({int8(), int32()}, utf8()); + + ASSERT_EQ(sig1.Hash(), sig1.Hash()); + ASSERT_EQ(sig2.Hash(), sig2.Hash()); + ASSERT_NE(sig1.Hash(), sig2.Hash()); + ASSERT_NE(sig2.Hash(), sig3.Hash()); +} + +TEST(KernelSignature, MatchesInputs) { + // () -> boolean + KernelSignature sig1({}, boolean()); + + ASSERT_TRUE(sig1.MatchesInputs({})); + ASSERT_FALSE(sig1.MatchesInputs({int8()})); + + // (any[int8], any[decimal]) -> boolean + KernelSignature sig2({int8(), Type::DECIMAL}, boolean()); + + ASSERT_FALSE(sig2.MatchesInputs({})); + ASSERT_FALSE(sig2.MatchesInputs({int8()})); + ASSERT_TRUE(sig2.MatchesInputs({int8(), decimal(12, 2)})); + ASSERT_TRUE(sig2.MatchesInputs( + {ValueDescr::Scalar(int8()), ValueDescr::Scalar(decimal(12, 2))})); + ASSERT_TRUE( + sig2.MatchesInputs({ValueDescr::Array(int8()), ValueDescr::Array(decimal(12, 2))})); + + // (scalar[int8], array[int32]) -> boolean + KernelSignature sig3({ValueDescr::Scalar(int8()), ValueDescr::Array(int32())}, + boolean()); + + ASSERT_FALSE(sig3.MatchesInputs({})); + + // Unqualified, these are ANY type and do not match because the kernel + // requires a scalar and an array + ASSERT_FALSE(sig3.MatchesInputs({int8(), int32()})); + ASSERT_TRUE( + sig3.MatchesInputs({ValueDescr::Scalar(int8()), ValueDescr::Array(int32())})); + ASSERT_FALSE( + sig3.MatchesInputs({ValueDescr::Array(int8()), ValueDescr::Array(int32())})); +} + +TEST(KernelSignature, VarargsMatchesInputs) { + KernelSignature sig({int8()}, utf8(), /*is_varargs=*/true); + + std::vector args = {int8()}; + ASSERT_TRUE(sig.MatchesInputs(args)); + args.push_back(ValueDescr::Scalar(int8())); + args.push_back(ValueDescr::Array(int8())); + ASSERT_TRUE(sig.MatchesInputs(args)); + args.push_back(int32()); + ASSERT_FALSE(sig.MatchesInputs(args)); +} + +TEST(KernelSignature, ToString) { + std::vector in_types = {InputType(int8(), ValueDescr::SCALAR), + InputType(Type::DECIMAL, ValueDescr::ARRAY), + InputType(utf8())}; + KernelSignature sig(in_types, utf8()); + ASSERT_EQ("(scalar[int8], array[decimal*], any[string]) -> string", sig.ToString()); + + OutputType out_type( + [](const std::vector& args) { return Status::Invalid("NYI"); }); + KernelSignature sig2({int8(), Type::DECIMAL}, out_type); + ASSERT_EQ("(any[int8], any[decimal*]) -> computed", sig2.ToString()); +} + +TEST(KernelSignature, VarargsToString) { + KernelSignature sig({int8()}, utf8(), /*is_varargs=*/true); + ASSERT_EQ("varargs[any[int8]] -> string", sig.ToString()); +} + +} // namespace compute +} // namespace arrow diff --git a/cpp/src/arrow/compute/kernels/CMakeLists.txt b/cpp/src/arrow/compute/kernels/CMakeLists.txt index 12ad4d3a958..b230621ad53 100644 --- a/cpp/src/arrow/compute/kernels/CMakeLists.txt +++ b/cpp/src/arrow/compute/kernels/CMakeLists.txt @@ -15,37 +15,49 @@ # specific language governing permissions and limitations # under the License. -arrow_install_all_headers("arrow/compute/kernels") - -add_arrow_compute_test(boolean_test) -add_arrow_compute_test(cast_test) -add_arrow_compute_test(hash_test) -add_arrow_compute_test(isin_test) -add_arrow_compute_test(match_test) -add_arrow_compute_test(sort_to_indices_test) -add_arrow_compute_test(nth_to_indices_test) -add_arrow_compute_test(util_internal_test) -add_arrow_compute_test(add_test) +# ---------------------------------------------------------------------- +# Scalar kernels -# Aggregates -add_arrow_compute_test(aggregate_test) +add_arrow_compute_test(scalar_test + SOURCES + scalar_arithmetic_test.cc + scalar_boolean_test.cc + scalar_compare_test.cc + scalar_set_lookup_test.cc) -# Comparison -add_arrow_compute_test(compare_test) +# add_arrow_compute_test(cast_test) -# Selection -add_arrow_compute_test(take_test) -add_arrow_compute_test(filter_test) +add_arrow_benchmark(scalar_compare_benchmark PREFIX "arrow-compute") -add_arrow_benchmark(sort_to_indices_benchmark PREFIX "arrow-compute") -add_arrow_benchmark(nth_to_indices_benchmark PREFIX "arrow-compute") +# ---------------------------------------------------------------------- +# Vector kernels -# Aggregates -add_arrow_benchmark(aggregate_benchmark PREFIX "arrow-compute") +add_arrow_compute_test(vector_test + SOURCES + vector_partition_test.cc) + +# add_arrow_compute_test(hash_test) + +# add_arrow_benchmark(hash_benchmark PREFIX "arrow-compute") + +# Single-array sorting + +# add_arrow_compute_test(sort_to_indices_test) +# add_arrow_benchmark(sort_to_indices_benchmark PREFIX "arrow-compute") +# add_arrow_benchmark(nth_to_indices_benchmark PREFIX "arrow-compute") -# Comparison -add_arrow_benchmark(compare_benchmark PREFIX "arrow-compute") +# Array value selection -# Selection -add_arrow_benchmark(filter_benchmark PREFIX "arrow-compute") -add_arrow_benchmark(take_benchmark PREFIX "arrow-compute") +# add_arrow_compute_test(filter_test) +# add_arrow_compute_test(take_test) + +# add_arrow_benchmark(filter_benchmark PREFIX "arrow-compute")a +# add_arrow_benchmark(take_benchmark PREFIX "arrow-compute") + +# ---------------------------------------------------------------------- +# Aggregate kernels + +# Aggregates + +add_arrow_compute_test(aggregate_test) +# add_arrow_benchmark(aggregate_benchmark PREFIX "arrow-compute") diff --git a/cpp/src/arrow/compute/kernels/add.cc b/cpp/src/arrow/compute/kernels/add.cc deleted file mode 100644 index 19eb153b5cd..00000000000 --- a/cpp/src/arrow/compute/kernels/add.cc +++ /dev/null @@ -1,131 +0,0 @@ -// Licensed to the Apache Software Foundation (ASF) under one -// or more contributor license agreements. See the NOTICE file -// distributed with this work for additional information -// regarding copyright ownership. The ASF licenses this file -// to you under the Apache License, Version 2.0 (the -// "License"); you may not use this file except in compliance -// with the License. You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, -// software distributed under the License is distributed on an -// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -// KIND, either express or implied. See the License for the -// specific language governing permissions and limitations -// under the License. - -#include "arrow/compute/kernels/add.h" -#include "arrow/builder.h" -#include "arrow/compute/context.h" -#include "arrow/type_traits.h" - -namespace arrow { -namespace compute { - -template -class AddKernelImpl : public AddKernel { - private: - using ArrayType = typename TypeTraits::ArrayType; - std::shared_ptr result_type_; - - Status Add(FunctionContext* ctx, const std::shared_ptr& lhs, - const std::shared_ptr& rhs, std::shared_ptr* result) { - NumericBuilder builder; - RETURN_NOT_OK(builder.Reserve(lhs->length())); - for (int i = 0; i < lhs->length(); i++) { - if (lhs->IsNull(i) || rhs->IsNull(i)) { - builder.UnsafeAppendNull(); - } else { - builder.UnsafeAppend(lhs->Value(i) + rhs->Value(i)); - } - } - return builder.Finish(result); - } - - public: - explicit AddKernelImpl(std::shared_ptr result_type) - : result_type_(result_type) {} - - Status Call(FunctionContext* ctx, const Datum& lhs, const Datum& rhs, - Datum* out) override { - if (!lhs.is_array() || !rhs.is_array()) { - return Status::Invalid("AddKernel expects array values"); - } - if (lhs.length() != rhs.length()) { - return Status::Invalid("AddKernel expects arrays with the same length"); - } - auto lhs_array = lhs.make_array(); - auto rhs_array = rhs.make_array(); - std::shared_ptr result; - RETURN_NOT_OK(this->Add(ctx, lhs_array, rhs_array, &result)); - *out = result; - return Status::OK(); - } - - std::shared_ptr out_type() const override { return result_type_; } - - Status Add(FunctionContext* ctx, const std::shared_ptr& lhs, - const std::shared_ptr& rhs, std::shared_ptr* result) override { - auto lhs_array = std::static_pointer_cast(lhs); - auto rhs_array = std::static_pointer_cast(rhs); - return Add(ctx, lhs_array, rhs_array, result); - } -}; - -Status AddKernel::Make(const std::shared_ptr& value_type, - std::unique_ptr* out) { - AddKernel* kernel; - switch (value_type->id()) { - case Type::UINT8: - kernel = new AddKernelImpl(value_type); - break; - case Type::INT8: - kernel = new AddKernelImpl(value_type); - break; - case Type::UINT16: - kernel = new AddKernelImpl(value_type); - break; - case Type::INT16: - kernel = new AddKernelImpl(value_type); - break; - case Type::UINT32: - kernel = new AddKernelImpl(value_type); - break; - case Type::INT32: - kernel = new AddKernelImpl(value_type); - break; - case Type::UINT64: - kernel = new AddKernelImpl(value_type); - break; - case Type::INT64: - kernel = new AddKernelImpl(value_type); - break; - case Type::FLOAT: - kernel = new AddKernelImpl(value_type); - break; - case Type::DOUBLE: - kernel = new AddKernelImpl(value_type); - break; - default: - return Status::NotImplemented("Arithmetic operations on ", *value_type, " arrays"); - } - out->reset(kernel); - return Status::OK(); -} - -Status Add(FunctionContext* ctx, const Array& lhs, const Array& rhs, - std::shared_ptr* result) { - Datum result_datum; - std::unique_ptr kernel; - ARROW_RETURN_IF( - !lhs.type()->Equals(rhs.type()), - Status::Invalid("Array types should be equal to use arithmetic kernels")); - RETURN_NOT_OK(AddKernel::Make(lhs.type(), &kernel)); - RETURN_NOT_OK(kernel->Call(ctx, Datum(lhs.data()), Datum(rhs.data()), &result_datum)); - *result = result_datum.make_array(); - return Status::OK(); -} - -} // namespace compute -} // namespace arrow diff --git a/cpp/src/arrow/compute/kernels/add.h b/cpp/src/arrow/compute/kernels/add.h deleted file mode 100644 index 19991aa4473..00000000000 --- a/cpp/src/arrow/compute/kernels/add.h +++ /dev/null @@ -1,77 +0,0 @@ -// Licensed to the Apache Software Foundation (ASF) under one -// or more contributor license agreements. See the NOTICE file -// distributed with this work for additional information -// regarding copyright ownership. The ASF licenses this file -// to you under the Apache License, Version 2.0 (the -// "License"); you may not use this file except in compliance -// with the License. You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, -// software distributed under the License is distributed on an -// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -// KIND, either express or implied. See the License for the -// specific language governing permissions and limitations -// under the License. - -#pragma once - -#include - -#include "arrow/compute/kernel.h" -#include "arrow/status.h" -#include "arrow/util/visibility.h" - -namespace arrow { - -class Array; - -namespace compute { - -class FunctionContext; - -/// \brief Summarizes two arrays. -/// -/// Summarizes two arrays with the same length. -/// The output is an array with same length and type as input. -/// Types of both input arrays should be equal -/// -/// For example given lhs = [1, null, 3], rhs = [4, 5, 6], the output -/// will be [5, null, 7] -/// -/// \param[in] ctx the FunctionContext -/// \param[in] lhs the first array -/// \param[in] rhs the second array -/// \param[out] result the sum of first and second arrays - -ARROW_EXPORT -Status Add(FunctionContext* ctx, const Array& lhs, const Array& rhs, - std::shared_ptr* result); - -/// \brief BinaryKernel implementing Add operation -class ARROW_EXPORT AddKernel : public BinaryKernel { - public: - /// \brief BinaryKernel interface - /// - /// delegates to subclasses via Add() - Status Call(FunctionContext* ctx, const Datum& lhs, const Datum& rhs, - Datum* out) override = 0; - - /// \brief output type of this kernel - std::shared_ptr out_type() const override = 0; - - /// \brief single-array implementation - virtual Status Add(FunctionContext* ctx, const std::shared_ptr& lhs, - const std::shared_ptr& rhs, - std::shared_ptr* result) = 0; - - /// \brief factory for Add - /// - /// \param[in] value_type constructed AddKernel - /// \param[out] out created kernel - static Status Make(const std::shared_ptr& value_type, - std::unique_ptr* out); -}; -} // namespace compute -} // namespace arrow diff --git a/cpp/src/arrow/compute/kernels/aggregate.cc b/cpp/src/arrow/compute/kernels/aggregate.cc deleted file mode 100644 index 90337588615..00000000000 --- a/cpp/src/arrow/compute/kernels/aggregate.cc +++ /dev/null @@ -1,88 +0,0 @@ -// Licensed to the Apache Software Foundation (ASF) under one -// or more contributor license agreements. See the NOTICE file -// distributed with this work for additional information -// regarding copyright ownership. The ASF licenses this file -// to you under the Apache License, Version 2.0 (the -// "License"); you may not use this file except in compliance -// with the License. You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, -// software distributed under the License is distributed on an -// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -// KIND, either express or implied. See the License for the -// specific language governing permissions and limitations -// under the License. - -#include - -#include "arrow/compute/context.h" -#include "arrow/compute/kernels/aggregate.h" - -namespace arrow { -namespace compute { - -// Helper class that properly invokes destructor when state goes out of scope. -class ManagedAggregateState { - public: - ManagedAggregateState(std::shared_ptr& desc, - std::shared_ptr&& buffer) - : desc_(desc), state_(buffer) { - desc_->New(state_->mutable_data()); - } - - ~ManagedAggregateState() { desc_->Delete(state_->mutable_data()); } - - void* mutable_data() { return state_->mutable_data(); } - - static std::shared_ptr Make( - std::shared_ptr& desc, MemoryPool* pool) { - auto maybe_buf = AllocateBuffer(desc->Size(), pool); - if (!maybe_buf.ok()) { - return nullptr; - } - return std::make_shared(desc, *std::move(maybe_buf)); - } - - private: - std::shared_ptr desc_; - std::shared_ptr state_; -}; - -Status AggregateUnaryKernel::Call(FunctionContext* ctx, const Datum& input, Datum* out) { - if (!input.is_arraylike()) { - return Status::Invalid("AggregateKernel expects Array or ChunkedArray datum"); - } - auto state = ManagedAggregateState::Make(aggregate_function_, ctx->memory_pool()); - if (!state) { - return Status::OutOfMemory("AggregateState allocation failed"); - } - - if (input.is_array()) { - auto array = input.make_array(); - RETURN_NOT_OK(aggregate_function_->Consume(*array, state->mutable_data())); - } else { - auto chunked_array = input.chunked_array(); - for (int i = 0; i < chunked_array->num_chunks(); i++) { - auto tmp_state = - ManagedAggregateState::Make(aggregate_function_, ctx->memory_pool()); - if (!tmp_state) { - return Status::OutOfMemory("AggregateState allocation failed"); - } - RETURN_NOT_OK(aggregate_function_->Consume(*chunked_array->chunk(i), - tmp_state->mutable_data())); - RETURN_NOT_OK( - aggregate_function_->Merge(tmp_state->mutable_data(), state->mutable_data())); - } - } - - return aggregate_function_->Finalize(state->mutable_data(), out); -} - -std::shared_ptr AggregateUnaryKernel::out_type() const { - return aggregate_function_->out_type(); -} - -} // namespace compute -} // namespace arrow diff --git a/cpp/src/arrow/compute/kernels/aggregate.h b/cpp/src/arrow/compute/kernels/aggregate.h deleted file mode 100644 index f342e31a0b6..00000000000 --- a/cpp/src/arrow/compute/kernels/aggregate.h +++ /dev/null @@ -1,115 +0,0 @@ -// Licensed to the Apache Software Foundation (ASF) under one -// or more contributor license agreements. See the NOTICE file -// distributed with this work for additional information -// regarding copyright ownership. The ASF licenses this file -// to you under the Apache License, Version 2.0 (the -// "License"); you may not use this file except in compliance -// with the License. You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, -// software distributed under the License is distributed on an -// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -// KIND, either express or implied. See the License for the -// specific language governing permissions and limitations -// under the License. - -#pragma once - -#include - -#include "arrow/compute/kernel.h" - -namespace arrow { - -class Array; -class Status; - -namespace compute { - -class FunctionContext; -struct Datum; - -/// AggregateFunction is an interface for Aggregates -/// -/// An aggregates transforms an array into single result called a state via the -/// Consume method.. State supports the merge operation via the Merge method. -/// State can be sealed into a final result via the Finalize method. -// -/// State ownership is handled by callers, thus the interface exposes 3 methods -/// for the caller to manage memory: -/// - Size -/// - New (placement new constructor invocation) -/// - Delete (state destructor) -/// -/// Design inspired by ClickHouse aggregate functions. -class AggregateFunction { - public: - /// \brief Consume an array into a state. - virtual Status Consume(const Array& input, void* state) const = 0; - - /// \brief Merge states. - virtual Status Merge(const void* src, void* dst) const = 0; - - /// \brief Convert state into a final result. - virtual Status Finalize(const void* src, Datum* output) const = 0; - - virtual ~AggregateFunction() {} - - virtual std::shared_ptr out_type() const = 0; - - /// State management methods. - virtual int64_t Size() const = 0; - virtual void New(void* ptr) const = 0; - virtual void Delete(void* ptr) const = 0; -}; - -/// AggregateFunction partial implementation for static type state -template -class AggregateFunctionStaticState : public AggregateFunction { - virtual Status Consume(const Array& input, State* state) const = 0; - virtual Status Merge(const State& src, State* dst) const = 0; - virtual Status Finalize(const State& src, Datum* output) const = 0; - - Status Consume(const Array& input, void* state) const final { - return Consume(input, static_cast(state)); - } - - Status Merge(const void* src, void* dst) const final { - return Merge(*static_cast(src), static_cast(dst)); - } - - /// \brief Convert state into a final result. - Status Finalize(const void* src, Datum* output) const final { - return Finalize(*static_cast(src), output); - } - - int64_t Size() const final { return sizeof(State); } - - void New(void* ptr) const final { - // By using placement-new syntax, the constructor of the State is invoked - // in the memory location defined by the caller. This only supports State - // with a parameter-less constructor. - new (ptr) State; - } - - void Delete(void* ptr) const final { static_cast(ptr)->~State(); } -}; - -/// \brief UnaryKernel implemented by an AggregateState -class ARROW_EXPORT AggregateUnaryKernel : public UnaryKernel { - public: - explicit AggregateUnaryKernel(std::shared_ptr& aggregate) - : aggregate_function_(aggregate) {} - - Status Call(FunctionContext* ctx, const Datum& input, Datum* out) override; - - std::shared_ptr out_type() const override; - - private: - std::shared_ptr aggregate_function_; -}; - -} // namespace compute -} // namespace arrow diff --git a/cpp/src/arrow/compute/kernels/aggregate_basic.cc b/cpp/src/arrow/compute/kernels/aggregate_basic.cc new file mode 100644 index 00000000000..8fb02f18d9e --- /dev/null +++ b/cpp/src/arrow/compute/kernels/aggregate_basic.cc @@ -0,0 +1,366 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// returnGegarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#include "arrow/compute/kernels/aggregate_internal.h" +#include "arrow/compute/kernels/common.h" + +namespace arrow { +namespace compute { + +namespace { + +struct ScalarAggregator : public KernelState { + virtual void Consume(KernelContext* ctx, const ExecBatch& batch) = 0; + virtual void MergeFrom(KernelContext* ctx, const KernelState& src) = 0; + virtual void Finalize(KernelContext* ctx, Datum* out) = 0; +}; + +void AggregateConsume(KernelContext* ctx, const ExecBatch& batch) { + checked_cast(ctx->state())->Consume(ctx, batch); +} + +void AggregateMerge(KernelContext* ctx, const KernelState& src, KernelState* dst) { + checked_cast(dst)->MergeFrom(ctx, src); +} + +void AggregateFinalize(KernelContext* ctx, Datum* out) { + checked_cast(ctx->state())->Finalize(ctx, out); +} + +// ---------------------------------------------------------------------- +// Count implementation + +struct CountImpl : public ScalarAggregator { + explicit CountImpl(CountOptions options) + : options(std::move(options)), non_nulls(0), nulls(0) {} + + void Consume(KernelContext*, const ExecBatch& batch) override { + const ArrayData& input = *batch[0].array(); + const int64_t nulls = input.GetNullCount(); + this->nulls += nulls; + this->non_nulls += input.length - nulls; + } + + void MergeFrom(KernelContext*, const KernelState& src) override { + const auto& other_state = checked_cast(src); + this->non_nulls += other_state.non_nulls; + this->nulls += other_state.nulls; + } + + void Finalize(KernelContext* ctx, Datum* out) override { + const auto& state = checked_cast(*ctx->state()); + switch (state.options.count_mode) { + case CountOptions::COUNT_ALL: + *out = Datum(state.non_nulls); + break; + case CountOptions::COUNT_NULL: + *out = Datum(state.nulls); + break; + default: + ctx->SetStatus(Status::Invalid("Unknown CountOptions encountered")); + break; + } + } + + CountOptions options; + int64_t non_nulls = 0; + int64_t nulls = 0; +}; + +std::unique_ptr CountInit(KernelContext*, const Kernel&, + const FunctionOptions* options) { + return std::unique_ptr( + new CountImpl(static_cast(*options))); +} + +// ---------------------------------------------------------------------- +// Sum implementation + +template ::Type> +struct SumState { + using ThisType = SumState; + using T = typename TypeTraits::CType; + using ArrayType = typename TypeTraits::ArrayType; + + // A small number of elements rounded to the next cacheline. This should + // amount to a maximum of 4 cachelines when dealing with 8 bytes elements. + static constexpr int64_t kTinyThreshold = 32; + static_assert(kTinyThreshold >= (2 * CHAR_BIT) + 1, + "ConsumeSparse requires 3 bytes of null bitmap, and 17 is the" + "required minimum number of bits/elements to cover 3 bytes."); + + ThisType operator+(const ThisType& rhs) const { + return ThisType(this->count + rhs.count, this->sum + rhs.sum); + } + + ThisType& operator+=(const ThisType& rhs) { + this->count += rhs.count; + this->sum += rhs.sum; + + return *this; + } + + public: + void Consume(const Array& input) { + const ArrayType& array = static_cast(input); + if (input.null_count() == 0) { + (*this) += ConsumeDense(array); + } else if (input.length() <= kTinyThreshold) { + // In order to simplify ConsumeSparse implementation (requires at least 3 + // bytes of bitmap data), small arrays are handled differently. + (*this) += ConsumeTiny(array); + } else { + (*this) += ConsumeSparse(array); + } + } + + size_t count = 0; + typename SumType::c_type sum = 0; + + private: + ThisType ConsumeDense(const ArrayType& array) const { + ThisType local; + const auto values = array.raw_values(); + const int64_t length = array.length(); + for (int64_t i = 0; i < length; i++) { + local.sum += values[i]; + } + local.count = length; + return local; + } + + ThisType ConsumeTiny(const ArrayType& array) const { + ThisType local; + + internal::BitmapReader reader(array.null_bitmap_data(), array.offset(), + array.length()); + const auto values = array.raw_values(); + for (int64_t i = 0; i < array.length(); i++) { + if (reader.IsSet()) { + local.sum += values[i]; + local.count++; + } + reader.Next(); + } + + return local; + } + + // While this is not branchless, gcc needs this to be in a different function + // for it to generate cmov which ends to be slightly faster than + // multiplication but safe for handling NaN with doubles. + inline T MaskedValue(bool valid, T value) const { return valid ? value : 0; } + + inline ThisType UnrolledSum(uint8_t bits, const T* values) const { + ThisType local; + + if (bits < 0xFF) { + // Some nulls + for (size_t i = 0; i < 8; i++) { + local.sum += MaskedValue(bits & (1U << i), values[i]); + } + local.count += BitUtil::kBytePopcount[bits]; + } else { + // No nulls + for (size_t i = 0; i < 8; i++) { + local.sum += values[i]; + } + local.count += 8; + } + + return local; + } + + ThisType ConsumeSparse(const ArrayType& array) const { + ThisType local; + + // Sliced bitmaps on non-byte positions induce problem with the branchless + // unrolled technique. Thus extra padding is added on both left and right + // side of the slice such that both ends are byte-aligned. The first and + // last bitmap are properly masked to ignore extra values induced by + // padding. + // + // The execution is divided in 3 sections. + // + // 1. Compute the sum of the first masked byte. + // 2. Compute the sum of the middle bytes + // 3. Compute the sum of the last masked byte. + + const int64_t length = array.length(); + const int64_t offset = array.offset(); + + // The number of bytes covering the range, this includes partial bytes. + // This number bounded by `<= (length / 8) + 2`, e.g. a possible extra byte + // on the left, and on the right. + const int64_t covering_bytes = BitUtil::CoveringBytes(offset, length); + DCHECK_GE(covering_bytes, 3); + + // Align values to the first batch of 8 elements. Note that raw_values() is + // already adjusted with the offset, thus we rewind a little to align to + // the closest 8-batch offset. + const auto values = array.raw_values() - (offset % 8); + + // Align bitmap at the first consumable byte. + const auto bitmap = array.null_bitmap_data() + BitUtil::RoundDown(offset, 8) / 8; + + // Consume the first (potentially partial) byte. + const uint8_t first_mask = BitUtil::kTrailingBitmask[offset % 8]; + local += UnrolledSum(bitmap[0] & first_mask, values); + + // Consume the (full) middle bytes. The loop iterates in unit of + // batches of 8 values and 1 byte of bitmap. + for (int64_t i = 1; i < covering_bytes - 1; i++) { + local += UnrolledSum(bitmap[i], &values[i * 8]); + } + + // Consume the last (potentially partial) byte. + const int64_t last_idx = covering_bytes - 1; + const uint8_t last_mask = BitUtil::kPrecedingWrappingBitmask[(offset + length) % 8]; + local += UnrolledSum(bitmap[last_idx] & last_mask, &values[last_idx * 8]); + + return local; + } +}; + +template +struct SumImpl : public ScalarAggregator { + using ArrayType = typename TypeTraits::ArrayType; + using ThisType = SumImpl; + using SumType = typename FindAccumulatorType::Type; + using OutputType = typename TypeTraits::ScalarType; + + void Consume(KernelContext*, const ExecBatch& batch) override { + this->state.Consume(ArrayType(batch[0].array())); + } + + void MergeFrom(KernelContext*, const KernelState& src) override { + const auto& other = checked_cast(src); + this->state += other.state; + } + + void Finalize(KernelContext*, Datum* out) override { + if (state.count == 0) { + out->value = std::make_shared(); + } else { + out->value = MakeScalar(state.sum); + } + } + + SumState state; +}; + +template +struct MeanImpl : public SumImpl { + void Finalize(KernelContext*, Datum* out) override { + const bool is_valid = this->state.count > 0; + const double divisor = static_cast(is_valid ? this->state.count : 1UL); + const double mean = static_cast(this->state.sum) / divisor; + + if (!is_valid) { + out->value = std::make_shared(); + } else { + out->value = std::make_shared(mean); + } + } +}; + +template