diff --git a/docs/spark_expressions_support.md b/docs/spark_expressions_support.md index 362869f11e..b71ef83aea 100644 --- a/docs/spark_expressions_support.md +++ b/docs/spark_expressions_support.md @@ -93,7 +93,7 @@ - [x] array_join - [x] array_max - [ ] array_min -- [ ] array_position +- [x] array_position - [x] array_remove - [x] array_repeat - [x] array_union diff --git a/native/spark-expr/src/array_funcs/array_position.rs b/native/spark-expr/src/array_funcs/array_position.rs new file mode 100644 index 0000000000..dbcd8615df --- /dev/null +++ b/native/spark-expr/src/array_funcs/array_position.rs @@ -0,0 +1,335 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use arrow::array::{ + Array, ArrayRef, AsArray, BooleanArray, GenericListArray, Int64Array, OffsetSizeTrait, +}; +use arrow::buffer::{NullBuffer, ScalarBuffer}; +use arrow::datatypes::{ + ArrowPrimitiveType, DataType, Date32Type, Decimal128Type, Float32Type, Float64Type, Int16Type, + Int32Type, Int64Type, Int8Type, TimestampMicrosecondType, +}; +use datafusion::common::{exec_err, DataFusionError, Result as DataFusionResult, ScalarValue}; +use datafusion::logical_expr::{ + ColumnarValue, ScalarFunctionArgs, ScalarUDFImpl, Signature, TypeSignature, Volatility, +}; +use num::Float; +use std::any::Any; +use std::sync::Arc; + +/// Spark array_position() function that returns the 1-based position of an element in an array. +/// Returns 0 if the element is not found (Spark behavior differs from DataFusion which returns null). +fn spark_array_position(args: &[ColumnarValue]) -> Result { + if args.len() != 2 { + return exec_err!("array_position function takes exactly two arguments"); + } + + let len = args + .iter() + .fold(Option::::None, |acc, arg| match arg { + ColumnarValue::Scalar(_) => acc, + ColumnarValue::Array(a) => Some(a.len()), + }); + + let is_scalar = len.is_none(); + let arrays = ColumnarValue::values_to_arrays(args)?; + + let result = array_position_inner(&arrays)?; + + if is_scalar { + let scalar = ScalarValue::try_from_array(&result, 0)?; + Ok(ColumnarValue::Scalar(scalar)) + } else { + Ok(ColumnarValue::Array(result)) + } +} + +fn array_position_inner(args: &[ArrayRef]) -> Result { + let array = &args[0]; + let element = &args[1]; + + match array.data_type() { + DataType::List(_) => generic_array_position::(array, element), + DataType::LargeList(_) => generic_array_position::(array, element), + other => exec_err!("array_position does not support type '{other:?}'"), + } +} + +/// Searches for an element in a list array using the flat values buffer and offsets directly, +/// avoiding per-row subarray allocation. Dispatches to typed fast paths by element data type. +fn generic_array_position( + array: &ArrayRef, + element: &ArrayRef, +) -> Result { + let list_array = array + .as_any() + .downcast_ref::>() + .ok_or_else(|| DataFusionError::Internal("expected list array".into()))?; + + let values = list_array.values(); + let offsets = list_array.offsets(); + let elem_type = values.data_type().clone(); + + match &elem_type { + DataType::Boolean => position_boolean::(list_array, offsets, values, element), + DataType::Int8 => position_primitive::(list_array, offsets, values, element), + DataType::Int16 => position_primitive::(list_array, offsets, values, element), + DataType::Int32 => position_primitive::(list_array, offsets, values, element), + DataType::Int64 => position_primitive::(list_array, offsets, values, element), + DataType::Float32 => position_float::(list_array, offsets, values, element), + DataType::Float64 => position_float::(list_array, offsets, values, element), + DataType::Decimal128(_, _) => { + position_primitive::(list_array, offsets, values, element) + } + DataType::Date32 => { + position_primitive::(list_array, offsets, values, element) + } + DataType::Timestamp(arrow::datatypes::TimeUnit::Microsecond, _) => { + position_primitive::(list_array, offsets, values, element) + } + DataType::Utf8 => position_string::(list_array, offsets, values, element), + DataType::LargeUtf8 => position_string::(list_array, offsets, values, element), + // Fallback to ScalarValue for complex types (nested arrays, etc.) + _ => position_fallback::(list_array, offsets, values, element), + } +} + +/// Compute the combined null buffer from list array and element nulls. +fn combined_nulls( + list_array_nulls: Option<&NullBuffer>, + element_nulls: Option<&NullBuffer>, +) -> Option { + match (list_array_nulls, element_nulls) { + (Some(a), Some(b)) => NullBuffer::union(Some(a), Some(b)), + (Some(a), None) => Some(a.clone()), + (None, Some(b)) => Some(b.clone()), + (None, None) => None, + } +} + +/// Fast path for primitive types: downcast once, iterate using offsets into the flat buffer. +fn position_primitive( + list_array: &GenericListArray, + offsets: &arrow::buffer::OffsetBuffer, + values: &ArrayRef, + element: &ArrayRef, +) -> Result +where + T::Native: PartialEq, +{ + let values_typed = values.as_primitive::(); + let element_typed = element.as_primitive::(); + let num_rows = list_array.len(); + let nulls = combined_nulls(list_array.nulls(), element.nulls()); + let mut result = vec![0i64; num_rows]; + + for (row_index, w) in offsets.windows(2).enumerate() { + if nulls.as_ref().is_some_and(|n| n.is_null(row_index)) { + continue; + } + let start = w[0].as_usize(); + let end = w[1].as_usize(); + let search_val = element_typed.value(row_index); + for i in start..end { + if !values_typed.is_null(i) && values_typed.value(i) == search_val { + result[row_index] = (i - start + 1) as i64; + break; + } + } + } + + Ok(Arc::new(Int64Array::new(ScalarBuffer::from(result), nulls))) +} + +/// Float path: same as primitive but treats NaN == NaN (Spark's ordering.equiv() semantics). +fn position_float( + list_array: &GenericListArray, + offsets: &arrow::buffer::OffsetBuffer, + values: &ArrayRef, + element: &ArrayRef, +) -> Result +where + T::Native: PartialEq + num::Float, +{ + let values_typed = values.as_primitive::(); + let element_typed = element.as_primitive::(); + let num_rows = list_array.len(); + let nulls = combined_nulls(list_array.nulls(), element.nulls()); + let mut result = vec![0i64; num_rows]; + + for (row_index, w) in offsets.windows(2).enumerate() { + if nulls.as_ref().is_some_and(|n| n.is_null(row_index)) { + continue; + } + let start = w[0].as_usize(); + let end = w[1].as_usize(); + let search_val = element_typed.value(row_index); + let search_is_nan = search_val.is_nan(); + for i in start..end { + if !values_typed.is_null(i) { + let v = values_typed.value(i); + if (search_is_nan && v.is_nan()) || v == search_val { + result[row_index] = (i - start + 1) as i64; + break; + } + } + } + } + + Ok(Arc::new(Int64Array::new(ScalarBuffer::from(result), nulls))) +} + +/// Boolean path. +fn position_boolean( + list_array: &GenericListArray, + offsets: &arrow::buffer::OffsetBuffer, + values: &ArrayRef, + element: &ArrayRef, +) -> Result { + let values_typed = values + .as_any() + .downcast_ref::() + .ok_or_else(|| DataFusionError::Internal("expected boolean array".into()))?; + let element_typed = element + .as_any() + .downcast_ref::() + .ok_or_else(|| DataFusionError::Internal("expected boolean array".into()))?; + let num_rows = list_array.len(); + let nulls = combined_nulls(list_array.nulls(), element.nulls()); + let mut result = vec![0i64; num_rows]; + + for (row_index, w) in offsets.windows(2).enumerate() { + if nulls.as_ref().is_some_and(|n| n.is_null(row_index)) { + continue; + } + let start = w[0].as_usize(); + let end = w[1].as_usize(); + let search_val = element_typed.value(row_index); + for i in start..end { + if !values_typed.is_null(i) && values_typed.value(i) == search_val { + result[row_index] = (i - start + 1) as i64; + break; + } + } + } + + Ok(Arc::new(Int64Array::new(ScalarBuffer::from(result), nulls))) +} + +/// String path: downcast once, iterate using offsets into the flat string buffer. +fn position_string( + list_array: &GenericListArray, + offsets: &arrow::buffer::OffsetBuffer, + values: &ArrayRef, + element: &ArrayRef, +) -> Result { + let values_typed = values.as_string::(); + let element_typed = element.as_string::(); + let num_rows = list_array.len(); + let nulls = combined_nulls(list_array.nulls(), element.nulls()); + let mut result = vec![0i64; num_rows]; + + for (row_index, w) in offsets.windows(2).enumerate() { + if nulls.as_ref().is_some_and(|n| n.is_null(row_index)) { + continue; + } + let start = w[0].as_usize(); + let end = w[1].as_usize(); + let search_val = element_typed.value(row_index); + for i in start..end { + if !values_typed.is_null(i) && values_typed.value(i) == search_val { + result[row_index] = (i - start + 1) as i64; + break; + } + } + } + + Ok(Arc::new(Int64Array::new(ScalarBuffer::from(result), nulls))) +} + +/// Fallback for complex types (nested arrays, structs, etc.) using ScalarValue comparison. +fn position_fallback( + list_array: &GenericListArray, + offsets: &arrow::buffer::OffsetBuffer, + values: &ArrayRef, + element: &ArrayRef, +) -> Result { + let num_rows = list_array.len(); + let nulls = combined_nulls(list_array.nulls(), element.nulls()); + let mut result = vec![0i64; num_rows]; + + for (row_index, w) in offsets.windows(2).enumerate() { + if nulls.as_ref().is_some_and(|n| n.is_null(row_index)) { + continue; + } + let start = w[0].as_usize(); + let end = w[1].as_usize(); + let search_scalar = ScalarValue::try_from_array(element, row_index)?; + for i in start..end { + if !values.is_null(i) { + let item_scalar = ScalarValue::try_from_array(values, i)?; + if search_scalar == item_scalar { + result[row_index] = (i - start + 1) as i64; + break; + } + } + } + } + + Ok(Arc::new(Int64Array::new(ScalarBuffer::from(result), nulls))) +} + +#[derive(Debug, Hash, Eq, PartialEq)] +pub struct SparkArrayPositionFunc { + signature: Signature, +} + +impl Default for SparkArrayPositionFunc { + fn default() -> Self { + Self::new() + } +} + +impl SparkArrayPositionFunc { + pub fn new() -> Self { + Self { + signature: Signature::new(TypeSignature::Any(2), Volatility::Immutable), + } + } +} + +impl ScalarUDFImpl for SparkArrayPositionFunc { + fn as_any(&self) -> &dyn Any { + self + } + + fn name(&self) -> &str { + "spark_array_position" + } + + fn signature(&self) -> &Signature { + &self.signature + } + + fn return_type(&self, _arg_types: &[DataType]) -> DataFusionResult { + Ok(DataType::Int64) + } + + fn invoke_with_args(&self, args: ScalarFunctionArgs) -> DataFusionResult { + spark_array_position(&args.args) + } +} diff --git a/native/spark-expr/src/array_funcs/mod.rs b/native/spark-expr/src/array_funcs/mod.rs index 99f0a1eecc..5a503eba39 100644 --- a/native/spark-expr/src/array_funcs/mod.rs +++ b/native/spark-expr/src/array_funcs/mod.rs @@ -17,6 +17,7 @@ mod array_compact; mod array_insert; +mod array_position; mod arrays_overlap; mod arrays_zip; mod get_array_struct_fields; @@ -25,6 +26,7 @@ mod size; pub use array_compact::SparkArrayCompact; pub use array_insert::ArrayInsert; +pub use array_position::SparkArrayPositionFunc; pub use arrays_overlap::SparkArraysOverlap; pub use arrays_zip::SparkArraysZipFunc; pub use get_array_struct_fields::GetArrayStructFields; diff --git a/native/spark-expr/src/comet_scalar_funcs.rs b/native/spark-expr/src/comet_scalar_funcs.rs index 74e688cd1c..a3ea0e6816 100644 --- a/native/spark-expr/src/comet_scalar_funcs.rs +++ b/native/spark-expr/src/comet_scalar_funcs.rs @@ -23,8 +23,9 @@ use crate::math_funcs::modulo_expr::spark_modulo; use crate::{ spark_ceil, spark_decimal_div, spark_decimal_integral_div, spark_floor, spark_isnan, spark_lpad, spark_make_decimal, spark_read_side_padding, spark_round, spark_rpad, spark_unhex, - spark_unscaled_value, EvalMode, SparkArrayCompact, SparkArraysOverlap, SparkContains, - SparkDateDiff, SparkDateFromUnixDate, SparkDateTrunc, SparkMakeDate, SparkSizeFunc, + spark_unscaled_value, EvalMode, SparkArrayCompact, SparkArrayPositionFunc, SparkArraysOverlap, + SparkContains, SparkDateDiff, SparkDateFromUnixDate, SparkDateTrunc, SparkMakeDate, + SparkSizeFunc, }; use arrow::datatypes::DataType; use datafusion::common::{DataFusionError, Result as DataFusionResult}; @@ -201,6 +202,7 @@ pub fn create_comet_physical_fun_with_eval_mode( fn all_scalar_functions() -> Vec> { vec![ Arc::new(ScalarUDF::new_from_impl(SparkArrayCompact::default())), + Arc::new(ScalarUDF::new_from_impl(SparkArrayPositionFunc::default())), Arc::new(ScalarUDF::new_from_impl(SparkArraysOverlap::default())), Arc::new(ScalarUDF::new_from_impl(SparkContains::default())), Arc::new(ScalarUDF::new_from_impl(SparkDateDiff::default())), diff --git a/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala b/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala index 768e4e0ed6..62159a40a1 100644 --- a/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala +++ b/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala @@ -58,6 +58,7 @@ object QueryPlanSerde extends Logging with CometExprShim { classOf[ArrayJoin] -> CometArrayJoin, classOf[ArrayMax] -> CometArrayMax, classOf[ArrayMin] -> CometArrayMin, + classOf[ArrayPosition] -> CometArrayPosition, classOf[ArrayRemove] -> CometArrayRemove, classOf[ArrayRepeat] -> CometArrayRepeat, classOf[SortArray] -> CometSortArray, diff --git a/spark/src/main/scala/org/apache/comet/serde/arrays.scala b/spark/src/main/scala/org/apache/comet/serde/arrays.scala index 14d4536fc1..f54eb6d67a 100644 --- a/spark/src/main/scala/org/apache/comet/serde/arrays.scala +++ b/spark/src/main/scala/org/apache/comet/serde/arrays.scala @@ -22,7 +22,7 @@ package org.apache.comet.serde import scala.annotation.tailrec import scala.jdk.CollectionConverters._ -import org.apache.spark.sql.catalyst.expressions.{And, ArrayAppend, ArrayContains, ArrayExcept, ArrayFilter, ArrayInsert, ArrayIntersect, ArrayJoin, ArrayMax, ArrayMin, ArrayRemove, ArrayRepeat, ArraysOverlap, ArraysZip, ArrayUnion, Attribute, CreateArray, ElementAt, EmptyRow, Expression, Flatten, GetArrayItem, IsNotNull, Literal, Reverse, Size, SortArray} +import org.apache.spark.sql.catalyst.expressions.{And, ArrayAppend, ArrayContains, ArrayExcept, ArrayFilter, ArrayInsert, ArrayIntersect, ArrayJoin, ArrayMax, ArrayMin, ArrayPosition, ArrayRemove, ArrayRepeat, ArraysOverlap, ArraysZip, ArrayUnion, Attribute, CreateArray, ElementAt, EmptyRow, Expression, Flatten, GetArrayItem, IsNotNull, Literal, Reverse, Size, SortArray} import org.apache.spark.sql.catalyst.util.GenericArrayData import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types._ @@ -654,6 +654,38 @@ object CometSize extends CometExpressionSerde[Size] { } +object CometArrayPosition extends CometExpressionSerde[ArrayPosition] with ArraysBase { + + override def getSupportLevel(expr: ArrayPosition): SupportLevel = Compatible() + + override def convert( + expr: ArrayPosition, + inputs: Seq[Attribute], + binding: Boolean): Option[ExprOuterClass.Expr] = { + if (expr.children.forall(_.foldable)) { + withInfo(expr, "all arguments are literals, falling back to Spark") + return None + } + // Check if input types are supported + val inputTypes: Set[DataType] = expr.children.map(_.dataType).toSet + for (dt <- inputTypes) { + if (!isTypeSupported(dt)) { + withInfo(expr, s"data type not supported: $dt") + return None + } + } + + val arrayExprProto = exprToProto(expr.left, inputs, binding) + val elementExprProto = exprToProto(expr.right, inputs, binding) + + // Use spark_array_position which returns Int64 and 0 when not found + // (matching Spark's behavior) + val optExpr = + scalarFunctionExprToProto("spark_array_position", arrayExprProto, elementExprProto) + optExprWithInfo(optExpr, expr, expr.left, expr.right) + } +} + object CometArraysZip extends CometExpressionSerde[ArraysZip] { private def isTypeSupported(dt: DataType): Boolean = { diff --git a/spark/src/test/resources/sql-tests/expressions/array/array_position.sql b/spark/src/test/resources/sql-tests/expressions/array/array_position.sql new file mode 100644 index 0000000000..132158ab6d --- /dev/null +++ b/spark/src/test/resources/sql-tests/expressions/array/array_position.sql @@ -0,0 +1,283 @@ +-- Licensed to the Apache Software Foundation (ASF) under one +-- or more contributor license agreements. See the NOTICE file +-- distributed with this work for additional information +-- regarding copyright ownership. The ASF licenses this file +-- to you under the Apache License, Version 2.0 (the +-- "License"); you may not use this file except in compliance +-- with the License. You may obtain a copy of the License at +-- +-- http://www.apache.org/licenses/LICENSE-2.0 +-- +-- Unless required by applicable law or agreed to in writing, +-- software distributed under the License is distributed on an +-- "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +-- KIND, either express or implied. See the License for the +-- specific language governing permissions and limitations +-- under the License. + +-- ConfigMatrix: parquet.enable.dictionary=false,true + +statement +CREATE TABLE test_array_position(int_arr array, str_arr array, val int, str_val string) USING parquet + +statement +INSERT INTO test_array_position VALUES + (array(1, 2, 3, 4), array('a', 'b', 'c'), 2, 'b'), + (array(1, 2, NULL, 3), array('a', NULL, 'c'), 3, 'c'), + (array(10, 20, 30), array('x', 'y', 'z'), 99, 'w'), + (array(), array(), 1, 'a'), + (NULL, NULL, 1, 'a'), + (array(1, 1, 1), array('a', 'a', 'a'), 1, 'a'), + (array(5, 6, 7), array('p', 'q', 'r'), NULL, NULL) + +-- literal args fall back to Spark +query spark_answer_only +SELECT array_position(array(1, 2, 3, 4), 3) + +query spark_answer_only +SELECT array_position(array(1, 2, 3, 4), 5) + +query spark_answer_only +SELECT array_position(array('a', 'b', 'c'), 'b') + +query spark_answer_only +SELECT array_position(array(1, 2, NULL, 3), 3) + +query spark_answer_only +SELECT array_position(array(1, 2, 3), cast(NULL as int)) + +query spark_answer_only +SELECT array_position(cast(NULL as array), 1) + +query spark_answer_only +SELECT array_position(array(), 1) + +query spark_answer_only +SELECT array_position(array(1, 2, 1, 3), 1) + +-- column array + column value (includes NULL val row) +query +SELECT array_position(int_arr, val) FROM test_array_position + +-- column array + literal value +query +SELECT array_position(int_arr, 3) FROM test_array_position + +-- literal array + column value +query +SELECT array_position(array(1, 2, 3), val) FROM test_array_position + +-- string column array + column value (includes NULL str_val row) +query +SELECT array_position(str_arr, str_val) FROM test_array_position + +-- string column array + literal value +query +SELECT array_position(str_arr, 'c') FROM test_array_position + +-- expressions in array construction +query +SELECT array_position(array(val, val + 1, val + 2), val) FROM test_array_position + +-- boolean arrays +statement +CREATE TABLE test_ap_bool(arr array, val boolean) USING parquet + +statement +INSERT INTO test_ap_bool VALUES + (array(true, false, true), false), + (array(true, true), false), + (array(false, false), true), + (NULL, true), + (array(true, false), NULL) + +query +SELECT array_position(arr, val) FROM test_ap_bool + +-- tinyint arrays +statement +CREATE TABLE test_ap_byte(arr array, val tinyint) USING parquet + +statement +INSERT INTO test_ap_byte VALUES + (array(cast(1 as tinyint), cast(2 as tinyint), cast(3 as tinyint)), cast(2 as tinyint)), + (array(cast(-128 as tinyint), cast(127 as tinyint)), cast(127 as tinyint)), + (NULL, cast(1 as tinyint)), + (array(cast(1 as tinyint)), NULL) + +query +SELECT array_position(arr, val) FROM test_ap_byte + +-- smallint arrays +statement +CREATE TABLE test_ap_short(arr array, val smallint) USING parquet + +statement +INSERT INTO test_ap_short VALUES + (array(cast(100 as smallint), cast(200 as smallint), cast(300 as smallint)), cast(200 as smallint)), + (NULL, cast(1 as smallint)), + (array(cast(1 as smallint)), NULL) + +query +SELECT array_position(arr, val) FROM test_ap_short + +-- bigint arrays +statement +CREATE TABLE test_ap_long(arr array, val bigint) USING parquet + +statement +INSERT INTO test_ap_long VALUES + (array(cast(1000000000000 as bigint), cast(2000000000000 as bigint)), cast(2000000000000 as bigint)), + (array(cast(-1 as bigint), cast(0 as bigint), cast(1 as bigint)), cast(0 as bigint)), + (NULL, cast(1 as bigint)), + (array(cast(1 as bigint)), NULL) + +query +SELECT array_position(arr, val) FROM test_ap_long + +-- float arrays +statement +CREATE TABLE test_ap_float(arr array, val float) USING parquet + +statement +INSERT INTO test_ap_float VALUES + (array(cast(1.1 as float), cast(2.2 as float), cast(3.3 as float)), cast(2.2 as float)), + (array(cast(0.0 as float), cast(-1.5 as float)), cast(-1.5 as float)), + (NULL, cast(1.0 as float)), + (array(cast(1.0 as float)), NULL) + +query +SELECT array_position(arr, val) FROM test_ap_float + +-- double arrays +statement +CREATE TABLE test_ap_double(arr array, val double) USING parquet + +statement +INSERT INTO test_ap_double VALUES + (array(1.1, 2.2, 3.3), 2.2), + (array(0.0, -1.5), -1.5), + (NULL, 1.0), + (array(1.0), NULL) + +query +SELECT array_position(arr, val) FROM test_ap_double + +-- NaN handling for float arrays (Spark treats NaN == NaN) +query spark_answer_only +SELECT array_position(array(cast('NaN' as float), cast(1.0 as float)), cast('NaN' as float)) + +query spark_answer_only +SELECT array_position(array(cast(1.0 as float), cast('NaN' as float), cast(2.0 as float)), cast('NaN' as float)) + +-- NaN handling for double arrays (Spark treats NaN == NaN) +query spark_answer_only +SELECT array_position(array(cast('NaN' as double), 1.0), cast('NaN' as double)) + +query spark_answer_only +SELECT array_position(array(1.0, cast('NaN' as double), 2.0), cast('NaN' as double)) + +-- NaN handling with column data +statement +CREATE TABLE test_ap_nan(arr array, val float) USING parquet + +statement +INSERT INTO test_ap_nan VALUES + (array(cast('NaN' as float), cast(1.0 as float)), cast('NaN' as float)), + (array(cast(1.0 as float), cast('NaN' as float), cast(2.0 as float)), cast('NaN' as float)), + (array(cast(1.0 as float), cast(2.0 as float)), cast('NaN' as float)) + +query +SELECT array_position(arr, val) FROM test_ap_nan + +-- decimal arrays +statement +CREATE TABLE test_ap_decimal(arr array, val decimal(10,2)) USING parquet + +statement +INSERT INTO test_ap_decimal VALUES + (array(cast(1.10 as decimal(10,2)), cast(2.20 as decimal(10,2)), cast(3.30 as decimal(10,2))), cast(2.20 as decimal(10,2))), + (array(cast(0.00 as decimal(10,2))), cast(0.00 as decimal(10,2))), + (NULL, cast(1.00 as decimal(10,2))), + (array(cast(1.00 as decimal(10,2))), NULL) + +query +SELECT array_position(arr, val) FROM test_ap_decimal + +-- date arrays +statement +CREATE TABLE test_ap_date(arr array, val date) USING parquet + +statement +INSERT INTO test_ap_date VALUES + (array(date '2024-01-01', date '2024-06-15', date '2024-12-31'), date '2024-06-15'), + (array(date '2000-01-01'), date '1999-12-31'), + (NULL, date '2024-01-01'), + (array(date '2024-01-01'), NULL) + +query +SELECT array_position(arr, val) FROM test_ap_date + +-- nested array (exercises position_fallback code path) +query spark_answer_only +SELECT array_position(array(array(1, 2), array(3, 4)), array(1, 2)) + +query spark_answer_only +SELECT array_position(array(array(1, 2), array(3, 4)), array(5, 6)) + +-- nested int array column (exercises position_fallback natively) +statement +CREATE TABLE test_ap_nested_int(arr array>, val array) USING parquet + +statement +INSERT INTO test_ap_nested_int VALUES + (array(array(1, 2), array(3, 4), array(5, 6)), array(3, 4)), + (array(array(1, 2), array(3, 4)), array(5, 6)), + (array(array(1, 2), array(1, 2, 3)), array(1, 2)), + (NULL, array(1, 2)), + (array(array(1, 2)), NULL) + +query +SELECT array_position(arr, val) FROM test_ap_nested_int + +-- nested string array column (exercises position_fallback natively) +statement +CREATE TABLE test_ap_nested_str(arr array>, val array) USING parquet + +statement +INSERT INTO test_ap_nested_str VALUES + (array(array('a', 'b'), array('c', 'd')), array('c', 'd')), + (array(array('a', 'b'), array('c', 'd')), array('x', 'y')), + (NULL, array('a')), + (array(array('a')), NULL) + +query +SELECT array_position(arr, val) FROM test_ap_nested_str + +-- timestamp arrays +statement +CREATE TABLE test_ap_ts(arr array, val timestamp) USING parquet + +statement +INSERT INTO test_ap_ts VALUES + (array(timestamp '2024-01-01 00:00:00', timestamp '2024-06-15 12:30:00'), timestamp '2024-06-15 12:30:00'), + (array(timestamp '2000-01-01 00:00:00'), timestamp '1999-12-31 23:59:59'), + (NULL, timestamp '2024-01-01 00:00:00'), + (array(timestamp '2024-01-01 00:00:00'), NULL) + +query +SELECT array_position(arr, val) FROM test_ap_ts + +-- timestamp_ntz arrays +statement +CREATE TABLE test_ap_ts_ntz(arr array, val timestamp_ntz) USING parquet + +statement +INSERT INTO test_ap_ts_ntz VALUES + (array(timestamp_ntz '2024-01-01 00:00:00', timestamp_ntz '2024-06-15 12:30:00'), timestamp_ntz '2024-06-15 12:30:00'), + (array(timestamp_ntz '2000-01-01 00:00:00'), timestamp_ntz '1999-12-31 23:59:59'), + (NULL, timestamp_ntz '2024-01-01 00:00:00'), + (array(timestamp_ntz '2024-01-01 00:00:00'), NULL) + +query +SELECT array_position(arr, val) FROM test_ap_ts_ntz diff --git a/spark/src/test/scala/org/apache/spark/sql/benchmark/CometArrayExpressionBenchmark.scala b/spark/src/test/scala/org/apache/spark/sql/benchmark/CometArrayExpressionBenchmark.scala index 784923ca2e..44ef1a4735 100644 --- a/spark/src/test/scala/org/apache/spark/sql/benchmark/CometArrayExpressionBenchmark.scala +++ b/spark/src/test/scala/org/apache/spark/sql/benchmark/CometArrayExpressionBenchmark.scala @@ -19,6 +19,9 @@ package org.apache.spark.sql.benchmark +import org.apache.comet.CometConf + +// spotless:off /** * Benchmark to measure performance of Comet array expressions. To run this benchmark: * {{{ @@ -26,6 +29,7 @@ package org.apache.spark.sql.benchmark * }}} * Results will be written to "spark/benchmarks/CometArrayExpressionBenchmark-**results.txt". */ +// spotless:on object CometArrayExpressionBenchmark extends CometBenchmarkBase { private def buildWideIntArrayExpr(width: Int, modulus: Int): String = { @@ -86,6 +90,72 @@ object CometArrayExpressionBenchmark extends CometBenchmarkBase { } } + def arrayPositionBenchmark(values: Int): Unit = { + withTempPath { dir => + withTempTable("parquetV1Table") { + // Create a table with int arrays of size 10 and a search value + prepareTable( + dir, + spark.sql(s"""SELECT + | array( + | cast(value % 100 as int), + | cast((value + 1) % 100 as int), + | cast((value + 2) % 100 as int), + | cast((value + 3) % 100 as int), + | cast((value + 4) % 100 as int), + | cast((value + 5) % 100 as int), + | cast((value + 6) % 100 as int), + | cast((value + 7) % 100 as int), + | cast((value + 8) % 100 as int), + | cast((value + 9) % 100 as int) + | ) as int_arr, + | cast((value + 5) % 100 as int) as search_val + |FROM $tbl""".stripMargin)) + + val nativeScanConfig = + Map(CometConf.COMET_NATIVE_SCAN_IMPL.key -> CometConf.SCAN_NATIVE_DATAFUSION) + + runExpressionBenchmark( + "array_position - int array", + values, + "SELECT array_position(int_arr, search_val) FROM parquetV1Table", + nativeScanConfig) + } + } + + withTempPath { dir => + withTempTable("parquetV1Table") { + // Create a table with string arrays of size 10 and a search value + prepareTable( + dir, + spark.sql(s"""SELECT + | array( + | cast(value % 100 as string), + | cast((value + 1) % 100 as string), + | cast((value + 2) % 100 as string), + | cast((value + 3) % 100 as string), + | cast((value + 4) % 100 as string), + | cast((value + 5) % 100 as string), + | cast((value + 6) % 100 as string), + | cast((value + 7) % 100 as string), + | cast((value + 8) % 100 as string), + | cast((value + 9) % 100 as string) + | ) as str_arr, + | cast((value + 5) % 100 as string) as search_val + |FROM $tbl""".stripMargin)) + + val nativeScanConfig = + Map(CometConf.COMET_NATIVE_SCAN_IMPL.key -> CometConf.SCAN_NATIVE_DATAFUSION) + + runExpressionBenchmark( + "array_position - string array", + values, + "SELECT array_position(str_arr, search_val) FROM parquetV1Table", + nativeScanConfig) + } + } + } + override def runCometBenchmark(mainArgs: Array[String]): Unit = { val values = 4 * 1024 * 1024 @@ -104,5 +174,9 @@ object CometArrayExpressionBenchmark extends CometBenchmarkBase { runBenchmarkWithTable("sortArrayIntAscFirstElement", values) { v => sortArrayIntAscFirstElementBenchmark(v, width = 32) } + + runBenchmarkWithTable("ArrayPosition", values) { v => + arrayPositionBenchmark(v) + } } }