Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
122 changes: 50 additions & 72 deletions datafusion/sql/src/expr/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -621,7 +621,12 @@ impl<S: ContextProvider> SqlToRel<'_, S> {
_ => {
let left_expr = self.sql_to_expr(*left, schema, planner_context)?;
let right_expr = self.sql_to_expr(*right, schema, planner_context)?;
plan_any_op(left_expr, right_expr, &compare_op)
plan_quantified_op(
&left_expr,
&right_expr,
&compare_op,
SetQuantifier::Any,
)
}
},
SQLExpr::AllOp {
Expand All @@ -640,7 +645,12 @@ impl<S: ContextProvider> SqlToRel<'_, S> {
_ => {
let left_expr = self.sql_to_expr(*left, schema, planner_context)?;
let right_expr = self.sql_to_expr(*right, schema, planner_context)?;
plan_all_op(&left_expr, &right_expr, &compare_op)
plan_quantified_op(
&left_expr,
&right_expr,
&compare_op,
SetQuantifier::All,
)
}
},
#[expect(deprecated)]
Expand Down Expand Up @@ -1249,73 +1259,20 @@ impl<S: ContextProvider> SqlToRel<'_, S> {
}
}

/// Builds a CASE expression that handles NULL semantics for `x <op> ANY(arr)`:
///
/// ```text
/// CASE
/// WHEN <min_or_max>(arr) IS NOT NULL THEN <comparison>
/// WHEN arr IS NOT NULL THEN FALSE -- empty or all-null array
/// ELSE NULL -- NULL array
/// END
/// ```
fn any_op_with_null_handling(bound: Expr, comparison: Expr, arr: Expr) -> Result<Expr> {
when(bound.is_not_null(), comparison)
.when(arr.is_not_null(), lit(false))
.otherwise(lit(ScalarValue::Boolean(None)))
}

/// Plans a `<left> <op> ANY(<right>)` expression for non-subquery operands.
fn plan_any_op(
left_expr: Expr,
right_expr: Expr,
compare_op: &BinaryOperator,
) -> Result<Expr> {
match compare_op {
BinaryOperator::Eq => Ok(array_has(right_expr, left_expr)),
BinaryOperator::NotEq => {
let min = array_min(right_expr.clone());
let max = array_max(right_expr.clone());
// NOT EQ is true when either bound differs from left
let comparison = min
.not_eq(left_expr.clone())
.or(max.clone().not_eq(left_expr));
any_op_with_null_handling(max, comparison, right_expr)
}
BinaryOperator::Gt => {
let min = array_min(right_expr.clone());
any_op_with_null_handling(min.clone(), min.lt(left_expr), right_expr)
}
BinaryOperator::Lt => {
let max = array_max(right_expr.clone());
any_op_with_null_handling(max.clone(), max.gt(left_expr), right_expr)
}
BinaryOperator::GtEq => {
let min = array_min(right_expr.clone());
any_op_with_null_handling(min.clone(), min.lt_eq(left_expr), right_expr)
}
BinaryOperator::LtEq => {
let max = array_max(right_expr.clone());
any_op_with_null_handling(max.clone(), max.gt_eq(left_expr), right_expr)
}
_ => plan_err!(
"Unsupported AnyOp: '{compare_op}', only '=', '<>', '>', '<', '>=', '<=' are supported"
),
}
}

