diff --git a/datafusion/benches/filter_query_sql.rs b/datafusion/benches/filter_query_sql.rs index 253ef455f5af2..aac7f96248725 100644 --- a/datafusion/benches/filter_query_sql.rs +++ b/datafusion/benches/filter_query_sql.rs @@ -25,16 +25,14 @@ use datafusion::prelude::ExecutionContext; use datafusion::{datasource::MemTable, error::Result}; use futures::executor::block_on; use std::sync::Arc; +use tokio::runtime::Runtime; async fn query(ctx: &mut ExecutionContext, sql: &str) { + let rt = Runtime::new().unwrap(); + // execute the query let df = ctx.sql(sql).unwrap(); - let results = df.collect().await.unwrap(); - - // display the relation - for _batch in results { - // println!("num_rows: {}", _batch.num_rows()); - } + criterion::black_box(rt.block_on(df.collect()).unwrap()); } fn create_context(array_len: usize, batch_size: usize) -> Result { @@ -85,6 +83,16 @@ fn criterion_benchmark(c: &mut Criterion) { )) }) }); + + c.bench_function("filter_scalar in list", |b| { + let mut ctx = create_context(array_len, batch_size).unwrap(); + b.iter(|| { + block_on(query( + &mut ctx, + "select f32, f64 from t where f32 in (10, 20, 30, 40)", + )) + }) + }); } criterion_group!(benches, criterion_benchmark); diff --git a/datafusion/src/physical_plan/expressions/in_list.rs b/datafusion/src/physical_plan/expressions/in_list.rs index 38b2b9d45b9bb..00767c7a67079 100644 --- a/datafusion/src/physical_plan/expressions/in_list.rs +++ b/datafusion/src/physical_plan/expressions/in_list.rs @@ -26,14 +26,39 @@ use arrow::array::{ Int64Array, Int8Array, StringOffsetSizeTrait, UInt16Array, UInt32Array, UInt64Array, UInt8Array, }; +use arrow::datatypes::ArrowPrimitiveType; use arrow::{ datatypes::{DataType, Schema}, record_batch::RecordBatch, }; -use crate::error::Result; +use crate::error::{DataFusionError, Result}; use crate::physical_plan::{ColumnarValue, PhysicalExpr}; use crate::scalar::ScalarValue; +use arrow::array::*; +use arrow::buffer::{Buffer, MutableBuffer}; + +macro_rules! compare_op_scalar { + ($left: expr, $right:expr, $op:expr) => {{ + let null_bit_buffer = $left.data().null_buffer().cloned(); + + let comparison = + (0..$left.len()).map(|i| unsafe { $op($left.value_unchecked(i), $right) }); + // same as $left.len() + let buffer = unsafe { MutableBuffer::from_trusted_len_iter_bool(comparison) }; + + let data = ArrayData::new( + DataType::Boolean, + $left.len(), + None, + null_bit_buffer, + 0, + vec![Buffer::from(buffer)], + vec![], + ); + Ok(BooleanArray::from(data)) + }}; +} /// InList #[derive(Debug)] @@ -47,20 +72,16 @@ macro_rules! make_contains { ($ARRAY:expr, $LIST_VALUES:expr, $NEGATED:expr, $SCALAR_VALUE:ident, $ARRAY_TYPE:ident) => {{ let array = $ARRAY.as_any().downcast_ref::<$ARRAY_TYPE>().unwrap(); - let mut contains_null = false; + let contains_null = $LIST_VALUES + .iter() + .any(|v| matches!(v, ColumnarValue::Scalar(s) if s.is_null())); let values = $LIST_VALUES .iter() .flat_map(|expr| match expr { ColumnarValue::Scalar(s) => match s { ScalarValue::$SCALAR_VALUE(Some(v)) => Some(*v), - ScalarValue::$SCALAR_VALUE(None) => { - contains_null = true; - None - } - ScalarValue::Utf8(None) => { - contains_null = true; - None - } + ScalarValue::$SCALAR_VALUE(None) => None, + ScalarValue::Utf8(None) => None, datatype => unimplemented!("Unexpected type {} for InList", datatype), }, ColumnarValue::Array(_) => { @@ -99,6 +120,103 @@ macro_rules! make_contains { }}; } +macro_rules! make_contains_primitive { + ($ARRAY:expr, $LIST_VALUES:expr, $NEGATED:expr, $SCALAR_VALUE:ident, $ARRAY_TYPE:ident) => {{ + let array = $ARRAY.as_any().downcast_ref::<$ARRAY_TYPE>().unwrap(); + + let contains_null = $LIST_VALUES + .iter() + .any(|v| matches!(v, ColumnarValue::Scalar(s) if s.is_null())); + let values = $LIST_VALUES + .iter() + .flat_map(|expr| match expr { + ColumnarValue::Scalar(s) => match s { + ScalarValue::$SCALAR_VALUE(Some(v)) => Some(*v), + ScalarValue::$SCALAR_VALUE(None) => None, + ScalarValue::Utf8(None) => None, + datatype => unimplemented!("Unexpected type {} for InList", datatype), + }, + ColumnarValue::Array(_) => { + unimplemented!("InList does not yet support nested columns.") + } + }) + .collect::>(); + + if $NEGATED { + if contains_null { + Ok(ColumnarValue::Array(Arc::new( + array + .iter() + .map(|x| match x.map(|v| !values.contains(&v)) { + Some(true) => None, + x => x, + }) + .collect::(), + ))) + } else { + Ok(ColumnarValue::Array(Arc::new( + not_in_list_primitive(array, &values)?, + ))) + } + } else { + if contains_null { + Ok(ColumnarValue::Array(Arc::new( + array + .iter() + .map(|x| match x.map(|v| values.contains(&v)) { + Some(false) => None, + x => x, + }) + .collect::(), + ))) + } else { + Ok(ColumnarValue::Array(Arc::new(in_list_primitive( + array, &values, + )?))) + } + } + }}; +} + +// whether each value on the left (can be null) is contained in the non-null list +fn in_list_primitive( + array: &PrimitiveArray, + values: &[::Native], +) -> Result { + compare_op_scalar!( + array, + values, + |x, v: &[::Native]| v.contains(&x) + ) +} + +// whether each value on the left (can be null) is contained in the non-null list +fn not_in_list_primitive( + array: &PrimitiveArray, + values: &[::Native], +) -> Result { + compare_op_scalar!( + array, + values, + |x, v: &[::Native]| !v.contains(&x) + ) +} + +// whether each value on the left (can be null) is contained in the non-null list +fn in_list_utf8( + array: &GenericStringArray, + values: &[&str], +) -> Result { + compare_op_scalar!(array, values, |x, v: &[&str]| v.contains(&x)) +} + +fn not_in_list_utf8( + array: &GenericStringArray, + values: &[&str], +) -> Result { + compare_op_scalar!(array, values, |x, v: &[&str]| !v.contains(&x)) +} + impl InListExpr { /// Create a new InList expression pub fn new( @@ -141,21 +259,17 @@ impl InListExpr { .downcast_ref::>() .unwrap(); - let mut contains_null = false; + let contains_null = list_values + .iter() + .any(|v| matches!(v, ColumnarValue::Scalar(s) if s.is_null())); let values = list_values .iter() .flat_map(|expr| match expr { ColumnarValue::Scalar(s) => match s { ScalarValue::Utf8(Some(v)) => Some(v.as_str()), - ScalarValue::Utf8(None) => { - contains_null = true; - None - } + ScalarValue::Utf8(None) => None, ScalarValue::LargeUtf8(Some(v)) => Some(v.as_str()), - ScalarValue::LargeUtf8(None) => { - contains_null = true; - None - } + ScalarValue::LargeUtf8(None) => None, datatype => unimplemented!("Unexpected type {} for InList", datatype), }, ColumnarValue::Array(_) => { @@ -164,33 +278,37 @@ impl InListExpr { }) .collect::>(); - Ok(ColumnarValue::Array(Arc::new( - array - .iter() - .map(|x| { - let contains = x.map(|x| values.contains(&x)); - match contains { - Some(true) => { - if negated { - Some(false) - } else { - Some(true) - } - } - Some(false) => { - if contains_null { - None - } else if negated { - Some(true) - } else { - Some(false) - } - } - None => None, - } - }) - .collect::(), - ))) + if negated { + if contains_null { + Ok(ColumnarValue::Array(Arc::new( + array + .iter() + .map(|x| match x.map(|v| !values.contains(&v)) { + Some(true) => None, + x => x, + }) + .collect::(), + ))) + } else { + Ok(ColumnarValue::Array(Arc::new(not_in_list_utf8( + array, &values, + )?))) + } + } else if contains_null { + Ok(ColumnarValue::Array(Arc::new( + array + .iter() + .map(|x| match x.map(|v| values.contains(&v)) { + Some(false) => None, + x => x, + }) + .collect::(), + ))) + } else { + Ok(ColumnarValue::Array(Arc::new(in_list_utf8( + array, &values, + )?))) + } } } @@ -234,34 +352,94 @@ impl PhysicalExpr for InListExpr { match value_data_type { DataType::Float32 => { - make_contains!(array, list_values, self.negated, Float32, Float32Array) + make_contains_primitive!( + array, + list_values, + self.negated, + Float32, + Float32Array + ) } DataType::Float64 => { - make_contains!(array, list_values, self.negated, Float64, Float64Array) + make_contains_primitive!( + array, + list_values, + self.negated, + Float64, + Float64Array + ) } DataType::Int16 => { - make_contains!(array, list_values, self.negated, Int16, Int16Array) + make_contains_primitive!( + array, + list_values, + self.negated, + Int16, + Int16Array + ) } DataType::Int32 => { - make_contains!(array, list_values, self.negated, Int32, Int32Array) + make_contains_primitive!( + array, + list_values, + self.negated, + Int32, + Int32Array + ) } DataType::Int64 => { - make_contains!(array, list_values, self.negated, Int64, Int64Array) + make_contains_primitive!( + array, + list_values, + self.negated, + Int64, + Int64Array + ) } DataType::Int8 => { - make_contains!(array, list_values, self.negated, Int8, Int8Array) + make_contains_primitive!( + array, + list_values, + self.negated, + Int8, + Int8Array + ) } DataType::UInt16 => { - make_contains!(array, list_values, self.negated, UInt16, UInt16Array) + make_contains_primitive!( + array, + list_values, + self.negated, + UInt16, + UInt16Array + ) } DataType::UInt32 => { - make_contains!(array, list_values, self.negated, UInt32, UInt32Array) + make_contains_primitive!( + array, + list_values, + self.negated, + UInt32, + UInt32Array + ) } DataType::UInt64 => { - make_contains!(array, list_values, self.negated, UInt64, UInt64Array) + make_contains_primitive!( + array, + list_values, + self.negated, + UInt64, + UInt64Array + ) } DataType::UInt8 => { - make_contains!(array, list_values, self.negated, UInt8, UInt8Array) + make_contains_primitive!( + array, + list_values, + self.negated, + UInt8, + UInt8Array + ) } DataType::Boolean => { make_contains!(array, list_values, self.negated, Boolean, BooleanArray) @@ -270,9 +448,10 @@ impl PhysicalExpr for InListExpr { DataType::LargeUtf8 => { self.compare_utf8::(array, list_values, self.negated) } - datatype => { - unimplemented!("InList does not support datatype {:?}.", datatype) - } + datatype => Result::Err(DataFusionError::NotImplemented(format!( + "InList does not support datatype {:?}.", + datatype + ))), } } }