-
Notifications
You must be signed in to change notification settings - Fork 2.1k
feat: Improve InListExpr types, flatten dict haystacks and validate in try_new_from_array #21402
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -29,7 +29,7 @@ use arrow::array::*; | |
| use arrow::buffer::{BooleanBuffer, NullBuffer}; | ||
| use arrow::compute::kernels::boolean::{not, or_kleene}; | ||
| use arrow::compute::kernels::cmp::eq as arrow_eq; | ||
| use arrow::compute::{SortOptions, take}; | ||
| use arrow::compute::{SortOptions, cast, take}; | ||
| use arrow::datatypes::*; | ||
| use arrow::util::bit_iterator::BitIndexIterator; | ||
| use datafusion_common::hash_utils::with_hashes; | ||
|
|
@@ -43,11 +43,21 @@ use datafusion_common::HashMap; | |
| use datafusion_common::hash_utils::RandomState; | ||
| use hashbrown::hash_map::RawEntryMut; | ||
|
|
||
| /// Trait for InList static filters | ||
| /// Trait for InList static filters. | ||
| /// | ||
| /// Static filters store a pre-computed set of values (the haystack) and check | ||
| /// whether needle values are contained in that set. The haystack is always | ||
| /// represented in its non-dictionary (value) type. Dictionary haystacks are | ||
| /// flattened via `cast()` before construction. | ||
| /// | ||
| /// Dictionary-encoded needles are unwrapped inside `contains()` and | ||
| /// evaluated against the dictionary's values. | ||
| trait StaticFilter { | ||
| fn null_count(&self) -> usize; | ||
|
|
||
| /// Checks if values in `v` are contained in the filter | ||
| /// Checks if values in `v` (needle) are contained in this filter's | ||
| /// haystack. `v` may be dictionary-encoded, in which case the | ||
| /// implementation unwraps the dictionary and operates on its values. | ||
| fn contains(&self, v: &dyn Array, negated: bool) -> Result<BooleanArray>; | ||
| } | ||
|
|
||
|
|
@@ -164,6 +174,13 @@ fn supports_arrow_eq(dt: &DataType) -> bool { | |
| fn instantiate_static_filter( | ||
| in_array: ArrayRef, | ||
| ) -> Result<Arc<dyn StaticFilter + Send + Sync>> { | ||
| // Flatten dictionary-encoded haystacks to their value type so that | ||
| // specialized filters (e.g. Int32StaticFilter) are used instead of | ||
| // falling through to the generic ArrayStaticFilter. | ||
| let in_array = match in_array.data_type() { | ||
| DataType::Dictionary(_, value_type) => cast(&in_array, value_type.as_ref())?, | ||
| _ => in_array, | ||
| }; | ||
| match in_array.data_type() { | ||
| // Integer primitive types | ||
| DataType::Int8 => Ok(Arc::new(Int8StaticFilter::try_new(&in_array)?)), | ||
|
|
@@ -642,20 +659,34 @@ impl InListExpr { | |
|
|
||
| /// Create a new InList expression directly from an array, bypassing expression evaluation. | ||
| /// | ||
| /// This is more efficient than `in_list()` when you already have the list as an array, | ||
| /// as it avoids the conversion: `ArrayRef -> Vec<PhysicalExpr> -> ArrayRef -> StaticFilter`. | ||
| /// Instead it goes directly: `ArrayRef -> StaticFilter`. | ||
| /// This is more efficient than [`InListExpr::try_new`] when you already have the list | ||
| /// as an array, as it builds the static filter directly from the array instead of | ||
| /// reconstructing an intermediate array from literal expressions. | ||
| /// | ||
| /// The `list` field is populated with literal expressions extracted from | ||
| /// the array, and the array is used to build a static filter for | ||
| /// efficient set membership evaluation. | ||
| /// | ||
| /// The `list` field will be empty when using this constructor, as the array is stored | ||
| /// directly in the static filter. | ||
| /// The `array` may be dictionary-encoded — it will be flattened to its | ||
| /// value type such that specialized filters are used. | ||
| /// | ||
| /// This does not make the expression any more performant at runtime, but it does make it slightly | ||
| /// cheaper to build. | ||
| /// Returns an error if the expression's data type and the array's data type | ||
| /// are not logically equal. Null arrays are always accepted. | ||
| pub fn try_new_from_array( | ||
| expr: Arc<dyn PhysicalExpr>, | ||
| array: ArrayRef, | ||
| negated: bool, | ||
| schema: &Schema, | ||
| ) -> Result<Self> { | ||
| let expr_data_type = expr.data_type(schema)?; | ||
| let array_data_type = array.data_type(); | ||
| if *array_data_type != DataType::Null { | ||
| assert_or_internal_err!( | ||
| DFSchema::datatype_is_logically_equal(&expr_data_type, array_data_type), | ||
| "The data type inlist should be same, the value type is {expr_data_type}, one of list expr type is {array_data_type}" | ||
| ); | ||
| } | ||
|
|
||
| let list = (0..array.len()) | ||
| .map(|i| { | ||
| let scalar = ScalarValue::try_from_array(array.as_ref(), i)?; | ||
|
|
@@ -2318,6 +2349,7 @@ mod tests { | |
| Arc::clone(&col_a), | ||
| array, | ||
| false, | ||
| &schema, | ||
| )?) as Arc<dyn PhysicalExpr>; | ||
|
|
||
| // Create test data: [1, 2, 3, 4, null] | ||
|
|
@@ -2447,6 +2479,7 @@ mod tests { | |
| Arc::clone(&col_a), | ||
| null_array, | ||
| false, | ||
| &schema, | ||
| )?) as Arc<dyn PhysicalExpr>; | ||
| let batch = RecordBatch::try_new(Arc::new(schema.clone()), vec![Arc::new(a)])?; | ||
| let result = expr.evaluate(&batch)?.into_array(batch.num_rows())?; | ||
|
|
@@ -2475,6 +2508,7 @@ mod tests { | |
| Arc::clone(&col_a), | ||
| null_array, | ||
| false, | ||
| &schema, | ||
| )?) as Arc<dyn PhysicalExpr>; | ||
|
|
||
| let batch = RecordBatch::try_new(Arc::new(schema.clone()), vec![Arc::new(a)])?; | ||
|
|
@@ -3911,8 +3945,9 @@ mod tests { | |
| let schema = | ||
| Schema::new(vec![Field::new("a", needle.data_type().clone(), false)]); | ||
| let col_a = col("a", &schema)?; | ||
| let expr = Arc::new(InListExpr::try_new_from_array(col_a, in_array, false)?) | ||
| as Arc<dyn PhysicalExpr>; | ||
| let expr = Arc::new(InListExpr::try_new_from_array( | ||
| col_a, in_array, false, &schema, | ||
| )?) as Arc<dyn PhysicalExpr>; | ||
| let batch = RecordBatch::try_new(Arc::new(schema), vec![needle])?; | ||
| let result = expr.evaluate(&batch)?.into_array(batch.num_rows())?; | ||
| Ok(as_boolean_array(&result).clone()) | ||
|
|
@@ -4045,43 +4080,182 @@ mod tests { | |
| Ok(()) | ||
| } | ||
|
|
||
| fn make_int32_dict_array(values: Vec<Option<i32>>) -> ArrayRef { | ||
| let mut builder = PrimitiveDictionaryBuilder::<Int8Type, Int32Type>::new(); | ||
| for v in values { | ||
| match v { | ||
| Some(val) => builder.append_value(val), | ||
| None => builder.append_null(), | ||
| } | ||
| } | ||
| Arc::new(builder.finish()) | ||
| } | ||
|
|
||
| fn make_f64_dict_array(values: Vec<Option<f64>>) -> ArrayRef { | ||
| let mut builder = PrimitiveDictionaryBuilder::<Int8Type, Float64Type>::new(); | ||
| for v in values { | ||
| match v { | ||
| Some(val) => builder.append_value(val), | ||
| None => builder.append_null(), | ||
| } | ||
| } | ||
| Arc::new(builder.finish()) | ||
| } | ||
|
|
||
| #[test] | ||
| fn test_try_new_from_array_dict_haystack_int32() -> Result<()> { | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Nice coverage here for primitive, string, and float dictionary haystacks. One gap I still see is the multi-column or |
||
| let schema = Schema::new(vec![Field::new("a", DataType::Int32, false)]); | ||
| let needle = Int32Array::from(vec![1, 2, 3, 4]); | ||
| let batch = | ||
| RecordBatch::try_new(Arc::new(schema.clone()), vec![Arc::new(needle)])?; | ||
|
|
||
| let haystack = make_int32_dict_array(vec![Some(1), None, Some(3)]); | ||
|
|
||
| let col_a = col("a", &schema)?; | ||
| let expr = InListExpr::try_new_from_array(col_a, haystack, false, &schema)?; | ||
| let result = expr.evaluate(&batch)?.into_array(batch.num_rows())?; | ||
| let result = as_boolean_array(&result); | ||
| assert_eq!( | ||
| result, | ||
| &BooleanArray::from(vec![Some(true), None, Some(true), None]) | ||
| ); | ||
|
|
||
| Ok(()) | ||
| } | ||
|
|
||
| #[test] | ||
| fn test_in_list_from_array_type_mismatch_errors() -> Result<()> { | ||
| // Utf8 needle, Dict(Utf8) in_array | ||
| let err = eval_in_list_from_array( | ||
| Arc::new(StringArray::from(vec!["a", "d", "b"])), | ||
| wrap_in_dict(Arc::new(StringArray::from(vec!["a", "b", "c"]))), | ||
| ) | ||
| .unwrap_err() | ||
| .to_string(); | ||
| assert!( | ||
| err.contains("Can't compare arrays of different types"), | ||
| "{err}" | ||
| // Utf8 needle, Dict(Utf8) in_array: now works with dict haystack support | ||
| assert_eq!( | ||
| BooleanArray::from(vec![Some(true), Some(false), Some(true)]), | ||
| eval_in_list_from_array( | ||
| Arc::new(StringArray::from(vec!["a", "d", "b"])), | ||
| wrap_in_dict(Arc::new(StringArray::from(vec!["a", "b", "c"]))), | ||
| )? | ||
| ); | ||
|
|
||
| // Dict(Utf8) needle, Int64 in_array: specialized Int64StaticFilter | ||
| // rejects the Utf8 dictionary values at construction time | ||
| // Dict(Utf8) needle, Int64 in_array: type validation rejects at construction | ||
| let err = eval_in_list_from_array( | ||
| wrap_in_dict(Arc::new(StringArray::from(vec!["a", "d", "b"]))), | ||
| Arc::new(Int64Array::from(vec![1, 2, 3])), | ||
| ) | ||
| .unwrap_err() | ||
| .to_string(); | ||
| assert!(err.contains("Failed to downcast"), "{err}"); | ||
| assert!(err.contains("The data type inlist should be same"), "{err}"); | ||
|
|
||
| // Dict(Int64) needle, Dict(Utf8) in_array: both Dict but different | ||
| // value types, make_comparator rejects the comparison | ||
| // value types, type validation rejects at construction | ||
| let err = eval_in_list_from_array( | ||
| wrap_in_dict(Arc::new(Int64Array::from(vec![1, 4, 2]))), | ||
| wrap_in_dict(Arc::new(StringArray::from(vec!["a", "b", "c"]))), | ||
| ) | ||
| .unwrap_err() | ||
| .to_string(); | ||
| assert!( | ||
| err.contains("Can't compare arrays of different types"), | ||
| "{err}" | ||
| assert!(err.contains("The data type inlist should be same"), "{err}"); | ||
|
|
||
| Ok(()) | ||
| } | ||
|
|
||
| #[test] | ||
| fn test_try_new_from_array_dict_haystack_negated() -> Result<()> { | ||
| let schema = Schema::new(vec![Field::new("a", DataType::Int32, false)]); | ||
| let needle = Int32Array::from(vec![1, 2, 3, 4]); | ||
| let batch = | ||
| RecordBatch::try_new(Arc::new(schema.clone()), vec![Arc::new(needle)])?; | ||
|
|
||
| let haystack = make_int32_dict_array(vec![Some(1), None, Some(3)]); | ||
|
|
||
| let col_a = col("a", &schema)?; | ||
| let expr = InListExpr::try_new_from_array(col_a, haystack, true, &schema)?; | ||
| let result = expr.evaluate(&batch)?.into_array(batch.num_rows())?; | ||
| let result = as_boolean_array(&result); | ||
| assert_eq!( | ||
| result, | ||
| &BooleanArray::from(vec![Some(false), None, Some(false), None]) | ||
| ); | ||
|
|
||
| Ok(()) | ||
| } | ||
|
|
||
| #[test] | ||
| fn test_try_new_from_array_dict_haystack_utf8() -> Result<()> { | ||
| let schema = Schema::new(vec![Field::new("a", DataType::Utf8, false)]); | ||
| let needle = StringArray::from(vec!["a", "b", "c"]); | ||
| let batch = | ||
| RecordBatch::try_new(Arc::new(schema.clone()), vec![Arc::new(needle)])?; | ||
|
|
||
| let dict_builder = StringDictionaryBuilder::<Int8Type>::new(); | ||
| let mut builder = dict_builder; | ||
| builder.append_value("a"); | ||
| builder.append_value("c"); | ||
| let haystack: ArrayRef = Arc::new(builder.finish()); | ||
|
|
||
| let col_a = col("a", &schema)?; | ||
| let expr = InListExpr::try_new_from_array(col_a, haystack, false, &schema)?; | ||
| let result = expr.evaluate(&batch)?.into_array(batch.num_rows())?; | ||
| let result = as_boolean_array(&result); | ||
| assert_eq!( | ||
| result, | ||
| &BooleanArray::from(vec![Some(true), Some(false), Some(true)]) | ||
| ); | ||
|
|
||
| Ok(()) | ||
| } | ||
|
|
||
| #[test] | ||
| fn test_try_new_from_array_dict_needle_and_plain_haystack() -> Result<()> { | ||
| let schema = Schema::new(vec![Field::new( | ||
| "a", | ||
| DataType::Dictionary(Box::new(DataType::Int8), Box::new(DataType::Int32)), | ||
| false, | ||
| )]); | ||
|
|
||
| let needle = make_int32_dict_array(vec![Some(1), Some(2), Some(3), Some(4)]); | ||
| let batch = | ||
| RecordBatch::try_new(Arc::new(schema.clone()), vec![Arc::clone(&needle)])?; | ||
|
|
||
| let haystack: ArrayRef = Arc::new(Int32Array::from(vec![1, 3])); | ||
| let col_a = col("a", &schema)?; | ||
| let expr = InListExpr::try_new_from_array(col_a, haystack, false, &schema)?; | ||
| let result = expr.evaluate(&batch)?.into_array(batch.num_rows())?; | ||
| let result = as_boolean_array(&result); | ||
| assert_eq!( | ||
| result, | ||
| &BooleanArray::from(vec![Some(true), Some(false), Some(true), Some(false)]) | ||
| ); | ||
|
|
||
| Ok(()) | ||
| } | ||
|
|
||
| #[test] | ||
| fn test_try_new_from_array_dict_haystack_float64() -> Result<()> { | ||
| let schema = Schema::new(vec![Field::new("a", DataType::Float64, false)]); | ||
| let needle = Float64Array::from(vec![1.0, 2.0, 3.0]); | ||
| let batch = | ||
| RecordBatch::try_new(Arc::new(schema.clone()), vec![Arc::new(needle)])?; | ||
|
|
||
| let haystack = make_f64_dict_array(vec![Some(1.0), Some(3.0)]); | ||
|
|
||
| let col_a = col("a", &schema)?; | ||
| let expr = InListExpr::try_new_from_array(col_a, haystack, false, &schema)?; | ||
| let result = expr.evaluate(&batch)?.into_array(batch.num_rows())?; | ||
| let result = as_boolean_array(&result); | ||
| assert_eq!( | ||
| result, | ||
| &BooleanArray::from(vec![Some(true), Some(false), Some(true)]) | ||
| ); | ||
|
|
||
| Ok(()) | ||
| } | ||
|
|
||
| #[test] | ||
| fn test_try_new_from_array_type_mismatch_rejects() -> Result<()> { | ||
| let schema = Schema::new(vec![Field::new("a", DataType::Int32, false)]); | ||
| let col_a = col("a", &schema)?; | ||
| let haystack: ArrayRef = Arc::new(Float64Array::from(vec![1.0, 2.0])); | ||
|
|
||
| let result = InListExpr::try_new_from_array(col_a, haystack, false, &schema); | ||
| assert!(result.is_err()); | ||
| Ok(()) | ||
| } | ||
| } | ||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I noticed that
try_new_from_arraynow does its own logical type validation, buttry_newstill has a very similar check a few lines below.Would it make sense to extract this into a small helper so both constructors share the same validation path and error message? That might help avoid subtle drift in the future, especially around dictionary or logical equality handling.