/// Plans `needle <compare_op> ALL(haystack)` with proper SQL NULL semantics.
/// Plans `needle <compare_op> ANY/ALL(haystack)` with proper SQL NULL semantics.
///
/// CASE/WHEN structure:
/// WHEN arr IS NULL → NULL
/// WHEN empty → TRUE
/// WHEN empty → vacuous_result (ANY:false, ALL:true)
/// WHEN lhs IS NULL → NULL
/// WHEN decisive_condition → FALSE
/// WHEN decisive_condition → decisive_result (ANY:true match found, ALL:false violation found)
/// WHEN has_nulls → NULL
/// ELSE → TRUE
fn plan_all_op(
/// ELSE → vacuous_result
fn plan_quantified_op(
needle: &Expr,
haystack: &Expr,
compare_op: &BinaryOperator,
quantifier: SetQuantifier,
) -> Result<Expr> {
let null_arr_check = haystack.clone().is_null();
let empty_check = cardinality(haystack.clone()).eq(lit(0u64));
Expand All @@ -1325,40 +1282,61 @@ fn plan_all_op(
let has_nulls =
array_position(haystack.clone(), lit(ScalarValue::Null), lit(1i64)).is_not_null();

let decisive_condition = match compare_op {
BinaryOperator::NotEq => array_has(haystack.clone(), needle.clone()),
BinaryOperator::Eq => {
let decisive_condition = match (compare_op, quantifier) {
(BinaryOperator::Eq, SetQuantifier::Any)
| (BinaryOperator::NotEq, SetQuantifier::All) => {
array_has(haystack.clone(), needle.clone())
}
(BinaryOperator::Eq, SetQuantifier::All)
| (BinaryOperator::NotEq, SetQuantifier::Any) => {
let all_equal = array_min(haystack.clone())
.eq(needle.clone())
.and(array_max(haystack.clone()).eq(needle.clone()));
Expr::Not(Box::new(all_equal))
}
BinaryOperator::Gt => {
(BinaryOperator::Gt, SetQuantifier::Any) => {
needle.clone().gt(array_min(haystack.clone()))
}
(BinaryOperator::Gt, SetQuantifier::All) => {
Expr::Not(Box::new(needle.clone().gt(array_max(haystack.clone()))))
}
BinaryOperator::Lt => {
(BinaryOperator::Lt, SetQuantifier::Any) => {
needle.clone().lt(array_max(haystack.clone()))
}
(BinaryOperator::Lt, SetQuantifier::All) => {
Expr::Not(Box::new(needle.clone().lt(array_min(haystack.clone()))))
}
BinaryOperator::GtEq => {
(BinaryOperator::GtEq, SetQuantifier::Any) => {
needle.clone().gt_eq(array_min(haystack.clone()))
}
(BinaryOperator::GtEq, SetQuantifier::All) => {
Expr::Not(Box::new(needle.clone().gt_eq(array_max(haystack.clone()))))
}
BinaryOperator::LtEq => {
(BinaryOperator::LtEq, SetQuantifier::Any) => {
needle.clone().lt_eq(array_max(haystack.clone()))
}
(BinaryOperator::LtEq, SetQuantifier::All) => {
Expr::Not(Box::new(needle.clone().lt_eq(array_min(haystack.clone()))))
}
_ => {
return plan_err!(
"Unsupported AllOp: '{compare_op}', only '=', '<>', '>', '<', '>=', '<=' are supported"
"Unsupported {quantifier}Op: '{compare_op}', only '=', '<>', '>', '<', '>=', '<=' are supported"
);
}
};

let (vacuous_result, decisive_result) = match quantifier {
SetQuantifier::Any => (false, true),
SetQuantifier::All => (true, false),
};

let null_bool = lit(ScalarValue::Boolean(None));
when(null_arr_check, null_bool.clone())
.when(empty_check, lit(true))
.when(empty_check, lit(vacuous_result))
.when(null_lhs_check, null_bool.clone())
.when(decisive_condition, lit(false))
.when(decisive_condition, lit(decisive_result))
.when(has_nulls, null_bool)
.otherwise(lit(true))
.otherwise(lit(vacuous_result))
}

#[cfg(test)]
Expand Down
2 changes: 1 addition & 1 deletion datafusion/sql/tests/cases/plan_to_sql.rs
Original file line number Diff line number Diff line change
Expand Up @@ -367,7 +367,7 @@ fn roundtrip_statement_postgres_any_array_expr() -> Result<(), DataFusionError>
sql: "select left from array where 1 = any(left);",
parser_dialect: GenericDialect {},
unparser_dialect: UnparserPostgreSqlDialect {},
expected: @r#"SELECT "array"."left" FROM "array" WHERE 1 = ANY("array"."left")"#,
expected: @r#"SELECT "array"."left" FROM "array" WHERE CASE WHEN "array"."left" IS NULL THEN NULL WHEN (cardinality("array"."left") = 0) THEN false WHEN 1 IS NULL THEN NULL WHEN 1 = ANY("array"."left") THEN true WHEN array_position("array"."left", NULL, 1) IS NOT NULL THEN NULL ELSE false END"#,
);
Ok(())
}
Expand Down
70 changes: 53 additions & 17 deletions datafusion/sqllogictest/test_files/array/array_has.slt
Original file line number Diff line number Diff line change
Expand Up @@ -517,16 +517,18 @@ logical_plan
03)----SubqueryAlias: test
04)------SubqueryAlias: t
05)--------Projection:
06)----------Filter: substr(CAST(md5(CAST(generate_series().value AS Utf8View)) AS Utf8View), Int64(1), Int64(32)) IN ([Utf8View("7f4b18de3cfeb9b4ac78c381ee2ad278"), Utf8View("a"), Utf8View("b"), Utf8View("c")])
07)------------TableScan: generate_series() projection=[value]
06)----------Filter: __common_expr_3 IS NULL AND Boolean(NULL) OR __common_expr_3 IN ([Utf8View("7f4b18de3cfeb9b4ac78c381ee2ad278"), Utf8View("a"), Utf8View("b"), Utf8View("c")]) IS NOT DISTINCT FROM Boolean(true) AND __common_expr_3 IS NOT NULL
07)------------Projection: substr(CAST(md5(CAST(generate_series().value AS Utf8View)) AS Utf8View), Int64(1), Int64(32)) AS __common_expr_3
08)--------------TableScan: generate_series() projection=[value]
physical_plan
01)ProjectionExec: expr=[count(Int64(1))@0 as count(*)]
02)--AggregateExec: mode=Final, gby=[], aggr=[count(Int64(1))]
03)----CoalescePartitionsExec
04)------AggregateExec: mode=Partial, gby=[], aggr=[count(Int64(1))]
05)--------FilterExec: substr(md5(CAST(value@0 AS Utf8View)), 1, 32) IN (SET) ([7f4b18de3cfeb9b4ac78c381ee2ad278, a, b, c]), projection=[]
06)----------RepartitionExec: partitioning=RoundRobinBatch(4), input_partitions=1
07)------------LazyMemoryExec: partitions=1, batch_generators=[generate_series: start=1, end=100000, batch_size=8192]
05)--------FilterExec: __common_expr_3@0 IS NULL AND NULL OR __common_expr_3@0 IN (SET) ([7f4b18de3cfeb9b4ac78c381ee2ad278, a, b, c]) IS NOT DISTINCT FROM true AND __common_expr_3@0 IS NOT NULL, projection=[]
06)----------ProjectionExec: expr=[substr(md5(CAST(value@0 AS Utf8View)), 1, 32) as __common_expr_3]
07)------------RepartitionExec: partitioning=RoundRobinBatch(4), input_partitions=1
08)--------------LazyMemoryExec: partitions=1, batch_generators=[generate_series: start=1, end=100000, batch_size=8192]

