Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
31 commits
Select commit Hold shift + click to select a range
3911ed0
feat: add support for array_position expression
andygrove Jan 15, 2026
8cb27ec
update docs
andygrove Jan 15, 2026
36c3320
Merge remote-tracking branch 'origin/main' into feature/array-position
andygrove Jan 30, 2026
4864e9a
Merge remote-tracking branch 'apache/main' into feature/array-position
andygrove Feb 10, 2026
258442e
Migrate array_position tests to SQL file-based approach
andygrove Feb 10, 2026
3f7d69d
upmerge
andygrove Feb 11, 2026
554dcf8
Merge branch 'main' of https://github.com/apache/datafusion-comet int…
andygrove Feb 18, 2026
b716584
revert whitespace change in CometArrayExpressionSuite.scala
andygrove Feb 18, 2026
30c0de7
fix merge artifacts and expand array_position test coverage
andygrove Feb 18, 2026
cd3a577
rustfmt
andygrove Feb 18, 2026
d9d424f
Merge branch 'main' into feature/array-position
andygrove Feb 26, 2026
c66024e
upmerge
andygrove Mar 7, 2026
6172a0c
Merge remote-tracking branch 'apache/main' into feature/array-position
andygrove Mar 12, 2026
e4620e9
fix: remove stray merge conflict marker
andygrove Mar 12, 2026
6cf8f6a
feat: optimize array_position with typed array comparison and address…
andygrove Mar 16, 2026
fad07b9
fix: explain NaN incompatibility in ArrayPosition getSupportLevel
andygrove Mar 16, 2026
24318d3
fix: handle NaN equality in array_position to match Spark semantics
andygrove Mar 16, 2026
b7ddad1
perf: use flat values buffer and offsets for array_position
andygrove Mar 16, 2026
79b348f
perf: use native_datafusion scan in array_position benchmark
andygrove Mar 16, 2026
41b5495
format
andygrove Mar 18, 2026
6a4ce49
Merge remote-tracking branch 'apache/main' into feature/array-position
andygrove Apr 13, 2026
736ca63
chore: apply cargo fmt
andygrove Apr 13, 2026
0f2e41e
Merge branch 'main' into feature/array-position
andygrove Apr 14, 2026
8986da2
fix: address review feedback for array_position
andygrove Apr 14, 2026
bbfe1dd
Merge remote-tracking branch 'apache/main' into feature/array-position
andygrove Apr 18, 2026
235d053
style: cargo fmt
andygrove Apr 18, 2026
7558c66
test: add timestamp_ntz coverage for array_position
andygrove Apr 21, 2026
60bceb5
test: add column-based nested array coverage for array_position
andygrove Apr 21, 2026
d49130c
Merge remote-tracking branch 'apache/main' into feature/array-position
andygrove Apr 21, 2026
ca48fc7
Merge remote-tracking branch 'apache/main' into feature/array-position
andygrove Apr 23, 2026
575e04e
Merge remote-tracking branch 'apache/main' into feature/array-position
andygrove Apr 24, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion docs/spark_expressions_support.md
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,7 @@
- [x] array_join
- [x] array_max
- [ ] array_min
- [ ] array_position
- [x] array_position
- [x] array_remove
- [x] array_repeat
- [x] array_union
Expand Down
335 changes: 335 additions & 0 deletions native/spark-expr/src/array_funcs/array_position.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,335 @@
// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.

use arrow::array::{
Array, ArrayRef, AsArray, BooleanArray, GenericListArray, Int64Array, OffsetSizeTrait,
};
use arrow::buffer::{NullBuffer, ScalarBuffer};
use arrow::datatypes::{
ArrowPrimitiveType, DataType, Date32Type, Decimal128Type, Float32Type, Float64Type, Int16Type,
Int32Type, Int64Type, Int8Type, TimestampMicrosecondType,
};
use datafusion::common::{exec_err, DataFusionError, Result as DataFusionResult, ScalarValue};
use datafusion::logical_expr::{
ColumnarValue, ScalarFunctionArgs, ScalarUDFImpl, Signature, TypeSignature, Volatility,
};
use num::Float;
use std::any::Any;
use std::sync::Arc;

