Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
73 changes: 73 additions & 0 deletions cpp/src/arrow/compute/expression.cc
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
#include <unordered_set>

#include "arrow/chunked_array.h"
#include "arrow/compute/api_aggregate.h"
#include "arrow/compute/api_vector.h"
#include "arrow/compute/exec_internal.h"
#include "arrow/compute/expression_internal.h"
Expand Down Expand Up @@ -1242,6 +1243,72 @@ struct Inequality {
/*insert_implicit_casts=*/false, &exec_context);
}

/// Simplify an `is_in` call against an inequality guarantee.
///
/// We avoid the complexity of fully simplifying EQUAL comparisons to true
/// literals (e.g., 'x is_in [1, 2, 3]' given the guarantee 'x = 2') due to
/// potential complications with null matching behavior. This is ok for the
/// predicate pushdown use case because the overall aim is to simplify to an
/// unsatisfiable expression.
///
/// \pre `is_in_call` is a call to the `is_in` function
/// \return a simplified expression, or nullopt if no simplification occurred
static Result<std::optional<Expression>> SimplifyIsIn(
const Inequality& guarantee, const Expression::Call* is_in_call) {
DCHECK_EQ(is_in_call->function_name, "is_in");

auto options = checked_pointer_cast<SetLookupOptions>(is_in_call->options);

const auto& lhs = Comparison::StripOrderPreservingCasts(is_in_call->arguments[0]);
if (!lhs.field_ref()) return std::nullopt;
if (*lhs.field_ref() != guarantee.target) return std::nullopt;

FilterOptions::NullSelectionBehavior null_selection;
switch (options->null_matching_behavior) {
case SetLookupOptions::MATCH:
null_selection =
guarantee.nullable ? FilterOptions::EMIT_NULL : FilterOptions::DROP;
break;
case SetLookupOptions::SKIP:
null_selection = FilterOptions::DROP;
break;
case SetLookupOptions::EMIT_NULL:
if (guarantee.nullable) return std::nullopt;
null_selection = FilterOptions::DROP;
break;
case SetLookupOptions::INCONCLUSIVE:
if (guarantee.nullable) return std::nullopt;
ARROW_ASSIGN_OR_RAISE(Datum is_null, IsNull(options->value_set));
ARROW_ASSIGN_OR_RAISE(Datum any_null, Any(is_null));
if (any_null.scalar_as<BooleanScalar>().value) return std::nullopt;
null_selection = FilterOptions::DROP;
break;
}

std::string func_name = Comparison::GetName(guarantee.cmp);
DCHECK_NE(func_name, "na");
std::vector<Datum> args{options->value_set, guarantee.bound};
ARROW_ASSIGN_OR_RAISE(Datum filter_mask, CallFunction(func_name, args));
FilterOptions filter_options(null_selection);
ARROW_ASSIGN_OR_RAISE(Datum simplified_value_set,
Filter(options->value_set, filter_mask, filter_options));

if (simplified_value_set.length() == 0) return literal(false);
if (simplified_value_set.length() == options->value_set.length()) return std::nullopt;

ExecContext exec_context;
Expression::Call simplified_call;
simplified_call.function_name = "is_in";
simplified_call.arguments = is_in_call->arguments;
simplified_call.options = std::make_shared<SetLookupOptions>(
simplified_value_set, options->null_matching_behavior);
ARROW_ASSIGN_OR_RAISE(
Expression simplified_expr,
BindNonRecursive(std::move(simplified_call),
/*insert_implicit_casts=*/false, &exec_context));
return simplified_expr;
}

/// \brief Simplify the given expression given this inequality as a guarantee.
Result<Expression> Simplify(Expression expr) {
const auto& guarantee = *this;
Expand All @@ -1258,6 +1325,12 @@ struct Inequality {
return call->function_name == "is_valid" ? literal(true) : literal(false);
}

if (call->function_name == "is_in") {
ARROW_ASSIGN_OR_RAISE(std::optional<Expression> result,
SimplifyIsIn(guarantee, call));
return result.value_or(expr);
}

auto cmp = Comparison::Get(expr);
if (!cmp) return expr;

Expand Down
173 changes: 173 additions & 0 deletions cpp/src/arrow/compute/expression_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
#include <gmock/gmock.h>
#include <gtest/gtest.h>

#include "arrow/array/builder_primitive.h"
#include "arrow/compute/expression_internal.h"
#include "arrow/compute/function_internal.h"
#include "arrow/compute/registry.h"
Expand Down Expand Up @@ -1616,6 +1617,144 @@ TEST(Expression, SimplifyWithComparisonAndNullableCaveat) {
true_unless_null(field_ref("i32")))); // not satisfiable, will drop row group
}

