From 1bd40695c71baeeeb325130231d128229d99839c Mon Sep 17 00:00:00 2001 From: Larry Wang Date: Mon, 19 Aug 2024 16:21:53 -0400 Subject: [PATCH 1/4] implement basic is_in simplification --- cpp/src/arrow/compute/expression.cc | 52 ++++++++++++++++ cpp/src/arrow/compute/expression_test.cc | 76 ++++++++++++++++++++++++ 2 files changed, 128 insertions(+) diff --git a/cpp/src/arrow/compute/expression.cc b/cpp/src/arrow/compute/expression.cc index 33e5928c286..6e5219a57b1 100644 --- a/cpp/src/arrow/compute/expression.cc +++ b/cpp/src/arrow/compute/expression.cc @@ -1242,6 +1242,52 @@ struct Inequality { /*insert_implicit_casts=*/false, &exec_context); } + /// Simplify an `is_in` call against an inequality guarantee. + /// \pre `is_in_call` is a call to the `is_in` function + /// \return a simplified expression, or nullopt if no simplification occurred + static Result> SimplifyIsIn( + const Inequality& guarantee, const Expression::Call* is_in_call) { + DCHECK_EQ(is_in_call->function_name, "is_in"); + + auto options = checked_pointer_cast(is_in_call->options); + + // Null-matching behavior is complex and reduces the chances of reduction + // of `is_in` calls to a single literal for every possible input, so we + // abort the simplification if nulls are possible in the input or output. + if (guarantee.nullable || + options->null_matching_behavior == SetLookupOptions::INCONCLUSIVE) { + return std::nullopt; + } + + const auto& lhs = Comparison::StripOrderPreservingCasts(is_in_call->arguments[0]); + if (!lhs.field_ref()) return std::nullopt; + if (*lhs.field_ref() != guarantee.target) return std::nullopt; + + std::string func_name = Comparison::GetName(guarantee.cmp); + DCHECK_NE(func_name, "na"); + std::vector args{options->value_set, guarantee.bound}; + ARROW_ASSIGN_OR_RAISE(Datum filter_mask, CallFunction(func_name, args)); + FilterOptions filter_options(FilterOptions::DROP); + ARROW_ASSIGN_OR_RAISE(Datum simplified_value_set, + Filter(options->value_set, filter_mask, filter_options)); + + if (simplified_value_set.length() == 0) return literal(false); + if (guarantee.cmp == Comparison::EQUAL) return literal(true); + if (simplified_value_set.length() == options->value_set.length()) return std::nullopt; + + ExecContext exec_context; + Expression::Call simplified_call; + simplified_call.function_name = "is_in"; + simplified_call.arguments = is_in_call->arguments; + simplified_call.options = std::make_shared( + simplified_value_set, options->null_matching_behavior); + ARROW_ASSIGN_OR_RAISE( + Expression simplified_expr, + BindNonRecursive(std::move(simplified_call), + /*insert_implicit_casts=*/false, &exec_context)); + return simplified_expr; + } + /// \brief Simplify the given expression given this inequality as a guarantee. Result Simplify(Expression expr) { const auto& guarantee = *this; @@ -1258,6 +1304,12 @@ struct Inequality { return call->function_name == "is_valid" ? literal(true) : literal(false); } + if (call->function_name == "is_in") { + ARROW_ASSIGN_OR_RAISE(std::optional result, + SimplifyIsIn(guarantee, call)); + return result.value_or(expr); + } + auto cmp = Comparison::Get(expr); if (!cmp) return expr; diff --git a/cpp/src/arrow/compute/expression_test.cc b/cpp/src/arrow/compute/expression_test.cc index d94a17b6ffa..6e866252fc4 100644 --- a/cpp/src/arrow/compute/expression_test.cc +++ b/cpp/src/arrow/compute/expression_test.cc @@ -1616,6 +1616,82 @@ TEST(Expression, SimplifyWithComparisonAndNullableCaveat) { true_unless_null(field_ref("i32")))); // not satisfiable, will drop row group } +TEST(Expression, SimplifyIsIn) { + auto is_in = [](Expression field, std::shared_ptr value_set_type, + std::string json_array, + SetLookupOptions::NullMatchingBehavior null_matching_behavior) { + SetLookupOptions options{ArrayFromJSON(value_set_type, json_array), + null_matching_behavior}; + return call("is_in", {field}, options); + }; + + for (SetLookupOptions::NullMatchingBehavior null_matching_behavior : + {SetLookupOptions::MATCH, SetLookupOptions::SKIP, SetLookupOptions::EMIT_NULL}) { + Simplify{is_in(field_ref("i32"), int32(), "[]", null_matching_behavior)} + .WithGuarantee(greater(field_ref("i32"), literal(2))) + .Expect(false); + + Simplify{is_in(field_ref("i32"), int32(), "[null]", null_matching_behavior)} + .WithGuarantee(greater(field_ref("i32"), literal(2))) + .Expect(false); + + Simplify{is_in(field_ref("i32"), int32(), "[1,3,5,7,9]", null_matching_behavior)} + .WithGuarantee(equal(field_ref("i32"), literal(7))) + .Expect(true); + + Simplify{is_in(field_ref("i32"), int32(), "[1,3,5,7,9]", null_matching_behavior)} + .WithGuarantee(equal(field_ref("i32"), literal(6))) + .Expect(false); + + Simplify{is_in(field_ref("i32"), int32(), "[1,3,5,7,9]", null_matching_behavior)} + .WithGuarantee(greater(field_ref("i32"), literal(3))) + .Expect(is_in(field_ref("i32"), int32(), "[5,7,9]", null_matching_behavior)); + + Simplify{ + is_in(field_ref("i32"), int32(), "[1,null,3,5,null,7,9]", null_matching_behavior) + } + .WithGuarantee(greater(field_ref("i32"), literal(3))) + .Expect(is_in(field_ref("i32"), int32(), "[5,7,9]", null_matching_behavior)); + + Simplify{is_in(field_ref("i32"), int32(), "[1,3,5,7,9]", null_matching_behavior)} + .WithGuarantee(greater(field_ref("i32"), literal(9))) + .Expect(false); + + Simplify{is_in(field_ref("i32"), int32(), "[1,3,5,7,9]", null_matching_behavior)} + .WithGuarantee(less_equal(field_ref("i32"), literal(0))) + .Expect(false); + + Simplify{is_in(field_ref("i32"), int32(), "[1,3,5,7,9]", null_matching_behavior)} + .WithGuarantee(greater(field_ref("i32"), literal(0))) + .ExpectUnchanged(); + + Simplify{is_in(field_ref("i32"), int32(), "[1,3,5,7,9]", null_matching_behavior)} + .WithGuarantee( + or_(equal(field_ref("i32"), literal(3)), is_null(field_ref("i32")))) + .ExpectUnchanged(); + + Simplify{is_in(field_ref("i32"), int32(), "[1,3,5,7,9]", null_matching_behavior)} + .WithGuarantee( + and_(less_equal(field_ref("i32"), literal(7)), + greater(field_ref("i32"), literal(4)))) + .Expect(is_in(field_ref("i32"), int32(), "[5,7]", null_matching_behavior)); + + Simplify{is_in(field_ref("u32"), int8(), "[1,3,5,7,9]", null_matching_behavior)} + .WithGuarantee(greater(field_ref("u32"), literal(3))) + .Expect(is_in(field_ref("u32"), int8(), "[5,7,9]", null_matching_behavior)); + + Simplify{is_in(field_ref("u32"), int64(), "[1,3,5,7,9]", null_matching_behavior)} + .WithGuarantee(greater(field_ref("u32"), literal(3))) + .Expect(is_in(field_ref("u32"), int64(), "[5,7,9]", null_matching_behavior)); + } + + Simplify{ + is_in(field_ref("i32"), int32(), "[1,3,5,7,9]", SetLookupOptions::INCONCLUSIVE) + } + .WithGuarantee(greater(field_ref("i32"), literal(3))) + .ExpectUnchanged(); +} + TEST(Expression, SimplifyThenExecute) { auto filter = or_({equal(field_ref("f32"), literal(0)), From fda5a537b1aad0a077bc217da484aa5b8fd26747 Mon Sep 17 00:00:00 2001 From: Larry Wang Date: Fri, 23 Aug 2024 16:01:43 -0400 Subject: [PATCH 2/4] add execute tests --- cpp/src/arrow/compute/expression_test.cc | 48 +++++++++++++++++++++--- 1 file changed, 43 insertions(+), 5 deletions(-) diff --git a/cpp/src/arrow/compute/expression_test.cc b/cpp/src/arrow/compute/expression_test.cc index 6e866252fc4..b5265bf431a 100644 --- a/cpp/src/arrow/compute/expression_test.cc +++ b/cpp/src/arrow/compute/expression_test.cc @@ -27,6 +27,7 @@ #include #include +#include "arrow/array/builder_primitive.h" #include "arrow/compute/expression_internal.h" #include "arrow/compute/function_internal.h" #include "arrow/compute/registry.h" @@ -1648,7 +1649,7 @@ TEST(Expression, SimplifyIsIn) { .Expect(is_in(field_ref("i32"), int32(), "[5,7,9]", null_matching_behavior)); Simplify{ - is_in(field_ref("i32"), int32(), "[1,null,3,5,null,7,9]", null_matching_behavior) + is_in(field_ref("i32"), int32(), "[1,null,3,5,null,7,9]", null_matching_behavior), } .WithGuarantee(greater(field_ref("i32"), literal(3))) .Expect(is_in(field_ref("i32"), int32(), "[5,7,9]", null_matching_behavior)); @@ -1671,9 +1672,8 @@ TEST(Expression, SimplifyIsIn) { .ExpectUnchanged(); Simplify{is_in(field_ref("i32"), int32(), "[1,3,5,7,9]", null_matching_behavior)} - .WithGuarantee( - and_(less_equal(field_ref("i32"), literal(7)), - greater(field_ref("i32"), literal(4)))) + .WithGuarantee(and_(less_equal(field_ref("i32"), literal(7)), + greater(field_ref("i32"), literal(4)))) .Expect(is_in(field_ref("i32"), int32(), "[5,7]", null_matching_behavior)); Simplify{is_in(field_ref("u32"), int8(), "[1,3,5,7,9]", null_matching_behavior)} @@ -1686,7 +1686,7 @@ TEST(Expression, SimplifyIsIn) { } Simplify{ - is_in(field_ref("i32"), int32(), "[1,3,5,7,9]", SetLookupOptions::INCONCLUSIVE) + is_in(field_ref("i32"), int32(), "[1,3,5,7,9]", SetLookupOptions::INCONCLUSIVE), } .WithGuarantee(greater(field_ref("i32"), literal(3))) .ExpectUnchanged(); @@ -1719,6 +1719,44 @@ TEST(Expression, SimplifyThenExecute) { AssertDatumsEqual(evaluated, simplified_evaluated, /*verbose=*/true); } +TEST(Expression, SimplifyIsInThenExecute) { + auto input = RecordBatchFromJSON(kBoringSchema, R"([ + {"i64": 2, "i32": 5}, + {"i64": 5, "i32": 6}, + {"i64": 3, "i32": 6}, + {"i64": 3, "i32": 5}, + {"i64": 4, "i32": 5}, + {"i64": 2, "i32": 7}, + {"i64": 5, "i32": 5} + ])"); + + std::vector guarantees{ + greater(field_ref("i64"), literal(1)), + greater_equal(field_ref("i32"), literal(5)), + less_equal(field_ref("i64"), literal(5))}; + + for (const Expression& guarantee : guarantees) { + auto filter = call( + "is_in", {guarantee.call()->arguments[0]}, + compute::SetLookupOptions{ArrayFromJSON(int32(), "[1,2,3]"), true}); + ASSERT_OK_AND_ASSIGN(filter, filter.Bind(*kBoringSchema)); + ASSERT_OK_AND_ASSIGN(auto simplified, SimplifyWithGuarantee(filter, guarantee)); + + Datum evaluated, simplified_evaluated; + ExpectExecute(filter, input, &evaluated); + ExpectExecute(simplified, input, &simplified_evaluated); + if (simplified_evaluated.is_scalar()) { + ASSERT_EQ(evaluated.kind(), Datum::ARRAY); + ASSERT_EQ(simplified_evaluated.type()->id(), Type::BOOL); + BooleanBuilder builder; + ASSERT_OK(builder.AppendValues( + evaluated.length(), simplified_evaluated.scalar_as().value)); + ASSERT_OK_AND_ASSIGN(simplified_evaluated, builder.Finish()); + } + AssertDatumsEqual(evaluated, simplified_evaluated, /*verbose=*/true); + } +} + TEST(Expression, Filter) { auto ExpectFilter = [](Expression filter, std::string batch_json) { ASSERT_OK_AND_ASSIGN(auto s, kBoringSchema->AddField(0, field("in", boolean()))); From 4a3e6bb9f84f2df275fdf08e85b6a8a193e27a6d Mon Sep 17 00:00:00 2001 From: larry98 Date: Wed, 28 Aug 2024 12:07:57 -0400 Subject: [PATCH 3/4] Update cpp/src/arrow/compute/expression_test.cc Co-authored-by: Benjamin Kietzman --- cpp/src/arrow/compute/expression_test.cc | 8 ++------ 1 file changed, 2 insertions(+), 6 deletions(-) diff --git a/cpp/src/arrow/compute/expression_test.cc b/cpp/src/arrow/compute/expression_test.cc index b5265bf431a..55ffcc721ee 100644 --- a/cpp/src/arrow/compute/expression_test.cc +++ b/cpp/src/arrow/compute/expression_test.cc @@ -1746,12 +1746,8 @@ TEST(Expression, SimplifyIsInThenExecute) { ExpectExecute(filter, input, &evaluated); ExpectExecute(simplified, input, &simplified_evaluated); if (simplified_evaluated.is_scalar()) { - ASSERT_EQ(evaluated.kind(), Datum::ARRAY); - ASSERT_EQ(simplified_evaluated.type()->id(), Type::BOOL); - BooleanBuilder builder; - ASSERT_OK(builder.AppendValues( - evaluated.length(), simplified_evaluated.scalar_as().value)); - ASSERT_OK_AND_ASSIGN(simplified_evaluated, builder.Finish()); + ASSERT_OK_AND_ASSIGN(simplified_evaluated, + MakeArrayFromScalar(*simplified_evaluated.scalar(), expected->length())); } AssertDatumsEqual(evaluated, simplified_evaluated, /*verbose=*/true); } From 75fd2ea0279bcc4385a3d0f12b8ba69709235fb2 Mon Sep 17 00:00:00 2001 From: Larry Wang Date: Wed, 28 Aug 2024 17:33:43 -0400 Subject: [PATCH 4/4] support all null matching behaviors --- cpp/src/arrow/compute/expression.cc | 41 +++++-- cpp/src/arrow/compute/expression_test.cc | 149 ++++++++++++++++------- 2 files changed, 137 insertions(+), 53 deletions(-) diff --git a/cpp/src/arrow/compute/expression.cc b/cpp/src/arrow/compute/expression.cc index 6e5219a57b1..12fda5d58f3 100644 --- a/cpp/src/arrow/compute/expression.cc +++ b/cpp/src/arrow/compute/expression.cc @@ -23,6 +23,7 @@ #include #include "arrow/chunked_array.h" +#include "arrow/compute/api_aggregate.h" #include "arrow/compute/api_vector.h" #include "arrow/compute/exec_internal.h" #include "arrow/compute/expression_internal.h" @@ -1243,6 +1244,13 @@ struct Inequality { } /// Simplify an `is_in` call against an inequality guarantee. + /// + /// We avoid the complexity of fully simplifying EQUAL comparisons to true + /// literals (e.g., 'x is_in [1, 2, 3]' given the guarantee 'x = 2') due to + /// potential complications with null matching behavior. This is ok for the + /// predicate pushdown use case because the overall aim is to simplify to an + /// unsatisfiable expression. + /// /// \pre `is_in_call` is a call to the `is_in` function /// \return a simplified expression, or nullopt if no simplification occurred static Result> SimplifyIsIn( @@ -1251,28 +1259,41 @@ struct Inequality { auto options = checked_pointer_cast(is_in_call->options); - // Null-matching behavior is complex and reduces the chances of reduction - // of `is_in` calls to a single literal for every possible input, so we - // abort the simplification if nulls are possible in the input or output. - if (guarantee.nullable || - options->null_matching_behavior == SetLookupOptions::INCONCLUSIVE) { - return std::nullopt; - } - const auto& lhs = Comparison::StripOrderPreservingCasts(is_in_call->arguments[0]); if (!lhs.field_ref()) return std::nullopt; if (*lhs.field_ref() != guarantee.target) return std::nullopt; + FilterOptions::NullSelectionBehavior null_selection; + switch (options->null_matching_behavior) { + case SetLookupOptions::MATCH: + null_selection = + guarantee.nullable ? FilterOptions::EMIT_NULL : FilterOptions::DROP; + break; + case SetLookupOptions::SKIP: + null_selection = FilterOptions::DROP; + break; + case SetLookupOptions::EMIT_NULL: + if (guarantee.nullable) return std::nullopt; + null_selection = FilterOptions::DROP; + break; + case SetLookupOptions::INCONCLUSIVE: + if (guarantee.nullable) return std::nullopt; + ARROW_ASSIGN_OR_RAISE(Datum is_null, IsNull(options->value_set)); + ARROW_ASSIGN_OR_RAISE(Datum any_null, Any(is_null)); + if (any_null.scalar_as().value) return std::nullopt; + null_selection = FilterOptions::DROP; + break; + } + std::string func_name = Comparison::GetName(guarantee.cmp); DCHECK_NE(func_name, "na"); std::vector args{options->value_set, guarantee.bound}; ARROW_ASSIGN_OR_RAISE(Datum filter_mask, CallFunction(func_name, args)); - FilterOptions filter_options(FilterOptions::DROP); + FilterOptions filter_options(null_selection); ARROW_ASSIGN_OR_RAISE(Datum simplified_value_set, Filter(options->value_set, filter_mask, filter_options)); if (simplified_value_set.length() == 0) return literal(false); - if (guarantee.cmp == Comparison::EQUAL) return literal(true); if (simplified_value_set.length() == options->value_set.length()) return std::nullopt; ExecContext exec_context; diff --git a/cpp/src/arrow/compute/expression_test.cc b/cpp/src/arrow/compute/expression_test.cc index 55ffcc721ee..0b7e8a9c23b 100644 --- a/cpp/src/arrow/compute/expression_test.cc +++ b/cpp/src/arrow/compute/expression_test.cc @@ -1626,69 +1626,132 @@ TEST(Expression, SimplifyIsIn) { return call("is_in", {field}, options); }; - for (SetLookupOptions::NullMatchingBehavior null_matching_behavior : - {SetLookupOptions::MATCH, SetLookupOptions::SKIP, SetLookupOptions::EMIT_NULL}) { - Simplify{is_in(field_ref("i32"), int32(), "[]", null_matching_behavior)} - .WithGuarantee(greater(field_ref("i32"), literal(2))) - .Expect(false); - - Simplify{is_in(field_ref("i32"), int32(), "[null]", null_matching_behavior)} + for (SetLookupOptions::NullMatchingBehavior null_matching : { + SetLookupOptions::MATCH, + SetLookupOptions::SKIP, + SetLookupOptions::EMIT_NULL, + SetLookupOptions::INCONCLUSIVE, + }) { + Simplify{is_in(field_ref("i32"), int32(), "[]", null_matching)} .WithGuarantee(greater(field_ref("i32"), literal(2))) .Expect(false); - Simplify{is_in(field_ref("i32"), int32(), "[1,3,5,7,9]", null_matching_behavior)} - .WithGuarantee(equal(field_ref("i32"), literal(7))) - .Expect(true); - - Simplify{is_in(field_ref("i32"), int32(), "[1,3,5,7,9]", null_matching_behavior)} + Simplify{is_in(field_ref("i32"), int32(), "[1,3,5,7,9]", null_matching)} .WithGuarantee(equal(field_ref("i32"), literal(6))) .Expect(false); - Simplify{is_in(field_ref("i32"), int32(), "[1,3,5,7,9]", null_matching_behavior)} - .WithGuarantee(greater(field_ref("i32"), literal(3))) - .Expect(is_in(field_ref("i32"), int32(), "[5,7,9]", null_matching_behavior)); - - Simplify{ - is_in(field_ref("i32"), int32(), "[1,null,3,5,null,7,9]", null_matching_behavior), - } + Simplify{is_in(field_ref("i32"), int32(), "[1,3,5,7,9]", null_matching)} .WithGuarantee(greater(field_ref("i32"), literal(3))) - .Expect(is_in(field_ref("i32"), int32(), "[5,7,9]", null_matching_behavior)); + .Expect(is_in(field_ref("i32"), int32(), "[5,7,9]", null_matching)); - Simplify{is_in(field_ref("i32"), int32(), "[1,3,5,7,9]", null_matching_behavior)} + Simplify{is_in(field_ref("i32"), int32(), "[1,3,5,7,9]", null_matching)} .WithGuarantee(greater(field_ref("i32"), literal(9))) .Expect(false); - Simplify{is_in(field_ref("i32"), int32(), "[1,3,5,7,9]", null_matching_behavior)} + Simplify{is_in(field_ref("i32"), int32(), "[1,3,5,7,9]", null_matching)} .WithGuarantee(less_equal(field_ref("i32"), literal(0))) .Expect(false); - Simplify{is_in(field_ref("i32"), int32(), "[1,3,5,7,9]", null_matching_behavior)} + Simplify{is_in(field_ref("i32"), int32(), "[1,3,5,7,9]", null_matching)} .WithGuarantee(greater(field_ref("i32"), literal(0))) .ExpectUnchanged(); - Simplify{is_in(field_ref("i32"), int32(), "[1,3,5,7,9]", null_matching_behavior)} - .WithGuarantee( - or_(equal(field_ref("i32"), literal(3)), is_null(field_ref("i32")))) + Simplify{is_in(field_ref("i32"), int32(), "[1,3,5,7,9]", null_matching)} + .WithGuarantee(less_equal(field_ref("i32"), literal(9))) .ExpectUnchanged(); - Simplify{is_in(field_ref("i32"), int32(), "[1,3,5,7,9]", null_matching_behavior)} + Simplify{is_in(field_ref("i32"), int32(), "[1,3,5,7,9]", null_matching)} .WithGuarantee(and_(less_equal(field_ref("i32"), literal(7)), greater(field_ref("i32"), literal(4)))) - .Expect(is_in(field_ref("i32"), int32(), "[5,7]", null_matching_behavior)); + .Expect(is_in(field_ref("i32"), int32(), "[5,7]", null_matching)); - Simplify{is_in(field_ref("u32"), int8(), "[1,3,5,7,9]", null_matching_behavior)} + Simplify{is_in(field_ref("u32"), int8(), "[1,3,5,7,9]", null_matching)} .WithGuarantee(greater(field_ref("u32"), literal(3))) - .Expect(is_in(field_ref("u32"), int8(), "[5,7,9]", null_matching_behavior)); + .Expect(is_in(field_ref("u32"), int8(), "[5,7,9]", null_matching)); - Simplify{is_in(field_ref("u32"), int64(), "[1,3,5,7,9]", null_matching_behavior)} + Simplify{is_in(field_ref("u32"), int64(), "[1,3,5,7,9]", null_matching)} .WithGuarantee(greater(field_ref("u32"), literal(3))) - .Expect(is_in(field_ref("u32"), int64(), "[5,7,9]", null_matching_behavior)); + .Expect(is_in(field_ref("u32"), int64(), "[5,7,9]", null_matching)); } Simplify{ - is_in(field_ref("i32"), int32(), "[1,3,5,7,9]", SetLookupOptions::INCONCLUSIVE), + is_in(field_ref("i32"), int32(), "[1,2,3]", SetLookupOptions::MATCH), } - .WithGuarantee(greater(field_ref("i32"), literal(3))) + .WithGuarantee( + or_(greater(field_ref("i32"), literal(2)), is_null(field_ref("i32")))) + .Expect(is_in(field_ref("i32"), int32(), "[3]", SetLookupOptions::MATCH)); + + Simplify{ + is_in(field_ref("i32"), int32(), "[1,2,3,null]", SetLookupOptions::MATCH), + } + .WithGuarantee(greater(field_ref("i32"), literal(2))) + .Expect(is_in(field_ref("i32"), int32(), "[3]", SetLookupOptions::MATCH)); + + Simplify{ + is_in(field_ref("i32"), int32(), "[1,2,3,null]", SetLookupOptions::MATCH), + } + .WithGuarantee( + or_(greater(field_ref("i32"), literal(2)), is_null(field_ref("i32")))) + .Expect(is_in(field_ref("i32"), int32(), "[3,null]", SetLookupOptions::MATCH)); + + Simplify{ + is_in(field_ref("i32"), int32(), "[1,2,3]", SetLookupOptions::SKIP), + } + .WithGuarantee( + or_(greater(field_ref("i32"), literal(2)), is_null(field_ref("i32")))) + .Expect(is_in(field_ref("i32"), int32(), "[3]", SetLookupOptions::SKIP)); + + Simplify{ + is_in(field_ref("i32"), int32(), "[1,2,3,null]", SetLookupOptions::SKIP), + } + .WithGuarantee(greater(field_ref("i32"), literal(2))) + .Expect(is_in(field_ref("i32"), int32(), "[3]", SetLookupOptions::SKIP)); + + Simplify{ + is_in(field_ref("i32"), int32(), "[1,2,3,null]", SetLookupOptions::SKIP), + } + .WithGuarantee( + or_(greater(field_ref("i32"), literal(2)), is_null(field_ref("i32")))) + .Expect(is_in(field_ref("i32"), int32(), "[3]", SetLookupOptions::SKIP)); + + Simplify{ + is_in(field_ref("i32"), int32(), "[1,2,3]", SetLookupOptions::EMIT_NULL), + } + .WithGuarantee( + or_(greater(field_ref("i32"), literal(2)), is_null(field_ref("i32")))) + .ExpectUnchanged(); + + Simplify{ + is_in(field_ref("i32"), int32(), "[1,2,3,null]", SetLookupOptions::EMIT_NULL), + } + .WithGuarantee(greater(field_ref("i32"), literal(2))) + .Expect(is_in(field_ref("i32"), int32(), "[3]", SetLookupOptions::EMIT_NULL)); + + Simplify{ + is_in(field_ref("i32"), int32(), "[1,2,3,null]", SetLookupOptions::EMIT_NULL), + } + .WithGuarantee( + or_(greater(field_ref("i32"), literal(2)), is_null(field_ref("i32")))) + .ExpectUnchanged(); + + Simplify{ + is_in(field_ref("i32"), int32(), "[1,2,3]", SetLookupOptions::INCONCLUSIVE), + } + .WithGuarantee( + or_(greater(field_ref("i32"), literal(2)), is_null(field_ref("i32")))) + .ExpectUnchanged(); + + Simplify{ + is_in(field_ref("i32"), int32(), "[1,2,3,null]", SetLookupOptions::INCONCLUSIVE), + } + .WithGuarantee(greater(field_ref("i32"), literal(2))) + .ExpectUnchanged(); + + Simplify{ + is_in(field_ref("i32"), int32(), "[1,2,3,null]", SetLookupOptions::INCONCLUSIVE), + } + .WithGuarantee( + or_(greater(field_ref("i32"), literal(2)), is_null(field_ref("i32")))) .ExpectUnchanged(); } @@ -1730,15 +1793,14 @@ TEST(Expression, SimplifyIsInThenExecute) { {"i64": 5, "i32": 5} ])"); - std::vector guarantees{ - greater(field_ref("i64"), literal(1)), - greater_equal(field_ref("i32"), literal(5)), - less_equal(field_ref("i64"), literal(5))}; + std::vector guarantees{greater(field_ref("i64"), literal(1)), + greater_equal(field_ref("i32"), literal(5)), + less_equal(field_ref("i64"), literal(5))}; for (const Expression& guarantee : guarantees) { - auto filter = call( - "is_in", {guarantee.call()->arguments[0]}, - compute::SetLookupOptions{ArrayFromJSON(int32(), "[1,2,3]"), true}); + auto filter = + call("is_in", {guarantee.call()->arguments[0]}, + compute::SetLookupOptions{ArrayFromJSON(int32(), "[1,2,3]"), true}); ASSERT_OK_AND_ASSIGN(filter, filter.Bind(*kBoringSchema)); ASSERT_OK_AND_ASSIGN(auto simplified, SimplifyWithGuarantee(filter, guarantee)); @@ -1746,8 +1808,9 @@ TEST(Expression, SimplifyIsInThenExecute) { ExpectExecute(filter, input, &evaluated); ExpectExecute(simplified, input, &simplified_evaluated); if (simplified_evaluated.is_scalar()) { - ASSERT_OK_AND_ASSIGN(simplified_evaluated, - MakeArrayFromScalar(*simplified_evaluated.scalar(), expected->length())); + ASSERT_OK_AND_ASSIGN( + simplified_evaluated, + MakeArrayFromScalar(*simplified_evaluated.scalar(), evaluated.length())); } AssertDatumsEqual(evaluated, simplified_evaluated, /*verbose=*/true); }