/// Spark array_position() function that returns the 1-based position of an element in an array.
/// Returns 0 if the element is not found (Spark behavior differs from DataFusion which returns null).
fn spark_array_position(args: &[ColumnarValue]) -> Result<ColumnarValue, DataFusionError> {
if args.len() != 2 {
return exec_err!("array_position function takes exactly two arguments");
}

let len = args
.iter()
.fold(Option::<usize>::None, |acc, arg| match arg {
ColumnarValue::Scalar(_) => acc,
ColumnarValue::Array(a) => Some(a.len()),
});

let is_scalar = len.is_none();
let arrays = ColumnarValue::values_to_arrays(args)?;

let result = array_position_inner(&arrays)?;

if is_scalar {
let scalar = ScalarValue::try_from_array(&result, 0)?;
Ok(ColumnarValue::Scalar(scalar))
} else {
Ok(ColumnarValue::Array(result))
}
}

fn array_position_inner(args: &[ArrayRef]) -> Result<ArrayRef, DataFusionError> {
let array = &args[0];
let element = &args[1];

match array.data_type() {
DataType::List(_) => generic_array_position::<i32>(array, element),
DataType::LargeList(_) => generic_array_position::<i64>(array, element),
other => exec_err!("array_position does not support type '{other:?}'"),
}
}

/// Searches for an element in a list array using the flat values buffer and offsets directly,
/// avoiding per-row subarray allocation. Dispatches to typed fast paths by element data type.
fn generic_array_position<O: OffsetSizeTrait>(
array: &ArrayRef,
element: &ArrayRef,
) -> Result<ArrayRef, DataFusionError> {
let list_array = array
.as_any()
.downcast_ref::<GenericListArray<O>>()
.ok_or_else(|| DataFusionError::Internal("expected list array".into()))?;

let values = list_array.values();
let offsets = list_array.offsets();
let elem_type = values.data_type().clone();

match &elem_type {
DataType::Boolean => position_boolean::<O>(list_array, offsets, values, element),
DataType::Int8 => position_primitive::<O, Int8Type>(list_array, offsets, values, element),
DataType::Int16 => position_primitive::<O, Int16Type>(list_array, offsets, values, element),
DataType::Int32 => position_primitive::<O, Int32Type>(list_array, offsets, values, element),
DataType::Int64 => position_primitive::<O, Int64Type>(list_array, offsets, values, element),
DataType::Float32 => position_float::<O, Float32Type>(list_array, offsets, values, element),
DataType::Float64 => position_float::<O, Float64Type>(list_array, offsets, values, element),
DataType::Decimal128(_, _) => {
position_primitive::<O, Decimal128Type>(list_array, offsets, values, element)
}
DataType::Date32 => {
position_primitive::<O, Date32Type>(list_array, offsets, values, element)
}
DataType::Timestamp(arrow::datatypes::TimeUnit::Microsecond, _) => {
position_primitive::<O, TimestampMicrosecondType>(list_array, offsets, values, element)
}
DataType::Utf8 => position_string::<O, i32>(list_array, offsets, values, element),
DataType::LargeUtf8 => position_string::<O, i64>(list_array, offsets, values, element),
// Fallback to ScalarValue for complex types (nested arrays, etc.)
_ => position_fallback::<O>(list_array, offsets, values, element),
}
}

/// Compute the combined null buffer from list array and element nulls.
fn combined_nulls(
list_array_nulls: Option<&NullBuffer>,
element_nulls: Option<&NullBuffer>,
) -> Option<NullBuffer> {
match (list_array_nulls, element_nulls) {
(Some(a), Some(b)) => NullBuffer::union(Some(a), Some(b)),
(Some(a), None) => Some(a.clone()),
(None, Some(b)) => Some(b.clone()),
(None, None) => None,
}
}

/// Fast path for primitive types: downcast once, iterate using offsets into the flat buffer.
fn position_primitive<O: OffsetSizeTrait, T: ArrowPrimitiveType>(
list_array: &GenericListArray<O>,
offsets: &arrow::buffer::OffsetBuffer<O>,
values: &ArrayRef,
element: &ArrayRef,
) -> Result<ArrayRef, DataFusionError>
where
T::Native: PartialEq,
{
let values_typed = values.as_primitive::<T>();
let element_typed = element.as_primitive::<T>();
let num_rows = list_array.len();
let nulls = combined_nulls(list_array.nulls(), element.nulls());
let mut result = vec![0i64; num_rows];

for (row_index, w) in offsets.windows(2).enumerate() {
if nulls.as_ref().is_some_and(|n| n.is_null(row_index)) {
continue;
}
let start = w[0].as_usize();
let end = w[1].as_usize();
let search_val = element_typed.value(row_index);
for i in start..end {
if !values_typed.is_null(i) && values_typed.value(i) == search_val {
result[row_index] = (i - start + 1) as i64;
break;
}
}
}

Ok(Arc::new(Int64Array::new(ScalarBuffer::from(result), nulls)))
}