TEST(Expression, SimplifyIsIn) {
auto is_in = [](Expression field, std::shared_ptr<DataType> value_set_type,
std::string json_array,
SetLookupOptions::NullMatchingBehavior null_matching_behavior) {
SetLookupOptions options{ArrayFromJSON(value_set_type, json_array),
null_matching_behavior};
return call("is_in", {field}, options);
};

for (SetLookupOptions::NullMatchingBehavior null_matching : {
SetLookupOptions::MATCH,
SetLookupOptions::SKIP,
SetLookupOptions::EMIT_NULL,
SetLookupOptions::INCONCLUSIVE,
}) {
Simplify{is_in(field_ref("i32"), int32(), "[]", null_matching)}
.WithGuarantee(greater(field_ref("i32"), literal(2)))
.Expect(false);

Simplify{is_in(field_ref("i32"), int32(), "[1,3,5,7,9]", null_matching)}
.WithGuarantee(equal(field_ref("i32"), literal(6)))
.Expect(false);

Simplify{is_in(field_ref("i32"), int32(), "[1,3,5,7,9]", null_matching)}
.WithGuarantee(greater(field_ref("i32"), literal(3)))
.Expect(is_in(field_ref("i32"), int32(), "[5,7,9]", null_matching));

Simplify{is_in(field_ref("i32"), int32(), "[1,3,5,7,9]", null_matching)}
.WithGuarantee(greater(field_ref("i32"), literal(9)))
.Expect(false);

Simplify{is_in(field_ref("i32"), int32(), "[1,3,5,7,9]", null_matching)}
.WithGuarantee(less_equal(field_ref("i32"), literal(0)))
.Expect(false);

Simplify{is_in(field_ref("i32"), int32(), "[1,3,5,7,9]", null_matching)}
.WithGuarantee(greater(field_ref("i32"), literal(0)))
.ExpectUnchanged();

Simplify{is_in(field_ref("i32"), int32(), "[1,3,5,7,9]", null_matching)}
.WithGuarantee(less_equal(field_ref("i32"), literal(9)))
.ExpectUnchanged();

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'd expect this to pass:

Suggested change
Simplify{is_in(field_ref("i32"), int32(), "[1,3,null]", SetLookupOptions::MATCH)}
.WithGuarantee(
or_(equal(field_ref("i32"), literal(3)), is_null(field_ref("i32"))))
.Expect(is_in(field_ref("i32"), int32(), "[3,null]", SetLookupOptions::MATCH));

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The original intention was to not support simplification if nulls in the value set cannot be dropped (either the guarantee is nullable or the null matching behavior is INCONCLUSIVE). This is because in the optimized implementation where we binary search and slice the value set array, slicing the front would drop nulls (assuming they are placed at the end) so we would have to reallocate a new array for the simplified value set.

Do you think we ought to support nulls in the value set, and if so any thoughts on how we'd continue to support this with the binary search/slice implementation?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The filtering approach should be able to support arbitrary value sets, so it could serve as a fallback for the binary search/slice implementation

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I see, that sounds reasonable to me. I added new tests for nullable guarantees and nulls in the value set for all of the different null matching behaviors.

Simplify{is_in(field_ref("i32"), int32(), "[1,3,5,7,9]", null_matching)}
.WithGuarantee(and_(less_equal(field_ref("i32"), literal(7)),
greater(field_ref("i32"), literal(4))))
.Expect(is_in(field_ref("i32"), int32(), "[5,7]", null_matching));

Simplify{is_in(field_ref("u32"), int8(), "[1,3,5,7,9]", null_matching)}
.WithGuarantee(greater(field_ref("u32"), literal(3)))
.Expect(is_in(field_ref("u32"), int8(), "[5,7,9]", null_matching));

Simplify{is_in(field_ref("u32"), int64(), "[1,3,5,7,9]", null_matching)}
.WithGuarantee(greater(field_ref("u32"), literal(3)))
.Expect(is_in(field_ref("u32"), int64(), "[5,7,9]", null_matching));
}

Simplify{
is_in(field_ref("i32"), int32(), "[1,2,3]", SetLookupOptions::MATCH),
}
.WithGuarantee(
or_(greater(field_ref("i32"), literal(2)), is_null(field_ref("i32"))))
.Expect(is_in(field_ref("i32"), int32(), "[3]", SetLookupOptions::MATCH));

Simplify{
is_in(field_ref("i32"), int32(), "[1,2,3,null]", SetLookupOptions::MATCH),
}
.WithGuarantee(greater(field_ref("i32"), literal(2)))
.Expect(is_in(field_ref("i32"), int32(), "[3]", SetLookupOptions::MATCH));

Simplify{
is_in(field_ref("i32"), int32(), "[1,2,3,null]", SetLookupOptions::MATCH),
}
.WithGuarantee(
or_(greater(field_ref("i32"), literal(2)), is_null(field_ref("i32"))))
.Expect(is_in(field_ref("i32"), int32(), "[3,null]", SetLookupOptions::MATCH));

Simplify{
is_in(field_ref("i32"), int32(), "[1,2,3]", SetLookupOptions::SKIP),
}
.WithGuarantee(
or_(greater(field_ref("i32"), literal(2)), is_null(field_ref("i32"))))
.Expect(is_in(field_ref("i32"), int32(), "[3]", SetLookupOptions::SKIP));

Simplify{
is_in(field_ref("i32"), int32(), "[1,2,3,null]", SetLookupOptions::SKIP),
}
.WithGuarantee(greater(field_ref("i32"), literal(2)))
.Expect(is_in(field_ref("i32"), int32(), "[3]", SetLookupOptions::SKIP));

Simplify{
is_in(field_ref("i32"), int32(), "[1,2,3,null]", SetLookupOptions::SKIP),
}
.WithGuarantee(
or_(greater(field_ref("i32"), literal(2)), is_null(field_ref("i32"))))
.Expect(is_in(field_ref("i32"), int32(), "[3]", SetLookupOptions::SKIP));

Simplify{
is_in(field_ref("i32"), int32(), "[1,2,3]", SetLookupOptions::EMIT_NULL),
}
.WithGuarantee(
or_(greater(field_ref("i32"), literal(2)), is_null(field_ref("i32"))))
.ExpectUnchanged();

Simplify{
is_in(field_ref("i32"), int32(), "[1,2,3,null]", SetLookupOptions::EMIT_NULL),
}
.WithGuarantee(greater(field_ref("i32"), literal(2)))
.Expect(is_in(field_ref("i32"), int32(), "[3]", SetLookupOptions::EMIT_NULL));

Simplify{
is_in(field_ref("i32"), int32(), "[1,2,3,null]", SetLookupOptions::EMIT_NULL),
}
.WithGuarantee(
or_(greater(field_ref("i32"), literal(2)), is_null(field_ref("i32"))))
.ExpectUnchanged();

Simplify{
is_in(field_ref("i32"), int32(), "[1,2,3]", SetLookupOptions::INCONCLUSIVE),
}
.WithGuarantee(
or_(greater(field_ref("i32"), literal(2)), is_null(field_ref("i32"))))
.ExpectUnchanged();

Simplify{
is_in(field_ref("i32"), int32(), "[1,2,3,null]", SetLookupOptions::INCONCLUSIVE),
}
.WithGuarantee(greater(field_ref("i32"), literal(2)))
.ExpectUnchanged();

Simplify{
is_in(field_ref("i32"), int32(), "[1,2,3,null]", SetLookupOptions::INCONCLUSIVE),
}
.WithGuarantee(
or_(greater(field_ref("i32"), literal(2)), is_null(field_ref("i32"))))
.ExpectUnchanged();
}

