diff --git a/datafusion/functions-nested/Cargo.toml b/datafusion/functions-nested/Cargo.toml index e5e601f30ae8..ee1d92ce5db8 100644 --- a/datafusion/functions-nested/Cargo.toml +++ b/datafusion/functions-nested/Cargo.toml @@ -96,3 +96,7 @@ name = "array_repeat" [[bench]] harness = false name = "array_set_ops" + +[[bench]] +harness = false +name = "array_position" diff --git a/datafusion/functions-nested/benches/array_position.rs b/datafusion/functions-nested/benches/array_position.rs new file mode 100644 index 000000000000..08367648449d --- /dev/null +++ b/datafusion/functions-nested/benches/array_position.rs @@ -0,0 +1,237 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use arrow::array::{ArrayRef, Int64Array, ListArray}; +use arrow::buffer::OffsetBuffer; +use arrow::datatypes::{DataType, Field}; +use criterion::{ + criterion_group, criterion_main, {BenchmarkId, Criterion}, +}; +use datafusion_common::ScalarValue; +use datafusion_common::config::ConfigOptions; +use datafusion_expr::{ColumnarValue, ScalarFunctionArgs, ScalarUDFImpl}; +use datafusion_functions_nested::position::ArrayPosition; +use rand::Rng; +use rand::SeedableRng; +use rand::rngs::StdRng; +use std::hint::black_box; +use std::sync::Arc; + +const NUM_ROWS: usize = 10000; +const SEED: u64 = 42; +const NULL_DENSITY: f64 = 0.1; +const SENTINEL_NEEDLE: i64 = -1; + +fn criterion_benchmark(c: &mut Criterion) { + for size in [10, 100, 500] { + bench_array_position(c, size); + } +} + +fn bench_array_position(c: &mut Criterion, array_size: usize) { + let mut group = c.benchmark_group("array_position_i64"); + let haystack_found_once = create_haystack_with_sentinel( + NUM_ROWS, + array_size, + NULL_DENSITY, + SENTINEL_NEEDLE, + 0, + ); + let haystack_found_many = create_haystack_with_sentinels( + NUM_ROWS, + array_size, + NULL_DENSITY, + SENTINEL_NEEDLE, + ); + let haystack_not_found = + create_haystack_without_sentinel(NUM_ROWS, array_size, NULL_DENSITY); + let num_rows = haystack_not_found.len(); + let arg_fields: Vec> = vec![ + Field::new("haystack", haystack_not_found.data_type().clone(), false).into(), + Field::new("needle", DataType::Int64, false).into(), + ]; + let return_field: Arc = Field::new("result", DataType::UInt64, true).into(); + let config_options = Arc::new(ConfigOptions::default()); + let needle = ScalarValue::Int64(Some(SENTINEL_NEEDLE)); + + // Benchmark: one match per row. + let args_found_once = vec![ + ColumnarValue::Array(haystack_found_once.clone()), + ColumnarValue::Scalar(needle.clone()), + ]; + group.bench_with_input( + BenchmarkId::new("found_once", array_size), + &array_size, + |b, _| { + let udf = ArrayPosition::new(); + b.iter(|| { + black_box( + udf.invoke_with_args(ScalarFunctionArgs { + args: args_found_once.clone(), + arg_fields: arg_fields.clone(), + number_rows: num_rows, + return_field: return_field.clone(), + config_options: config_options.clone(), + }) + .unwrap(), + ) + }) + }, + ); + + // Benchmark: many matches per row. + let args_found_many = vec![ + ColumnarValue::Array(haystack_found_many.clone()), + ColumnarValue::Scalar(needle.clone()), + ]; + group.bench_with_input( + BenchmarkId::new("found_many", array_size), + &array_size, + |b, _| { + let udf = ArrayPosition::new(); + b.iter(|| { + black_box( + udf.invoke_with_args(ScalarFunctionArgs { + args: args_found_many.clone(), + arg_fields: arg_fields.clone(), + number_rows: num_rows, + return_field: return_field.clone(), + config_options: config_options.clone(), + }) + .unwrap(), + ) + }) + }, + ); + + // Benchmark: needle is not found in any row. + let args_not_found = vec![ + ColumnarValue::Array(haystack_not_found.clone()), + ColumnarValue::Scalar(needle.clone()), + ]; + group.bench_with_input( + BenchmarkId::new("not_found", array_size), + &array_size, + |b, _| { + let udf = ArrayPosition::new(); + b.iter(|| { + black_box( + udf.invoke_with_args(ScalarFunctionArgs { + args: args_not_found.clone(), + arg_fields: arg_fields.clone(), + number_rows: num_rows, + return_field: return_field.clone(), + config_options: config_options.clone(), + }) + .unwrap(), + ) + }) + }, + ); + + group.finish(); +} + +fn create_haystack_without_sentinel( + num_rows: usize, + array_size: usize, + null_density: f64, +) -> ArrayRef { + create_haystack_from_fn(num_rows, array_size, |_, _, rng| { + random_haystack_value(rng, array_size, null_density) + }) +} + +fn create_haystack_with_sentinel( + num_rows: usize, + array_size: usize, + null_density: f64, + sentinel: i64, + sentinel_index: usize, +) -> ArrayRef { + assert!(sentinel_index < array_size); + + create_haystack_from_fn(num_rows, array_size, |_, col, rng| { + if col == sentinel_index { + Some(sentinel) + } else { + random_haystack_value(rng, array_size, null_density) + } + }) +} + +fn create_haystack_with_sentinels( + num_rows: usize, + array_size: usize, + null_density: f64, + sentinel: i64, +) -> ArrayRef { + create_haystack_from_fn(num_rows, array_size, |_, col, rng| { + // Place the sentinel in half the positions to create many matches per row. + if col % 2 == 0 { + Some(sentinel) + } else { + random_haystack_value(rng, array_size, null_density) + } + }) +} + +fn create_haystack_from_fn( + num_rows: usize, + array_size: usize, + mut value_at: F, +) -> ArrayRef +where + F: FnMut(usize, usize, &mut StdRng) -> Option, +{ + let mut rng = StdRng::seed_from_u64(SEED); + let mut values = Vec::with_capacity(num_rows * array_size); + for row in 0..num_rows { + for col in 0..array_size { + values.push(value_at(row, col, &mut rng)); + } + } + let values = values.into_iter().collect::(); + let offsets = (0..=num_rows) + .map(|i| (i * array_size) as i32) + .collect::>(); + + Arc::new( + ListArray::try_new( + Arc::new(Field::new("item", DataType::Int64, true)), + OffsetBuffer::new(offsets.into()), + Arc::new(values), + None, + ) + .unwrap(), + ) +} + +fn random_haystack_value( + rng: &mut StdRng, + array_size: usize, + null_density: f64, +) -> Option { + if rng.random::() < null_density { + None + } else { + Some(rng.random_range(0..array_size as i64)) + } +} + +criterion_group!(benches, criterion_benchmark); +criterion_main!(benches); diff --git a/datafusion/functions-nested/src/position.rs b/datafusion/functions-nested/src/position.rs index fc3a295963ce..ba16d08538c6 100644 --- a/datafusion/functions-nested/src/position.rs +++ b/datafusion/functions-nested/src/position.rs @@ -17,11 +17,13 @@ //! [`ScalarUDFImpl`] definitions for array_position and array_positions functions. +use arrow::array::Scalar; use arrow::datatypes::DataType; use arrow::datatypes::{ DataType::{LargeList, List, UInt64}, Field, }; +use datafusion_common::ScalarValue; use datafusion_expr::{ ColumnarValue, Documentation, ScalarUDFImpl, Signature, Volatility, }; @@ -37,9 +39,7 @@ use arrow::array::{ use datafusion_common::cast::{ as_generic_list_array, as_int64_array, as_large_list_array, as_list_array, }; -use datafusion_common::{ - Result, assert_or_internal_err, exec_err, utils::take_function_args, -}; +use datafusion_common::{Result, exec_err, utils::take_function_args}; use itertools::Itertools; use crate::utils::{compare_element_to_list, make_scalar_function}; @@ -54,7 +54,7 @@ make_udf_expr_and_func!( #[user_doc( doc_section(label = "Array Functions"), - description = "Returns the position of the first occurrence of the specified element in the array, or NULL if not found.", + description = "Returns the position of the first occurrence of the specified element in the array, or NULL if not found. Comparisons are done using `IS DISTINCT FROM` semantics, so NULL is considered to match NULL.", syntax_example = "array_position(array, element)\narray_position(array, element, index)", sql_example = r#"```sql > select array_position([1, 2, 2, 3, 1, 4], 2); @@ -74,10 +74,7 @@ make_udf_expr_and_func!( name = "array", description = "Array expression. Can be a constant, column, or function, and any combination of array operators." ), - argument( - name = "element", - description = "Element to search for position in the array." - ), + argument(name = "element", description = "Element to search for in the array."), argument( name = "index", description = "Index at which to start searching (1-indexed)." @@ -129,7 +126,54 @@ impl ScalarUDFImpl for ArrayPosition { &self, args: datafusion_expr::ScalarFunctionArgs, ) -> Result { - make_scalar_function(array_position_inner)(&args.args) + let [first_arg, second_arg, third_arg @ ..] = args.args.as_slice() else { + return exec_err!("array_position expects two or three arguments"); + }; + + match second_arg { + ColumnarValue::Scalar(scalar_element) => { + // Nested element types (List, Struct) can't use the fast path + // (because Arrow's `non_distinct` does not support them). + if scalar_element.data_type().is_nested() { + return make_scalar_function(array_position_inner)(&args.args); + } + + // Determine batch length from whichever argument is columnar; + // if all inputs are scalar, batch length is 1. + let (num_rows, all_inputs_scalar) = match (first_arg, third_arg.first()) { + (ColumnarValue::Array(a), _) => (a.len(), false), + (_, Some(ColumnarValue::Array(a))) => (a.len(), false), + _ => (1, true), + }; + + let element_arr = scalar_element.to_array_of_size(1)?; + let haystack = first_arg.to_array(num_rows)?; + let arr_from = resolve_start_from(third_arg.first(), num_rows)?; + + let result = match haystack.data_type() { + List(_) => { + let list = as_generic_list_array::(&haystack)?; + array_position_scalar::(list, &element_arr, &arr_from) + } + LargeList(_) => { + let list = as_generic_list_array::(&haystack)?; + array_position_scalar::(list, &element_arr, &arr_from) + } + t => exec_err!("array_position does not support type '{t}'."), + }?; + + if all_inputs_scalar { + Ok(ColumnarValue::Scalar(ScalarValue::try_from_array( + &result, 0, + )?)) + } else { + Ok(ColumnarValue::Array(result)) + } + } + ColumnarValue::Array(_) => { + make_scalar_function(array_position_inner)(&args.args) + } + } } fn aliases(&self) -> &[String] { @@ -152,6 +196,99 @@ fn array_position_inner(args: &[ArrayRef]) -> Result { } } +/// Resolves the optional `start_from` argument into a `Vec` of +/// 0-indexed starting positions. +fn resolve_start_from( + third_arg: Option<&ColumnarValue>, + num_rows: usize, +) -> Result> { + match third_arg { + None => Ok(vec![0i64; num_rows]), + Some(ColumnarValue::Scalar(ScalarValue::Int64(Some(v)))) => { + Ok(vec![v - 1; num_rows]) + } + Some(ColumnarValue::Scalar(s)) => { + exec_err!("array_position expected Int64 for start_from, got {s}") + } + Some(ColumnarValue::Array(a)) => { + Ok(as_int64_array(a)?.values().iter().map(|&x| x - 1).collect()) + } + } +} + +/// Fast path for `array_position` when the element is a scalar. +/// +/// Performs a single bulk `not_distinct` comparison of the scalar element +/// against the entire flattened values buffer, then walks the result bitmap +/// using offsets to find per-row first-match positions. +fn array_position_scalar( + list_array: &GenericListArray, + element_array: &ArrayRef, + arr_from: &[i64], // 0-indexed +) -> Result { + crate::utils::check_datatypes( + "array_position", + &[list_array.values(), element_array], + )?; + let element_datum = Scalar::new(Arc::clone(element_array)); + + let offsets = list_array.offsets(); + let validity = list_array.nulls(); + + if list_array.len() == 0 { + return Ok(Arc::new(UInt64Array::new_null(0))); + } + + // `not_distinct` treats NULL=NULL as true, matching the semantics of + // `array_position` + let eq_array = arrow_ord::cmp::not_distinct(list_array.values(), &element_datum)?; + let eq_bits = eq_array.values(); + + let mut result: Vec> = Vec::with_capacity(list_array.len()); + let mut matches = eq_bits.set_indices().peekable(); + + for i in 0..list_array.len() { + let start = offsets[i].as_usize(); + let end = offsets[i + 1].as_usize(); + + if validity.is_some_and(|v| v.is_null(i)) { + // Null row -> null output; advance past matches in range + while matches.peek().is_some_and(|&p| p < end) { + matches.next(); + } + result.push(None); + continue; + } + + let from = arr_from[i]; + let row_len = end - start; + if !(from >= 0 && (from as usize) <= row_len) { + return exec_err!("start_from out of bounds: {}", from + 1); + } + let search_start = start + from as usize; + + // Advance past matches before search_start + while matches.peek().is_some_and(|&p| p < search_start) { + matches.next(); + } + + // First match in [search_start, end)? + if matches.peek().is_some_and(|&p| p < end) { + let pos = *matches.peek().unwrap(); + result.push(Some((pos - start + 1) as u64)); + // Advance past remaining matches in this row + while matches.peek().is_some_and(|&p| p < end) { + matches.next(); + } + } else { + result.push(None); + } + } + + debug_assert_eq!(result.len(), list_array.len()); + Ok(Arc::new(UInt64Array::from(result))) +} + fn general_position_dispatch(args: &[ArrayRef]) -> Result { let list_array = as_generic_list_array::(&args[0])?; let element_array = &args[1]; @@ -171,13 +308,11 @@ fn general_position_dispatch(args: &[ArrayRef]) -> Result= 0 && (from as usize) <= arr.len()), - "start_from index out of bounds" - ); + if !arr.is_none_or(|arr| from >= 0 && (from as usize) <= arr.len()) { + return exec_err!("start_from out of bounds: {}", from + 1); + } } generic_position::(list_array, element_array, &arr_from) diff --git a/datafusion/sqllogictest/test_files/array.slt b/datafusion/sqllogictest/test_files/array.slt index 66503c957c5a..cf3494394e3e 100644 --- a/datafusion/sqllogictest/test_files/array.slt +++ b/datafusion/sqllogictest/test_files/array.slt @@ -3880,6 +3880,111 @@ select array_position(arrow_cast(make_array([1, 2, 3], [4, 5, 6], [11, 12, 13]), NULL 6 4 NULL 1 NULL +# array_position with NULL element in haystack array (NULL = NULL semantics) +query III +select array_position([1, NULL, 3], arrow_cast(NULL, 'Int64')), array_position([NULL, 2, 3], arrow_cast(NULL, 'Int64')), array_position([1, 2, NULL], arrow_cast(NULL, 'Int64')); +---- +2 1 3 + +query I +select array_position(arrow_cast([1, NULL, 3], 'LargeList(Int64)'), arrow_cast(NULL, 'Int64')); +---- +2 + +# array_position with NULL element in array and start_from +query II +select array_position([NULL, 1, NULL, 2], arrow_cast(NULL, 'Int64'), 2), array_position([NULL, 1, NULL, 2], arrow_cast(NULL, 'Int64'), 1); +---- +3 1 + +# array_position with column array and scalar element +query IIII +select array_position(column1, 3), array_position(column1, 10), array_position(column1, 20), array_position(column1, 999) from arrays_values_without_nulls; +---- +3 10 NULL NULL +NULL NULL 10 NULL +NULL NULL NULL NULL +NULL NULL NULL NULL + +query II +select array_position(column1, 3), array_position(column1, 20) from large_arrays_values_without_nulls; +---- +3 NULL +NULL 10 +NULL NULL +NULL NULL + +query II +select array_position(column1, 3), array_position(column1, 20) from fixed_size_arrays_values_without_nulls; +---- +3 NULL +NULL 10 +NULL NULL +NULL NULL + +# array_position with column array, scalar element, and scalar start_from +query II +select array_position(column1, 3, 1), array_position(column1, 3, 4) from arrays_values_without_nulls; +---- +3 NULL +NULL NULL +NULL NULL +NULL NULL + +query II +select array_position(column1, 3, 1), array_position(column1, 3, 4) from large_arrays_values_without_nulls; +---- +3 NULL +NULL NULL +NULL NULL +NULL NULL + +# array_position with column array, scalar element, and column start_from +query I +select array_position(column1, 3, column3) from arrays_values_without_nulls; +---- +3 +NULL +NULL +NULL + +# array_position with scalar haystack, scalar element, and column start_from +query I +select array_position([1, 2, 1, 2], 2, column3) from arrays_values_without_nulls; +---- +2 +2 +4 +4 + +# array_position start_from boundary cases +query IIII +select array_position([1, 2, 3], 3, 3), array_position([1, 2, 3], 1, 2), array_position([1, 2, 3], 1, 1), array_position([1, 2, 3], 3, 4); +---- +3 NULL 1 NULL + +query II +select array_position([1, 2, 3], 3, 4), array_position([1], 1, 2); +---- +NULL NULL + +# array_position with empty array in various contexts +query II +select array_position(arrow_cast(make_array(), 'List(Int64)'), 1), array_position(arrow_cast(make_array(), 'LargeList(Int64)'), 1); +---- +NULL NULL + +# FixedSizeList with start_from +query II +select array_position(arrow_cast([1, 2, 3, 1, 2], 'FixedSizeList(5, Int64)'), 1, 2), array_position(arrow_cast([1, 2, 3, 1, 2], 'FixedSizeList(5, Int64)'), 2, 4); +---- +4 5 + +query I +select array_position(arrow_cast(['a', 'b', 'c', 'b'], 'FixedSizeList(4, Utf8)'), 'b', 3); +---- +4 + ## array_positions (aliases: `list_positions`) query ? diff --git a/docs/source/user-guide/sql/scalar_functions.md b/docs/source/user-guide/sql/scalar_functions.md index ebd2abe2b382..01c682e39da7 100644 --- a/docs/source/user-guide/sql/scalar_functions.md +++ b/docs/source/user-guide/sql/scalar_functions.md @@ -3776,7 +3776,7 @@ array_pop_front(array) ### `array_position` -Returns the position of the first occurrence of the specified element in the array, or NULL if not found. +Returns the position of the first occurrence of the specified element in the array, or NULL if not found. Comparisons are done using `IS DISTINCT FROM` semantics, so NULL is considered to match NULL. ```sql array_position(array, element) @@ -3786,7 +3786,7 @@ array_position(array, element, index) #### Arguments - **array**: Array expression. Can be a constant, column, or function, and any combination of array operators. -- **element**: Element to search for position in the array. +- **element**: Element to search for in the array. - **index**: Index at which to start searching (1-indexed). #### Example