Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
259 changes: 184 additions & 75 deletions datafusion/physical-expr/src/expressions/in_list.rs
Original file line number Diff line number Diff line change
Expand Up @@ -34,11 +34,10 @@ use arrow::{

use crate::PhysicalExpr;
use arrow::array::*;
use arrow::buffer::{Buffer, MutableBuffer};
use datafusion_common::ScalarValue;
use datafusion_common::ScalarValue::{
Boolean, Decimal128, Int16, Int32, Int64, Int8, LargeUtf8, UInt16, UInt32, UInt64,
UInt8, Utf8,
Binary, Boolean, Decimal128, Int16, Int32, Int64, Int8, LargeBinary, LargeUtf8,
UInt16, UInt32, UInt64, UInt8, Utf8,
};
use datafusion_common::{DataFusionError, Result};
use datafusion_expr::ColumnarValue;
Expand All @@ -49,30 +48,6 @@ use datafusion_expr::ColumnarValue;
/// TODO: add switch codeGen in In_List
static OPTIMIZER_INSET_THRESHOLD: usize = 30;

macro_rules! compare_op_scalar {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

❤️

($left: expr, $right:expr, $op:expr) => {{
let null_bit_buffer = $left.data().null_buffer().cloned();

let comparison =
(0..$left.len()).map(|i| unsafe { $op($left.value_unchecked(i), $right) });
// same as $left.len()
let buffer = unsafe { MutableBuffer::from_trusted_len_iter_bool(comparison) };

let data = unsafe {
ArrayData::new_unchecked(
DataType::Boolean,
$left.len(),
None,
null_bit_buffer,
0,
vec![Buffer::from(buffer)],
vec![],
)
};
Ok(BooleanArray::from(data))
}};
}

/// InList
#[derive(Debug)]
pub struct InListExpr {
Expand Down Expand Up @@ -293,21 +268,6 @@ macro_rules! collection_contains_check_decimal {
}};
}

// whether each value on the left (can be null) is contained in the non-null list
fn in_list_utf8<OffsetSize: OffsetSizeTrait>(
array: &GenericStringArray<OffsetSize>,
values: &[&str],
) -> Result<BooleanArray> {
compare_op_scalar!(array, values, |x, v: &[&str]| v.contains(&x))
}

fn not_in_list_utf8<OffsetSize: OffsetSizeTrait>(
array: &GenericStringArray<OffsetSize>,
values: &[&str],
) -> Result<BooleanArray> {
compare_op_scalar!(array, values, |x, v: &[&str]| !v.contains(&x))
}