/// Float path: same as primitive but treats NaN == NaN (Spark's ordering.equiv() semantics).
fn position_float<O: OffsetSizeTrait, T: ArrowPrimitiveType>(
list_array: &GenericListArray<O>,
offsets: &arrow::buffer::OffsetBuffer<O>,
values: &ArrayRef,
element: &ArrayRef,
) -> Result<ArrayRef, DataFusionError>
where
T::Native: PartialEq + num::Float,
{
let values_typed = values.as_primitive::<T>();
let element_typed = element.as_primitive::<T>();
let num_rows = list_array.len();
let nulls = combined_nulls(list_array.nulls(), element.nulls());
let mut result = vec![0i64; num_rows];

for (row_index, w) in offsets.windows(2).enumerate() {
if nulls.as_ref().is_some_and(|n| n.is_null(row_index)) {
continue;
}
let start = w[0].as_usize();
let end = w[1].as_usize();
let search_val = element_typed.value(row_index);
let search_is_nan = search_val.is_nan();
for i in start..end {
if !values_typed.is_null(i) {
let v = values_typed.value(i);
if (search_is_nan && v.is_nan()) || v == search_val {
result[row_index] = (i - start + 1) as i64;
break;
}
}
}
}

Ok(Arc::new(Int64Array::new(ScalarBuffer::from(result), nulls)))
}

/// Boolean path.
fn position_boolean<O: OffsetSizeTrait>(
list_array: &GenericListArray<O>,
offsets: &arrow::buffer::OffsetBuffer<O>,
values: &ArrayRef,
element: &ArrayRef,
) -> Result<ArrayRef, DataFusionError> {
let values_typed = values
.as_any()
.downcast_ref::<BooleanArray>()
.ok_or_else(|| DataFusionError::Internal("expected boolean array".into()))?;
let element_typed = element
.as_any()
.downcast_ref::<BooleanArray>()
.ok_or_else(|| DataFusionError::Internal("expected boolean array".into()))?;
let num_rows = list_array.len();
let nulls = combined_nulls(list_array.nulls(), element.nulls());
let mut result = vec![0i64; num_rows];

for (row_index, w) in offsets.windows(2).enumerate() {
if nulls.as_ref().is_some_and(|n| n.is_null(row_index)) {
continue;
}
let start = w[0].as_usize();
let end = w[1].as_usize();
let search_val = element_typed.value(row_index);
for i in start..end {
if !values_typed.is_null(i) && values_typed.value(i) == search_val {
result[row_index] = (i - start + 1) as i64;
break;
}
}
}

Ok(Arc::new(Int64Array::new(ScalarBuffer::from(result), nulls)))
}

/// String path: downcast once, iterate using offsets into the flat string buffer.
fn position_string<O: OffsetSizeTrait, S: OffsetSizeTrait>(
list_array: &GenericListArray<O>,
offsets: &arrow::buffer::OffsetBuffer<O>,
values: &ArrayRef,
element: &ArrayRef,
) -> Result<ArrayRef, DataFusionError> {
let values_typed = values.as_string::<S>();
let element_typed = element.as_string::<S>();
let num_rows = list_array.len();
let nulls = combined_nulls(list_array.nulls(), element.nulls());
let mut result = vec![0i64; num_rows];

for (row_index, w) in offsets.windows(2).enumerate() {
if nulls.as_ref().is_some_and(|n| n.is_null(row_index)) {
continue;
}
let start = w[0].as_usize();
let end = w[1].as_usize();
let search_val = element_typed.value(row_index);
for i in start..end {
if !values_typed.is_null(i) && values_typed.value(i) == search_val {
result[row_index] = (i - start + 1) as i64;
break;
}
}
}

Ok(Arc::new(Int64Array::new(ScalarBuffer::from(result), nulls)))
}

