diff --git a/cpp/src/arrow/compute/expression.cc b/cpp/src/arrow/compute/expression.cc index 12fda5d58f3..b49527d6498 100644 --- a/cpp/src/arrow/compute/expression.cc +++ b/cpp/src/arrow/compute/expression.cc @@ -22,12 +22,14 @@ #include #include +#include "arrow/array/concatenate.h" #include "arrow/chunked_array.h" #include "arrow/compute/api_aggregate.h" #include "arrow/compute/api_vector.h" #include "arrow/compute/exec_internal.h" #include "arrow/compute/expression_internal.h" #include "arrow/compute/function_internal.h" +#include "arrow/compute/kernels/set_lookup_internal.h" #include "arrow/compute/util.h" #include "arrow/io/memory.h" #include "arrow/ipc/reader.h" @@ -38,6 +40,7 @@ #include "arrow/util/string.h" #include "arrow/util/value_parsing.h" #include "arrow/util/vector.h" +#include "arrow/visit_array_inline.h" namespace arrow { @@ -1243,72 +1246,6 @@ struct Inequality { /*insert_implicit_casts=*/false, &exec_context); } - /// Simplify an `is_in` call against an inequality guarantee. - /// - /// We avoid the complexity of fully simplifying EQUAL comparisons to true - /// literals (e.g., 'x is_in [1, 2, 3]' given the guarantee 'x = 2') due to - /// potential complications with null matching behavior. This is ok for the - /// predicate pushdown use case because the overall aim is to simplify to an - /// unsatisfiable expression. - /// - /// \pre `is_in_call` is a call to the `is_in` function - /// \return a simplified expression, or nullopt if no simplification occurred - static Result> SimplifyIsIn( - const Inequality& guarantee, const Expression::Call* is_in_call) { - DCHECK_EQ(is_in_call->function_name, "is_in"); - - auto options = checked_pointer_cast(is_in_call->options); - - const auto& lhs = Comparison::StripOrderPreservingCasts(is_in_call->arguments[0]); - if (!lhs.field_ref()) return std::nullopt; - if (*lhs.field_ref() != guarantee.target) return std::nullopt; - - FilterOptions::NullSelectionBehavior null_selection; - switch (options->null_matching_behavior) { - case SetLookupOptions::MATCH: - null_selection = - guarantee.nullable ? FilterOptions::EMIT_NULL : FilterOptions::DROP; - break; - case SetLookupOptions::SKIP: - null_selection = FilterOptions::DROP; - break; - case SetLookupOptions::EMIT_NULL: - if (guarantee.nullable) return std::nullopt; - null_selection = FilterOptions::DROP; - break; - case SetLookupOptions::INCONCLUSIVE: - if (guarantee.nullable) return std::nullopt; - ARROW_ASSIGN_OR_RAISE(Datum is_null, IsNull(options->value_set)); - ARROW_ASSIGN_OR_RAISE(Datum any_null, Any(is_null)); - if (any_null.scalar_as().value) return std::nullopt; - null_selection = FilterOptions::DROP; - break; - } - - std::string func_name = Comparison::GetName(guarantee.cmp); - DCHECK_NE(func_name, "na"); - std::vector args{options->value_set, guarantee.bound}; - ARROW_ASSIGN_OR_RAISE(Datum filter_mask, CallFunction(func_name, args)); - FilterOptions filter_options(null_selection); - ARROW_ASSIGN_OR_RAISE(Datum simplified_value_set, - Filter(options->value_set, filter_mask, filter_options)); - - if (simplified_value_set.length() == 0) return literal(false); - if (simplified_value_set.length() == options->value_set.length()) return std::nullopt; - - ExecContext exec_context; - Expression::Call simplified_call; - simplified_call.function_name = "is_in"; - simplified_call.arguments = is_in_call->arguments; - simplified_call.options = std::make_shared( - simplified_value_set, options->null_matching_behavior); - ARROW_ASSIGN_OR_RAISE( - Expression simplified_expr, - BindNonRecursive(std::move(simplified_call), - /*insert_implicit_casts=*/false, &exec_context)); - return simplified_expr; - } - /// \brief Simplify the given expression given this inequality as a guarantee. Result Simplify(Expression expr) { const auto& guarantee = *this; @@ -1325,12 +1262,6 @@ struct Inequality { return call->function_name == "is_valid" ? literal(true) : literal(false); } - if (call->function_name == "is_in") { - ARROW_ASSIGN_OR_RAISE(std::optional result, - SimplifyIsIn(guarantee, call)); - return result.value_or(expr); - } - auto cmp = Comparison::Get(expr); if (!cmp) return expr; @@ -1416,6 +1347,213 @@ Result SimplifyIsValidGuarantee(Expression expr, }); } +/// Simplify an `is_in` value set against a single inequality guarantee. +/// +/// We avoid the complexity of fully simplifying EQUAL comparisons to true +/// literals (e.g., 'x is_in [1, 2, 3]' given the guarantee 'x = 2') due to +/// potential complications with null matching behavior. This is ok for the +/// predicate pushdown use case because the overall aim is to simplify to an +/// unsatisfiable expression. +struct IsInValueSetSimplifier { + template + Status Visit(const T&) { + ARROW_ASSIGN_OR_RAISE(result, SimplifyBasic()); + return Status::OK(); + } + + template + enable_if_t || std::is_base_of_v, + Status> + Visit(const T&) { + auto simplified = + enable_fast_simplification ? SimplifyOptimized() : Status::Invalid(); + if (simplified.ok()) { + result = simplified.ValueUnsafe(); + } else { + ARROW_ASSIGN_OR_RAISE(result, SimplifyBasic()); + } + return Status::OK(); + } + + /// Simplify the value set using a linear scan filter. + Result> SimplifyBasic() { + FilterOptions::NullSelectionBehavior null_selection; + switch (null_matching) { + case SetLookupOptions::MATCH: + null_selection = + guarantee.nullable ? FilterOptions::EMIT_NULL : FilterOptions::DROP; + break; + case SetLookupOptions::SKIP: + null_selection = FilterOptions::DROP; + break; + case SetLookupOptions::EMIT_NULL: + if (guarantee.nullable) return value_set; + null_selection = FilterOptions::DROP; + break; + case SetLookupOptions::INCONCLUSIVE: + if (guarantee.nullable) return value_set; + ARROW_ASSIGN_OR_RAISE(Datum is_null, IsNull(value_set)); + ARROW_ASSIGN_OR_RAISE(Datum any_null, Any(is_null)); + if (any_null.scalar_as().value) return value_set; + null_selection = FilterOptions::DROP; + break; + } + + std::string func_name = Comparison::GetName(guarantee.cmp); + DCHECK_NE(func_name, "na"); + std::vector args{value_set, guarantee.bound}; + ARROW_ASSIGN_OR_RAISE(Datum filter_mask, CallFunction(func_name, args)); + FilterOptions filter_options(null_selection); + ARROW_ASSIGN_OR_RAISE(Datum simplified_value_set, + Filter(value_set, filter_mask, filter_options)); + return simplified_value_set.make_array(); + } + + /// Simplify the value set using binary search. + /// + /// \pre `value_set` is sorted + /// \pre `value_set` contains no duplicates + /// \pre `value_set` contains no nulls + template + Result> SimplifyOptimized() { + if (guarantee.nullable) return Status::Invalid(); + if (null_matching == SetLookupOptions::INCONCLUSIVE) return Status::Invalid(); + + ARROW_ASSIGN_OR_RAISE(std::shared_ptr scalar_bound, + guarantee.bound.scalar()->CastTo(value_set->type())); + auto bound = internal::UnboxScalar::Unbox(*scalar_bound); + auto compare = [&](size_t i) -> Comparison::type { + DCHECK(value_set->IsValid(i)); + auto value = checked_pointer_cast(value_set)->GetView(i); + return value == bound ? Comparison::EQUAL + : value < bound ? Comparison::LESS + : Comparison::GREATER; + }; + + size_t lo = 0; + size_t hi = value_set->length(); + while (lo + 1 < hi) { + size_t mid = (lo + hi) / 2; + Comparison::type cmp = compare(mid); + if (cmp & Comparison::LESS_EQUAL) { + lo = mid; + } else { + hi = mid; + } + } + + Comparison::type cmp = compare(lo); + size_t pivot = lo + (cmp == Comparison::LESS ? 1 : 0); + bool found = cmp == Comparison::EQUAL; + + switch (guarantee.cmp) { + case Comparison::EQUAL: + return value_set->Slice(pivot, found ? 1 : 0); + case Comparison::LESS: + return value_set->Slice(0, pivot); + case Comparison::LESS_EQUAL: + return value_set->Slice(0, pivot + (found ? 1 : 0)); + case Comparison::GREATER: + return value_set->Slice(pivot + (found ? 1 : 0)); + case Comparison::GREATER_EQUAL: + return value_set->Slice(pivot); + case Comparison::NOT_EQUAL: + case Comparison::NA: + DCHECK(false); + return Status::Invalid("Invalid comparison"); + } + } + + static Result> Simplify( + std::shared_ptr value_set, const Inequality& guarantee, + SetLookupOptions::NullMatchingBehavior null_matching, + bool enable_fast_simplification) { + IsInValueSetSimplifier simplifier{value_set, guarantee, null_matching, + enable_fast_simplification, nullptr}; + RETURN_NOT_OK(VisitArrayInline(*value_set, &simplifier)); + return simplifier.result; + } + + std::shared_ptr value_set; + const Inequality& guarantee; + SetLookupOptions::NullMatchingBehavior null_matching; + bool enable_fast_simplification; + std::shared_ptr result; +}; + +/// Simplify an `is_in` call against a list of inequality guarantees. +/// +/// Simplification is done across all guarantee conjunction members at once to +/// avoid the cost of repeatedly binding the simplified expression, which is +/// linear in the size of the `is_in` value set. +/// +/// Returns a simplified expression, or nullopt if no simfpliciation occurred. +Result> SimplifyIsInWithGuarantees( + const Expression::Call* is_in_call, + const std::vector& guarantee_conjunction_members) { + DCHECK_EQ(is_in_call->function_name, "is_in"); + DCHECK_EQ(is_in_call->arguments.size(), 1); + + auto options = checked_pointer_cast(is_in_call->options); + + const auto& lhs = Comparison::StripOrderPreservingCasts(is_in_call->arguments[0]); + if (!lhs.field_ref()) return std::nullopt; + + std::vector guarantees; + for (const Expression& guarantee : guarantee_conjunction_members) { + std::optional inequality = Inequality::ExtractOne(guarantee); + if (!inequality) continue; + if (inequality->target != *lhs.field_ref()) continue; + guarantees.emplace_back(std::move(*inequality)); + } + + bool guaranteed_non_nullable = + std::any_of(guarantees.begin(), guarantees.end(), + [](const Inequality& guarantee) { return !guarantee.nullable; }); + + std::shared_ptr simplified_value_set; + bool enable_fast_simplification = false; + if (guaranteed_non_nullable && + options->null_matching_behavior != SetLookupOptions::INCONCLUSIVE) { + auto state = + checked_pointer_cast(is_in_call->kernel_state); + simplified_value_set = state->sorted_and_unique_value_set; + enable_fast_simplification = static_cast(simplified_value_set); + } + if (!simplified_value_set) { + if (options->value_set.is_array()) { + simplified_value_set = options->value_set.make_array(); + } else if (options->value_set.is_chunked_array()) { + ARROW_ASSIGN_OR_RAISE(simplified_value_set, + Concatenate(options->value_set.chunked_array()->chunks())); + } else { + return Status::Invalid("`is_in` value set must be an array or chunked array"); + } + } + + for (Inequality& guarantee : guarantees) { + if (guaranteed_non_nullable) guarantee.nullable = false; + ARROW_ASSIGN_OR_RAISE(simplified_value_set, IsInValueSetSimplifier::Simplify( + simplified_value_set, guarantee, + options->null_matching_behavior, + enable_fast_simplification)); + } + + if (simplified_value_set->length() == 0) return literal(false); + if (simplified_value_set->length() == options->value_set.length()) return std::nullopt; + + ExecContext exec_context; + Expression::Call simplified_call; + simplified_call.function_name = "is_in"; + simplified_call.arguments = is_in_call->arguments; + simplified_call.options = std::make_shared( + simplified_value_set, options->null_matching_behavior); + ARROW_ASSIGN_OR_RAISE(Expression simplified_expr, + BindNonRecursive(std::move(simplified_call), + /*insert_implicit_casts=*/false, &exec_context)); + return simplified_expr; +} + } // namespace Result SimplifyWithGuarantee(Expression expr, @@ -1464,6 +1602,24 @@ Result SimplifyWithGuarantee(Expression expr, } } + auto simplify_is_in = [&](Expression expr, ...) -> Result { + if (expr.call() && expr.call()->function_name == "is_in") { + ARROW_ASSIGN_OR_RAISE(auto simplified, + SimplifyIsInWithGuarantees(expr.call(), conjunction_members)); + return simplified.value_or(expr); + } else { + return expr; + } + }; + ARROW_ASSIGN_OR_RAISE( + auto simplified, + ModifyExpression( + std::move(expr), [](Expression expr) { return expr; }, simplify_is_in)); + if (!Identical(simplified, expr)) { + expr = std::move(simplified); + RETURN_NOT_OK(CanonicalizeAndFoldConstants()); + } + return expr; } diff --git a/cpp/src/arrow/compute/kernels/scalar_set_lookup.cc b/cpp/src/arrow/compute/kernels/scalar_set_lookup.cc index e2d5583e36e..b18dbde484b 100644 --- a/cpp/src/arrow/compute/kernels/scalar_set_lookup.cc +++ b/cpp/src/arrow/compute/kernels/scalar_set_lookup.cc @@ -16,9 +16,12 @@ // under the License. #include "arrow/array/array_base.h" +#include "arrow/array/concatenate.h" #include "arrow/compute/api_scalar.h" +#include "arrow/compute/api_vector.h" #include "arrow/compute/cast.h" #include "arrow/compute/kernels/common_internal.h" +#include "arrow/compute/kernels/set_lookup_internal.h" #include "arrow/compute/kernels/util_internal.h" #include "arrow/type.h" #include "arrow/util/bit_util.h" @@ -34,10 +37,26 @@ using internal::HashTraits; namespace compute::internal { namespace { -// This base class enables non-templated access to the value set type -struct SetLookupStateBase : public KernelState { - std::shared_ptr value_set_type; -}; +template +Result> SortAndUnique(std::shared_ptr value_set) { + if constexpr (std::is_same_v) { + ARROW_ASSIGN_OR_RAISE(std::shared_ptr value_set_array, + Concatenate(value_set->chunks())); + return SortAndUnique(value_set_array); + } else { + ARROW_ASSIGN_OR_RAISE(value_set, Unique(std::move(value_set))); + ARROW_ASSIGN_OR_RAISE(std::shared_ptr sort_indices, + SortIndices(value_set, SortOptions({}, NullPlacement::AtEnd))); + ARROW_ASSIGN_OR_RAISE( + value_set, Take(*value_set, *sort_indices, TakeOptions(/*bounds_check=*/false))); + if (value_set->length() > 0 && value_set->IsNull(value_set->length() - 1)) { + // If the last one is null we know it's the only one because of the call + // to `Unique` above. + value_set = value_set->Slice(0, value_set->length() - 1); + } + return value_set; + } +} template struct SetLookupState : public SetLookupStateBase { @@ -209,6 +228,21 @@ struct InitStateVisitor { } Result> GetResult() { + if (!options.value_set.is_arraylike()) { + return Status::Invalid("Set lookup value set must be Array or ChunkedArray"); + } + + // The sorted and unique value set needs to be derived from the value set + // before casting occurs. + std::shared_ptr sorted_and_unique_value_set; + if (options.value_set.is_chunked_array()) { + sorted_and_unique_value_set = + SortAndUnique(options.value_set.chunked_array()).ValueOr(nullptr); + } else { + sorted_and_unique_value_set = + SortAndUnique(options.value_set.make_array()).ValueOr(nullptr); + } + if (arg_type.id() == Type::TIMESTAMP && options.value_set.type()->id() == Type::TIMESTAMP) { // Other types will fail when casting, so no separate check is needed @@ -228,9 +262,7 @@ struct InitStateVisitor { " vs ", *options.value_set.type()); } - if (!options.value_set.is_arraylike()) { - return Status::Invalid("Set lookup value set must be Array or ChunkedArray"); - } else if (!options.value_set.type()->Equals(*arg_type)) { + if (!options.value_set.type()->Equals(*arg_type)) { auto cast_result = Cast(options.value_set, CastOptions::Safe(arg_type.GetSharedPtr()), ctx->exec_context()); @@ -252,6 +284,8 @@ struct InitStateVisitor { } RETURN_NOT_OK(VisitTypeInline(*options.value_set.type(), this)); + checked_cast(result.get())->sorted_and_unique_value_set = + sorted_and_unique_value_set; return std::move(result); } }; diff --git a/cpp/src/arrow/compute/kernels/set_lookup_internal.h b/cpp/src/arrow/compute/kernels/set_lookup_internal.h new file mode 100644 index 00000000000..e168ed09056 --- /dev/null +++ b/cpp/src/arrow/compute/kernels/set_lookup_internal.h @@ -0,0 +1,36 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#pragma once + +#include "arrow/compute/kernel.h" + +namespace arrow::compute::internal { + +/// Base class for `is_in` and `index_in` kernel states. +struct SetLookupStateBase : public KernelState { + /// Enables non-templated access to the value set type. + std::shared_ptr value_set_type; + /// Enables fast simplification for `is_in` expressions. + /// + /// This field may be null. + /// + /// \invariant sorted and contains no null or duplicate values + std::shared_ptr sorted_and_unique_value_set; +}; + +} // namespace arrow::compute::internal