TEST(Expression, SimplifyThenExecute) {
auto filter =
or_({equal(field_ref("f32"), literal(0)),
Expand Down Expand Up @@ -1643,6 +1782,40 @@ TEST(Expression, SimplifyThenExecute) {
AssertDatumsEqual(evaluated, simplified_evaluated, /*verbose=*/true);
}

TEST(Expression, SimplifyIsInThenExecute) {
auto input = RecordBatchFromJSON(kBoringSchema, R"([
{"i64": 2, "i32": 5},
{"i64": 5, "i32": 6},
{"i64": 3, "i32": 6},
{"i64": 3, "i32": 5},
{"i64": 4, "i32": 5},
{"i64": 2, "i32": 7},
{"i64": 5, "i32": 5}
])");

std::vector<Expression> guarantees{greater(field_ref("i64"), literal(1)),
greater_equal(field_ref("i32"), literal(5)),
less_equal(field_ref("i64"), literal(5))};

for (const Expression& guarantee : guarantees) {
auto filter =
call("is_in", {guarantee.call()->arguments[0]},
compute::SetLookupOptions{ArrayFromJSON(int32(), "[1,2,3]"), true});
ASSERT_OK_AND_ASSIGN(filter, filter.Bind(*kBoringSchema));
ASSERT_OK_AND_ASSIGN(auto simplified, SimplifyWithGuarantee(filter, guarantee));

Datum evaluated, simplified_evaluated;
ExpectExecute(filter, input, &evaluated);
ExpectExecute(simplified, input, &simplified_evaluated);
if (simplified_evaluated.is_scalar()) {
ASSERT_OK_AND_ASSIGN(
simplified_evaluated,
MakeArrayFromScalar(*simplified_evaluated.scalar(), evaluated.length()));
}
AssertDatumsEqual(evaluated, simplified_evaluated, /*verbose=*/true);
}
}

TEST(Expression, Filter) {
auto ExpectFilter = [](Expression filter, std::string batch_json) {
ASSERT_OK_AND_ASSIGN(auto s, kBoringSchema->AddField(0, field("in", boolean())));
Expand Down