From 3911ed0a1f70223f3ae5ff146cdfb89464404060 Mon Sep 17 00:00:00 2001 From: Andy Grove Date: Wed, 14 Jan 2026 19:27:29 -0700 Subject: [PATCH 01/18] feat: add support for array_position expression Implements Spark's array_position function which returns the 1-based position of an element in an array, returning 0 if not found. This required a custom Rust implementation because DataFusion's array_position returns UInt64 and null when not found, while Spark returns Int64 (LongType) and 0. Key implementation details: - Returns Int64 to match Spark's LongType - Returns 0 when element is not found (Spark behavior) - Returns null when array is null or search element is null - Supports both List and LargeList array types Closes #3153 Co-Authored-By: Claude Opus 4.5 --- .../src/array_funcs/array_position.rs | 148 ++++++++++++++++++ native/spark-expr/src/array_funcs/mod.rs | 2 + native/spark-expr/src/comet_scalar_funcs.rs | 5 +- .../apache/comet/serde/QueryPlanSerde.scala | 1 + .../scala/org/apache/comet/serde/arrays.scala | 28 +++- .../comet/CometArrayExpressionSuite.scala | 33 ++++ 6 files changed, 214 insertions(+), 3 deletions(-) create mode 100644 native/spark-expr/src/array_funcs/array_position.rs diff --git a/native/spark-expr/src/array_funcs/array_position.rs b/native/spark-expr/src/array_funcs/array_position.rs new file mode 100644 index 0000000000..868e3307b6 --- /dev/null +++ b/native/spark-expr/src/array_funcs/array_position.rs @@ -0,0 +1,148 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use arrow::array::{Array, ArrayRef, GenericListArray, Int64Array, OffsetSizeTrait}; +use arrow::datatypes::DataType; +use datafusion::common::{exec_err, DataFusionError, Result as DataFusionResult, ScalarValue}; +use datafusion::logical_expr::{ + ColumnarValue, ScalarFunctionArgs, ScalarUDFImpl, Signature, Volatility, +}; +use std::any::Any; +use std::sync::Arc; + +/// Spark array_position() function that returns the 1-based position of an element in an array. +/// Returns 0 if the element is not found (Spark behavior differs from DataFusion which returns null). +pub fn spark_array_position(args: &[ColumnarValue]) -> Result { + if args.len() != 2 { + return exec_err!("array_position function takes exactly two arguments"); + } + + // Convert all arguments to arrays for consistent processing + let len = args + .iter() + .fold(Option::::None, |acc, arg| match arg { + ColumnarValue::Scalar(_) => acc, + ColumnarValue::Array(a) => Some(a.len()), + }); + + let is_scalar = len.is_none(); + let arrays = ColumnarValue::values_to_arrays(args)?; + + let result = array_position_inner(&arrays)?; + + if is_scalar { + let scalar = ScalarValue::try_from_array(&result, 0)?; + Ok(ColumnarValue::Scalar(scalar)) + } else { + Ok(ColumnarValue::Array(result)) + } +} + +fn array_position_inner(args: &[ArrayRef]) -> Result { + let array = &args[0]; + let element = &args[1]; + + match array.data_type() { + DataType::List(_) => generic_array_position::(array, element), + DataType::LargeList(_) => generic_array_position::(array, element), + other => exec_err!("array_position does not support type '{other:?}'"), + } +} + +fn generic_array_position( + array: &ArrayRef, + element: &ArrayRef, +) -> Result { + let list_array = array + .as_any() + .downcast_ref::>() + .unwrap(); + + let mut data = Vec::with_capacity(list_array.len()); + + for row_index in 0..list_array.len() { + if list_array.is_null(row_index) { + // Null array returns null position (same as Spark) + data.push(None); + } else if element.is_null(row_index) { + // Searching for null element returns null in Spark + data.push(None); + } else { + let list_array_row = list_array.value(row_index); + + // Get the search element as a scalar + let element_scalar = ScalarValue::try_from_array(element, row_index)?; + + // Compare element to each item in the list + let mut position: i64 = 0; + for i in 0..list_array_row.len() { + let list_item_scalar = ScalarValue::try_from_array(&list_array_row, i)?; + + // null != anything in Spark array_position + if !list_item_scalar.is_null() && element_scalar == list_item_scalar { + position = (i + 1) as i64; // 1-indexed + break; + } + } + + data.push(Some(position)); + } + } + + Ok(Arc::new(Int64Array::from(data))) +} + +#[derive(Debug, Hash, Eq, PartialEq)] +pub struct SparkArrayPositionFunc { + signature: Signature, +} + +impl Default for SparkArrayPositionFunc { + fn default() -> Self { + Self::new() + } +} + +impl SparkArrayPositionFunc { + pub fn new() -> Self { + Self { + signature: Signature::variadic_any(Volatility::Immutable), + } + } +} + +impl ScalarUDFImpl for SparkArrayPositionFunc { + fn as_any(&self) -> &dyn Any { + self + } + + fn name(&self) -> &str { + "spark_array_position" + } + + fn signature(&self) -> &Signature { + &self.signature + } + + fn return_type(&self, _arg_types: &[DataType]) -> DataFusionResult { + Ok(DataType::Int64) + } + + fn invoke_with_args(&self, args: ScalarFunctionArgs) -> DataFusionResult { + spark_array_position(&args.args) + } +} diff --git a/native/spark-expr/src/array_funcs/mod.rs b/native/spark-expr/src/array_funcs/mod.rs index 063dd7a5aa..2961ba514e 100644 --- a/native/spark-expr/src/array_funcs/mod.rs +++ b/native/spark-expr/src/array_funcs/mod.rs @@ -16,12 +16,14 @@ // under the License. mod array_insert; +mod array_position; mod array_repeat; mod get_array_struct_fields; mod list_extract; mod size; pub use array_insert::ArrayInsert; +pub use array_position::{spark_array_position, SparkArrayPositionFunc}; pub use array_repeat::spark_array_repeat; pub use get_array_struct_fields::GetArrayStructFields; pub use list_extract::ListExtract; diff --git a/native/spark-expr/src/comet_scalar_funcs.rs b/native/spark-expr/src/comet_scalar_funcs.rs index 8384a4646a..cc188bea8c 100644 --- a/native/spark-expr/src/comet_scalar_funcs.rs +++ b/native/spark-expr/src/comet_scalar_funcs.rs @@ -22,8 +22,8 @@ use crate::math_funcs::modulo_expr::spark_modulo; use crate::{ spark_array_repeat, spark_ceil, spark_decimal_div, spark_decimal_integral_div, spark_floor, spark_isnan, spark_lpad, spark_make_decimal, spark_read_side_padding, spark_round, spark_rpad, - spark_unhex, spark_unscaled_value, EvalMode, SparkBitwiseCount, SparkDateTrunc, SparkSizeFunc, - SparkStringSpace, + spark_unhex, spark_unscaled_value, EvalMode, SparkArrayPositionFunc, SparkBitwiseCount, + SparkDateTrunc, SparkSizeFunc, SparkStringSpace, }; use arrow::datatypes::DataType; use datafusion::common::{DataFusionError, Result as DataFusionResult}; @@ -191,6 +191,7 @@ pub fn create_comet_physical_fun_with_eval_mode( fn all_scalar_functions() -> Vec> { vec![ + Arc::new(ScalarUDF::new_from_impl(SparkArrayPositionFunc::default())), Arc::new(ScalarUDF::new_from_impl(SparkBitwiseCount::default())), Arc::new(ScalarUDF::new_from_impl(SparkDateTrunc::default())), Arc::new(ScalarUDF::new_from_impl(SparkStringSpace::default())), diff --git a/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala b/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala index e50b1d80e6..287fd7aa37 100644 --- a/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala +++ b/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala @@ -56,6 +56,7 @@ object QueryPlanSerde extends Logging with CometExprShim { classOf[ArrayJoin] -> CometArrayJoin, classOf[ArrayMax] -> CometArrayMax, classOf[ArrayMin] -> CometArrayMin, + classOf[ArrayPosition] -> CometArrayPosition, classOf[ArrayRemove] -> CometArrayRemove, classOf[ArrayRepeat] -> CometArrayRepeat, classOf[ArraysOverlap] -> CometArraysOverlap, diff --git a/spark/src/main/scala/org/apache/comet/serde/arrays.scala b/spark/src/main/scala/org/apache/comet/serde/arrays.scala index b552a071d6..3a8c8f0710 100644 --- a/spark/src/main/scala/org/apache/comet/serde/arrays.scala +++ b/spark/src/main/scala/org/apache/comet/serde/arrays.scala @@ -21,7 +21,7 @@ package org.apache.comet.serde import scala.annotation.tailrec -import org.apache.spark.sql.catalyst.expressions.{ArrayAppend, ArrayContains, ArrayDistinct, ArrayExcept, ArrayFilter, ArrayInsert, ArrayIntersect, ArrayJoin, ArrayMax, ArrayMin, ArrayRemove, ArrayRepeat, ArraysOverlap, ArrayUnion, Attribute, CreateArray, ElementAt, Expression, Flatten, GetArrayItem, IsNotNull, Literal, Reverse, Size} +import org.apache.spark.sql.catalyst.expressions.{ArrayAppend, ArrayContains, ArrayDistinct, ArrayExcept, ArrayFilter, ArrayInsert, ArrayIntersect, ArrayJoin, ArrayMax, ArrayMin, ArrayPosition, ArrayRemove, ArrayRepeat, ArraysOverlap, ArrayUnion, Attribute, CreateArray, ElementAt, Expression, Flatten, GetArrayItem, IsNotNull, Literal, Reverse, Size} import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types._ @@ -597,6 +597,32 @@ object CometSize extends CometExpressionSerde[Size] { } +object CometArrayPosition extends CometExpressionSerde[ArrayPosition] with ArraysBase { + + override def convert( + expr: ArrayPosition, + inputs: Seq[Attribute], + binding: Boolean): Option[ExprOuterClass.Expr] = { + // Check if input types are supported + val inputTypes: Set[DataType] = expr.children.map(_.dataType).toSet + for (dt <- inputTypes) { + if (!isTypeSupported(dt)) { + withInfo(expr, s"data type not supported: $dt") + return None + } + } + + val arrayExprProto = exprToProto(expr.left, inputs, binding) + val elementExprProto = exprToProto(expr.right, inputs, binding) + + // Use spark_array_position which returns Int64 and 0 when not found + // (matching Spark's behavior) + val optExpr = + scalarFunctionExprToProto("spark_array_position", arrayExprProto, elementExprProto) + optExprWithInfo(optExpr, expr, expr.left, expr.right) + } +} + trait ArraysBase { def isTypeSupported(dt: DataType): Boolean = { diff --git a/spark/src/test/scala/org/apache/comet/CometArrayExpressionSuite.scala b/spark/src/test/scala/org/apache/comet/CometArrayExpressionSuite.scala index cf49117364..3dd66f22ee 100644 --- a/spark/src/test/scala/org/apache/comet/CometArrayExpressionSuite.scala +++ b/spark/src/test/scala/org/apache/comet/CometArrayExpressionSuite.scala @@ -890,4 +890,37 @@ class CometArrayExpressionSuite extends CometTestBase with AdaptiveSparkPlanHelp } } } + + test("array_position") { + Seq(true, false).foreach { dictionaryEnabled => + withTempDir { dir => + val path = new Path(dir.toURI.toString, "test.parquet") + makeParquetFileAllPrimitiveTypes(path, dictionaryEnabled, 100) + spark.read.parquet(path.toString).createOrReplaceTempView("t1") + + // Tests with literal values need constant folding disabled + withSQLConf( + SQLConf.OPTIMIZER_EXCLUDED_RULES.key -> + "org.apache.spark.sql.catalyst.optimizer.ConstantFolding") { + // Basic array_position tests with integers + checkSparkAnswerAndOperator( + sql("SELECT array_position(array(1, 2, 3, 4), 3) FROM t1 LIMIT 1")) + // Element not found should return 0 + checkSparkAnswerAndOperator( + sql("SELECT array_position(array(1, 2, 3, 4), 5) FROM t1 LIMIT 1")) + // Test with strings + checkSparkAnswerAndOperator( + sql("SELECT array_position(array('a', 'b', 'c'), 'b') FROM t1 LIMIT 1")) + // Test with null element in array + checkSparkAnswerAndOperator( + sql("SELECT array_position(array(1, 2, null, 3), 3) FROM t1 LIMIT 1")) + } + + // Test with column values + checkSparkAnswerAndOperator(sql("SELECT array_position(array(1, _4, 3), _4) FROM t1")) + checkSparkAnswerAndOperator( + sql("SELECT array_position(array(_4, _4 + 1, _4 + 2), _4) FROM t1")) + } + } + } } From 8cb27ec7c67aa02715e9e862babca7daf5cd900f Mon Sep 17 00:00:00 2001 From: Andy Grove Date: Thu, 15 Jan 2026 08:15:34 -0700 Subject: [PATCH 02/18] update docs --- docs/source/user-guide/latest/configs.md | 1 + 1 file changed, 1 insertion(+) diff --git a/docs/source/user-guide/latest/configs.md b/docs/source/user-guide/latest/configs.md index 1a273ad033..bbfdf4d39a 100644 --- a/docs/source/user-guide/latest/configs.md +++ b/docs/source/user-guide/latest/configs.md @@ -202,6 +202,7 @@ These settings can be used to determine which parts of the plan are accelerated | `spark.comet.expression.ArrayJoin.enabled` | Enable Comet acceleration for `ArrayJoin` | true | | `spark.comet.expression.ArrayMax.enabled` | Enable Comet acceleration for `ArrayMax` | true | | `spark.comet.expression.ArrayMin.enabled` | Enable Comet acceleration for `ArrayMin` | true | +| `spark.comet.expression.ArrayPosition.enabled` | Enable Comet acceleration for `ArrayPosition` | true | | `spark.comet.expression.ArrayRemove.enabled` | Enable Comet acceleration for `ArrayRemove` | true | | `spark.comet.expression.ArrayRepeat.enabled` | Enable Comet acceleration for `ArrayRepeat` | true | | `spark.comet.expression.ArrayUnion.enabled` | Enable Comet acceleration for `ArrayUnion` | true | From 258442e63327ebef0057feeeeefcdac37b04bd98 Mon Sep 17 00:00:00 2001 From: Andy Grove Date: Tue, 10 Feb 2026 11:24:28 -0700 Subject: [PATCH 03/18] Migrate array_position tests to SQL file-based approach Move array_position tests from CometArrayExpressionSuite to a SQL file test and fall back to Spark when all arguments are literals. Co-Authored-By: Claude Opus 4.6 --- .../scala/org/apache/comet/serde/arrays.scala | 4 + .../expressions/array/array_position.sql | 79 +++++++++++++++++++ .../comet/CometArrayExpressionSuite.scala | 32 -------- 3 files changed, 83 insertions(+), 32 deletions(-) create mode 100644 spark/src/test/resources/sql-tests/expressions/array/array_position.sql diff --git a/spark/src/main/scala/org/apache/comet/serde/arrays.scala b/spark/src/main/scala/org/apache/comet/serde/arrays.scala index b05fab9df0..385e8be6ac 100644 --- a/spark/src/main/scala/org/apache/comet/serde/arrays.scala +++ b/spark/src/main/scala/org/apache/comet/serde/arrays.scala @@ -640,6 +640,10 @@ object CometArrayPosition extends CometExpressionSerde[ArrayPosition] with Array expr: ArrayPosition, inputs: Seq[Attribute], binding: Boolean): Option[ExprOuterClass.Expr] = { + if (expr.children.forall(_.foldable)) { + withInfo(expr, "all arguments are literals, falling back to Spark") + return None + } // Check if input types are supported val inputTypes: Set[DataType] = expr.children.map(_.dataType).toSet for (dt <- inputTypes) { diff --git a/spark/src/test/resources/sql-tests/expressions/array/array_position.sql b/spark/src/test/resources/sql-tests/expressions/array/array_position.sql new file mode 100644 index 0000000000..608a112b82 --- /dev/null +++ b/spark/src/test/resources/sql-tests/expressions/array/array_position.sql @@ -0,0 +1,79 @@ +-- Licensed to the Apache Software Foundation (ASF) under one +-- or more contributor license agreements. See the NOTICE file +-- distributed with this work for additional information +-- regarding copyright ownership. The ASF licenses this file +-- to you under the Apache License, Version 2.0 (the +-- "License"); you may not use this file except in compliance +-- with the License. You may obtain a copy of the License at +-- +-- http://www.apache.org/licenses/LICENSE-2.0 +-- +-- Unless required by applicable law or agreed to in writing, +-- software distributed under the License is distributed on an +-- "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +-- KIND, either express or implied. See the License for the +-- specific language governing permissions and limitations +-- under the License. + +-- ConfigMatrix: parquet.enable.dictionary=false,true + +statement +CREATE TABLE test_array_position(int_arr array, str_arr array, val int, str_val string) USING parquet + +statement +INSERT INTO test_array_position VALUES + (array(1, 2, 3, 4), array('a', 'b', 'c'), 2, 'b'), + (array(1, 2, NULL, 3), array('a', NULL, 'c'), 3, 'c'), + (array(10, 20, 30), array('x', 'y', 'z'), 99, 'w'), + (array(), array(), 1, 'a'), + (NULL, NULL, 1, 'a'), + (array(1, 1, 1), array('a', 'a', 'a'), 1, 'a') + +-- literal args fall back to Spark +query spark_answer_only +SELECT array_position(array(1, 2, 3, 4), 3) + +query spark_answer_only +SELECT array_position(array(1, 2, 3, 4), 5) + +query spark_answer_only +SELECT array_position(array('a', 'b', 'c'), 'b') + +query spark_answer_only +SELECT array_position(array(1, 2, NULL, 3), 3) + +query spark_answer_only +SELECT array_position(array(1, 2, 3), cast(NULL as int)) + +query spark_answer_only +SELECT array_position(cast(NULL as array), 1) + +query spark_answer_only +SELECT array_position(array(), 1) + +query spark_answer_only +SELECT array_position(array(1, 2, 1, 3), 1) + +-- column array + column value +query +SELECT array_position(int_arr, val) FROM test_array_position + +-- column array + literal value +query +SELECT array_position(int_arr, 3) FROM test_array_position + +-- literal array + column value +query +SELECT array_position(array(1, 2, 3), val) FROM test_array_position + +-- string column array + column value +query +SELECT array_position(str_arr, str_val) FROM test_array_position + +-- string column array + literal value +query +SELECT array_position(str_arr, 'c') FROM test_array_position + +-- expressions in array construction +query +SELECT array_position(array(val, val + 1, val + 2), val) FROM test_array_position diff --git a/spark/src/test/scala/org/apache/comet/CometArrayExpressionSuite.scala b/spark/src/test/scala/org/apache/comet/CometArrayExpressionSuite.scala index 7b6d20e764..9a34553ea0 100644 --- a/spark/src/test/scala/org/apache/comet/CometArrayExpressionSuite.scala +++ b/spark/src/test/scala/org/apache/comet/CometArrayExpressionSuite.scala @@ -923,36 +923,4 @@ class CometArrayExpressionSuite extends CometTestBase with AdaptiveSparkPlanHelp } } - test("array_position") { - Seq(true, false).foreach { dictionaryEnabled => - withTempDir { dir => - val path = new Path(dir.toURI.toString, "test.parquet") - makeParquetFileAllPrimitiveTypes(path, dictionaryEnabled, 100) - spark.read.parquet(path.toString).createOrReplaceTempView("t1") - - // Tests with literal values need constant folding disabled - withSQLConf( - SQLConf.OPTIMIZER_EXCLUDED_RULES.key -> - "org.apache.spark.sql.catalyst.optimizer.ConstantFolding") { - // Basic array_position tests with integers - checkSparkAnswerAndOperator( - sql("SELECT array_position(array(1, 2, 3, 4), 3) FROM t1 LIMIT 1")) - // Element not found should return 0 - checkSparkAnswerAndOperator( - sql("SELECT array_position(array(1, 2, 3, 4), 5) FROM t1 LIMIT 1")) - // Test with strings - checkSparkAnswerAndOperator( - sql("SELECT array_position(array('a', 'b', 'c'), 'b') FROM t1 LIMIT 1")) - // Test with null element in array - checkSparkAnswerAndOperator( - sql("SELECT array_position(array(1, 2, null, 3), 3) FROM t1 LIMIT 1")) - } - - // Test with column values - checkSparkAnswerAndOperator(sql("SELECT array_position(array(1, _4, 3), _4) FROM t1")) - checkSparkAnswerAndOperator( - sql("SELECT array_position(array(_4, _4 + 1, _4 + 2), _4) FROM t1")) - } - } - } } From b7165849b8c8816d864632c84022b3ab7ecef5e8 Mon Sep 17 00:00:00 2001 From: Andy Grove Date: Wed, 18 Feb 2026 07:44:00 -0700 Subject: [PATCH 04/18] revert whitespace change in CometArrayExpressionSuite.scala Co-Authored-By: Claude Opus 4.6 --- .../test/scala/org/apache/comet/CometArrayExpressionSuite.scala | 1 - 1 file changed, 1 deletion(-) diff --git a/spark/src/test/scala/org/apache/comet/CometArrayExpressionSuite.scala b/spark/src/test/scala/org/apache/comet/CometArrayExpressionSuite.scala index 9a34553ea0..b22d0f72db 100644 --- a/spark/src/test/scala/org/apache/comet/CometArrayExpressionSuite.scala +++ b/spark/src/test/scala/org/apache/comet/CometArrayExpressionSuite.scala @@ -922,5 +922,4 @@ class CometArrayExpressionSuite extends CometTestBase with AdaptiveSparkPlanHelp } } } - } From 30c0de7b7d7f49f48103151c804e07738a1adb55 Mon Sep 17 00:00:00 2001 From: Andy Grove Date: Wed, 18 Feb 2026 07:55:48 -0700 Subject: [PATCH 05/18] fix merge artifacts and expand array_position test coverage Remove stray array_repeat references from merge conflict resolution. Add NULL val row to test data and add tests for all supported array element types: boolean, tinyint, smallint, bigint, float, double, decimal, date, and timestamp. Co-Authored-By: Claude Opus 4.6 --- native/spark-expr/src/array_funcs/mod.rs | 2 - native/spark-expr/src/comet_scalar_funcs.rs | 6 +- .../expressions/array/array_position.sql | 133 +++++++++++++++++- 3 files changed, 133 insertions(+), 8 deletions(-) diff --git a/native/spark-expr/src/array_funcs/mod.rs b/native/spark-expr/src/array_funcs/mod.rs index 2961ba514e..407cd4661b 100644 --- a/native/spark-expr/src/array_funcs/mod.rs +++ b/native/spark-expr/src/array_funcs/mod.rs @@ -17,14 +17,12 @@ mod array_insert; mod array_position; -mod array_repeat; mod get_array_struct_fields; mod list_extract; mod size; pub use array_insert::ArrayInsert; pub use array_position::{spark_array_position, SparkArrayPositionFunc}; -pub use array_repeat::spark_array_repeat; pub use get_array_struct_fields::GetArrayStructFields; pub use list_extract::ListExtract; pub use size::{spark_size, SparkSizeFunc}; diff --git a/native/spark-expr/src/comet_scalar_funcs.rs b/native/spark-expr/src/comet_scalar_funcs.rs index d6795d3126..3d339bc494 100644 --- a/native/spark-expr/src/comet_scalar_funcs.rs +++ b/native/spark-expr/src/comet_scalar_funcs.rs @@ -20,9 +20,9 @@ use crate::math_funcs::abs::abs; use crate::math_funcs::checked_arithmetic::{checked_add, checked_div, checked_mul, checked_sub}; use crate::math_funcs::modulo_expr::spark_modulo; use crate::{ - spark_array_repeat, spark_ceil, spark_decimal_div, spark_decimal_integral_div, spark_floor, - spark_isnan, spark_lpad, spark_make_decimal, spark_read_side_padding, spark_round, spark_rpad, - spark_unhex, spark_unscaled_value, EvalMode, SparkArrayPositionFunc, SparkBitwiseCount, + spark_ceil, spark_decimal_div, spark_decimal_integral_div, spark_floor, spark_isnan, + spark_lpad, spark_make_decimal, spark_read_side_padding, spark_round, spark_rpad, spark_unhex, + spark_unscaled_value, EvalMode, SparkArrayPositionFunc, SparkBitwiseCount, SparkContains, SparkDateDiff, SparkDateTrunc, SparkMakeDate, SparkSizeFunc, SparkStringSpace, }; use arrow::datatypes::DataType; diff --git a/spark/src/test/resources/sql-tests/expressions/array/array_position.sql b/spark/src/test/resources/sql-tests/expressions/array/array_position.sql index 608a112b82..13873d6f4c 100644 --- a/spark/src/test/resources/sql-tests/expressions/array/array_position.sql +++ b/spark/src/test/resources/sql-tests/expressions/array/array_position.sql @@ -27,7 +27,8 @@ INSERT INTO test_array_position VALUES (array(10, 20, 30), array('x', 'y', 'z'), 99, 'w'), (array(), array(), 1, 'a'), (NULL, NULL, 1, 'a'), - (array(1, 1, 1), array('a', 'a', 'a'), 1, 'a') + (array(1, 1, 1), array('a', 'a', 'a'), 1, 'a'), + (array(5, 6, 7), array('p', 'q', 'r'), NULL, NULL) -- literal args fall back to Spark query spark_answer_only @@ -54,7 +55,7 @@ SELECT array_position(array(), 1) query spark_answer_only SELECT array_position(array(1, 2, 1, 3), 1) --- column array + column value +-- column array + column value (includes NULL val row) query SELECT array_position(int_arr, val) FROM test_array_position @@ -66,7 +67,7 @@ SELECT array_position(int_arr, 3) FROM test_array_position query SELECT array_position(array(1, 2, 3), val) FROM test_array_position --- string column array + column value +-- string column array + column value (includes NULL str_val row) query SELECT array_position(str_arr, str_val) FROM test_array_position @@ -77,3 +78,129 @@ SELECT array_position(str_arr, 'c') FROM test_array_position -- expressions in array construction query SELECT array_position(array(val, val + 1, val + 2), val) FROM test_array_position + +-- boolean arrays +statement +CREATE TABLE test_ap_bool(arr array, val boolean) USING parquet + +statement +INSERT INTO test_ap_bool VALUES + (array(true, false, true), false), + (array(true, true), false), + (array(false, false), true), + (NULL, true), + (array(true, false), NULL) + +query +SELECT array_position(arr, val) FROM test_ap_bool + +-- tinyint arrays +statement +CREATE TABLE test_ap_byte(arr array, val tinyint) USING parquet + +statement +INSERT INTO test_ap_byte VALUES + (array(cast(1 as tinyint), cast(2 as tinyint), cast(3 as tinyint)), cast(2 as tinyint)), + (array(cast(-128 as tinyint), cast(127 as tinyint)), cast(127 as tinyint)), + (NULL, cast(1 as tinyint)), + (array(cast(1 as tinyint)), NULL) + +query +SELECT array_position(arr, val) FROM test_ap_byte + +-- smallint arrays +statement +CREATE TABLE test_ap_short(arr array, val smallint) USING parquet + +statement +INSERT INTO test_ap_short VALUES + (array(cast(100 as smallint), cast(200 as smallint), cast(300 as smallint)), cast(200 as smallint)), + (NULL, cast(1 as smallint)), + (array(cast(1 as smallint)), NULL) + +query +SELECT array_position(arr, val) FROM test_ap_short + +-- bigint arrays +statement +CREATE TABLE test_ap_long(arr array, val bigint) USING parquet + +statement +INSERT INTO test_ap_long VALUES + (array(cast(1000000000000 as bigint), cast(2000000000000 as bigint)), cast(2000000000000 as bigint)), + (array(cast(-1 as bigint), cast(0 as bigint), cast(1 as bigint)), cast(0 as bigint)), + (NULL, cast(1 as bigint)), + (array(cast(1 as bigint)), NULL) + +query +SELECT array_position(arr, val) FROM test_ap_long + +-- float arrays +statement +CREATE TABLE test_ap_float(arr array, val float) USING parquet + +statement +INSERT INTO test_ap_float VALUES + (array(cast(1.1 as float), cast(2.2 as float), cast(3.3 as float)), cast(2.2 as float)), + (array(cast(0.0 as float), cast(-1.5 as float)), cast(-1.5 as float)), + (NULL, cast(1.0 as float)), + (array(cast(1.0 as float)), NULL) + +query +SELECT array_position(arr, val) FROM test_ap_float + +-- double arrays +statement +CREATE TABLE test_ap_double(arr array, val double) USING parquet + +statement +INSERT INTO test_ap_double VALUES + (array(1.1, 2.2, 3.3), 2.2), + (array(0.0, -1.5), -1.5), + (NULL, 1.0), + (array(1.0), NULL) + +query +SELECT array_position(arr, val) FROM test_ap_double + +-- decimal arrays +statement +CREATE TABLE test_ap_decimal(arr array, val decimal(10,2)) USING parquet + +statement +INSERT INTO test_ap_decimal VALUES + (array(cast(1.10 as decimal(10,2)), cast(2.20 as decimal(10,2)), cast(3.30 as decimal(10,2))), cast(2.20 as decimal(10,2))), + (array(cast(0.00 as decimal(10,2))), cast(0.00 as decimal(10,2))), + (NULL, cast(1.00 as decimal(10,2))), + (array(cast(1.00 as decimal(10,2))), NULL) + +query +SELECT array_position(arr, val) FROM test_ap_decimal + +-- date arrays +statement +CREATE TABLE test_ap_date(arr array, val date) USING parquet + +statement +INSERT INTO test_ap_date VALUES + (array(date '2024-01-01', date '2024-06-15', date '2024-12-31'), date '2024-06-15'), + (array(date '2000-01-01'), date '1999-12-31'), + (NULL, date '2024-01-01'), + (array(date '2024-01-01'), NULL) + +query +SELECT array_position(arr, val) FROM test_ap_date + +-- timestamp arrays +statement +CREATE TABLE test_ap_ts(arr array, val timestamp) USING parquet + +statement +INSERT INTO test_ap_ts VALUES + (array(timestamp '2024-01-01 00:00:00', timestamp '2024-06-15 12:30:00'), timestamp '2024-06-15 12:30:00'), + (array(timestamp '2000-01-01 00:00:00'), timestamp '1999-12-31 23:59:59'), + (NULL, timestamp '2024-01-01 00:00:00'), + (array(timestamp '2024-01-01 00:00:00'), NULL) + +query +SELECT array_position(arr, val) FROM test_ap_ts From cd3a577372310a39873f1f1edfe3a39d073807ee Mon Sep 17 00:00:00 2001 From: Andy Grove Date: Wed, 18 Feb 2026 07:59:49 -0700 Subject: [PATCH 06/18] rustfmt Co-Authored-By: Claude Opus 4.6 --- native/spark-expr/src/comet_scalar_funcs.rs | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/native/spark-expr/src/comet_scalar_funcs.rs b/native/spark-expr/src/comet_scalar_funcs.rs index 3d339bc494..be6996cd77 100644 --- a/native/spark-expr/src/comet_scalar_funcs.rs +++ b/native/spark-expr/src/comet_scalar_funcs.rs @@ -22,8 +22,8 @@ use crate::math_funcs::modulo_expr::spark_modulo; use crate::{ spark_ceil, spark_decimal_div, spark_decimal_integral_div, spark_floor, spark_isnan, spark_lpad, spark_make_decimal, spark_read_side_padding, spark_round, spark_rpad, spark_unhex, - spark_unscaled_value, EvalMode, SparkArrayPositionFunc, SparkBitwiseCount, - SparkContains, SparkDateDiff, SparkDateTrunc, SparkMakeDate, SparkSizeFunc, SparkStringSpace, + spark_unscaled_value, EvalMode, SparkArrayPositionFunc, SparkBitwiseCount, SparkContains, + SparkDateDiff, SparkDateTrunc, SparkMakeDate, SparkSizeFunc, SparkStringSpace, }; use arrow::datatypes::DataType; use datafusion::common::{DataFusionError, Result as DataFusionResult}; From e4620e9ca09b42414449536dd8a0ff022e732d39 Mon Sep 17 00:00:00 2001 From: Andy Grove Date: Thu, 12 Mar 2026 07:43:26 -0600 Subject: [PATCH 07/18] fix: remove stray merge conflict marker --- native/spark-expr/src/comet_scalar_funcs.rs | 1 - 1 file changed, 1 deletion(-) diff --git a/native/spark-expr/src/comet_scalar_funcs.rs b/native/spark-expr/src/comet_scalar_funcs.rs index d89b70985c..22a2b83e20 100644 --- a/native/spark-expr/src/comet_scalar_funcs.rs +++ b/native/spark-expr/src/comet_scalar_funcs.rs @@ -22,7 +22,6 @@ use crate::math_funcs::modulo_expr::spark_modulo; use crate::{ spark_ceil, spark_decimal_div, spark_decimal_integral_div, spark_floor, spark_isnan, spark_lpad, spark_make_decimal, spark_read_side_padding, spark_round, spark_rpad, spark_unhex, -<<<<<<< HEAD spark_unscaled_value, EvalMode, SparkArrayPositionFunc, SparkContains, SparkDateDiff, SparkDateTrunc, SparkMakeDate, SparkSizeFunc, }; From 6cf8f6a597d43292b4e56af53f2dbe97f6374aa7 Mon Sep 17 00:00:00 2001 From: Andy Grove Date: Mon, 16 Mar 2026 12:25:59 -0600 Subject: [PATCH 08/18] feat: optimize array_position with typed array comparison and address review feedback - Use typed array downcasting instead of ScalarValue for element comparison, improving performance from 0.4X to 0.7-0.8X of Spark - Add getSupportLevel override marking as Incompatible (NaN equality) - Add NaN edge case tests for float/double arrays - Add CometArrayExpressionBenchmark microbenchmark - Make spark_array_position function private - Update docs to mark array_position as supported --- docs/spark_expressions_support.md | 2 +- .../src/array_funcs/array_position.rs | 133 ++++++++++++++---- native/spark-expr/src/array_funcs/mod.rs | 2 +- .../scala/org/apache/comet/serde/arrays.scala | 2 + .../expressions/array/array_position.sql | 28 ++++ .../CometArrayExpressionBenchmark.scala | 106 ++++++++++++++ 6 files changed, 247 insertions(+), 26 deletions(-) create mode 100644 spark/src/test/scala/org/apache/spark/sql/benchmark/CometArrayExpressionBenchmark.scala diff --git a/docs/spark_expressions_support.md b/docs/spark_expressions_support.md index 5474894108..a7b29a5c16 100644 --- a/docs/spark_expressions_support.md +++ b/docs/spark_expressions_support.md @@ -93,7 +93,7 @@ - [x] array_join - [x] array_max - [ ] array_min -- [ ] array_position +- [x] array_position - [x] array_remove - [x] array_repeat - [x] array_union diff --git a/native/spark-expr/src/array_funcs/array_position.rs b/native/spark-expr/src/array_funcs/array_position.rs index 868e3307b6..42db587935 100644 --- a/native/spark-expr/src/array_funcs/array_position.rs +++ b/native/spark-expr/src/array_funcs/array_position.rs @@ -15,8 +15,13 @@ // specific language governing permissions and limitations // under the License. -use arrow::array::{Array, ArrayRef, GenericListArray, Int64Array, OffsetSizeTrait}; -use arrow::datatypes::DataType; +use arrow::array::{ + Array, ArrayRef, AsArray, BooleanArray, GenericListArray, Int64Array, OffsetSizeTrait, +}; +use arrow::datatypes::{ + DataType, Date32Type, Decimal128Type, Float32Type, Float64Type, Int16Type, Int32Type, + Int64Type, Int8Type, TimestampMicrosecondType, +}; use datafusion::common::{exec_err, DataFusionError, Result as DataFusionResult, ScalarValue}; use datafusion::logical_expr::{ ColumnarValue, ScalarFunctionArgs, ScalarUDFImpl, Signature, Volatility, @@ -26,7 +31,7 @@ use std::sync::Arc; /// Spark array_position() function that returns the 1-based position of an element in an array. /// Returns 0 if the element is not found (Spark behavior differs from DataFusion which returns null). -pub fn spark_array_position(args: &[ColumnarValue]) -> Result { +fn spark_array_position(args: &[ColumnarValue]) -> Result { if args.len() != 2 { return exec_err!("array_position function takes exactly two arguments"); } @@ -63,6 +68,105 @@ fn array_position_inner(args: &[ArrayRef]) -> Result } } +/// Find the 1-based position of `search_val` in a typed primitive array. +/// Returns 0 if not found. +macro_rules! find_position_primitive { + ($list_items:expr, $element:expr, $row_index:expr, $arrow_type:ty) => {{ + let items = $list_items.as_primitive::<$arrow_type>(); + let search = $element.as_primitive::<$arrow_type>(); + let search_val = search.value($row_index); + let mut pos: i64 = 0; + for i in 0..items.len() { + if !items.is_null(i) && items.value(i) == search_val { + pos = (i + 1) as i64; + break; + } + } + pos + }}; +} + +fn find_position_in_row( + list_items: &ArrayRef, + element: &ArrayRef, + row_index: usize, +) -> Result { + let pos = match list_items.data_type() { + DataType::Boolean => { + let items = list_items.as_any().downcast_ref::().unwrap(); + let search = element.as_any().downcast_ref::().unwrap(); + let search_val = search.value(row_index); + let mut pos: i64 = 0; + for i in 0..items.len() { + if !items.is_null(i) && items.value(i) == search_val { + pos = (i + 1) as i64; + break; + } + } + pos + } + DataType::Int8 => find_position_primitive!(list_items, element, row_index, Int8Type), + DataType::Int16 => find_position_primitive!(list_items, element, row_index, Int16Type), + DataType::Int32 => find_position_primitive!(list_items, element, row_index, Int32Type), + DataType::Int64 => find_position_primitive!(list_items, element, row_index, Int64Type), + DataType::Float32 => { + find_position_primitive!(list_items, element, row_index, Float32Type) + } + DataType::Float64 => { + find_position_primitive!(list_items, element, row_index, Float64Type) + } + DataType::Decimal128(_, _) => { + find_position_primitive!(list_items, element, row_index, Decimal128Type) + } + DataType::Date32 => { + find_position_primitive!(list_items, element, row_index, Date32Type) + } + DataType::Timestamp(arrow::datatypes::TimeUnit::Microsecond, _) => { + find_position_primitive!(list_items, element, row_index, TimestampMicrosecondType) + } + DataType::Utf8 => { + let items = list_items.as_string::(); + let search = element.as_string::(); + let search_val = search.value(row_index); + let mut pos: i64 = 0; + for i in 0..items.len() { + if !items.is_null(i) && items.value(i) == search_val { + pos = (i + 1) as i64; + break; + } + } + pos + } + DataType::LargeUtf8 => { + let items = list_items.as_string::(); + let search = element.as_string::(); + let search_val = search.value(row_index); + let mut pos: i64 = 0; + for i in 0..items.len() { + if !items.is_null(i) && items.value(i) == search_val { + pos = (i + 1) as i64; + break; + } + } + pos + } + // Fallback to ScalarValue for complex types (nested arrays, etc.) + _ => { + let element_scalar = ScalarValue::try_from_array(element, row_index)?; + let mut pos: i64 = 0; + for i in 0..list_items.len() { + let item_scalar = ScalarValue::try_from_array(list_items, i)?; + if !item_scalar.is_null() && element_scalar == item_scalar { + pos = (i + 1) as i64; + break; + } + } + pos + } + }; + Ok(pos) +} + fn generic_array_position( array: &ArrayRef, element: &ArrayRef, @@ -75,30 +179,11 @@ fn generic_array_position( let mut data = Vec::with_capacity(list_array.len()); for row_index in 0..list_array.len() { - if list_array.is_null(row_index) { - // Null array returns null position (same as Spark) - data.push(None); - } else if element.is_null(row_index) { - // Searching for null element returns null in Spark + if list_array.is_null(row_index) || element.is_null(row_index) { data.push(None); } else { let list_array_row = list_array.value(row_index); - - // Get the search element as a scalar - let element_scalar = ScalarValue::try_from_array(element, row_index)?; - - // Compare element to each item in the list - let mut position: i64 = 0; - for i in 0..list_array_row.len() { - let list_item_scalar = ScalarValue::try_from_array(&list_array_row, i)?; - - // null != anything in Spark array_position - if !list_item_scalar.is_null() && element_scalar == list_item_scalar { - position = (i + 1) as i64; // 1-indexed - break; - } - } - + let position = find_position_in_row(&list_array_row, element, row_index)?; data.push(Some(position)); } } diff --git a/native/spark-expr/src/array_funcs/mod.rs b/native/spark-expr/src/array_funcs/mod.rs index 407cd4661b..a613ae0acf 100644 --- a/native/spark-expr/src/array_funcs/mod.rs +++ b/native/spark-expr/src/array_funcs/mod.rs @@ -22,7 +22,7 @@ mod list_extract; mod size; pub use array_insert::ArrayInsert; -pub use array_position::{spark_array_position, SparkArrayPositionFunc}; +pub use array_position::SparkArrayPositionFunc; pub use get_array_struct_fields::GetArrayStructFields; pub use list_extract::ListExtract; pub use size::{spark_size, SparkSizeFunc}; diff --git a/spark/src/main/scala/org/apache/comet/serde/arrays.scala b/spark/src/main/scala/org/apache/comet/serde/arrays.scala index 70d15fbae9..94ee5ee9d1 100644 --- a/spark/src/main/scala/org/apache/comet/serde/arrays.scala +++ b/spark/src/main/scala/org/apache/comet/serde/arrays.scala @@ -664,6 +664,8 @@ object CometSize extends CometExpressionSerde[Size] { object CometArrayPosition extends CometExpressionSerde[ArrayPosition] with ArraysBase { + override def getSupportLevel(expr: ArrayPosition): SupportLevel = Incompatible(None) + override def convert( expr: ArrayPosition, inputs: Seq[Attribute], diff --git a/spark/src/test/resources/sql-tests/expressions/array/array_position.sql b/spark/src/test/resources/sql-tests/expressions/array/array_position.sql index 13873d6f4c..08f53964c6 100644 --- a/spark/src/test/resources/sql-tests/expressions/array/array_position.sql +++ b/spark/src/test/resources/sql-tests/expressions/array/array_position.sql @@ -15,6 +15,7 @@ -- specific language governing permissions and limitations -- under the License. +-- Config: spark.comet.expression.ArrayPosition.allowIncompatible=true -- ConfigMatrix: parquet.enable.dictionary=false,true statement @@ -163,6 +164,33 @@ INSERT INTO test_ap_double VALUES query SELECT array_position(arr, val) FROM test_ap_double +-- NaN handling for float arrays (Spark treats NaN == NaN) +query spark_answer_only +SELECT array_position(array(cast('NaN' as float), cast(1.0 as float)), cast('NaN' as float)) + +query spark_answer_only +SELECT array_position(array(cast(1.0 as float), cast('NaN' as float), cast(2.0 as float)), cast('NaN' as float)) + +-- NaN handling for double arrays (Spark treats NaN == NaN) +query spark_answer_only +SELECT array_position(array(cast('NaN' as double), 1.0), cast('NaN' as double)) + +query spark_answer_only +SELECT array_position(array(1.0, cast('NaN' as double), 2.0), cast('NaN' as double)) + +-- NaN handling with column data +statement +CREATE TABLE test_ap_nan(arr array, val float) USING parquet + +statement +INSERT INTO test_ap_nan VALUES + (array(cast('NaN' as float), cast(1.0 as float)), cast('NaN' as float)), + (array(cast(1.0 as float), cast('NaN' as float), cast(2.0 as float)), cast('NaN' as float)), + (array(cast(1.0 as float), cast(2.0 as float)), cast('NaN' as float)) + +query ignore(NaN equality: IEEE 754 says NaN != NaN but Spark treats NaN == NaN) +SELECT array_position(arr, val) FROM test_ap_nan + -- decimal arrays statement CREATE TABLE test_ap_decimal(arr array, val decimal(10,2)) USING parquet diff --git a/spark/src/test/scala/org/apache/spark/sql/benchmark/CometArrayExpressionBenchmark.scala b/spark/src/test/scala/org/apache/spark/sql/benchmark/CometArrayExpressionBenchmark.scala new file mode 100644 index 0000000000..bc5d4af597 --- /dev/null +++ b/spark/src/test/scala/org/apache/spark/sql/benchmark/CometArrayExpressionBenchmark.scala @@ -0,0 +1,106 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.spark.sql.benchmark + +// spotless:off +/** + * Benchmark to measure performance of Comet array expressions. To run this benchmark: + * {{{ + * SPARK_GENERATE_BENCHMARK_FILES=1 make benchmark-org.apache.spark.sql.benchmark.CometArrayExpressionBenchmark + * }}} + * Results will be written to "spark/benchmarks/CometArrayExpressionBenchmark-**results.txt". + */ +// spotless:on +object CometArrayExpressionBenchmark extends CometBenchmarkBase { + + def arrayPositionBenchmark(values: Int): Unit = { + withTempPath { dir => + withTempTable("parquetV1Table") { + // Create a table with int arrays of size 10 and a search value + prepareTable( + dir, + spark.sql(s"""SELECT + | array( + | cast(value % 100 as int), + | cast((value + 1) % 100 as int), + | cast((value + 2) % 100 as int), + | cast((value + 3) % 100 as int), + | cast((value + 4) % 100 as int), + | cast((value + 5) % 100 as int), + | cast((value + 6) % 100 as int), + | cast((value + 7) % 100 as int), + | cast((value + 8) % 100 as int), + | cast((value + 9) % 100 as int) + | ) as int_arr, + | cast((value + 5) % 100 as int) as search_val + |FROM $tbl""".stripMargin)) + + val extraConfigs = + Map("spark.comet.expression.ArrayPosition.allowIncompatible" -> "true") + + runExpressionBenchmark( + "array_position - int array", + values, + "SELECT array_position(int_arr, search_val) FROM parquetV1Table", + extraConfigs) + } + } + + withTempPath { dir => + withTempTable("parquetV1Table") { + // Create a table with string arrays of size 10 and a search value + prepareTable( + dir, + spark.sql(s"""SELECT + | array( + | cast(value % 100 as string), + | cast((value + 1) % 100 as string), + | cast((value + 2) % 100 as string), + | cast((value + 3) % 100 as string), + | cast((value + 4) % 100 as string), + | cast((value + 5) % 100 as string), + | cast((value + 6) % 100 as string), + | cast((value + 7) % 100 as string), + | cast((value + 8) % 100 as string), + | cast((value + 9) % 100 as string) + | ) as str_arr, + | cast((value + 5) % 100 as string) as search_val + |FROM $tbl""".stripMargin)) + + val extraConfigs = + Map("spark.comet.expression.ArrayPosition.allowIncompatible" -> "true") + + runExpressionBenchmark( + "array_position - string array", + values, + "SELECT array_position(str_arr, search_val) FROM parquetV1Table", + extraConfigs) + } + } + } + + override def runCometBenchmark(mainArgs: Array[String]): Unit = { + val values = 1024 * 1024 + + runBenchmarkWithTable("ArrayPosition", values) { v => + arrayPositionBenchmark(v) + } + } +} From fad07b98f798ae38fbf2ea0ed009cb8dd012c20d Mon Sep 17 00:00:00 2001 From: Andy Grove Date: Mon, 16 Mar 2026 12:58:10 -0600 Subject: [PATCH 09/18] fix: explain NaN incompatibility in ArrayPosition getSupportLevel --- spark/src/main/scala/org/apache/comet/serde/arrays.scala | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/spark/src/main/scala/org/apache/comet/serde/arrays.scala b/spark/src/main/scala/org/apache/comet/serde/arrays.scala index 94ee5ee9d1..1a611b2007 100644 --- a/spark/src/main/scala/org/apache/comet/serde/arrays.scala +++ b/spark/src/main/scala/org/apache/comet/serde/arrays.scala @@ -664,7 +664,10 @@ object CometSize extends CometExpressionSerde[Size] { object CometArrayPosition extends CometExpressionSerde[ArrayPosition] with ArraysBase { - override def getSupportLevel(expr: ArrayPosition): SupportLevel = Incompatible(None) + override def getSupportLevel(expr: ArrayPosition): SupportLevel = + Incompatible(Some( + "element comparison uses IEEE 754 equality where NaN != NaN, " + + "but Spark treats NaN as equal to NaN")) override def convert( expr: ArrayPosition, From 24318d3368536c916041947c0d1cbfeaf48e2e16 Mon Sep 17 00:00:00 2001 From: Andy Grove Date: Mon, 16 Mar 2026 13:04:06 -0600 Subject: [PATCH 10/18] fix: handle NaN equality in array_position to match Spark semantics Treat NaN == NaN in float/double comparisons, matching Spark's ordering.equiv() behavior. This makes array_position Compatible rather than Incompatible. --- .../src/array_funcs/array_position.rs | 29 +++++++++++++++---- .../scala/org/apache/comet/serde/arrays.scala | 5 +--- .../expressions/array/array_position.sql | 3 +- .../CometArrayExpressionBenchmark.scala | 12 ++------ 4 files changed, 27 insertions(+), 22 deletions(-) diff --git a/native/spark-expr/src/array_funcs/array_position.rs b/native/spark-expr/src/array_funcs/array_position.rs index 42db587935..599d1ce951 100644 --- a/native/spark-expr/src/array_funcs/array_position.rs +++ b/native/spark-expr/src/array_funcs/array_position.rs @@ -86,6 +86,27 @@ macro_rules! find_position_primitive { }}; } +/// Float-aware comparison that treats NaN == NaN (matching Spark's ordering.equiv() semantics). +macro_rules! find_position_float { + ($list_items:expr, $element:expr, $row_index:expr, $arrow_type:ty) => {{ + let items = $list_items.as_primitive::<$arrow_type>(); + let search = $element.as_primitive::<$arrow_type>(); + let search_val = search.value($row_index); + let search_is_nan = search_val.is_nan(); + let mut pos: i64 = 0; + for i in 0..items.len() { + if !items.is_null(i) { + let item_val = items.value(i); + if (search_is_nan && item_val.is_nan()) || item_val == search_val { + pos = (i + 1) as i64; + break; + } + } + } + pos + }}; +} + fn find_position_in_row( list_items: &ArrayRef, element: &ArrayRef, @@ -109,12 +130,8 @@ fn find_position_in_row( DataType::Int16 => find_position_primitive!(list_items, element, row_index, Int16Type), DataType::Int32 => find_position_primitive!(list_items, element, row_index, Int32Type), DataType::Int64 => find_position_primitive!(list_items, element, row_index, Int64Type), - DataType::Float32 => { - find_position_primitive!(list_items, element, row_index, Float32Type) - } - DataType::Float64 => { - find_position_primitive!(list_items, element, row_index, Float64Type) - } + DataType::Float32 => find_position_float!(list_items, element, row_index, Float32Type), + DataType::Float64 => find_position_float!(list_items, element, row_index, Float64Type), DataType::Decimal128(_, _) => { find_position_primitive!(list_items, element, row_index, Decimal128Type) } diff --git a/spark/src/main/scala/org/apache/comet/serde/arrays.scala b/spark/src/main/scala/org/apache/comet/serde/arrays.scala index 1a611b2007..9943c20397 100644 --- a/spark/src/main/scala/org/apache/comet/serde/arrays.scala +++ b/spark/src/main/scala/org/apache/comet/serde/arrays.scala @@ -664,10 +664,7 @@ object CometSize extends CometExpressionSerde[Size] { object CometArrayPosition extends CometExpressionSerde[ArrayPosition] with ArraysBase { - override def getSupportLevel(expr: ArrayPosition): SupportLevel = - Incompatible(Some( - "element comparison uses IEEE 754 equality where NaN != NaN, " + - "but Spark treats NaN as equal to NaN")) + override def getSupportLevel(expr: ArrayPosition): SupportLevel = Compatible() override def convert( expr: ArrayPosition, diff --git a/spark/src/test/resources/sql-tests/expressions/array/array_position.sql b/spark/src/test/resources/sql-tests/expressions/array/array_position.sql index 08f53964c6..af04c0cd9d 100644 --- a/spark/src/test/resources/sql-tests/expressions/array/array_position.sql +++ b/spark/src/test/resources/sql-tests/expressions/array/array_position.sql @@ -15,7 +15,6 @@ -- specific language governing permissions and limitations -- under the License. --- Config: spark.comet.expression.ArrayPosition.allowIncompatible=true -- ConfigMatrix: parquet.enable.dictionary=false,true statement @@ -188,7 +187,7 @@ INSERT INTO test_ap_nan VALUES (array(cast(1.0 as float), cast('NaN' as float), cast(2.0 as float)), cast('NaN' as float)), (array(cast(1.0 as float), cast(2.0 as float)), cast('NaN' as float)) -query ignore(NaN equality: IEEE 754 says NaN != NaN but Spark treats NaN == NaN) +query SELECT array_position(arr, val) FROM test_ap_nan -- decimal arrays diff --git a/spark/src/test/scala/org/apache/spark/sql/benchmark/CometArrayExpressionBenchmark.scala b/spark/src/test/scala/org/apache/spark/sql/benchmark/CometArrayExpressionBenchmark.scala index bc5d4af597..b197209011 100644 --- a/spark/src/test/scala/org/apache/spark/sql/benchmark/CometArrayExpressionBenchmark.scala +++ b/spark/src/test/scala/org/apache/spark/sql/benchmark/CometArrayExpressionBenchmark.scala @@ -52,14 +52,10 @@ object CometArrayExpressionBenchmark extends CometBenchmarkBase { | cast((value + 5) % 100 as int) as search_val |FROM $tbl""".stripMargin)) - val extraConfigs = - Map("spark.comet.expression.ArrayPosition.allowIncompatible" -> "true") - runExpressionBenchmark( "array_position - int array", values, - "SELECT array_position(int_arr, search_val) FROM parquetV1Table", - extraConfigs) + "SELECT array_position(int_arr, search_val) FROM parquetV1Table") } } @@ -84,14 +80,10 @@ object CometArrayExpressionBenchmark extends CometBenchmarkBase { | cast((value + 5) % 100 as string) as search_val |FROM $tbl""".stripMargin)) - val extraConfigs = - Map("spark.comet.expression.ArrayPosition.allowIncompatible" -> "true") - runExpressionBenchmark( "array_position - string array", values, - "SELECT array_position(str_arr, search_val) FROM parquetV1Table", - extraConfigs) + "SELECT array_position(str_arr, search_val) FROM parquetV1Table") } } } From b7ddad16ab22706171857869177174fe0d83f0e9 Mon Sep 17 00:00:00 2001 From: Andy Grove Date: Mon, 16 Mar 2026 13:27:11 -0600 Subject: [PATCH 11/18] perf: use flat values buffer and offsets for array_position Avoid per-row subarray allocation from list_array.value(row_index). Instead, downcast the flat values buffer once and iterate using offset ranges directly. Improves from 0.7-0.8X to 0.9X of Spark. --- .../src/array_funcs/array_position.rs | 307 +++++++++++------- 1 file changed, 195 insertions(+), 112 deletions(-) diff --git a/native/spark-expr/src/array_funcs/array_position.rs b/native/spark-expr/src/array_funcs/array_position.rs index 599d1ce951..843ec6f1f3 100644 --- a/native/spark-expr/src/array_funcs/array_position.rs +++ b/native/spark-expr/src/array_funcs/array_position.rs @@ -19,13 +19,14 @@ use arrow::array::{ Array, ArrayRef, AsArray, BooleanArray, GenericListArray, Int64Array, OffsetSizeTrait, }; use arrow::datatypes::{ - DataType, Date32Type, Decimal128Type, Float32Type, Float64Type, Int16Type, Int32Type, - Int64Type, Int8Type, TimestampMicrosecondType, + ArrowPrimitiveType, DataType, Date32Type, Decimal128Type, Float32Type, Float64Type, Int16Type, + Int32Type, Int64Type, Int8Type, TimestampMicrosecondType, }; use datafusion::common::{exec_err, DataFusionError, Result as DataFusionResult, ScalarValue}; use datafusion::logical_expr::{ ColumnarValue, ScalarFunctionArgs, ScalarUDFImpl, Signature, Volatility, }; +use num::Float; use std::any::Any; use std::sync::Arc; @@ -36,7 +37,6 @@ fn spark_array_position(args: &[ColumnarValue]) -> Result::None, |acc, arg| match arg { @@ -68,144 +68,227 @@ fn array_position_inner(args: &[ArrayRef]) -> Result } } -/// Find the 1-based position of `search_val` in a typed primitive array. -/// Returns 0 if not found. -macro_rules! find_position_primitive { - ($list_items:expr, $element:expr, $row_index:expr, $arrow_type:ty) => {{ - let items = $list_items.as_primitive::<$arrow_type>(); - let search = $element.as_primitive::<$arrow_type>(); - let search_val = search.value($row_index); +/// Searches for an element in a list array using the flat values buffer and offsets directly, +/// avoiding per-row subarray allocation. Dispatches to typed fast paths by element data type. +fn generic_array_position( + array: &ArrayRef, + element: &ArrayRef, +) -> Result { + let list_array = array + .as_any() + .downcast_ref::>() + .unwrap(); + + let values = list_array.values(); + let offsets = list_array.offsets(); + let elem_type = values.data_type().clone(); + + match &elem_type { + DataType::Boolean => { + position_boolean::(list_array, offsets, values, element) + } + DataType::Int8 => position_primitive::(list_array, offsets, values, element), + DataType::Int16 => position_primitive::(list_array, offsets, values, element), + DataType::Int32 => position_primitive::(list_array, offsets, values, element), + DataType::Int64 => position_primitive::(list_array, offsets, values, element), + DataType::Float32 => { + position_float::(list_array, offsets, values, element) + } + DataType::Float64 => { + position_float::(list_array, offsets, values, element) + } + DataType::Decimal128(_, _) => { + position_primitive::(list_array, offsets, values, element) + } + DataType::Date32 => { + position_primitive::(list_array, offsets, values, element) + } + DataType::Timestamp(arrow::datatypes::TimeUnit::Microsecond, _) => { + position_primitive::( + list_array, offsets, values, element, + ) + } + DataType::Utf8 => position_string::(list_array, offsets, values, element), + DataType::LargeUtf8 => position_string::(list_array, offsets, values, element), + // Fallback to ScalarValue for complex types (nested arrays, etc.) + _ => position_fallback::(list_array, offsets, values, element), + } +} + +/// Fast path for primitive types: downcast once, iterate using offsets into the flat buffer. +fn position_primitive( + list_array: &GenericListArray, + offsets: &arrow::buffer::OffsetBuffer, + values: &ArrayRef, + element: &ArrayRef, +) -> Result +where + T::Native: PartialEq, +{ + let values_typed = values.as_primitive::(); + let element_typed = element.as_primitive::(); + let num_rows = list_array.len(); + let mut result = Vec::with_capacity(num_rows); + + for (row_index, w) in offsets.windows(2).enumerate() { + if list_array.is_null(row_index) || element.is_null(row_index) { + result.push(None); + continue; + } + let start = w[0].as_usize(); + let end = w[1].as_usize(); + let search_val = element_typed.value(row_index); let mut pos: i64 = 0; - for i in 0..items.len() { - if !items.is_null(i) && items.value(i) == search_val { - pos = (i + 1) as i64; + for i in start..end { + if !values_typed.is_null(i) && values_typed.value(i) == search_val { + pos = (i - start + 1) as i64; break; } } - pos - }}; + result.push(Some(pos)); + } + + Ok(Arc::new(Int64Array::from(result))) } -/// Float-aware comparison that treats NaN == NaN (matching Spark's ordering.equiv() semantics). -macro_rules! find_position_float { - ($list_items:expr, $element:expr, $row_index:expr, $arrow_type:ty) => {{ - let items = $list_items.as_primitive::<$arrow_type>(); - let search = $element.as_primitive::<$arrow_type>(); - let search_val = search.value($row_index); +/// Float path: same as primitive but treats NaN == NaN (Spark's ordering.equiv() semantics). +fn position_float( + list_array: &GenericListArray, + offsets: &arrow::buffer::OffsetBuffer, + values: &ArrayRef, + element: &ArrayRef, +) -> Result +where + T::Native: PartialEq + num::Float, +{ + let values_typed = values.as_primitive::(); + let element_typed = element.as_primitive::(); + let num_rows = list_array.len(); + let mut result = Vec::with_capacity(num_rows); + + for (row_index, w) in offsets.windows(2).enumerate() { + if list_array.is_null(row_index) || element.is_null(row_index) { + result.push(None); + continue; + } + let start = w[0].as_usize(); + let end = w[1].as_usize(); + let search_val = element_typed.value(row_index); let search_is_nan = search_val.is_nan(); let mut pos: i64 = 0; - for i in 0..items.len() { - if !items.is_null(i) { - let item_val = items.value(i); - if (search_is_nan && item_val.is_nan()) || item_val == search_val { - pos = (i + 1) as i64; + for i in start..end { + if !values_typed.is_null(i) { + let v = values_typed.value(i); + if (search_is_nan && v.is_nan()) || v == search_val { + pos = (i - start + 1) as i64; break; } } } - pos - }}; + result.push(Some(pos)); + } + + Ok(Arc::new(Int64Array::from(result))) } -fn find_position_in_row( - list_items: &ArrayRef, +/// Boolean path. +fn position_boolean( + list_array: &GenericListArray, + offsets: &arrow::buffer::OffsetBuffer, + values: &ArrayRef, element: &ArrayRef, - row_index: usize, -) -> Result { - let pos = match list_items.data_type() { - DataType::Boolean => { - let items = list_items.as_any().downcast_ref::().unwrap(); - let search = element.as_any().downcast_ref::().unwrap(); - let search_val = search.value(row_index); - let mut pos: i64 = 0; - for i in 0..items.len() { - if !items.is_null(i) && items.value(i) == search_val { - pos = (i + 1) as i64; - break; - } - } - pos - } - DataType::Int8 => find_position_primitive!(list_items, element, row_index, Int8Type), - DataType::Int16 => find_position_primitive!(list_items, element, row_index, Int16Type), - DataType::Int32 => find_position_primitive!(list_items, element, row_index, Int32Type), - DataType::Int64 => find_position_primitive!(list_items, element, row_index, Int64Type), - DataType::Float32 => find_position_float!(list_items, element, row_index, Float32Type), - DataType::Float64 => find_position_float!(list_items, element, row_index, Float64Type), - DataType::Decimal128(_, _) => { - find_position_primitive!(list_items, element, row_index, Decimal128Type) - } - DataType::Date32 => { - find_position_primitive!(list_items, element, row_index, Date32Type) +) -> Result { + let values_typed = values.as_any().downcast_ref::().unwrap(); + let element_typed = element.as_any().downcast_ref::().unwrap(); + let num_rows = list_array.len(); + let mut result = Vec::with_capacity(num_rows); + + for (row_index, w) in offsets.windows(2).enumerate() { + if list_array.is_null(row_index) || element.is_null(row_index) { + result.push(None); + continue; } - DataType::Timestamp(arrow::datatypes::TimeUnit::Microsecond, _) => { - find_position_primitive!(list_items, element, row_index, TimestampMicrosecondType) - } - DataType::Utf8 => { - let items = list_items.as_string::(); - let search = element.as_string::(); - let search_val = search.value(row_index); - let mut pos: i64 = 0; - for i in 0..items.len() { - if !items.is_null(i) && items.value(i) == search_val { - pos = (i + 1) as i64; - break; - } - } - pos - } - DataType::LargeUtf8 => { - let items = list_items.as_string::(); - let search = element.as_string::(); - let search_val = search.value(row_index); - let mut pos: i64 = 0; - for i in 0..items.len() { - if !items.is_null(i) && items.value(i) == search_val { - pos = (i + 1) as i64; - break; - } + let start = w[0].as_usize(); + let end = w[1].as_usize(); + let search_val = element_typed.value(row_index); + let mut pos: i64 = 0; + for i in start..end { + if !values_typed.is_null(i) && values_typed.value(i) == search_val { + pos = (i - start + 1) as i64; + break; } - pos } - // Fallback to ScalarValue for complex types (nested arrays, etc.) - _ => { - let element_scalar = ScalarValue::try_from_array(element, row_index)?; - let mut pos: i64 = 0; - for i in 0..list_items.len() { - let item_scalar = ScalarValue::try_from_array(list_items, i)?; - if !item_scalar.is_null() && element_scalar == item_scalar { - pos = (i + 1) as i64; - break; - } + result.push(Some(pos)); + } + + Ok(Arc::new(Int64Array::from(result))) +} + +/// String path: downcast once, iterate using offsets into the flat string buffer. +fn position_string( + list_array: &GenericListArray, + offsets: &arrow::buffer::OffsetBuffer, + values: &ArrayRef, + element: &ArrayRef, +) -> Result { + let values_typed = values.as_string::(); + let element_typed = element.as_string::(); + let num_rows = list_array.len(); + let mut result = Vec::with_capacity(num_rows); + + for (row_index, w) in offsets.windows(2).enumerate() { + if list_array.is_null(row_index) || element.is_null(row_index) { + result.push(None); + continue; + } + let start = w[0].as_usize(); + let end = w[1].as_usize(); + let search_val = element_typed.value(row_index); + let mut pos: i64 = 0; + for i in start..end { + if !values_typed.is_null(i) && values_typed.value(i) == search_val { + pos = (i - start + 1) as i64; + break; } - pos } - }; - Ok(pos) + result.push(Some(pos)); + } + + Ok(Arc::new(Int64Array::from(result))) } -fn generic_array_position( - array: &ArrayRef, +/// Fallback for complex types (nested arrays, structs, etc.) using ScalarValue comparison. +fn position_fallback( + list_array: &GenericListArray, + offsets: &arrow::buffer::OffsetBuffer, + values: &ArrayRef, element: &ArrayRef, ) -> Result { - let list_array = array - .as_any() - .downcast_ref::>() - .unwrap(); - - let mut data = Vec::with_capacity(list_array.len()); + let num_rows = list_array.len(); + let mut result = Vec::with_capacity(num_rows); - for row_index in 0..list_array.len() { + for (row_index, w) in offsets.windows(2).enumerate() { if list_array.is_null(row_index) || element.is_null(row_index) { - data.push(None); - } else { - let list_array_row = list_array.value(row_index); - let position = find_position_in_row(&list_array_row, element, row_index)?; - data.push(Some(position)); + result.push(None); + continue; + } + let start = w[0].as_usize(); + let end = w[1].as_usize(); + let search_scalar = ScalarValue::try_from_array(element, row_index)?; + let mut pos: i64 = 0; + for i in start..end { + if !values.is_null(i) { + let item_scalar = ScalarValue::try_from_array(values, i)?; + if search_scalar == item_scalar { + pos = (i - start + 1) as i64; + break; + } + } } + result.push(Some(pos)); } - Ok(Arc::new(Int64Array::from(data))) + Ok(Arc::new(Int64Array::from(result))) } #[derive(Debug, Hash, Eq, PartialEq)] From 79b348f459c6e3de517bd17ca6a593425be2158c Mon Sep 17 00:00:00 2001 From: Andy Grove Date: Mon, 16 Mar 2026 13:31:44 -0600 Subject: [PATCH 12/18] perf: use native_datafusion scan in array_position benchmark Switches benchmark to use SCAN_NATIVE_DATAFUSION for the Comet cases, avoiding JVM parquet reader overhead. Results now show Comet is 1.1-1.2X faster than Spark. --- .../benchmark/CometArrayExpressionBenchmark.scala | 14 ++++++++++++-- 1 file changed, 12 insertions(+), 2 deletions(-) diff --git a/spark/src/test/scala/org/apache/spark/sql/benchmark/CometArrayExpressionBenchmark.scala b/spark/src/test/scala/org/apache/spark/sql/benchmark/CometArrayExpressionBenchmark.scala index b197209011..28396ae14b 100644 --- a/spark/src/test/scala/org/apache/spark/sql/benchmark/CometArrayExpressionBenchmark.scala +++ b/spark/src/test/scala/org/apache/spark/sql/benchmark/CometArrayExpressionBenchmark.scala @@ -19,6 +19,8 @@ package org.apache.spark.sql.benchmark +import org.apache.comet.CometConf + // spotless:off /** * Benchmark to measure performance of Comet array expressions. To run this benchmark: @@ -52,10 +54,14 @@ object CometArrayExpressionBenchmark extends CometBenchmarkBase { | cast((value + 5) % 100 as int) as search_val |FROM $tbl""".stripMargin)) + val nativeScanConfig = + Map(CometConf.COMET_NATIVE_SCAN_IMPL.key -> CometConf.SCAN_NATIVE_DATAFUSION) + runExpressionBenchmark( "array_position - int array", values, - "SELECT array_position(int_arr, search_val) FROM parquetV1Table") + "SELECT array_position(int_arr, search_val) FROM parquetV1Table", + nativeScanConfig) } } @@ -80,10 +86,14 @@ object CometArrayExpressionBenchmark extends CometBenchmarkBase { | cast((value + 5) % 100 as string) as search_val |FROM $tbl""".stripMargin)) + val nativeScanConfig = + Map(CometConf.COMET_NATIVE_SCAN_IMPL.key -> CometConf.SCAN_NATIVE_DATAFUSION) + runExpressionBenchmark( "array_position - string array", values, - "SELECT array_position(str_arr, search_val) FROM parquetV1Table") + "SELECT array_position(str_arr, search_val) FROM parquetV1Table", + nativeScanConfig) } } } From 41b54959a025e3cd70126d52e36a676d0f0dd6d2 Mon Sep 17 00:00:00 2001 From: Andy Grove Date: Wed, 18 Mar 2026 06:08:20 -0600 Subject: [PATCH 13/18] format --- .../spark-expr/src/array_funcs/array_position.rs | 16 ++++------------ 1 file changed, 4 insertions(+), 12 deletions(-) diff --git a/native/spark-expr/src/array_funcs/array_position.rs b/native/spark-expr/src/array_funcs/array_position.rs index 843ec6f1f3..640c045841 100644 --- a/native/spark-expr/src/array_funcs/array_position.rs +++ b/native/spark-expr/src/array_funcs/array_position.rs @@ -84,19 +84,13 @@ fn generic_array_position( let elem_type = values.data_type().clone(); match &elem_type { - DataType::Boolean => { - position_boolean::(list_array, offsets, values, element) - } + DataType::Boolean => position_boolean::(list_array, offsets, values, element), DataType::Int8 => position_primitive::(list_array, offsets, values, element), DataType::Int16 => position_primitive::(list_array, offsets, values, element), DataType::Int32 => position_primitive::(list_array, offsets, values, element), DataType::Int64 => position_primitive::(list_array, offsets, values, element), - DataType::Float32 => { - position_float::(list_array, offsets, values, element) - } - DataType::Float64 => { - position_float::(list_array, offsets, values, element) - } + DataType::Float32 => position_float::(list_array, offsets, values, element), + DataType::Float64 => position_float::(list_array, offsets, values, element), DataType::Decimal128(_, _) => { position_primitive::(list_array, offsets, values, element) } @@ -104,9 +98,7 @@ fn generic_array_position( position_primitive::(list_array, offsets, values, element) } DataType::Timestamp(arrow::datatypes::TimeUnit::Microsecond, _) => { - position_primitive::( - list_array, offsets, values, element, - ) + position_primitive::(list_array, offsets, values, element) } DataType::Utf8 => position_string::(list_array, offsets, values, element), DataType::LargeUtf8 => position_string::(list_array, offsets, values, element), From 736ca63a3e16855e399bbf51d55b065283efbba8 Mon Sep 17 00:00:00 2001 From: Andy Grove Date: Mon, 13 Apr 2026 16:16:53 -0600 Subject: [PATCH 14/18] chore: apply cargo fmt --- native/spark-expr/src/comet_scalar_funcs.rs | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/native/spark-expr/src/comet_scalar_funcs.rs b/native/spark-expr/src/comet_scalar_funcs.rs index 3c0292b9d7..42354061b9 100644 --- a/native/spark-expr/src/comet_scalar_funcs.rs +++ b/native/spark-expr/src/comet_scalar_funcs.rs @@ -24,8 +24,7 @@ use crate::{ spark_ceil, spark_decimal_div, spark_decimal_integral_div, spark_floor, spark_isnan, spark_lpad, spark_make_decimal, spark_read_side_padding, spark_round, spark_rpad, spark_unhex, spark_unscaled_value, EvalMode, SparkArrayCompact, SparkArrayPositionFunc, SparkContains, - SparkDateDiff, - SparkDateTrunc, SparkMakeDate, SparkSizeFunc, + SparkDateDiff, SparkDateTrunc, SparkMakeDate, SparkSizeFunc, }; use arrow::datatypes::DataType; use datafusion::common::{DataFusionError, Result as DataFusionResult}; From 8986da2df58cd927699f018cedc5c38440f33b93 Mon Sep 17 00:00:00 2001 From: Andy Grove Date: Tue, 14 Apr 2026 09:34:40 -0600 Subject: [PATCH 15/18] fix: address review feedback for array_position - Compute combined null buffer upfront via NullBuffer::union and use Vec with Int64Array::new() instead of Vec>, avoiding per-row null tracking overhead in all typed paths - Use TypeSignature::Any(2) instead of variadic_any for precise arity - Replace .unwrap() on downcast with .ok_or_else() for safer error handling - Add nested array test cases to exercise position_fallback code path --- .../src/array_funcs/array_position.rs | 83 ++++++++++--------- .../expressions/array/array_position.sql | 7 ++ 2 files changed, 50 insertions(+), 40 deletions(-) diff --git a/native/spark-expr/src/array_funcs/array_position.rs b/native/spark-expr/src/array_funcs/array_position.rs index 640c045841..dc1e527c00 100644 --- a/native/spark-expr/src/array_funcs/array_position.rs +++ b/native/spark-expr/src/array_funcs/array_position.rs @@ -18,13 +18,14 @@ use arrow::array::{ Array, ArrayRef, AsArray, BooleanArray, GenericListArray, Int64Array, OffsetSizeTrait, }; +use arrow::buffer::{NullBuffer, ScalarBuffer}; use arrow::datatypes::{ ArrowPrimitiveType, DataType, Date32Type, Decimal128Type, Float32Type, Float64Type, Int16Type, Int32Type, Int64Type, Int8Type, TimestampMicrosecondType, }; use datafusion::common::{exec_err, DataFusionError, Result as DataFusionResult, ScalarValue}; use datafusion::logical_expr::{ - ColumnarValue, ScalarFunctionArgs, ScalarUDFImpl, Signature, Volatility, + ColumnarValue, ScalarFunctionArgs, ScalarUDFImpl, Signature, TypeSignature, Volatility, }; use num::Float; use std::any::Any; @@ -77,7 +78,7 @@ fn generic_array_position( let list_array = array .as_any() .downcast_ref::>() - .unwrap(); + .ok_or_else(|| DataFusionError::Internal("expected list array".into()))?; let values = list_array.values(); let offsets = list_array.offsets(); @@ -107,6 +108,16 @@ fn generic_array_position( } } +/// Compute the combined null buffer from list array and element nulls. +fn combined_nulls(list_array_nulls: Option<&NullBuffer>, element_nulls: Option<&NullBuffer>) -> Option { + match (list_array_nulls, element_nulls) { + (Some(a), Some(b)) => NullBuffer::union(Some(a), Some(b)), + (Some(a), None) => Some(a.clone()), + (None, Some(b)) => Some(b.clone()), + (None, None) => None, + } +} + /// Fast path for primitive types: downcast once, iterate using offsets into the flat buffer. fn position_primitive( list_array: &GenericListArray, @@ -120,27 +131,25 @@ where let values_typed = values.as_primitive::(); let element_typed = element.as_primitive::(); let num_rows = list_array.len(); - let mut result = Vec::with_capacity(num_rows); + let nulls = combined_nulls(list_array.nulls(), element.nulls()); + let mut result = vec![0i64; num_rows]; for (row_index, w) in offsets.windows(2).enumerate() { - if list_array.is_null(row_index) || element.is_null(row_index) { - result.push(None); + if nulls.as_ref().is_some_and(|n| n.is_null(row_index)) { continue; } let start = w[0].as_usize(); let end = w[1].as_usize(); let search_val = element_typed.value(row_index); - let mut pos: i64 = 0; for i in start..end { if !values_typed.is_null(i) && values_typed.value(i) == search_val { - pos = (i - start + 1) as i64; + result[row_index] = (i - start + 1) as i64; break; } } - result.push(Some(pos)); } - Ok(Arc::new(Int64Array::from(result))) + Ok(Arc::new(Int64Array::new(ScalarBuffer::from(result), nulls))) } /// Float path: same as primitive but treats NaN == NaN (Spark's ordering.equiv() semantics). @@ -156,31 +165,29 @@ where let values_typed = values.as_primitive::(); let element_typed = element.as_primitive::(); let num_rows = list_array.len(); - let mut result = Vec::with_capacity(num_rows); + let nulls = combined_nulls(list_array.nulls(), element.nulls()); + let mut result = vec![0i64; num_rows]; for (row_index, w) in offsets.windows(2).enumerate() { - if list_array.is_null(row_index) || element.is_null(row_index) { - result.push(None); + if nulls.as_ref().is_some_and(|n| n.is_null(row_index)) { continue; } let start = w[0].as_usize(); let end = w[1].as_usize(); let search_val = element_typed.value(row_index); let search_is_nan = search_val.is_nan(); - let mut pos: i64 = 0; for i in start..end { if !values_typed.is_null(i) { let v = values_typed.value(i); if (search_is_nan && v.is_nan()) || v == search_val { - pos = (i - start + 1) as i64; + result[row_index] = (i - start + 1) as i64; break; } } } - result.push(Some(pos)); } - Ok(Arc::new(Int64Array::from(result))) + Ok(Arc::new(Int64Array::new(ScalarBuffer::from(result), nulls))) } /// Boolean path. @@ -190,30 +197,30 @@ fn position_boolean( values: &ArrayRef, element: &ArrayRef, ) -> Result { - let values_typed = values.as_any().downcast_ref::().unwrap(); - let element_typed = element.as_any().downcast_ref::().unwrap(); + let values_typed = values.as_any().downcast_ref::() + .ok_or_else(|| DataFusionError::Internal("expected boolean array".into()))?; + let element_typed = element.as_any().downcast_ref::() + .ok_or_else(|| DataFusionError::Internal("expected boolean array".into()))?; let num_rows = list_array.len(); - let mut result = Vec::with_capacity(num_rows); + let nulls = combined_nulls(list_array.nulls(), element.nulls()); + let mut result = vec![0i64; num_rows]; for (row_index, w) in offsets.windows(2).enumerate() { - if list_array.is_null(row_index) || element.is_null(row_index) { - result.push(None); + if nulls.as_ref().is_some_and(|n| n.is_null(row_index)) { continue; } let start = w[0].as_usize(); let end = w[1].as_usize(); let search_val = element_typed.value(row_index); - let mut pos: i64 = 0; for i in start..end { if !values_typed.is_null(i) && values_typed.value(i) == search_val { - pos = (i - start + 1) as i64; + result[row_index] = (i - start + 1) as i64; break; } } - result.push(Some(pos)); } - Ok(Arc::new(Int64Array::from(result))) + Ok(Arc::new(Int64Array::new(ScalarBuffer::from(result), nulls))) } /// String path: downcast once, iterate using offsets into the flat string buffer. @@ -226,27 +233,25 @@ fn position_string( let values_typed = values.as_string::(); let element_typed = element.as_string::(); let num_rows = list_array.len(); - let mut result = Vec::with_capacity(num_rows); + let nulls = combined_nulls(list_array.nulls(), element.nulls()); + let mut result = vec![0i64; num_rows]; for (row_index, w) in offsets.windows(2).enumerate() { - if list_array.is_null(row_index) || element.is_null(row_index) { - result.push(None); + if nulls.as_ref().is_some_and(|n| n.is_null(row_index)) { continue; } let start = w[0].as_usize(); let end = w[1].as_usize(); let search_val = element_typed.value(row_index); - let mut pos: i64 = 0; for i in start..end { if !values_typed.is_null(i) && values_typed.value(i) == search_val { - pos = (i - start + 1) as i64; + result[row_index] = (i - start + 1) as i64; break; } } - result.push(Some(pos)); } - Ok(Arc::new(Int64Array::from(result))) + Ok(Arc::new(Int64Array::new(ScalarBuffer::from(result), nulls))) } /// Fallback for complex types (nested arrays, structs, etc.) using ScalarValue comparison. @@ -257,30 +262,28 @@ fn position_fallback( element: &ArrayRef, ) -> Result { let num_rows = list_array.len(); - let mut result = Vec::with_capacity(num_rows); + let nulls = combined_nulls(list_array.nulls(), element.nulls()); + let mut result = vec![0i64; num_rows]; for (row_index, w) in offsets.windows(2).enumerate() { - if list_array.is_null(row_index) || element.is_null(row_index) { - result.push(None); + if nulls.as_ref().is_some_and(|n| n.is_null(row_index)) { continue; } let start = w[0].as_usize(); let end = w[1].as_usize(); let search_scalar = ScalarValue::try_from_array(element, row_index)?; - let mut pos: i64 = 0; for i in start..end { if !values.is_null(i) { let item_scalar = ScalarValue::try_from_array(values, i)?; if search_scalar == item_scalar { - pos = (i - start + 1) as i64; + result[row_index] = (i - start + 1) as i64; break; } } } - result.push(Some(pos)); } - Ok(Arc::new(Int64Array::from(result))) + Ok(Arc::new(Int64Array::new(ScalarBuffer::from(result), nulls))) } #[derive(Debug, Hash, Eq, PartialEq)] @@ -297,7 +300,7 @@ impl Default for SparkArrayPositionFunc { impl SparkArrayPositionFunc { pub fn new() -> Self { Self { - signature: Signature::variadic_any(Volatility::Immutable), + signature: Signature::new(TypeSignature::Any(2), Volatility::Immutable), } } } diff --git a/spark/src/test/resources/sql-tests/expressions/array/array_position.sql b/spark/src/test/resources/sql-tests/expressions/array/array_position.sql index af04c0cd9d..373267fdb5 100644 --- a/spark/src/test/resources/sql-tests/expressions/array/array_position.sql +++ b/spark/src/test/resources/sql-tests/expressions/array/array_position.sql @@ -218,6 +218,13 @@ INSERT INTO test_ap_date VALUES query SELECT array_position(arr, val) FROM test_ap_date +-- nested array (exercises position_fallback code path) +query spark_answer_only +SELECT array_position(array(array(1, 2), array(3, 4)), array(1, 2)) + +query spark_answer_only +SELECT array_position(array(array(1, 2), array(3, 4)), array(5, 6)) + -- timestamp arrays statement CREATE TABLE test_ap_ts(arr array, val timestamp) USING parquet From 235d0535b689953265e6641bb08cfbb90448250a Mon Sep 17 00:00:00 2001 From: Andy Grove Date: Sat, 18 Apr 2026 10:46:34 -0600 Subject: [PATCH 16/18] style: cargo fmt --- native/spark-expr/src/array_funcs/array_position.rs | 13 ++++++++++--- 1 file changed, 10 insertions(+), 3 deletions(-) diff --git a/native/spark-expr/src/array_funcs/array_position.rs b/native/spark-expr/src/array_funcs/array_position.rs index dc1e527c00..dbcd8615df 100644 --- a/native/spark-expr/src/array_funcs/array_position.rs +++ b/native/spark-expr/src/array_funcs/array_position.rs @@ -109,7 +109,10 @@ fn generic_array_position( } /// Compute the combined null buffer from list array and element nulls. -fn combined_nulls(list_array_nulls: Option<&NullBuffer>, element_nulls: Option<&NullBuffer>) -> Option { +fn combined_nulls( + list_array_nulls: Option<&NullBuffer>, + element_nulls: Option<&NullBuffer>, +) -> Option { match (list_array_nulls, element_nulls) { (Some(a), Some(b)) => NullBuffer::union(Some(a), Some(b)), (Some(a), None) => Some(a.clone()), @@ -197,9 +200,13 @@ fn position_boolean( values: &ArrayRef, element: &ArrayRef, ) -> Result { - let values_typed = values.as_any().downcast_ref::() + let values_typed = values + .as_any() + .downcast_ref::() .ok_or_else(|| DataFusionError::Internal("expected boolean array".into()))?; - let element_typed = element.as_any().downcast_ref::() + let element_typed = element + .as_any() + .downcast_ref::() .ok_or_else(|| DataFusionError::Internal("expected boolean array".into()))?; let num_rows = list_array.len(); let nulls = combined_nulls(list_array.nulls(), element.nulls()); From 7558c66e901b49ff5d988ffcb9cf8f4d5eae173d Mon Sep 17 00:00:00 2001 From: Andy Grove Date: Mon, 20 Apr 2026 19:52:40 -0600 Subject: [PATCH 17/18] test: add timestamp_ntz coverage for array_position --- .../sql-tests/expressions/array/array_position.sql | 14 ++++++++++++++ 1 file changed, 14 insertions(+) diff --git a/spark/src/test/resources/sql-tests/expressions/array/array_position.sql b/spark/src/test/resources/sql-tests/expressions/array/array_position.sql index 373267fdb5..7f75e28fb9 100644 --- a/spark/src/test/resources/sql-tests/expressions/array/array_position.sql +++ b/spark/src/test/resources/sql-tests/expressions/array/array_position.sql @@ -238,3 +238,17 @@ INSERT INTO test_ap_ts VALUES query SELECT array_position(arr, val) FROM test_ap_ts + +-- timestamp_ntz arrays +statement +CREATE TABLE test_ap_ts_ntz(arr array, val timestamp_ntz) USING parquet + +statement +INSERT INTO test_ap_ts_ntz VALUES + (array(timestamp_ntz '2024-01-01 00:00:00', timestamp_ntz '2024-06-15 12:30:00'), timestamp_ntz '2024-06-15 12:30:00'), + (array(timestamp_ntz '2000-01-01 00:00:00'), timestamp_ntz '1999-12-31 23:59:59'), + (NULL, timestamp_ntz '2024-01-01 00:00:00'), + (array(timestamp_ntz '2024-01-01 00:00:00'), NULL) + +query +SELECT array_position(arr, val) FROM test_ap_ts_ntz From 60bceb5a9e0b998c646ee438a153c8d849df26a5 Mon Sep 17 00:00:00 2001 From: Andy Grove Date: Mon, 20 Apr 2026 19:56:23 -0600 Subject: [PATCH 18/18] test: add column-based nested array coverage for array_position --- .../expressions/array/array_position.sql | 29 +++++++++++++++++++ 1 file changed, 29 insertions(+) diff --git a/spark/src/test/resources/sql-tests/expressions/array/array_position.sql b/spark/src/test/resources/sql-tests/expressions/array/array_position.sql index 7f75e28fb9..132158ab6d 100644 --- a/spark/src/test/resources/sql-tests/expressions/array/array_position.sql +++ b/spark/src/test/resources/sql-tests/expressions/array/array_position.sql @@ -225,6 +225,35 @@ SELECT array_position(array(array(1, 2), array(3, 4)), array(1, 2)) query spark_answer_only SELECT array_position(array(array(1, 2), array(3, 4)), array(5, 6)) +-- nested int array column (exercises position_fallback natively) +statement +CREATE TABLE test_ap_nested_int(arr array>, val array) USING parquet + +statement +INSERT INTO test_ap_nested_int VALUES + (array(array(1, 2), array(3, 4), array(5, 6)), array(3, 4)), + (array(array(1, 2), array(3, 4)), array(5, 6)), + (array(array(1, 2), array(1, 2, 3)), array(1, 2)), + (NULL, array(1, 2)), + (array(array(1, 2)), NULL) + +query +SELECT array_position(arr, val) FROM test_ap_nested_int + +-- nested string array column (exercises position_fallback natively) +statement +CREATE TABLE test_ap_nested_str(arr array>, val array) USING parquet + +statement +INSERT INTO test_ap_nested_str VALUES + (array(array('a', 'b'), array('c', 'd')), array('c', 'd')), + (array(array('a', 'b'), array('c', 'd')), array('x', 'y')), + (NULL, array('a')), + (array(array('a')), NULL) + +query +SELECT array_position(arr, val) FROM test_ap_nested_str + -- timestamp arrays statement CREATE TABLE test_ap_ts(arr array, val timestamp) USING parquet