-
Notifications
You must be signed in to change notification settings - Fork 1.9k
Add InList support for binary type.
#3324
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -34,11 +34,10 @@ use arrow::{ | |
|
|
||
| use crate::PhysicalExpr; | ||
| use arrow::array::*; | ||
| use arrow::buffer::{Buffer, MutableBuffer}; | ||
| use datafusion_common::ScalarValue; | ||
| use datafusion_common::ScalarValue::{ | ||
| Boolean, Decimal128, Int16, Int32, Int64, Int8, LargeUtf8, UInt16, UInt32, UInt64, | ||
| UInt8, Utf8, | ||
| Binary, Boolean, Decimal128, Int16, Int32, Int64, Int8, LargeBinary, LargeUtf8, | ||
| UInt16, UInt32, UInt64, UInt8, Utf8, | ||
| }; | ||
| use datafusion_common::{DataFusionError, Result}; | ||
| use datafusion_expr::ColumnarValue; | ||
|
|
@@ -49,30 +48,6 @@ use datafusion_expr::ColumnarValue; | |
| /// TODO: add switch codeGen in In_List | ||
| static OPTIMIZER_INSET_THRESHOLD: usize = 30; | ||
|
|
||
| macro_rules! compare_op_scalar { | ||
| ($left: expr, $right:expr, $op:expr) => {{ | ||
| let null_bit_buffer = $left.data().null_buffer().cloned(); | ||
|
|
||
| let comparison = | ||
| (0..$left.len()).map(|i| unsafe { $op($left.value_unchecked(i), $right) }); | ||
| // same as $left.len() | ||
| let buffer = unsafe { MutableBuffer::from_trusted_len_iter_bool(comparison) }; | ||
|
|
||
| let data = unsafe { | ||
| ArrayData::new_unchecked( | ||
| DataType::Boolean, | ||
| $left.len(), | ||
| None, | ||
| null_bit_buffer, | ||
| 0, | ||
| vec![Buffer::from(buffer)], | ||
| vec![], | ||
| ) | ||
| }; | ||
| Ok(BooleanArray::from(data)) | ||
| }}; | ||
| } | ||
|
|
||
| /// InList | ||
| #[derive(Debug)] | ||
| pub struct InListExpr { | ||
|
|
@@ -293,21 +268,6 @@ macro_rules! collection_contains_check_decimal { | |
| }}; | ||
| } | ||
|
|
||
| // whether each value on the left (can be null) is contained in the non-null list | ||
| fn in_list_utf8<OffsetSize: OffsetSizeTrait>( | ||
| array: &GenericStringArray<OffsetSize>, | ||
| values: &[&str], | ||
| ) -> Result<BooleanArray> { | ||
| compare_op_scalar!(array, values, |x, v: &[&str]| v.contains(&x)) | ||
| } | ||
|
|
||
| fn not_in_list_utf8<OffsetSize: OffsetSizeTrait>( | ||
| array: &GenericStringArray<OffsetSize>, | ||
| values: &[&str], | ||
| ) -> Result<BooleanArray> { | ||
| compare_op_scalar!(array, values, |x, v: &[&str]| !v.contains(&x)) | ||
| } | ||
|
|
||
| // try evaluate all list exprs and check if the exprs are constants or not | ||
| fn try_cast_static_filter_to_set( | ||
| list: &[Arc<dyn PhysicalExpr>], | ||
|
|
@@ -386,8 +346,7 @@ fn set_contains_utf8<OffsetSize: OffsetSizeTrait>( | |
| let native_array = set | ||
| .iter() | ||
| .flat_map(|v| match v { | ||
| Utf8(v) => v.as_deref(), | ||
| LargeUtf8(v) => v.as_deref(), | ||
| Utf8(v) | LargeUtf8(v) => v.as_deref(), | ||
| datatype => { | ||
| unreachable!("InList can't reach other data type {} for {}.", datatype, v) | ||
| } | ||
|
|
@@ -398,6 +357,26 @@ fn set_contains_utf8<OffsetSize: OffsetSizeTrait>( | |
| collection_contains_check!(array, native_set, negated, contains_null) | ||
| } | ||
|
|
||
| fn set_contains_binary<OffsetSize: OffsetSizeTrait>( | ||
| array: &GenericBinaryArray<OffsetSize>, | ||
| set: &HashSet<ScalarValue>, | ||
| negated: bool, | ||
| ) -> ColumnarValue { | ||
| let contains_null = set.iter().any(|v| v.is_null()); | ||
| let native_array = set | ||
| .iter() | ||
| .flat_map(|v| match v { | ||
| Binary(v) | LargeBinary(v) => v.as_deref(), | ||
| datatype => { | ||
| unreachable!("InList can't reach other data type {} for {}.", datatype, v) | ||
| } | ||
| }) | ||
| .collect::<Vec<_>>(); | ||
| let native_set: HashSet<&[u8]> = HashSet::from_iter(native_array); | ||
|
Comment on lines
+374
to
+375
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Can you skip this intermediate Like
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I will do this in a follow-up patch as we have this problem in both |
||
|
|
||
| collection_contains_check!(array, native_set, negated, contains_null) | ||
| } | ||
|
|
||
| impl InListExpr { | ||
| /// Create a new InList expression | ||
| pub fn new( | ||
|
|
@@ -471,37 +450,50 @@ impl InListExpr { | |
| }) | ||
| .collect::<Vec<&str>>(); | ||
|
|
||
| if negated { | ||
| if contains_null { | ||
| Ok(ColumnarValue::Array(Arc::new( | ||
| array | ||
| .iter() | ||
| .map(|x| match x.map(|v| !values.contains(&v)) { | ||
| Some(true) => None, | ||
| x => x, | ||
| }) | ||
| .collect::<BooleanArray>(), | ||
| ))) | ||
| } else { | ||
| Ok(ColumnarValue::Array(Arc::new(not_in_list_utf8( | ||
| array, &values, | ||
| )?))) | ||
| } | ||
| } else if contains_null { | ||
| Ok(ColumnarValue::Array(Arc::new( | ||
| array | ||
| .iter() | ||
| .map(|x| match x.map(|v| values.contains(&v)) { | ||
| Some(false) => None, | ||
| x => x, | ||
| }) | ||
| .collect::<BooleanArray>(), | ||
| ))) | ||
| } else { | ||
| Ok(ColumnarValue::Array(Arc::new(in_list_utf8( | ||
| array, &values, | ||
| )?))) | ||
| } | ||
| Ok(collection_contains_check!( | ||
| array, | ||
| values, | ||
| negated, | ||
| contains_null | ||
| )) | ||
| } | ||
|
|
||
| fn compare_binary<T: OffsetSizeTrait>( | ||
| &self, | ||
| array: ArrayRef, | ||
| list_values: Vec<ColumnarValue>, | ||
| negated: bool, | ||
| ) -> Result<ColumnarValue> { | ||
| let array = array | ||
| .as_any() | ||
| .downcast_ref::<GenericBinaryArray<T>>() | ||
| .unwrap(); | ||
|
|
||
| let contains_null = list_values | ||
| .iter() | ||
| .any(|v| matches!(v, ColumnarValue::Scalar(s) if s.is_null())); | ||
| let values = list_values | ||
| .iter() | ||
| .flat_map(|expr| match expr { | ||
| ColumnarValue::Scalar(s) => match s { | ||
| ScalarValue::Binary(Some(v)) | ScalarValue::LargeBinary(Some(v)) => { | ||
| Some(v.as_slice()) | ||
| } | ||
| ScalarValue::Binary(None) | ScalarValue::LargeBinary(None) => None, | ||
| datatype => unimplemented!("Unexpected type {} for InList", datatype), | ||
| }, | ||
| ColumnarValue::Array(_) => { | ||
| unimplemented!("InList does not yet support nested columns.") | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think this case is when one of the arguments to So like (no change needed I was just trying to clarify) |
||
| } | ||
| }) | ||
| .collect::<Vec<&[u8]>>(); | ||
|
|
||
| Ok(collection_contains_check!( | ||
| array, | ||
| values, | ||
| negated, | ||
| contains_null | ||
| )) | ||
| } | ||
| } | ||
|
|
||
|
|
@@ -670,6 +662,20 @@ impl PhysicalExpr for InListExpr { | |
| .unwrap(); | ||
| Ok(set_contains_utf8(array, set, self.negated)) | ||
| } | ||
| DataType::Binary => { | ||
| let array = array | ||
| .as_any() | ||
| .downcast_ref::<GenericBinaryArray<i32>>() | ||
| .unwrap(); | ||
| Ok(set_contains_binary(array, set, self.negated)) | ||
| } | ||
| DataType::LargeBinary => { | ||
| let array = array | ||
| .as_any() | ||
| .downcast_ref::<GenericBinaryArray<i64>>() | ||
| .unwrap(); | ||
| Ok(set_contains_binary(array, set, self.negated)) | ||
| } | ||
| DataType::Decimal128(_, _) => { | ||
| let array = array.as_any().downcast_ref::<Decimal128Array>().unwrap(); | ||
| Ok(make_set_contains_decimal(array, set, self.negated)) | ||
|
|
@@ -795,6 +801,12 @@ impl PhysicalExpr for InListExpr { | |
| DataType::LargeUtf8 => { | ||
| self.compare_utf8::<i64>(array, list_values, self.negated) | ||
| } | ||
| DataType::Binary => { | ||
| self.compare_binary::<i32>(array, list_values, self.negated) | ||
| } | ||
| DataType::LargeBinary => { | ||
| self.compare_binary::<i64>(array, list_values, self.negated) | ||
| } | ||
| DataType::Null => { | ||
| let null_array = new_null_array(&DataType::Boolean, array.len()); | ||
| Ok(ColumnarValue::Array(Arc::new(null_array))) | ||
|
|
@@ -906,6 +918,66 @@ mod tests { | |
| Ok(()) | ||
| } | ||
|
|
||
| #[test] | ||
| fn in_list_binary() -> Result<()> { | ||
| let schema = Schema::new(vec![Field::new("a", DataType::Binary, true)]); | ||
| let a = BinaryArray::from(vec![ | ||
| Some([1, 2, 3].as_slice()), | ||
| Some([1, 2, 2].as_slice()), | ||
| None, | ||
| ]); | ||
| let col_a = col("a", &schema)?; | ||
| let batch = RecordBatch::try_new(Arc::new(schema.clone()), vec![Arc::new(a)])?; | ||
|
|
||
| // expression: "a in ([1, 2, 3], [4, 5, 6])" | ||
| let list = vec![lit([1, 2, 3].as_slice()), lit([4, 5, 6].as_slice())]; | ||
| in_list!( | ||
| batch, | ||
| list.clone(), | ||
| &false, | ||
| vec![Some(true), Some(false), None], | ||
| col_a.clone(), | ||
| &schema | ||
| ); | ||
|
|
||
| // expression: "a not in ([1, 2, 3], [4, 5, 6])" | ||
| in_list!( | ||
| batch, | ||
| list, | ||
| &true, | ||
| vec![Some(false), Some(true), None], | ||
| col_a.clone(), | ||
| &schema | ||
| ); | ||
|
|
||
| // expression: "a in ([1, 2, 3], [4, 5, 6], null)" | ||
| let list = vec![ | ||
| lit([1, 2, 3].as_slice()), | ||
| lit([4, 5, 6].as_slice()), | ||
| lit(ScalarValue::Binary(None)), | ||
| ]; | ||
| in_list!( | ||
| batch, | ||
| list.clone(), | ||
| &false, | ||
| vec![Some(true), None, None], | ||
| col_a.clone(), | ||
| &schema | ||
| ); | ||
|
|
||
| // expression: "a in ([1, 2, 3], [4, 5, 6], null)" | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Op, my mistake. Sorry.
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Never mind. |
||
| in_list!( | ||
| batch, | ||
| list, | ||
| &true, | ||
| vec![Some(false), None, None], | ||
| col_a.clone(), | ||
| &schema | ||
| ); | ||
|
|
||
| Ok(()) | ||
| } | ||
|
|
||
| #[test] | ||
| fn in_list_int64() -> Result<()> { | ||
| let schema = Schema::new(vec![Field::new("a", DataType::Int64, true)]); | ||
|
|
@@ -1316,6 +1388,43 @@ mod tests { | |
| Ok(()) | ||
| } | ||
|
|
||
| #[test] | ||
| fn in_list_set_binary() -> Result<()> { | ||
| let schema = Schema::new(vec![Field::new("a", DataType::Binary, true)]); | ||
| let a = BinaryArray::from(vec![ | ||
| Some([1, 2, 3].as_slice()), | ||
| Some([3, 2, 1].as_slice()), | ||
| None, | ||
| ]); | ||
| let col_a = col("a", &schema)?; | ||
| let batch = RecordBatch::try_new(Arc::new(schema.clone()), vec![Arc::new(a)])?; | ||
|
|
||
| let mut list = vec![lit([1, 2, 3].as_slice()), lit(ScalarValue::Binary(None))]; | ||
| for v in 0..OPTIMIZER_INSET_THRESHOLD { | ||
| list.push(lit([v as u8].as_slice())); | ||
| } | ||
|
|
||
| in_list!( | ||
| batch, | ||
| list.clone(), | ||
| &false, | ||
| vec![Some(true), None, None], | ||
| col_a.clone(), | ||
| &schema | ||
| ); | ||
|
|
||
| in_list!( | ||
| batch, | ||
| list.clone(), | ||
| &true, | ||
| vec![Some(false), None, None], | ||
| col_a.clone(), | ||
| &schema | ||
| ); | ||
|
|
||
| Ok(()) | ||
| } | ||
|
|
||
| #[test] | ||
| fn in_list_set_decimal() -> Result<()> { | ||
| let schema = | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
❤️