Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
108 changes: 93 additions & 15 deletions datafusion/common/src/scalar/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -63,11 +63,12 @@ use arrow::array::{
FixedSizeListArray, Float16Array, Float32Array, Float64Array, GenericListArray,
Int16Array, Int32Array, Int64Array, Int8Array, IntervalDayTimeArray,
IntervalMonthDayNanoArray, IntervalYearMonthArray, LargeBinaryArray, LargeListArray,
LargeStringArray, ListArray, MapArray, MutableArrayData, PrimitiveArray, Scalar,
StringArray, StringViewArray, StructArray, Time32MillisecondArray, Time32SecondArray,
Time64MicrosecondArray, Time64NanosecondArray, TimestampMicrosecondArray,
TimestampMillisecondArray, TimestampNanosecondArray, TimestampSecondArray,
UInt16Array, UInt32Array, UInt64Array, UInt8Array, UnionArray,
LargeStringArray, ListArray, MapArray, MutableArrayData, OffsetSizeTrait,
PrimitiveArray, Scalar, StringArray, StringViewArray, StructArray,
Time32MillisecondArray, Time32SecondArray, Time64MicrosecondArray,
Time64NanosecondArray, TimestampMicrosecondArray, TimestampMillisecondArray,
TimestampNanosecondArray, TimestampSecondArray, UInt16Array, UInt32Array,
UInt64Array, UInt8Array, UnionArray,
};
use arrow::buffer::ScalarBuffer;
use arrow::compute::kernels::cast::{cast_with_options, CastOptions};
Expand Down Expand Up @@ -3305,17 +3306,30 @@ impl ScalarValue {
/// assert_eq!(scalar_vec, expected);
/// ```
pub fn convert_array_to_scalar_vec(array: &dyn Array) -> Result<Vec<Vec<Self>>> {
let mut scalars = Vec::with_capacity(array.len());

for index in 0..array.len() {
let nested_array = array.as_list::<i32>().value(index);
let scalar_values = (0..nested_array.len())
.map(|i| ScalarValue::try_from_array(&nested_array, i))
.collect::<Result<Vec<_>>>()?;
scalars.push(scalar_values);
fn generic_collect<OffsetSize: OffsetSizeTrait>(
array: &dyn Array,
) -> Result<Vec<Vec<ScalarValue>>> {
array
.as_list::<OffsetSize>()
.iter()
.map(|nested_array| match nested_array {
Some(nested_array) => (0..nested_array.len())
.map(|i| ScalarValue::try_from_array(&nested_array, i))
.collect::<Result<Vec<_>>>(),
// TODO: what can we put for null?
// https://github.com/apache/datafusion/issues/17749
None => Ok(vec![]),
})
.collect()
}

Ok(scalars)
match array.data_type() {
DataType::List(_) => generic_collect::<i32>(array),
DataType::LargeList(_) => generic_collect::<i64>(array),
_ => _internal_err!(
"ScalarValue::convert_array_to_scalar_vec input must be a List/LargeList type"
),
}
}

#[deprecated(
Expand Down Expand Up @@ -4947,6 +4961,8 @@ impl ScalarType<i32> for Date32Type {

#[cfg(test)]
mod tests {
use std::sync::Arc;

use super::*;
use crate::cast::{as_list_array, as_map_array, as_struct_array};
use crate::test_util::batches_to_string;
Expand All @@ -4955,7 +4971,7 @@ mod tests {
NullArray, NullBufferBuilder, OffsetSizeTrait, PrimitiveBuilder, RecordBatch,
StringBuilder, StringDictionaryBuilder, StructBuilder, UnionBuilder,
};
use arrow::buffer::{Buffer, OffsetBuffer};
use arrow::buffer::{Buffer, NullBuffer, OffsetBuffer};
use arrow::compute::{is_null, kernels};
use arrow::datatypes::{
ArrowNumericType, Fields, Float64Type, DECIMAL256_MAX_PRECISION,
Expand Down Expand Up @@ -8996,4 +9012,66 @@ mod tests {
_ => panic!("Expected TimestampMillisecond with timezone"),
}
}

#[test]
fn test_convert_array_to_scalar_vec() {
// Regular ListArray
let list = ListArray::from_iter_primitive::<Int64Type, _, _>(vec![
Some(vec![Some(1), Some(2)]),
None,
Some(vec![Some(3), None, Some(4)]),
]);
let converted = ScalarValue::convert_array_to_scalar_vec(&list).unwrap();
assert_eq!(
converted,
vec![
vec![ScalarValue::Int64(Some(1)), ScalarValue::Int64(Some(2))],
vec![],
vec![
ScalarValue::Int64(Some(3)),
ScalarValue::Int64(None),
ScalarValue::Int64(Some(4))
],
]
);

// Regular LargeListArray
let large_list = LargeListArray::from_iter_primitive::<Int64Type, _, _>(vec![
Some(vec![Some(1), Some(2)]),
None,
Some(vec![Some(3), None, Some(4)]),
]);
let converted = ScalarValue::convert_array_to_scalar_vec(&large_list).unwrap();
assert_eq!(
converted,
vec![
vec![ScalarValue::Int64(Some(1)), ScalarValue::Int64(Some(2))],
vec![],
vec![
ScalarValue::Int64(Some(3)),
ScalarValue::Int64(None),
ScalarValue::Int64(Some(4))
],
]
);

// Funky (null slot has non-zero list offsets)
// Offsets + Values looks like this: [[1, 2], [3, 4], [5]]
// But with NullBuffer it's like this: [[1, 2], NULL, [5]]
let funky = ListArray::new(
Field::new_list_field(DataType::Int64, true).into(),
OffsetBuffer::new(vec![0, 2, 4, 5].into()),
Arc::new(Int64Array::from(vec![1, 2, 3, 4, 5, 6])),
Some(NullBuffer::from(vec![true, false, true])),
);
let converted = ScalarValue::convert_array_to_scalar_vec(&funky).unwrap();
assert_eq!(
converted,
vec![
vec![ScalarValue::Int64(Some(1)), ScalarValue::Int64(Some(2))],
vec![],
vec![ScalarValue::Int64(Some(5))],
]
);
}
}
70 changes: 36 additions & 34 deletions datafusion/functions-nested/src/array_has.rs
Original file line number Diff line number Diff line change
Expand Up @@ -25,10 +25,10 @@ use datafusion_common::cast::{as_fixed_size_list_array, as_generic_list_array};
use datafusion_common::utils::string_utils::string_array_to_vec;
use datafusion_common::utils::take_function_args;
use datafusion_common::{exec_err, DataFusionError, Result, ScalarValue};
use datafusion_expr::expr::{InList, ScalarFunction};
use datafusion_expr::expr::ScalarFunction;
use datafusion_expr::simplify::ExprSimplifyResult;
use datafusion_expr::{
ColumnarValue, Documentation, Expr, ScalarUDFImpl, Signature, Volatility,
in_list, ColumnarValue, Documentation, Expr, ScalarUDFImpl, Signature, Volatility,
};
use datafusion_macros::user_doc;
use datafusion_physical_expr_common::datum::compare_with_eq;
Expand Down Expand Up @@ -131,40 +131,42 @@ impl ScalarUDFImpl for ArrayHas {

// if the haystack is a constant list, we can use an inlist expression which is more
// efficient because the haystack is not varying per-row
if let Expr::Literal(ScalarValue::List(array), _) = haystack {
// TODO: support LargeList
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🎉

// (not supported by `convert_array_to_scalar_vec`)
// (FixedSizeList not supported either, but seems to have worked fine when attempting to
// build a reproducer)

assert_eq!(array.len(), 1); // guarantee of ScalarValue
if let Ok(scalar_values) =
ScalarValue::convert_array_to_scalar_vec(array.as_ref())
{
assert_eq!(scalar_values.len(), 1);
let list = scalar_values
.into_iter()
.flatten()
.map(|v| Expr::Literal(v, None))
.collect();

return Ok(ExprSimplifyResult::Simplified(Expr::InList(InList {
expr: Box::new(std::mem::take(needle)),
list,
negated: false,
})));
match haystack {
Expr::Literal(
// FixedSizeList gets coerced to List
scalar @ ScalarValue::List(_) | scalar @ ScalarValue::LargeList(_),
_,
) => {
let array = scalar.to_array().unwrap(); // guarantee of ScalarValue
if let Ok(scalar_values) =
ScalarValue::convert_array_to_scalar_vec(&array)
{
assert_eq!(scalar_values.len(), 1);
let list = scalar_values
.into_iter()
.flatten()
.map(|v| Expr::Literal(v, None))
.collect();

return Ok(ExprSimplifyResult::Simplified(in_list(
std::mem::take(needle),
list,
false,
)));
}
}
} else if let Expr::ScalarFunction(ScalarFunction { func, args }) = haystack {
// make_array has a static set of arguments, so we can pull the arguments out from it
if func == &make_array_udf() {
return Ok(ExprSimplifyResult::Simplified(Expr::InList(InList {
expr: Box::new(std::mem::take(needle)),
list: std::mem::take(args),
negated: false,
})));
Expr::ScalarFunction(ScalarFunction { func, args })
if func == &make_array_udf() =>
{
// make_array has a static set of arguments, so we can pull the arguments out from it
return Ok(ExprSimplifyResult::Simplified(in_list(
std::mem::take(needle),
std::mem::take(args),
false,
)));
}
}

_ => {}
};
Ok(ExprSimplifyResult::Original(args))
}

Expand Down
16 changes: 7 additions & 9 deletions datafusion/sqllogictest/test_files/array.slt
Original file line number Diff line number Diff line change
Expand Up @@ -6470,14 +6470,12 @@ physical_plan
08)--------------RepartitionExec: partitioning=RoundRobinBatch(4), input_partitions=1
09)----------------LazyMemoryExec: partitions=1, batch_generators=[generate_series: start=1, end=100000, batch_size=8192]

# FIXME: due to rewrite below not working, this is _extremely_ slow to evaluate
# query I
# with test AS (SELECT substr(md5(i)::text, 1, 32) as needle FROM generate_series(1, 100000) t(i))
# select count(*) from test WHERE array_has(arrow_cast(['7f4b18de3cfeb9b4ac78c381ee2ad278', 'a', 'b', 'c'], 'LargeList(Utf8View)'), needle);
# ----
# 1
query I
with test AS (SELECT substr(md5(i::text)::text, 1, 32) as needle FROM generate_series(1, 100000) t(i))
select count(*) from test WHERE array_has(arrow_cast(['7f4b18de3cfeb9b4ac78c381ee2ad278', 'a', 'b', 'c'], 'LargeList(Utf8View)'), needle);
----
1

# FIXME: array_has with large list haystack not currently rewritten to InList
query TT
explain with test AS (SELECT substr(md5(i::text)::text, 1, 32) as needle FROM generate_series(1, 100000) t(i))
select count(*) from test WHERE array_has(arrow_cast(['7f4b18de3cfeb9b4ac78c381ee2ad278', 'a', 'b', 'c'], 'LargeList(Utf8View)'), needle);
Expand All @@ -6488,7 +6486,7 @@ logical_plan
03)----SubqueryAlias: test
04)------SubqueryAlias: t
05)--------Projection:
06)----------Filter: array_has(LargeList([7f4b18de3cfeb9b4ac78c381ee2ad278, a, b, c]), substr(CAST(md5(CAST(generate_series().value AS Utf8View)) AS Utf8View), Int64(1), Int64(32)))
06)----------Filter: substr(CAST(md5(CAST(generate_series().value AS Utf8View)) AS Utf8View), Int64(1), Int64(32)) IN ([Utf8View("7f4b18de3cfeb9b4ac78c381ee2ad278"), Utf8View("a"), Utf8View("b"), Utf8View("c")])
07)------------TableScan: generate_series() projection=[value]
physical_plan
01)ProjectionExec: expr=[count(Int64(1))@0 as count(*)]
Expand All @@ -6497,7 +6495,7 @@ physical_plan
04)------AggregateExec: mode=Partial, gby=[], aggr=[count(Int64(1))]
05)--------ProjectionExec: expr=[]
06)----------CoalesceBatchesExec: target_batch_size=8192
07)------------FilterExec: array_has([7f4b18de3cfeb9b4ac78c381ee2ad278, a, b, c], substr(md5(CAST(value@0 AS Utf8View)), 1, 32))
07)------------FilterExec: substr(md5(CAST(value@0 AS Utf8View)), 1, 32) IN ([Literal { value: Utf8View("7f4b18de3cfeb9b4ac78c381ee2ad278"), field: Field { name: "lit", data_type: Utf8View, nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} } }, Literal { value: Utf8View("a"), field: Field { name: "lit", data_type: Utf8View, nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} } }, Literal { value: Utf8View("b"), field: Field { name: "lit", data_type: Utf8View, nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} } }, Literal { value: Utf8View("c"), field: Field { name: "lit", data_type: Utf8View, nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} } }])
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this is a pretty terrible plan display (not made worse by this PR, of course)

08)--------------RepartitionExec: partitioning=RoundRobinBatch(4), input_partitions=1
09)----------------LazyMemoryExec: partitions=1, batch_generators=[generate_series: start=1, end=100000, batch_size=8192]

Expand Down