diff --git a/datafusion/common/src/scalar/mod.rs b/datafusion/common/src/scalar/mod.rs index 1247aba3258d..8c079056e21d 100644 --- a/datafusion/common/src/scalar/mod.rs +++ b/datafusion/common/src/scalar/mod.rs @@ -63,11 +63,12 @@ use arrow::array::{ FixedSizeListArray, Float16Array, Float32Array, Float64Array, GenericListArray, Int16Array, Int32Array, Int64Array, Int8Array, IntervalDayTimeArray, IntervalMonthDayNanoArray, IntervalYearMonthArray, LargeBinaryArray, LargeListArray, - LargeStringArray, ListArray, MapArray, MutableArrayData, PrimitiveArray, Scalar, - StringArray, StringViewArray, StructArray, Time32MillisecondArray, Time32SecondArray, - Time64MicrosecondArray, Time64NanosecondArray, TimestampMicrosecondArray, - TimestampMillisecondArray, TimestampNanosecondArray, TimestampSecondArray, - UInt16Array, UInt32Array, UInt64Array, UInt8Array, UnionArray, + LargeStringArray, ListArray, MapArray, MutableArrayData, OffsetSizeTrait, + PrimitiveArray, Scalar, StringArray, StringViewArray, StructArray, + Time32MillisecondArray, Time32SecondArray, Time64MicrosecondArray, + Time64NanosecondArray, TimestampMicrosecondArray, TimestampMillisecondArray, + TimestampNanosecondArray, TimestampSecondArray, UInt16Array, UInt32Array, + UInt64Array, UInt8Array, UnionArray, }; use arrow::buffer::ScalarBuffer; use arrow::compute::kernels::cast::{cast_with_options, CastOptions}; @@ -3305,17 +3306,30 @@ impl ScalarValue { /// assert_eq!(scalar_vec, expected); /// ``` pub fn convert_array_to_scalar_vec(array: &dyn Array) -> Result>> { - let mut scalars = Vec::with_capacity(array.len()); - - for index in 0..array.len() { - let nested_array = array.as_list::().value(index); - let scalar_values = (0..nested_array.len()) - .map(|i| ScalarValue::try_from_array(&nested_array, i)) - .collect::>>()?; - scalars.push(scalar_values); + fn generic_collect( + array: &dyn Array, + ) -> Result>> { + array + .as_list::() + .iter() + .map(|nested_array| match nested_array { + Some(nested_array) => (0..nested_array.len()) + .map(|i| ScalarValue::try_from_array(&nested_array, i)) + .collect::>>(), + // TODO: what can we put for null? + // https://github.com/apache/datafusion/issues/17749 + None => Ok(vec![]), + }) + .collect() } - Ok(scalars) + match array.data_type() { + DataType::List(_) => generic_collect::(array), + DataType::LargeList(_) => generic_collect::(array), + _ => _internal_err!( + "ScalarValue::convert_array_to_scalar_vec input must be a List/LargeList type" + ), + } } #[deprecated( @@ -4947,6 +4961,8 @@ impl ScalarType for Date32Type { #[cfg(test)] mod tests { + use std::sync::Arc; + use super::*; use crate::cast::{as_list_array, as_map_array, as_struct_array}; use crate::test_util::batches_to_string; @@ -4955,7 +4971,7 @@ mod tests { NullArray, NullBufferBuilder, OffsetSizeTrait, PrimitiveBuilder, RecordBatch, StringBuilder, StringDictionaryBuilder, StructBuilder, UnionBuilder, }; - use arrow::buffer::{Buffer, OffsetBuffer}; + use arrow::buffer::{Buffer, NullBuffer, OffsetBuffer}; use arrow::compute::{is_null, kernels}; use arrow::datatypes::{ ArrowNumericType, Fields, Float64Type, DECIMAL256_MAX_PRECISION, @@ -8996,4 +9012,66 @@ mod tests { _ => panic!("Expected TimestampMillisecond with timezone"), } } + + #[test] + fn test_convert_array_to_scalar_vec() { + // Regular ListArray + let list = ListArray::from_iter_primitive::(vec![ + Some(vec![Some(1), Some(2)]), + None, + Some(vec![Some(3), None, Some(4)]), + ]); + let converted = ScalarValue::convert_array_to_scalar_vec(&list).unwrap(); + assert_eq!( + converted, + vec![ + vec![ScalarValue::Int64(Some(1)), ScalarValue::Int64(Some(2))], + vec![], + vec![ + ScalarValue::Int64(Some(3)), + ScalarValue::Int64(None), + ScalarValue::Int64(Some(4)) + ], + ] + ); + + // Regular LargeListArray + let large_list = LargeListArray::from_iter_primitive::(vec![ + Some(vec![Some(1), Some(2)]), + None, + Some(vec![Some(3), None, Some(4)]), + ]); + let converted = ScalarValue::convert_array_to_scalar_vec(&large_list).unwrap(); + assert_eq!( + converted, + vec![ + vec![ScalarValue::Int64(Some(1)), ScalarValue::Int64(Some(2))], + vec![], + vec![ + ScalarValue::Int64(Some(3)), + ScalarValue::Int64(None), + ScalarValue::Int64(Some(4)) + ], + ] + ); + + // Funky (null slot has non-zero list offsets) + // Offsets + Values looks like this: [[1, 2], [3, 4], [5]] + // But with NullBuffer it's like this: [[1, 2], NULL, [5]] + let funky = ListArray::new( + Field::new_list_field(DataType::Int64, true).into(), + OffsetBuffer::new(vec![0, 2, 4, 5].into()), + Arc::new(Int64Array::from(vec![1, 2, 3, 4, 5, 6])), + Some(NullBuffer::from(vec![true, false, true])), + ); + let converted = ScalarValue::convert_array_to_scalar_vec(&funky).unwrap(); + assert_eq!( + converted, + vec![ + vec![ScalarValue::Int64(Some(1)), ScalarValue::Int64(Some(2))], + vec![], + vec![ScalarValue::Int64(Some(5))], + ] + ); + } } diff --git a/datafusion/functions-nested/src/array_has.rs b/datafusion/functions-nested/src/array_has.rs index a6375808b102..f77cc5dd7b39 100644 --- a/datafusion/functions-nested/src/array_has.rs +++ b/datafusion/functions-nested/src/array_has.rs @@ -25,10 +25,10 @@ use datafusion_common::cast::{as_fixed_size_list_array, as_generic_list_array}; use datafusion_common::utils::string_utils::string_array_to_vec; use datafusion_common::utils::take_function_args; use datafusion_common::{exec_err, DataFusionError, Result, ScalarValue}; -use datafusion_expr::expr::{InList, ScalarFunction}; +use datafusion_expr::expr::ScalarFunction; use datafusion_expr::simplify::ExprSimplifyResult; use datafusion_expr::{ - ColumnarValue, Documentation, Expr, ScalarUDFImpl, Signature, Volatility, + in_list, ColumnarValue, Documentation, Expr, ScalarUDFImpl, Signature, Volatility, }; use datafusion_macros::user_doc; use datafusion_physical_expr_common::datum::compare_with_eq; @@ -131,40 +131,42 @@ impl ScalarUDFImpl for ArrayHas { // if the haystack is a constant list, we can use an inlist expression which is more // efficient because the haystack is not varying per-row - if let Expr::Literal(ScalarValue::List(array), _) = haystack { - // TODO: support LargeList - // (not supported by `convert_array_to_scalar_vec`) - // (FixedSizeList not supported either, but seems to have worked fine when attempting to - // build a reproducer) - - assert_eq!(array.len(), 1); // guarantee of ScalarValue - if let Ok(scalar_values) = - ScalarValue::convert_array_to_scalar_vec(array.as_ref()) - { - assert_eq!(scalar_values.len(), 1); - let list = scalar_values - .into_iter() - .flatten() - .map(|v| Expr::Literal(v, None)) - .collect(); - - return Ok(ExprSimplifyResult::Simplified(Expr::InList(InList { - expr: Box::new(std::mem::take(needle)), - list, - negated: false, - }))); + match haystack { + Expr::Literal( + // FixedSizeList gets coerced to List + scalar @ ScalarValue::List(_) | scalar @ ScalarValue::LargeList(_), + _, + ) => { + let array = scalar.to_array().unwrap(); // guarantee of ScalarValue + if let Ok(scalar_values) = + ScalarValue::convert_array_to_scalar_vec(&array) + { + assert_eq!(scalar_values.len(), 1); + let list = scalar_values + .into_iter() + .flatten() + .map(|v| Expr::Literal(v, None)) + .collect(); + + return Ok(ExprSimplifyResult::Simplified(in_list( + std::mem::take(needle), + list, + false, + ))); + } } - } else if let Expr::ScalarFunction(ScalarFunction { func, args }) = haystack { - // make_array has a static set of arguments, so we can pull the arguments out from it - if func == &make_array_udf() { - return Ok(ExprSimplifyResult::Simplified(Expr::InList(InList { - expr: Box::new(std::mem::take(needle)), - list: std::mem::take(args), - negated: false, - }))); + Expr::ScalarFunction(ScalarFunction { func, args }) + if func == &make_array_udf() => + { + // make_array has a static set of arguments, so we can pull the arguments out from it + return Ok(ExprSimplifyResult::Simplified(in_list( + std::mem::take(needle), + std::mem::take(args), + false, + ))); } - } - + _ => {} + }; Ok(ExprSimplifyResult::Original(args)) } diff --git a/datafusion/sqllogictest/test_files/array.slt b/datafusion/sqllogictest/test_files/array.slt index a135f1d184c4..764488e00f07 100644 --- a/datafusion/sqllogictest/test_files/array.slt +++ b/datafusion/sqllogictest/test_files/array.slt @@ -6470,14 +6470,12 @@ physical_plan 08)--------------RepartitionExec: partitioning=RoundRobinBatch(4), input_partitions=1 09)----------------LazyMemoryExec: partitions=1, batch_generators=[generate_series: start=1, end=100000, batch_size=8192] -# FIXME: due to rewrite below not working, this is _extremely_ slow to evaluate -# query I -# with test AS (SELECT substr(md5(i)::text, 1, 32) as needle FROM generate_series(1, 100000) t(i)) -# select count(*) from test WHERE array_has(arrow_cast(['7f4b18de3cfeb9b4ac78c381ee2ad278', 'a', 'b', 'c'], 'LargeList(Utf8View)'), needle); -# ---- -# 1 +query I +with test AS (SELECT substr(md5(i::text)::text, 1, 32) as needle FROM generate_series(1, 100000) t(i)) +select count(*) from test WHERE array_has(arrow_cast(['7f4b18de3cfeb9b4ac78c381ee2ad278', 'a', 'b', 'c'], 'LargeList(Utf8View)'), needle); +---- +1 -# FIXME: array_has with large list haystack not currently rewritten to InList query TT explain with test AS (SELECT substr(md5(i::text)::text, 1, 32) as needle FROM generate_series(1, 100000) t(i)) select count(*) from test WHERE array_has(arrow_cast(['7f4b18de3cfeb9b4ac78c381ee2ad278', 'a', 'b', 'c'], 'LargeList(Utf8View)'), needle); @@ -6488,7 +6486,7 @@ logical_plan 03)----SubqueryAlias: test 04)------SubqueryAlias: t 05)--------Projection: -06)----------Filter: array_has(LargeList([7f4b18de3cfeb9b4ac78c381ee2ad278, a, b, c]), substr(CAST(md5(CAST(generate_series().value AS Utf8View)) AS Utf8View), Int64(1), Int64(32))) +06)----------Filter: substr(CAST(md5(CAST(generate_series().value AS Utf8View)) AS Utf8View), Int64(1), Int64(32)) IN ([Utf8View("7f4b18de3cfeb9b4ac78c381ee2ad278"), Utf8View("a"), Utf8View("b"), Utf8View("c")]) 07)------------TableScan: generate_series() projection=[value] physical_plan 01)ProjectionExec: expr=[count(Int64(1))@0 as count(*)] @@ -6497,7 +6495,7 @@ physical_plan 04)------AggregateExec: mode=Partial, gby=[], aggr=[count(Int64(1))] 05)--------ProjectionExec: expr=[] 06)----------CoalesceBatchesExec: target_batch_size=8192 -07)------------FilterExec: array_has([7f4b18de3cfeb9b4ac78c381ee2ad278, a, b, c], substr(md5(CAST(value@0 AS Utf8View)), 1, 32)) +07)------------FilterExec: substr(md5(CAST(value@0 AS Utf8View)), 1, 32) IN ([Literal { value: Utf8View("7f4b18de3cfeb9b4ac78c381ee2ad278"), field: Field { name: "lit", data_type: Utf8View, nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} } }, Literal { value: Utf8View("a"), field: Field { name: "lit", data_type: Utf8View, nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} } }, Literal { value: Utf8View("b"), field: Field { name: "lit", data_type: Utf8View, nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} } }, Literal { value: Utf8View("c"), field: Field { name: "lit", data_type: Utf8View, nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} } }]) 08)--------------RepartitionExec: partitioning=RoundRobinBatch(4), input_partitions=1 09)----------------LazyMemoryExec: partitions=1, batch_generators=[generate_series: start=1, end=100000, batch_size=8192]