/// Fallback for complex types (nested arrays, structs, etc.) using ScalarValue comparison.
fn position_fallback<O: OffsetSizeTrait>(
list_array: &GenericListArray<O>,
offsets: &arrow::buffer::OffsetBuffer<O>,
values: &ArrayRef,
element: &ArrayRef,
) -> Result<ArrayRef, DataFusionError> {
let num_rows = list_array.len();
let nulls = combined_nulls(list_array.nulls(), element.nulls());
let mut result = vec![0i64; num_rows];

for (row_index, w) in offsets.windows(2).enumerate() {
if nulls.as_ref().is_some_and(|n| n.is_null(row_index)) {
continue;
}
let start = w[0].as_usize();
let end = w[1].as_usize();
let search_scalar = ScalarValue::try_from_array(element, row_index)?;
for i in start..end {
if !values.is_null(i) {
let item_scalar = ScalarValue::try_from_array(values, i)?;
if search_scalar == item_scalar {
result[row_index] = (i - start + 1) as i64;
break;
}
}
}
}

Ok(Arc::new(Int64Array::new(ScalarBuffer::from(result), nulls)))
}

#[derive(Debug, Hash, Eq, PartialEq)]
pub struct SparkArrayPositionFunc {
signature: Signature,
}

impl Default for SparkArrayPositionFunc {
fn default() -> Self {
Self::new()
}
}

impl SparkArrayPositionFunc {
pub fn new() -> Self {
Self {
signature: Signature::new(TypeSignature::Any(2), Volatility::Immutable),
}
}
}

impl ScalarUDFImpl for SparkArrayPositionFunc {
fn as_any(&self) -> &dyn Any {
self
}

fn name(&self) -> &str {
"spark_array_position"
}

fn signature(&self) -> &Signature {
&self.signature
}

fn return_type(&self, _arg_types: &[DataType]) -> DataFusionResult<DataType> {
Ok(DataType::Int64)
}

fn invoke_with_args(&self, args: ScalarFunctionArgs) -> DataFusionResult<ColumnarValue> {
spark_array_position(&args.args)
}
}
2 changes: 2 additions & 0 deletions native/spark-expr/src/array_funcs/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@

mod array_compact;
mod array_insert;
mod array_position;
mod arrays_overlap;
mod arrays_zip;
mod get_array_struct_fields;
Expand All @@ -25,6 +26,7 @@ mod size;

pub use array_compact::SparkArrayCompact;
pub use array_insert::ArrayInsert;
pub use array_position::SparkArrayPositionFunc;
pub use arrays_overlap::SparkArraysOverlap;
pub use arrays_zip::SparkArraysZipFunc;
pub use get_array_struct_fields::GetArrayStructFields;
Expand Down
6 changes: 4 additions & 2 deletions native/spark-expr/src/comet_scalar_funcs.rs
Original file line number Diff line number Diff line change
Expand Up @@ -23,8 +23,9 @@ use crate::math_funcs::modulo_expr::spark_modulo;
use crate::{
spark_ceil, spark_decimal_div, spark_decimal_integral_div, spark_floor, spark_isnan,
spark_lpad, spark_make_decimal, spark_read_side_padding, spark_round, spark_rpad, spark_unhex,
spark_unscaled_value, EvalMode, SparkArrayCompact, SparkArraysOverlap, SparkContains,
SparkDateDiff, SparkDateFromUnixDate, SparkDateTrunc, SparkMakeDate, SparkSizeFunc,
spark_unscaled_value, EvalMode, SparkArrayCompact, SparkArrayPositionFunc, SparkArraysOverlap,
SparkContains, SparkDateDiff, SparkDateFromUnixDate, SparkDateTrunc, SparkMakeDate,
SparkSizeFunc,
};
use arrow::datatypes::DataType;
use datafusion::common::{DataFusionError, Result as DataFusionResult};
Expand Down Expand Up @@ -201,6 +202,7 @@ pub fn create_comet_physical_fun_with_eval_mode(
fn all_scalar_functions() -> Vec<Arc<ScalarUDF>> {
vec![
Arc::new(ScalarUDF::new_from_impl(SparkArrayCompact::default())),
Arc::new(ScalarUDF::new_from_impl(SparkArrayPositionFunc::default())),
Arc::new(ScalarUDF::new_from_impl(SparkArraysOverlap::default())),
Arc::new(ScalarUDF::new_from_impl(SparkContains::default())),
Arc::new(ScalarUDF::new_from_impl(SparkDateDiff::default())),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,7 @@ object QueryPlanSerde extends Logging with CometExprShim {
classOf[ArrayJoin] -> CometArrayJoin,
classOf[ArrayMax] -> CometArrayMax,
classOf[ArrayMin] -> CometArrayMin,
classOf[ArrayPosition] -> CometArrayPosition,
classOf[ArrayRemove] -> CometArrayRemove,
classOf[ArrayRepeat] -> CometArrayRepeat,
classOf[SortArray] -> CometSortArray,
Expand Down
Loading
Loading