query I
with test AS (SELECT substr(md5(i::text)::text, 1, 32) as needle FROM generate_series(1, 100000) t(i))
Expand Down Expand Up @@ -754,26 +756,26 @@ select 5 <= any(make_array());
false

# Mixed NULL + non-NULL array where no non-NULL element satisfies the condition
# These return false (NULLs are skipped by array_min/array_max)
# These return NULL because NULLs leave the result indeterminate
query B
select 5 > any(make_array(6, NULL));
----
false
NULL

query B
select 5 < any(make_array(3, NULL));
----
false
NULL

query B
select 5 >= any(make_array(6, NULL));
----
false
NULL

query B
select 5 <= any(make_array(3, NULL));
----
false
NULL

# Mixed NULL + non-NULL array where a non-NULL element satisfies the condition
query B
Expand Down Expand Up @@ -804,33 +806,38 @@ true
query B
select 5 <> any(make_array(5, NULL));
----
false
NULL

# All-NULL array: all operators should return false
# All-NULL array: all operators should return NULL (unknown comparison)
query B
select 5 > any(make_array(NULL::INT, NULL::INT));
----
false
NULL

query B
select 5 < any(make_array(NULL::INT, NULL::INT));
----
false
NULL

query B
select 5 >= any(make_array(NULL::INT, NULL::INT));
----
false
NULL

query B
select 5 <= any(make_array(NULL::INT, NULL::INT));
----
false
NULL

query B
select 5 <> any(make_array(NULL::INT, NULL::INT));
----
false
NULL

query B
select 5 = any(make_array(NULL::INT, NULL::INT));
----
NULL

# NULL left operand: should return NULL for non-empty arrays
query B
Expand Down Expand Up @@ -890,6 +897,35 @@ select 5 <> any(NULL::INT[]);
----
NULL

query B
select 5 = any(NULL::INT[]);
----
NULL

# NULL = ANY with non-empty array
query B
select NULL = any(make_array(1, 2, 3));
----
NULL

# = ANY with no match, no NULLs
query B
select 5 = any(make_array(1, 2, 3));
----
false

# = ANY with mixed NULL (satisfying) returns TRUE
query B
select 5 = any(make_array(5, NULL));
----
true

# = ANY with mixed NULL (non-satisfying): NULLs leave result indeterminate
query B
select 5 = any(make_array(1, 2, NULL));
----
NULL

statement ok
DROP TABLE any_op_test;

Expand Down
Loading