// try evaluate all list exprs and check if the exprs are constants or not
fn try_cast_static_filter_to_set(
list: &[Arc<dyn PhysicalExpr>],
Expand Down Expand Up @@ -386,8 +346,7 @@ fn set_contains_utf8<OffsetSize: OffsetSizeTrait>(
let native_array = set
.iter()
.flat_map(|v| match v {
Utf8(v) => v.as_deref(),
LargeUtf8(v) => v.as_deref(),
Utf8(v) | LargeUtf8(v) => v.as_deref(),
datatype => {
unreachable!("InList can't reach other data type {} for {}.", datatype, v)
}
Expand All @@ -398,6 +357,26 @@ fn set_contains_utf8<OffsetSize: OffsetSizeTrait>(
collection_contains_check!(array, native_set, negated, contains_null)
}

fn set_contains_binary<OffsetSize: OffsetSizeTrait>(
array: &GenericBinaryArray<OffsetSize>,
set: &HashSet<ScalarValue>,
negated: bool,
) -> ColumnarValue {
let contains_null = set.iter().any(|v| v.is_null());
let native_array = set
.iter()
.flat_map(|v| match v {
Binary(v) | LargeBinary(v) => v.as_deref(),
datatype => {
unreachable!("InList can't reach other data type {} for {}.", datatype, v)
}
})
.collect::<Vec<_>>();
let native_set: HashSet<&[u8]> = HashSet::from_iter(native_array);
Comment on lines +374 to +375
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you skip this intermediate Vec and instead go directly to HashSet?

Like .collect::<HashSet<_>>() ?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I will do this in a follow-up patch as we have this problem in both binary and string


collection_contains_check!(array, native_set, negated, contains_null)
}

impl InListExpr {
/// Create a new InList expression
pub fn new(
Expand Down Expand Up @@ -471,37 +450,50 @@ impl InListExpr {
})
.collect::<Vec<&str>>();

if negated {
if contains_null {
Ok(ColumnarValue::Array(Arc::new(
array
.iter()
.map(|x| match x.map(|v| !values.contains(&v)) {
Some(true) => None,
x => x,
})
.collect::<BooleanArray>(),
)))
} else {
Ok(ColumnarValue::Array(Arc::new(not_in_list_utf8(
array, &values,
)?)))
}
} else if contains_null {
Ok(ColumnarValue::Array(Arc::new(
array
.iter()
.map(|x| match x.map(|v| values.contains(&v)) {
Some(false) => None,
x => x,
})
.collect::<BooleanArray>(),
)))
} else {
Ok(ColumnarValue::Array(Arc::new(in_list_utf8(
array, &values,
)?)))
}
Ok(collection_contains_check!(
array,
values,
negated,
contains_null
))
}

fn compare_binary<T: OffsetSizeTrait>(
&self,
array: ArrayRef,
list_values: Vec<ColumnarValue>,
negated: bool,
) -> Result<ColumnarValue> {
let array = array
.as_any()
.downcast_ref::<GenericBinaryArray<T>>()
.unwrap();

let contains_null = list_values
.iter()
.any(|v| matches!(v, ColumnarValue::Scalar(s) if s.is_null()));
let values = list_values
.iter()
.flat_map(|expr| match expr {
ColumnarValue::Scalar(s) => match s {
ScalarValue::Binary(Some(v)) | ScalarValue::LargeBinary(Some(v)) => {
Some(v.as_slice())
}
ScalarValue::Binary(None) | ScalarValue::LargeBinary(None) => None,
datatype => unimplemented!("Unexpected type {} for InList", datatype),
},
ColumnarValue::Array(_) => {
unimplemented!("InList does not yet support nested columns.")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this case is when one of the arguments to InList is a column (rather than a constant)

So like col1 IN ('1', '2', col2) type thing

(no change needed I was just trying to clarify)

}
})
.collect::<Vec<&[u8]>>();

Ok(collection_contains_check!(
array,
values,
negated,
contains_null
))
}
}

Expand Down Expand Up @@ -670,6 +662,20 @@ impl PhysicalExpr for InListExpr {
.unwrap();
Ok(set_contains_utf8(array, set, self.negated))
}
DataType::Binary => {
let array = array
.as_any()
.downcast_ref::<GenericBinaryArray<i32>>()
.unwrap();
Ok(set_contains_binary(array, set, self.negated))
}
DataType::LargeBinary => {
let array = array
.as_any()
.downcast_ref::<GenericBinaryArray<i64>>()
.unwrap();
Ok(set_contains_binary(array, set, self.negated))
}
DataType::Decimal128(_, _) => {
let array = array.as_any().downcast_ref::<Decimal128Array>().unwrap();
Ok(make_set_contains_decimal(array, set, self.negated))
Expand Down Expand Up @@ -795,6 +801,12 @@ impl PhysicalExpr for InListExpr {
DataType::LargeUtf8 => {
self.compare_utf8::<i64>(array, list_values, self.negated)
}
DataType::Binary => {
self.compare_binary::<i32>(array, list_values, self.negated)
}
DataType::LargeBinary => {
self.compare_binary::<i64>(array, list_values, self.negated)
}
DataType::Null => {
let null_array = new_null_array(&DataType::Boolean, array.len());
Ok(ColumnarValue::Array(Arc::new(null_array)))
Expand Down Expand Up @@ -906,6 +918,66 @@ mod tests {
Ok(())
}

#[test]
fn in_list_binary() -> Result<()> {
let schema = Schema::new(vec![Field::new("a", DataType::Binary, true)]);
let a = BinaryArray::from(vec![
Some([1, 2, 3].as_slice()),
Some([1, 2, 2].as_slice()),
None,
]);
let col_a = col("a", &schema)?;
let batch = RecordBatch::try_new(Arc::new(schema.clone()), vec![Arc::new(a)])?;

// expression: "a in ([1, 2, 3], [4, 5, 6])"
let list = vec![lit([1, 2, 3].as_slice()), lit([4, 5, 6].as_slice())];
in_list!(
batch,
list.clone(),
&false,
vec![Some(true), Some(false), None],
col_a.clone(),
&schema
);

// expression: "a not in ([1, 2, 3], [4, 5, 6])"
in_list!(
batch,
list,
&true,
vec![Some(false), Some(true), None],
col_a.clone(),
&schema
);

// expression: "a in ([1, 2, 3], [4, 5, 6], null)"
let list = vec![
lit([1, 2, 3].as_slice()),
lit([4, 5, 6].as_slice()),
lit(ScalarValue::Binary(None)),
];
in_list!(
batch,
list.clone(),
&false,
vec![Some(true), None, None],
col_a.clone(),
&schema
);

// expression: "a in ([1, 2, 3], [4, 5, 6], null)"
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think the comment should be

// expression: "a not in ([1, 2, 3], [4, 5, 6], null)"

cc @HaoYang670

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Op, my mistake. Sorry.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Never mind.
You can fix it in the next pr

in_list!(
batch,
list,
&true,
vec![Some(false), None, None],
col_a.clone(),
&schema
);

Ok(())
}

#[test]
fn in_list_int64() -> Result<()> {
let schema = Schema::new(vec![Field::new("a", DataType::Int64, true)]);
Expand Down Expand Up @@ -1316,6 +1388,43 @@ mod tests {
Ok(())
}

#[test]
fn in_list_set_binary() -> Result<()> {
let schema = Schema::new(vec![Field::new("a", DataType::Binary, true)]);
let a = BinaryArray::from(vec![
Some([1, 2, 3].as_slice()),
Some([3, 2, 1].as_slice()),
None,
]);
let col_a = col("a", &schema)?;
let batch = RecordBatch::try_new(Arc::new(schema.clone()), vec![Arc::new(a)])?;

let mut list = vec![lit([1, 2, 3].as_slice()), lit(ScalarValue::Binary(None))];
for v in 0..OPTIMIZER_INSET_THRESHOLD {
list.push(lit([v as u8].as_slice()));
}

in_list!(
batch,
list.clone(),
&false,
vec![Some(true), None, None],
col_a.clone(),
&schema
);

in_list!(
batch,
list.clone(),
&true,
vec![Some(false), None, None],
col_a.clone(),
&schema
);

Ok(())
}

#[test]
fn in_list_set_decimal() -> Result<()> {
let schema =
Expand Down