From 8b01c38a7cc3b3f9a210cbe166442dbb8055fe65 Mon Sep 17 00:00:00 2001 From: jayzhan211 Date: Mon, 24 Jun 2024 09:07:10 +0800 Subject: [PATCH 1/5] support basic list cmp Signed-off-by: jayzhan211 --- datafusion/physical-expr-common/src/datum.rs | 127 ++++++++++++++++++ datafusion/physical-expr-common/src/lib.rs | 1 + .../physical-expr/src/expressions/binary.rs | 14 +- .../physical-expr/src/expressions/datum.rs | 58 -------- .../physical-expr/src/expressions/like.rs | 2 +- .../physical-expr/src/expressions/mod.rs | 1 - .../sqllogictest/test_files/array_query.slt | 9 +- 7 files changed, 148 insertions(+), 64 deletions(-) create mode 100644 datafusion/physical-expr-common/src/datum.rs delete mode 100644 datafusion/physical-expr/src/expressions/datum.rs diff --git a/datafusion/physical-expr-common/src/datum.rs b/datafusion/physical-expr-common/src/datum.rs new file mode 100644 index 0000000000000..f3c6f02506a1b --- /dev/null +++ b/datafusion/physical-expr-common/src/datum.rs @@ -0,0 +1,127 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// UnLt required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use arrow::array::BooleanArray; +use arrow::array::{make_comparator, ArrayRef, Datum}; +use arrow::buffer::NullBuffer; +use arrow::compute::SortOptions; +use arrow::error::ArrowError; +use datafusion_common::{internal_err, not_impl_err}; +use datafusion_common::{Result, ScalarValue}; +use datafusion_expr::{ColumnarValue, Operator}; +use std::sync::Arc; + +/// Applies a binary [`Datum`] kernel `f` to `lhs` and `rhs` +/// +/// This maps arrow-rs' [`Datum`] kernels to DataFusion's [`ColumnarValue`] abstraction +pub fn apply( + lhs: &ColumnarValue, + rhs: &ColumnarValue, + f: impl Fn(&dyn Datum, &dyn Datum) -> Result, +) -> Result { + match (&lhs, &rhs) { + (ColumnarValue::Array(left), ColumnarValue::Array(right)) => { + Ok(ColumnarValue::Array(f(&left.as_ref(), &right.as_ref())?)) + } + (ColumnarValue::Scalar(left), ColumnarValue::Array(right)) => Ok( + ColumnarValue::Array(f(&left.to_scalar()?, &right.as_ref())?), + ), + (ColumnarValue::Array(left), ColumnarValue::Scalar(right)) => Ok( + ColumnarValue::Array(f(&left.as_ref(), &right.to_scalar()?)?), + ), + (ColumnarValue::Scalar(left), ColumnarValue::Scalar(right)) => { + let array = f(&left.to_scalar()?, &right.to_scalar()?)?; + let scalar = ScalarValue::try_from_array(array.as_ref(), 0)?; + Ok(ColumnarValue::Scalar(scalar)) + } + } +} + +/// Applies a binary [`Datum`] comparison kernel `f` to `lhs` and `rhs` +pub fn apply_cmp( + lhs: &ColumnarValue, + rhs: &ColumnarValue, + f: impl Fn(&dyn Datum, &dyn Datum) -> Result, +) -> Result { + apply(lhs, rhs, |l, r| Ok(Arc::new(f(l, r)?))) +} + +/// Applies a binary [`Datum`] comparison kernel `f` to `lhs` and `rhs` for nested type like +/// List, FixedSizeList, LargeList, Struct, Union, Map, or a dictionary of a nested type +pub fn apply_cmp_for_nested( + op: Operator, + lhs: &ColumnarValue, + rhs: &ColumnarValue, +) -> Result { + apply(lhs, rhs, |l, r| { + Ok(Arc::new(compare_op_for_nested(op, l, r)?)) + }) +} + +/// Compare on nested type List, Struct, and so on +fn compare_op_for_nested( + op: Operator, + lhs: &dyn Datum, + rhs: &dyn Datum, +) -> Result { + let (l, is_l_scalar) = lhs.get(); + let (r, is_r_scalar) = rhs.get(); + let l_len = l.len(); + let r_len = r.len(); + if l_len != r_len && !is_l_scalar && !is_r_scalar { + return internal_err!("len mismatch"); + } + + let len = match is_l_scalar { + true => r_len, + false => l_len, + }; + + let cmp = make_comparator(l, r, SortOptions::default())?; + + if !matches!( + op, + Operator::Eq + | Operator::Lt + | Operator::Gt + | Operator::LtEq + | Operator::GtEq + | Operator::NotEq + ) { + return not_impl_err!("other operation are not implemented"); + } + + let cmp_with_op = |i, j| match op { + Operator::Eq => cmp(i, j).is_eq(), + Operator::Lt => cmp(i, j).is_lt(), + Operator::Gt => cmp(i, j).is_gt(), + Operator::LtEq => !cmp(i, j).is_gt(), + Operator::GtEq => !cmp(i, j).is_lt(), + Operator::NotEq => !cmp(i, j).is_eq(), + _ => unreachable!("other operatations should be be handled above"), + }; + + let values = match (is_l_scalar, is_r_scalar) { + (false, false) => (0..len).map(|i| cmp_with_op(i, i)).collect(), + (true, false) => (0..len).map(|i| cmp_with_op(0, i)).collect(), + (false, true) => (0..len).map(|i| cmp_with_op(i, 0)).collect(), + (true, true) => std::iter::once(cmp_with_op(0, 0)).collect(), + }; + + let nulls = NullBuffer::union(l.nulls(), r.nulls()); + Ok(BooleanArray::new(values, nulls)) +} diff --git a/datafusion/physical-expr-common/src/lib.rs b/datafusion/physical-expr-common/src/lib.rs index 0ddb84141a073..8d50e0b964e5b 100644 --- a/datafusion/physical-expr-common/src/lib.rs +++ b/datafusion/physical-expr-common/src/lib.rs @@ -17,6 +17,7 @@ pub mod aggregate; pub mod binary_map; +pub mod datum; pub mod expressions; pub mod physical_expr; pub mod sort_expr; diff --git a/datafusion/physical-expr/src/expressions/binary.rs b/datafusion/physical-expr/src/expressions/binary.rs index 98df0cba9f3ec..33eb53a833ad0 100644 --- a/datafusion/physical-expr/src/expressions/binary.rs +++ b/datafusion/physical-expr/src/expressions/binary.rs @@ -20,7 +20,6 @@ mod kernels; use std::hash::{Hash, Hasher}; use std::{any::Any, sync::Arc}; -use crate::expressions::datum::{apply, apply_cmp}; use crate::intervals::cp_solver::{propagate_arithmetic, propagate_comparison}; use crate::physical_expr::down_cast_any_ref; use crate::PhysicalExpr; @@ -40,6 +39,7 @@ use datafusion_expr::interval_arithmetic::{apply_operator, Interval}; use datafusion_expr::sort_properties::ExprProperties; use datafusion_expr::type_coercion::binary::get_result_type; use datafusion_expr::{ColumnarValue, Operator}; +use datafusion_physical_expr_common::datum::{apply, apply_cmp, apply_cmp_for_nested}; use kernels::{ bitwise_and_dyn, bitwise_and_dyn_scalar, bitwise_or_dyn, bitwise_or_dyn_scalar, @@ -271,7 +271,17 @@ impl PhysicalExpr for BinaryExpr { Operator::Multiply => return apply(&lhs, &rhs, mul_wrapping), Operator::Divide => return apply(&lhs, &rhs, div), Operator::Modulo => return apply(&lhs, &rhs, rem), - Operator::Eq => return apply_cmp(&lhs, &rhs, eq), + Operator::Eq => { + if left_data_type.is_nested() { + if right_data_type != left_data_type { + return internal_err!("type mismatch"); + } + // apply cmp for nested + return apply_cmp_for_nested(self.op, &lhs, &rhs); + } + + return apply_cmp(&lhs, &rhs, eq); + } Operator::NotEq => return apply_cmp(&lhs, &rhs, neq), Operator::Lt => return apply_cmp(&lhs, &rhs, lt), Operator::Gt => return apply_cmp(&lhs, &rhs, gt), diff --git a/datafusion/physical-expr/src/expressions/datum.rs b/datafusion/physical-expr/src/expressions/datum.rs deleted file mode 100644 index 2bb79922cfecc..0000000000000 --- a/datafusion/physical-expr/src/expressions/datum.rs +++ /dev/null @@ -1,58 +0,0 @@ -// Licensed to the Apache Software Foundation (ASF) under one -// or more contributor license agreements. See the NOTICE file -// distributed with this work for additional information -// regarding copyright ownership. The ASF licenses this file -// to you under the Apache License, Version 2.0 (the -// "License"); you may not use this file except in compliance -// with the License. You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, -// software distributed under the License is distributed on an -// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -// KIND, either express or implied. See the License for the -// specific language governing permissions and limitations -// under the License. - -use arrow::array::{ArrayRef, Datum}; -use arrow::error::ArrowError; -use arrow_array::BooleanArray; -use datafusion_common::{Result, ScalarValue}; -use datafusion_expr::ColumnarValue; -use std::sync::Arc; - -/// Applies a binary [`Datum`] kernel `f` to `lhs` and `rhs` -/// -/// This maps arrow-rs' [`Datum`] kernels to DataFusion's [`ColumnarValue`] abstraction -pub(crate) fn apply( - lhs: &ColumnarValue, - rhs: &ColumnarValue, - f: impl Fn(&dyn Datum, &dyn Datum) -> Result, -) -> Result { - match (&lhs, &rhs) { - (ColumnarValue::Array(left), ColumnarValue::Array(right)) => { - Ok(ColumnarValue::Array(f(&left.as_ref(), &right.as_ref())?)) - } - (ColumnarValue::Scalar(left), ColumnarValue::Array(right)) => Ok( - ColumnarValue::Array(f(&left.to_scalar()?, &right.as_ref())?), - ), - (ColumnarValue::Array(left), ColumnarValue::Scalar(right)) => Ok( - ColumnarValue::Array(f(&left.as_ref(), &right.to_scalar()?)?), - ), - (ColumnarValue::Scalar(left), ColumnarValue::Scalar(right)) => { - let array = f(&left.to_scalar()?, &right.to_scalar()?)?; - let scalar = ScalarValue::try_from_array(array.as_ref(), 0)?; - Ok(ColumnarValue::Scalar(scalar)) - } - } -} - -/// Applies a binary [`Datum`] comparison kernel `f` to `lhs` and `rhs` -pub(crate) fn apply_cmp( - lhs: &ColumnarValue, - rhs: &ColumnarValue, - f: impl Fn(&dyn Datum, &dyn Datum) -> Result, -) -> Result { - apply(lhs, rhs, |l, r| Ok(Arc::new(f(l, r)?))) -} diff --git a/datafusion/physical-expr/src/expressions/like.rs b/datafusion/physical-expr/src/expressions/like.rs index d18651c641fd3..e0c02b0a90e9c 100644 --- a/datafusion/physical-expr/src/expressions/like.rs +++ b/datafusion/physical-expr/src/expressions/like.rs @@ -20,11 +20,11 @@ use std::{any::Any, sync::Arc}; use crate::{physical_expr::down_cast_any_ref, PhysicalExpr}; -use crate::expressions::datum::apply_cmp; use arrow::record_batch::RecordBatch; use arrow_schema::{DataType, Schema}; use datafusion_common::{internal_err, Result}; use datafusion_expr::ColumnarValue; +use datafusion_physical_expr_common::datum::apply_cmp; // Like expression #[derive(Debug, Hash)] diff --git a/datafusion/physical-expr/src/expressions/mod.rs b/datafusion/physical-expr/src/expressions/mod.rs index c98bcc56ad97a..608609b81d823 100644 --- a/datafusion/physical-expr/src/expressions/mod.rs +++ b/datafusion/physical-expr/src/expressions/mod.rs @@ -21,7 +21,6 @@ mod binary; mod case; mod column; -mod datum; mod in_list; mod is_not_null; mod is_null; diff --git a/datafusion/sqllogictest/test_files/array_query.slt b/datafusion/sqllogictest/test_files/array_query.slt index 24c99fc849b6b..9d1748a36aeb4 100644 --- a/datafusion/sqllogictest/test_files/array_query.slt +++ b/datafusion/sqllogictest/test_files/array_query.slt @@ -41,11 +41,16 @@ SELECT * FROM data; # Filtering ########### -query error DataFusion error: Arrow error: Invalid argument error: Invalid comparison operation: List\(Field \{ name: "item", data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: \{\} \}\) == List\(Field \{ name: "item", data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: \{\} \}\) +query ??I SELECT * FROM data WHERE column1 = [1,2,3]; +---- +[1, 2, 3] [4, 5] 1 +[1, 2, 3] NULL 1 -query error DataFusion error: Arrow error: Invalid argument error: Invalid comparison operation: List\(Field \{ name: "item", data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: \{\} \}\) == List\(Field \{ name: "item", data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: \{\} \}\) +query ??I SELECT * FROM data WHERE column1 = column2 +---- +[2, 3] [2, 3] 1 query error DataFusion error: Arrow error: Invalid argument error: Invalid comparison operation: List\(Field \{ name: "item", data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: \{\} \}\) != List\(Field \{ name: "item", data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: \{\} \}\) SELECT * FROM data WHERE column1 != [1,2,3]; From a7c7808ce422e46b26c93d2a726cb95967f1c6bd Mon Sep 17 00:00:00 2001 From: jayzhan211 Date: Mon, 24 Jun 2024 09:21:03 +0800 Subject: [PATCH 2/5] add more ops Signed-off-by: jayzhan211 --- .../physical-expr/src/expressions/binary.rs | 29 +++++++++----- .../sqllogictest/test_files/array_query.slt | 40 +++++++++++++++---- 2 files changed, 51 insertions(+), 18 deletions(-) diff --git a/datafusion/physical-expr/src/expressions/binary.rs b/datafusion/physical-expr/src/expressions/binary.rs index 33eb53a833ad0..b9646503bd5b5 100644 --- a/datafusion/physical-expr/src/expressions/binary.rs +++ b/datafusion/physical-expr/src/expressions/binary.rs @@ -265,23 +265,30 @@ impl PhysicalExpr for BinaryExpr { let schema = batch.schema(); let input_schema = schema.as_ref(); + if left_data_type.is_nested() { + if right_data_type != left_data_type { + return internal_err!("type mismatch"); + } + if matches!( + self.op, + Operator::Eq + | Operator::NotEq + | Operator::Lt + | Operator::Gt + | Operator::LtEq + | Operator::GtEq + ) { + return apply_cmp_for_nested(self.op, &lhs, &rhs); + } + } + match self.op { Operator::Plus => return apply(&lhs, &rhs, add_wrapping), Operator::Minus => return apply(&lhs, &rhs, sub_wrapping), Operator::Multiply => return apply(&lhs, &rhs, mul_wrapping), Operator::Divide => return apply(&lhs, &rhs, div), Operator::Modulo => return apply(&lhs, &rhs, rem), - Operator::Eq => { - if left_data_type.is_nested() { - if right_data_type != left_data_type { - return internal_err!("type mismatch"); - } - // apply cmp for nested - return apply_cmp_for_nested(self.op, &lhs, &rhs); - } - - return apply_cmp(&lhs, &rhs, eq); - } + Operator::Eq => return apply_cmp(&lhs, &rhs, eq), Operator::NotEq => return apply_cmp(&lhs, &rhs, neq), Operator::Lt => return apply_cmp(&lhs, &rhs, lt), Operator::Gt => return apply_cmp(&lhs, &rhs, gt), diff --git a/datafusion/sqllogictest/test_files/array_query.slt b/datafusion/sqllogictest/test_files/array_query.slt index 9d1748a36aeb4..138d4e4fb00c1 100644 --- a/datafusion/sqllogictest/test_files/array_query.slt +++ b/datafusion/sqllogictest/test_files/array_query.slt @@ -41,22 +41,48 @@ SELECT * FROM data; # Filtering ########### -query ??I +query ??I rowsort SELECT * FROM data WHERE column1 = [1,2,3]; ---- -[1, 2, 3] [4, 5] 1 [1, 2, 3] NULL 1 +[1, 2, 3] [4, 5] 1 query ??I -SELECT * FROM data WHERE column1 = column2 +SELECT * FROM data WHERE column1 != [1,2,3]; ---- [2, 3] [2, 3] 1 -query error DataFusion error: Arrow error: Invalid argument error: Invalid comparison operation: List\(Field \{ name: "item", data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: \{\} \}\) != List\(Field \{ name: "item", data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: \{\} \}\) -SELECT * FROM data WHERE column1 != [1,2,3]; - -query error DataFusion error: Arrow error: Invalid argument error: Invalid comparison operation: List\(Field \{ name: "item", data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: \{\} \}\) != List\(Field \{ name: "item", data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: \{\} \}\) +query ??I SELECT * FROM data WHERE column1 != column2 +---- +[1, 2, 3] [4, 5] 1 + +query ??I rowsort +SELECT * FROM data WHERE column1 < [1,2,3,4]; +---- +[1, 2, 3] NULL 1 +[1, 2, 3] [4, 5] 1 + +query ??I rowsort +SELECT * FROM data WHERE column1 <= [2, 3]; +---- +[1, 2, 3] NULL 1 +[1, 2, 3] [4, 5] 1 +[2, 3] [2, 3] 1 + +query ??I rowsort +SELECT * FROM data WHERE column1 > [1,2]; +---- +[1, 2, 3] NULL 1 +[1, 2, 3] [4, 5] 1 +[2, 3] [2, 3] 1 + +query ??I rowsort +SELECT * FROM data WHERE column1 >= [1, 2, 3]; +---- +[1, 2, 3] NULL 1 +[1, 2, 3] [4, 5] 1 +[2, 3] [2, 3] 1 ########### # Aggregates From 8ec73a256669ee186fc69f84d09bee3841b925c4 Mon Sep 17 00:00:00 2001 From: jayzhan211 Date: Tue, 25 Jun 2024 09:15:33 +0800 Subject: [PATCH 3/5] add distinct Signed-off-by: jayzhan211 --- datafusion/physical-expr-common/src/datum.rs | 89 ++++++++++++++----- .../physical-expr/src/expressions/binary.rs | 12 +-- .../sqllogictest/test_files/array_query.slt | 39 ++++++++ 3 files changed, 108 insertions(+), 32 deletions(-) diff --git a/datafusion/physical-expr-common/src/datum.rs b/datafusion/physical-expr-common/src/datum.rs index f3c6f02506a1b..6d4ec4050a2ca 100644 --- a/datafusion/physical-expr-common/src/datum.rs +++ b/datafusion/physical-expr-common/src/datum.rs @@ -20,7 +20,7 @@ use arrow::array::{make_comparator, ArrayRef, Datum}; use arrow::buffer::NullBuffer; use arrow::compute::SortOptions; use arrow::error::ArrowError; -use datafusion_common::{internal_err, not_impl_err}; +use datafusion_common::internal_err; use datafusion_common::{Result, ScalarValue}; use datafusion_expr::{ColumnarValue, Operator}; use std::sync::Arc; @@ -67,9 +67,23 @@ pub fn apply_cmp_for_nested( lhs: &ColumnarValue, rhs: &ColumnarValue, ) -> Result { - apply(lhs, rhs, |l, r| { - Ok(Arc::new(compare_op_for_nested(op, l, r)?)) - }) + if matches!( + op, + Operator::Eq + | Operator::NotEq + | Operator::Lt + | Operator::Gt + | Operator::LtEq + | Operator::GtEq + | Operator::IsDistinctFrom + | Operator::IsNotDistinctFrom + ) { + apply(lhs, rhs, |l, r| { + Ok(Arc::new(compare_op_for_nested(op, l, r)?)) + }) + } else { + internal_err!("invalid operator for nested") + } } /// Compare on nested type List, Struct, and so on @@ -82,6 +96,7 @@ fn compare_op_for_nested( let (r, is_r_scalar) = rhs.get(); let l_len = l.len(); let r_len = r.len(); + if l_len != r_len && !is_l_scalar && !is_r_scalar { return internal_err!("len mismatch"); } @@ -91,28 +106,23 @@ fn compare_op_for_nested( false => l_len, }; - let cmp = make_comparator(l, r, SortOptions::default())?; - - if !matches!( - op, - Operator::Eq - | Operator::Lt - | Operator::Gt - | Operator::LtEq - | Operator::GtEq - | Operator::NotEq - ) { - return not_impl_err!("other operation are not implemented"); + // fast path, if compare with one null and operator is not 'distinct', then we can return null array directly + if !matches!(op, Operator::IsDistinctFrom | Operator::IsNotDistinctFrom) + && (l.null_count() == 1 || r.null_count() == 1) + { + return Ok(BooleanArray::new_null(len)); } + let cmp = make_comparator(l, r, SortOptions::default())?; + let cmp_with_op = |i, j| match op { - Operator::Eq => cmp(i, j).is_eq(), + Operator::Eq | Operator::IsNotDistinctFrom => cmp(i, j).is_eq(), Operator::Lt => cmp(i, j).is_lt(), Operator::Gt => cmp(i, j).is_gt(), Operator::LtEq => !cmp(i, j).is_gt(), Operator::GtEq => !cmp(i, j).is_lt(), - Operator::NotEq => !cmp(i, j).is_eq(), - _ => unreachable!("other operatations should be be handled above"), + Operator::NotEq | Operator::IsDistinctFrom => !cmp(i, j).is_eq(), + _ => unreachable!("unexpected operator found"), }; let values = match (is_l_scalar, is_r_scalar) { @@ -122,6 +132,43 @@ fn compare_op_for_nested( (true, true) => std::iter::once(cmp_with_op(0, 0)).collect(), }; - let nulls = NullBuffer::union(l.nulls(), r.nulls()); - Ok(BooleanArray::new(values, nulls)) + if matches!(op, Operator::IsDistinctFrom | Operator::IsNotDistinctFrom) { + Ok(BooleanArray::new(values, None)) + } else { + let nulls = NullBuffer::union(l.nulls(), r.nulls()); + Ok(BooleanArray::new(values, nulls)) + } +} + +#[cfg(test)] +mod tests { + use arrow::{ + array::{make_comparator, Array, BooleanArray, ListArray}, + buffer::NullBuffer, + compute::SortOptions, + datatypes::Int32Type, + }; + + #[test] + fn test123() { + let data = vec![ + Some(vec![Some(0), Some(1), Some(2)]), + None, + Some(vec![Some(3), None, Some(5)]), + Some(vec![Some(6), Some(7)]), + ]; + let a = ListArray::from_iter_primitive::(data); + let data = vec![ + Some(vec![Some(0), Some(1), Some(2)]), + None, + Some(vec![Some(3), None, Some(5)]), + Some(vec![Some(6), Some(7)]), + ]; + let b = ListArray::from_iter_primitive::(data); + let cmp = make_comparator(&a, &b, SortOptions::default()).unwrap(); + let len = a.len().min(b.len()); + let values = (0..len).map(|i| cmp(i, i).is_eq()).collect(); + let nulls = NullBuffer::union(a.nulls(), b.nulls()); + println!("res: {:?}", BooleanArray::new(values, nulls)); + } } diff --git a/datafusion/physical-expr/src/expressions/binary.rs b/datafusion/physical-expr/src/expressions/binary.rs index b9646503bd5b5..3a8f7ee56ace3 100644 --- a/datafusion/physical-expr/src/expressions/binary.rs +++ b/datafusion/physical-expr/src/expressions/binary.rs @@ -269,17 +269,7 @@ impl PhysicalExpr for BinaryExpr { if right_data_type != left_data_type { return internal_err!("type mismatch"); } - if matches!( - self.op, - Operator::Eq - | Operator::NotEq - | Operator::Lt - | Operator::Gt - | Operator::LtEq - | Operator::GtEq - ) { - return apply_cmp_for_nested(self.op, &lhs, &rhs); - } + return apply_cmp_for_nested(self.op, &lhs, &rhs); } match self.op { diff --git a/datafusion/sqllogictest/test_files/array_query.slt b/datafusion/sqllogictest/test_files/array_query.slt index 138d4e4fb00c1..bab4132a87310 100644 --- a/datafusion/sqllogictest/test_files/array_query.slt +++ b/datafusion/sqllogictest/test_files/array_query.slt @@ -84,6 +84,26 @@ SELECT * FROM data WHERE column1 >= [1, 2, 3]; [1, 2, 3] [4, 5] 1 [2, 3] [2, 3] 1 +# test with scalar null +query ??I +SELECT * FROM data WHERE column2 = null; +---- + +query ??I +SELECT * FROM data WHERE null = column2; +---- + +query ??I +SELECT * FROM data WHERE column2 is distinct from null; +---- +[1, 2, 3] [4, 5] 1 +[2, 3] [2, 3] 1 + +query ??I +SELECT * FROM data WHERE column2 is not distinct from null; +---- +[1, 2, 3] NULL 1 + ########### # Aggregates ########### @@ -189,3 +209,22 @@ SELECT * FROM data ORDER BY column1, column3, column2; statement ok drop table data + + +# test filter column with all nulls +statement ok +create table data (a int) as values (null), (null), (null); + +query I +select * from data where a = null; +---- + +query I +select * from data where a is not distinct from null; +---- +NULL +NULL +NULL + +statement ok +drop table data; From 2e9371a0a4147cf31b5f683b61d33c8c58fce324 Mon Sep 17 00:00:00 2001 From: jayzhan211 Date: Tue, 25 Jun 2024 09:28:10 +0800 Subject: [PATCH 4/5] nested Signed-off-by: jayzhan211 --- datafusion/physical-expr-common/src/datum.rs | 2 +- .../sqllogictest/test_files/array_query.slt | 48 ++++++++++++++++++- 2 files changed, 48 insertions(+), 2 deletions(-) diff --git a/datafusion/physical-expr-common/src/datum.rs b/datafusion/physical-expr-common/src/datum.rs index 6d4ec4050a2ca..99adb34f9340a 100644 --- a/datafusion/physical-expr-common/src/datum.rs +++ b/datafusion/physical-expr-common/src/datum.rs @@ -108,7 +108,7 @@ fn compare_op_for_nested( // fast path, if compare with one null and operator is not 'distinct', then we can return null array directly if !matches!(op, Operator::IsDistinctFrom | Operator::IsNotDistinctFrom) - && (l.null_count() == 1 || r.null_count() == 1) + && (is_l_scalar && l.null_count() == 1 || is_r_scalar && r.null_count() == 1) { return Ok(BooleanArray::new_null(len)); } diff --git a/datafusion/sqllogictest/test_files/array_query.slt b/datafusion/sqllogictest/test_files/array_query.slt index bab4132a87310..b29b5f5efd986 100644 --- a/datafusion/sqllogictest/test_files/array_query.slt +++ b/datafusion/sqllogictest/test_files/array_query.slt @@ -96,8 +96,8 @@ SELECT * FROM data WHERE null = column2; query ??I SELECT * FROM data WHERE column2 is distinct from null; ---- -[1, 2, 3] [4, 5] 1 [2, 3] [2, 3] 1 +[1, 2, 3] [4, 5] 1 query ??I SELECT * FROM data WHERE column2 is not distinct from null; @@ -228,3 +228,49 @@ NULL statement ok drop table data; + +statement ok +create table data (a int[][], b int) as values ([[1,2,3]], 1), ([[2,3], [4,5]], 2), (null, 3); + +query ?I +select * from data; +---- +[[1, 2, 3]] 1 +[[2, 3], [4, 5]] 2 +NULL 3 + +query ?I +select * from data where a = [[1,2,3]]; +---- +[[1, 2, 3]] 1 + +query ?I +select * from data where a > [[1,2,3]]; +---- +[[2, 3], [4, 5]] 2 + +query ?I +select * from data where a > [[1,2]]; +---- +[[1, 2, 3]] 1 +[[2, 3], [4, 5]] 2 + +query ?I +select * from data where a < [[2, 3]]; +---- +[[1, 2, 3]] 1 + +# compare with null with eq results in null +query ?I +select * from data where a = null; +---- + +query ?I +select * from data where a != null; +---- + +# compare with null with distinct results in true/false +query ?I +select * from data where a is not distinct from null; +---- +NULL 3 From 0628f3b878ed6541209636c09a3203a106bc7de9 Mon Sep 17 00:00:00 2001 From: jayzhan211 Date: Thu, 27 Jun 2024 08:48:57 +0800 Subject: [PATCH 5/5] add comment Signed-off-by: jayzhan211 --- datafusion/physical-expr-common/src/datum.rs | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/datafusion/physical-expr-common/src/datum.rs b/datafusion/physical-expr-common/src/datum.rs index 99adb34f9340a..f4ce0eebc0813 100644 --- a/datafusion/physical-expr-common/src/datum.rs +++ b/datafusion/physical-expr-common/src/datum.rs @@ -113,6 +113,8 @@ fn compare_op_for_nested( return Ok(BooleanArray::new_null(len)); } + // TODO: make SortOptions configurable + // we choose the default behaviour from arrow-rs which has null-first that follow spark's behaviour let cmp = make_comparator(l, r, SortOptions::default())?; let cmp_with_op = |i, j| match op { @@ -132,9 +134,13 @@ fn compare_op_for_nested( (true, true) => std::iter::once(cmp_with_op(0, 0)).collect(), }; + // Distinct understand how to compare with NULL + // i.e NULL is distinct from NULL -> false if matches!(op, Operator::IsDistinctFrom | Operator::IsNotDistinctFrom) { Ok(BooleanArray::new(values, None)) } else { + // If one of the side is NULL, we returns NULL + // i.e. NULL eq NULL -> NULL let nulls = NullBuffer::union(l.nulls(), r.nulls()); Ok(BooleanArray::new(values, nulls)) }