diff --git a/datafusion/sql/src/expr/mod.rs b/datafusion/sql/src/expr/mod.rs index a753c91162bea..ae52d8403eef3 100644 --- a/datafusion/sql/src/expr/mod.rs +++ b/datafusion/sql/src/expr/mod.rs @@ -621,7 +621,12 @@ impl SqlToRel<'_, S> { _ => { let left_expr = self.sql_to_expr(*left, schema, planner_context)?; let right_expr = self.sql_to_expr(*right, schema, planner_context)?; - plan_any_op(left_expr, right_expr, &compare_op) + plan_quantified_op( + &left_expr, + &right_expr, + &compare_op, + SetQuantifier::Any, + ) } }, SQLExpr::AllOp { @@ -640,7 +645,12 @@ impl SqlToRel<'_, S> { _ => { let left_expr = self.sql_to_expr(*left, schema, planner_context)?; let right_expr = self.sql_to_expr(*right, schema, planner_context)?; - plan_all_op(&left_expr, &right_expr, &compare_op) + plan_quantified_op( + &left_expr, + &right_expr, + &compare_op, + SetQuantifier::All, + ) } }, #[expect(deprecated)] @@ -1249,73 +1259,20 @@ impl SqlToRel<'_, S> { } } -/// Builds a CASE expression that handles NULL semantics for `x ANY(arr)`: -/// -/// ```text -/// CASE -/// WHEN (arr) IS NOT NULL THEN -/// WHEN arr IS NOT NULL THEN FALSE -- empty or all-null array -/// ELSE NULL -- NULL array -/// END -/// ``` -fn any_op_with_null_handling(bound: Expr, comparison: Expr, arr: Expr) -> Result { - when(bound.is_not_null(), comparison) - .when(arr.is_not_null(), lit(false)) - .otherwise(lit(ScalarValue::Boolean(None))) -} - -/// Plans a ` ANY()` expression for non-subquery operands. -fn plan_any_op( - left_expr: Expr, - right_expr: Expr, - compare_op: &BinaryOperator, -) -> Result { - match compare_op { - BinaryOperator::Eq => Ok(array_has(right_expr, left_expr)), - BinaryOperator::NotEq => { - let min = array_min(right_expr.clone()); - let max = array_max(right_expr.clone()); - // NOT EQ is true when either bound differs from left - let comparison = min - .not_eq(left_expr.clone()) - .or(max.clone().not_eq(left_expr)); - any_op_with_null_handling(max, comparison, right_expr) - } - BinaryOperator::Gt => { - let min = array_min(right_expr.clone()); - any_op_with_null_handling(min.clone(), min.lt(left_expr), right_expr) - } - BinaryOperator::Lt => { - let max = array_max(right_expr.clone()); - any_op_with_null_handling(max.clone(), max.gt(left_expr), right_expr) - } - BinaryOperator::GtEq => { - let min = array_min(right_expr.clone()); - any_op_with_null_handling(min.clone(), min.lt_eq(left_expr), right_expr) - } - BinaryOperator::LtEq => { - let max = array_max(right_expr.clone()); - any_op_with_null_handling(max.clone(), max.gt_eq(left_expr), right_expr) - } - _ => plan_err!( - "Unsupported AnyOp: '{compare_op}', only '=', '<>', '>', '<', '>=', '<=' are supported" - ), - } -} - -/// Plans `needle ALL(haystack)` with proper SQL NULL semantics. +/// Plans `needle ANY/ALL(haystack)` with proper SQL NULL semantics. /// /// CASE/WHEN structure: /// WHEN arr IS NULL → NULL -/// WHEN empty → TRUE +/// WHEN empty → vacuous_result (ANY:false, ALL:true) /// WHEN lhs IS NULL → NULL -/// WHEN decisive_condition → FALSE +/// WHEN decisive_condition → decisive_result (ANY:true match found, ALL:false violation found) /// WHEN has_nulls → NULL -/// ELSE → TRUE -fn plan_all_op( +/// ELSE → vacuous_result +fn plan_quantified_op( needle: &Expr, haystack: &Expr, compare_op: &BinaryOperator, + quantifier: SetQuantifier, ) -> Result { let null_arr_check = haystack.clone().is_null(); let empty_check = cardinality(haystack.clone()).eq(lit(0u64)); @@ -1325,40 +1282,61 @@ fn plan_all_op( let has_nulls = array_position(haystack.clone(), lit(ScalarValue::Null), lit(1i64)).is_not_null(); - let decisive_condition = match compare_op { - BinaryOperator::NotEq => array_has(haystack.clone(), needle.clone()), - BinaryOperator::Eq => { + let decisive_condition = match (compare_op, quantifier) { + (BinaryOperator::Eq, SetQuantifier::Any) + | (BinaryOperator::NotEq, SetQuantifier::All) => { + array_has(haystack.clone(), needle.clone()) + } + (BinaryOperator::Eq, SetQuantifier::All) + | (BinaryOperator::NotEq, SetQuantifier::Any) => { let all_equal = array_min(haystack.clone()) .eq(needle.clone()) .and(array_max(haystack.clone()).eq(needle.clone())); Expr::Not(Box::new(all_equal)) } - BinaryOperator::Gt => { + (BinaryOperator::Gt, SetQuantifier::Any) => { + needle.clone().gt(array_min(haystack.clone())) + } + (BinaryOperator::Gt, SetQuantifier::All) => { Expr::Not(Box::new(needle.clone().gt(array_max(haystack.clone())))) } - BinaryOperator::Lt => { + (BinaryOperator::Lt, SetQuantifier::Any) => { + needle.clone().lt(array_max(haystack.clone())) + } + (BinaryOperator::Lt, SetQuantifier::All) => { Expr::Not(Box::new(needle.clone().lt(array_min(haystack.clone())))) } - BinaryOperator::GtEq => { + (BinaryOperator::GtEq, SetQuantifier::Any) => { + needle.clone().gt_eq(array_min(haystack.clone())) + } + (BinaryOperator::GtEq, SetQuantifier::All) => { Expr::Not(Box::new(needle.clone().gt_eq(array_max(haystack.clone())))) } - BinaryOperator::LtEq => { + (BinaryOperator::LtEq, SetQuantifier::Any) => { + needle.clone().lt_eq(array_max(haystack.clone())) + } + (BinaryOperator::LtEq, SetQuantifier::All) => { Expr::Not(Box::new(needle.clone().lt_eq(array_min(haystack.clone())))) } _ => { return plan_err!( - "Unsupported AllOp: '{compare_op}', only '=', '<>', '>', '<', '>=', '<=' are supported" + "Unsupported {quantifier}Op: '{compare_op}', only '=', '<>', '>', '<', '>=', '<=' are supported" ); } }; + let (vacuous_result, decisive_result) = match quantifier { + SetQuantifier::Any => (false, true), + SetQuantifier::All => (true, false), + }; + let null_bool = lit(ScalarValue::Boolean(None)); when(null_arr_check, null_bool.clone()) - .when(empty_check, lit(true)) + .when(empty_check, lit(vacuous_result)) .when(null_lhs_check, null_bool.clone()) - .when(decisive_condition, lit(false)) + .when(decisive_condition, lit(decisive_result)) .when(has_nulls, null_bool) - .otherwise(lit(true)) + .otherwise(lit(vacuous_result)) } #[cfg(test)] diff --git a/datafusion/sql/tests/cases/plan_to_sql.rs b/datafusion/sql/tests/cases/plan_to_sql.rs index 06680e60714b8..ba4eee7133ef2 100644 --- a/datafusion/sql/tests/cases/plan_to_sql.rs +++ b/datafusion/sql/tests/cases/plan_to_sql.rs @@ -367,7 +367,7 @@ fn roundtrip_statement_postgres_any_array_expr() -> Result<(), DataFusionError> sql: "select left from array where 1 = any(left);", parser_dialect: GenericDialect {}, unparser_dialect: UnparserPostgreSqlDialect {}, - expected: @r#"SELECT "array"."left" FROM "array" WHERE 1 = ANY("array"."left")"#, + expected: @r#"SELECT "array"."left" FROM "array" WHERE CASE WHEN "array"."left" IS NULL THEN NULL WHEN (cardinality("array"."left") = 0) THEN false WHEN 1 IS NULL THEN NULL WHEN 1 = ANY("array"."left") THEN true WHEN array_position("array"."left", NULL, 1) IS NOT NULL THEN NULL ELSE false END"#, ); Ok(()) } diff --git a/datafusion/sqllogictest/test_files/array/array_has.slt b/datafusion/sqllogictest/test_files/array/array_has.slt index e343c1b1fae41..abfd697a42d54 100644 --- a/datafusion/sqllogictest/test_files/array/array_has.slt +++ b/datafusion/sqllogictest/test_files/array/array_has.slt @@ -517,16 +517,18 @@ logical_plan 03)----SubqueryAlias: test 04)------SubqueryAlias: t 05)--------Projection: -06)----------Filter: substr(CAST(md5(CAST(generate_series().value AS Utf8View)) AS Utf8View), Int64(1), Int64(32)) IN ([Utf8View("7f4b18de3cfeb9b4ac78c381ee2ad278"), Utf8View("a"), Utf8View("b"), Utf8View("c")]) -07)------------TableScan: generate_series() projection=[value] +06)----------Filter: __common_expr_3 IS NULL AND Boolean(NULL) OR __common_expr_3 IN ([Utf8View("7f4b18de3cfeb9b4ac78c381ee2ad278"), Utf8View("a"), Utf8View("b"), Utf8View("c")]) IS NOT DISTINCT FROM Boolean(true) AND __common_expr_3 IS NOT NULL +07)------------Projection: substr(CAST(md5(CAST(generate_series().value AS Utf8View)) AS Utf8View), Int64(1), Int64(32)) AS __common_expr_3 +08)--------------TableScan: generate_series() projection=[value] physical_plan 01)ProjectionExec: expr=[count(Int64(1))@0 as count(*)] 02)--AggregateExec: mode=Final, gby=[], aggr=[count(Int64(1))] 03)----CoalescePartitionsExec 04)------AggregateExec: mode=Partial, gby=[], aggr=[count(Int64(1))] -05)--------FilterExec: substr(md5(CAST(value@0 AS Utf8View)), 1, 32) IN (SET) ([7f4b18de3cfeb9b4ac78c381ee2ad278, a, b, c]), projection=[] -06)----------RepartitionExec: partitioning=RoundRobinBatch(4), input_partitions=1 -07)------------LazyMemoryExec: partitions=1, batch_generators=[generate_series: start=1, end=100000, batch_size=8192] +05)--------FilterExec: __common_expr_3@0 IS NULL AND NULL OR __common_expr_3@0 IN (SET) ([7f4b18de3cfeb9b4ac78c381ee2ad278, a, b, c]) IS NOT DISTINCT FROM true AND __common_expr_3@0 IS NOT NULL, projection=[] +06)----------ProjectionExec: expr=[substr(md5(CAST(value@0 AS Utf8View)), 1, 32) as __common_expr_3] +07)------------RepartitionExec: partitioning=RoundRobinBatch(4), input_partitions=1 +08)--------------LazyMemoryExec: partitions=1, batch_generators=[generate_series: start=1, end=100000, batch_size=8192] query I with test AS (SELECT substr(md5(i::text)::text, 1, 32) as needle FROM generate_series(1, 100000) t(i)) @@ -754,26 +756,26 @@ select 5 <= any(make_array()); false # Mixed NULL + non-NULL array where no non-NULL element satisfies the condition -# These return false (NULLs are skipped by array_min/array_max) +# These return NULL because NULLs leave the result indeterminate query B select 5 > any(make_array(6, NULL)); ---- -false +NULL query B select 5 < any(make_array(3, NULL)); ---- -false +NULL query B select 5 >= any(make_array(6, NULL)); ---- -false +NULL query B select 5 <= any(make_array(3, NULL)); ---- -false +NULL # Mixed NULL + non-NULL array where a non-NULL element satisfies the condition query B @@ -804,33 +806,38 @@ true query B select 5 <> any(make_array(5, NULL)); ---- -false +NULL -# All-NULL array: all operators should return false +# All-NULL array: all operators should return NULL (unknown comparison) query B select 5 > any(make_array(NULL::INT, NULL::INT)); ---- -false +NULL query B select 5 < any(make_array(NULL::INT, NULL::INT)); ---- -false +NULL query B select 5 >= any(make_array(NULL::INT, NULL::INT)); ---- -false +NULL query B select 5 <= any(make_array(NULL::INT, NULL::INT)); ---- -false +NULL query B select 5 <> any(make_array(NULL::INT, NULL::INT)); ---- -false +NULL + +query B +select 5 = any(make_array(NULL::INT, NULL::INT)); +---- +NULL # NULL left operand: should return NULL for non-empty arrays query B @@ -890,6 +897,35 @@ select 5 <> any(NULL::INT[]); ---- NULL +query B +select 5 = any(NULL::INT[]); +---- +NULL + +# NULL = ANY with non-empty array +query B +select NULL = any(make_array(1, 2, 3)); +---- +NULL + +# = ANY with no match, no NULLs +query B +select 5 = any(make_array(1, 2, 3)); +---- +false + +# = ANY with mixed NULL (satisfying) returns TRUE +query B +select 5 = any(make_array(5, NULL)); +---- +true + +# = ANY with mixed NULL (non-satisfying): NULLs leave result indeterminate +query B +select 5 = any(make_array(1, 2, NULL)); +---- +NULL + statement ok DROP TABLE any_op_test;