From c514fdab2e3867d57118a20d47570c8688234b21 Mon Sep 17 00:00:00 2001 From: shekharrajak Date: Sat, 17 Jan 2026 01:40:16 +0530 Subject: [PATCH 01/11] feat: implement map_sort native function --- native/spark-expr/src/comet_scalar_funcs.rs | 10 +- native/spark-expr/src/lib.rs | 2 + native/spark-expr/src/map_funcs/map_sort.rs | 185 ++++++++++++++++++++ native/spark-expr/src/map_funcs/mod.rs | 20 +++ 4 files changed, 214 insertions(+), 3 deletions(-) create mode 100644 native/spark-expr/src/map_funcs/map_sort.rs create mode 100644 native/spark-expr/src/map_funcs/mod.rs diff --git a/native/spark-expr/src/comet_scalar_funcs.rs b/native/spark-expr/src/comet_scalar_funcs.rs index 1eaf0b2a97..4f11957409 100644 --- a/native/spark-expr/src/comet_scalar_funcs.rs +++ b/native/spark-expr/src/comet_scalar_funcs.rs @@ -22,9 +22,9 @@ use crate::math_funcs::log::spark_log; use crate::math_funcs::modulo_expr::spark_modulo; use crate::{ spark_ceil, spark_decimal_div, spark_decimal_integral_div, spark_floor, spark_isnan, - spark_lpad, spark_make_decimal, spark_read_side_padding, spark_round, spark_rpad, spark_unhex, - spark_unscaled_value, EvalMode, SparkContains, SparkDateDiff, SparkDateTrunc, SparkMakeDate, - SparkSizeFunc, + spark_lpad, spark_make_decimal, spark_map_sort, spark_read_side_padding, spark_round, + spark_rpad, spark_unhex, spark_unscaled_value, EvalMode, SparkContains, SparkDateDiff, + SparkDateTrunc, SparkMakeDate, SparkSizeFunc, }; use arrow::datatypes::DataType; use datafusion::common::{DataFusionError, Result as DataFusionResult}; @@ -186,6 +186,10 @@ pub fn create_comet_physical_fun_with_eval_mode( let func = Arc::new(crate::string_funcs::spark_split); make_comet_scalar_udf!("split", func, without data_type) } + "map_sort" => { + let func = Arc::new(spark_map_sort); + make_comet_scalar_udf!("map_sort", func, without data_type) + } _ => registry.udf(fun_name).map_err(|e| { DataFusionError::Execution(format!( "Function {fun_name} not found in the registry: {e}", diff --git a/native/spark-expr/src/lib.rs b/native/spark-expr/src/lib.rs index 342ef73619..5aabe5c62e 100644 --- a/native/spark-expr/src/lib.rs +++ b/native/spark-expr/src/lib.rs @@ -45,6 +45,7 @@ mod array_funcs; mod comet_scalar_funcs; pub mod hash_funcs; +mod map_funcs; mod string_funcs; mod datetime_funcs; @@ -63,6 +64,7 @@ mod nondetermenistic_funcs; pub use array_funcs::*; pub use conditional_funcs::*; pub use conversion_funcs::*; +pub use map_funcs::*; pub use nondetermenistic_funcs::*; pub use comet_scalar_funcs::{ diff --git a/native/spark-expr/src/map_funcs/map_sort.rs b/native/spark-expr/src/map_funcs/map_sort.rs new file mode 100644 index 0000000000..4841cf50a6 --- /dev/null +++ b/native/spark-expr/src/map_funcs/map_sort.rs @@ -0,0 +1,185 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use arrow::array::{Array, ArrayRef, MapArray}; +use arrow::compute::{lexsort_to_indices, take, SortColumn, SortOptions}; +use arrow::datatypes::{DataType, Field}; +use datafusion::common::{exec_err, DataFusionError, ScalarValue}; +use datafusion::logical_expr::ColumnarValue; +use std::sync::Arc; + +pub fn spark_map_sort(args: &[ColumnarValue]) -> Result { + if args.len() != 1 { + return exec_err!("map_sort function takes exactly one argument"); + } + + match &args[0] { + ColumnarValue::Array(array) => { + let result = spark_map_sort_array(array)?; + Ok(ColumnarValue::Array(result)) + } + ColumnarValue::Scalar(scalar) => { + let result = spark_map_sort_scalar(scalar)?; + Ok(ColumnarValue::Scalar(result)) + } + } +} + +fn spark_map_sort_array(array: &ArrayRef) -> Result { + let map_array = array + .as_any() + .downcast_ref::() + .ok_or_else(|| DataFusionError::Internal("Expected MapArray".to_string()))?; + + let entries = map_array.entries(); + let struct_array = entries + .as_any() + .downcast_ref::() + .ok_or_else(|| DataFusionError::Internal("Expected StructArray for entries".to_string()))?; + + if struct_array.num_columns() != 2 { + return exec_err!("Map entries must have exactly 2 columns (keys and values)"); + } + + let keys = struct_array.column(0); + let values = struct_array.column(1); + let offsets = map_array.offsets(); + + let mut sorted_keys_arrays = Vec::new(); + let mut sorted_values_arrays = Vec::new(); + let mut new_offsets = Vec::with_capacity(map_array.len() + 1); + new_offsets.push(0i32); + + for row_idx in 0..map_array.len() { + let start = offsets[row_idx] as usize; + let end = offsets[row_idx + 1] as usize; + let len = end - start; + + if len == 0 { + new_offsets.push(new_offsets[row_idx]); + continue; + } + + let row_keys = keys.slice(start, len); + let row_values = values.slice(start, len); + + if len == 1 { + sorted_keys_arrays.push(row_keys); + sorted_values_arrays.push(row_values); + new_offsets.push(new_offsets[row_idx] + len as i32); + continue; + } + + let sort_columns = vec![SortColumn { + values: Arc::clone(&row_keys), + options: Some(SortOptions { + descending: false, + nulls_first: false, + }), + }]; + + let indices = lexsort_to_indices(&sort_columns, None)?; + let sorted_keys = take(&row_keys, &indices, None)?; + let sorted_values = take(&row_values, &indices, None)?; + + sorted_keys_arrays.push(sorted_keys); + sorted_values_arrays.push(sorted_values); + new_offsets.push(new_offsets[row_idx] + len as i32); + } + + if sorted_keys_arrays.is_empty() { + let key_field = Arc::new(Field::new( + "key", + keys.data_type().clone(), + keys.is_nullable(), + )); + let value_field = Arc::new(Field::new( + "value", + values.data_type().clone(), + values.is_nullable(), + )); + let entries_field = Arc::new(Field::new( + "entries", + DataType::Struct(vec![Arc::clone(&key_field), Arc::clone(&value_field)].into()), + false, + )); + + let empty_keys = arrow::array::new_empty_array(keys.data_type()); + let empty_values = arrow::array::new_empty_array(values.data_type()); + let empty_entries = arrow::array::StructArray::new( + vec![key_field, value_field].into(), + vec![empty_keys, empty_values], + None, + ); + + return Ok(Arc::new(MapArray::new( + entries_field, + arrow::buffer::OffsetBuffer::new(vec![0i32; map_array.len() + 1].into()), + empty_entries, + map_array.nulls().cloned(), + false, + ))); + } + + let sorted_keys_refs: Vec<&dyn Array> = sorted_keys_arrays.iter().map(|a| a.as_ref()).collect(); + let sorted_values_refs: Vec<&dyn Array> = + sorted_values_arrays.iter().map(|a| a.as_ref()).collect(); + + let concatenated_keys = arrow::compute::concat(&sorted_keys_refs)?; + let concatenated_values = arrow::compute::concat(&sorted_values_refs)?; + + let key_field = Arc::new(Field::new( + "key", + keys.data_type().clone(), + keys.is_nullable(), + )); + let value_field = Arc::new(Field::new( + "value", + values.data_type().clone(), + values.is_nullable(), + )); + + let sorted_entries = arrow::array::StructArray::new( + vec![Arc::clone(&key_field), Arc::clone(&value_field)].into(), + vec![concatenated_keys, concatenated_values], + None, + ); + + let entries_field = Arc::new(Field::new( + "entries", + DataType::Struct(vec![key_field, value_field].into()), + false, + )); + + Ok(Arc::new(MapArray::new( + entries_field, + arrow::buffer::OffsetBuffer::new(new_offsets.into()), + sorted_entries, + map_array.nulls().cloned(), + false, + ))) +} + +fn spark_map_sort_scalar(scalar: &ScalarValue) -> Result { + match scalar { + ScalarValue::Null => Ok(ScalarValue::Null), + _ => exec_err!( + "map_sort scalar function only supports map types, got: {:?}", + scalar + ), + } +} diff --git a/native/spark-expr/src/map_funcs/mod.rs b/native/spark-expr/src/map_funcs/mod.rs new file mode 100644 index 0000000000..a1bfd34923 --- /dev/null +++ b/native/spark-expr/src/map_funcs/mod.rs @@ -0,0 +1,20 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +mod map_sort; + +pub use map_sort::spark_map_sort; From ada3501c53f29e66652abc98ee69bc5e6ce5e992 Mon Sep 17 00:00:00 2001 From: shekharrajak Date: Sat, 17 Jan 2026 01:40:20 +0530 Subject: [PATCH 02/11] feat: add CometMapSort serializer --- .../apache/comet/serde/QueryPlanSerde.scala | 3 ++- .../scala/org/apache/comet/serde/maps.scala | 22 ++++++++++++++++++- 2 files changed, 23 insertions(+), 2 deletions(-) diff --git a/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala b/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala index 59fb0f9819..407f82589f 100644 --- a/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala +++ b/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala @@ -131,7 +131,8 @@ object QueryPlanSerde extends Logging with CometExprShim { classOf[MapValues] -> CometMapValues, classOf[MapFromArrays] -> CometMapFromArrays, classOf[MapContainsKey] -> CometMapContainsKey, - classOf[MapFromEntries] -> CometMapFromEntries) + classOf[MapFromEntries] -> CometMapFromEntries, + classOf[MapSort] -> CometMapSort) private val structExpressions: Map[Class[_ <: Expression], CometExpressionSerde[_]] = Map( classOf[CreateNamedStruct] -> CometCreateNamedStruct, diff --git a/spark/src/main/scala/org/apache/comet/serde/maps.scala b/spark/src/main/scala/org/apache/comet/serde/maps.scala index ceafc157c4..acb135b369 100644 --- a/spark/src/main/scala/org/apache/comet/serde/maps.scala +++ b/spark/src/main/scala/org/apache/comet/serde/maps.scala @@ -20,9 +20,11 @@ package org.apache.comet.serde import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.expressions.objects.RowOrdering import org.apache.spark.sql.types._ -import org.apache.comet.serde.QueryPlanSerde.{createBinaryExpr, exprToProtoInternal, optExprWithInfo, scalarFunctionExprToProto} +import org.apache.comet.CometSparkSessionExtensions.withInfo +import org.apache.comet.serde.QueryPlanSerde.{createBinaryExpr, exprToProtoInternal, optExprWithInfo, scalarFunctionExprToProto, scalarFunctionExprToProtoWithReturnType} object CometMapKeys extends CometExpressionSerde[MapKeys] { @@ -156,3 +158,21 @@ object CometMapFromEntries extends CometScalarFunction[MapFromEntries]("map_from Compatible(None) } } + +object CometMapSort extends CometExpressionSerde[MapSort] { + + override def convert( + expr: MapSort, + inputs: Seq[Attribute], + binding: Boolean): Option[ExprOuterClass.Expr] = { + val keyType = expr.base.dataType.asInstanceOf[MapType].keyType + if (!RowOrdering.isOrderable(keyType)) { + withInfo(expr, s"map_sort requires orderable key type, got: $keyType") + return None + } + + val childExpr = exprToProtoInternal(expr.base, inputs, binding) + val mapSortScalarExpr = scalarFunctionExprToProto("map_sort", childExpr) + optExprWithInfo(mapSortScalarExpr, expr, expr.children: _*) + } +} From c4b4c44268114c41438a983622f1e2221182e769 Mon Sep 17 00:00:00 2001 From: shekharrajak Date: Sat, 17 Jan 2026 01:40:26 +0530 Subject: [PATCH 03/11] test: add comprehensive map_sort tests --- .../comet/CometMapExpressionSuite.scala | 163 ++++++++++++++++++ 1 file changed, 163 insertions(+) diff --git a/spark/src/test/scala/org/apache/comet/CometMapExpressionSuite.scala b/spark/src/test/scala/org/apache/comet/CometMapExpressionSuite.scala index 03db26e566..afb1e96bbb 100644 --- a/spark/src/test/scala/org/apache/comet/CometMapExpressionSuite.scala +++ b/spark/src/test/scala/org/apache/comet/CometMapExpressionSuite.scala @@ -236,4 +236,167 @@ class CometMapExpressionSuite extends CometTestBase { } } + + + test("map_sort with integer keys") { + withTempDir { dir => + withTempView("t1") { + val path = new Path(dir.toURI.toString, "test.parquet") + withSQLConf(CometConf.COMET_ENABLED.key -> "false") { + val df = spark + .range(5) + .select(map(lit(3), lit("c"), lit(1), lit("a"), lit(2), lit("b")).alias("m")) + df.write.parquet(path.toString) + } + spark.read.parquet(path.toString).createOrReplaceTempView("t1") + checkSparkAnswerAndOperator(sql("SELECT map_sort(m) FROM t1")) + } + } + } + + test("map_sort with string keys") { + withTempDir { dir => + withTempView("t1") { + val path = new Path(dir.toURI.toString, "test.parquet") + withSQLConf(CometConf.COMET_ENABLED.key -> "false") { + val df = spark + .range(5) + .select(map(lit("z"), lit(1), lit("a"), lit(2), lit("m"), lit(3)).alias("m")) + df.write.parquet(path.toString) + } + spark.read.parquet(path.toString).createOrReplaceTempView("t1") + checkSparkAnswerAndOperator(sql("SELECT map_sort(m) FROM t1")) + } + } + } + + test("map_sort with double keys") { + withTempDir { dir => + withTempView("t1") { + val path = new Path(dir.toURI.toString, "test.parquet") + withSQLConf(CometConf.COMET_ENABLED.key -> "false") { + val df = spark + .range(5) + .select(map(lit(3.5), lit("c"), lit(1.2), lit("a"), lit(2.8), lit("b")).alias("m")) + df.write.parquet(path.toString) + } + spark.read.parquet(path.toString).createOrReplaceTempView("t1") + checkSparkAnswerAndOperator(sql("SELECT map_sort(m) FROM t1")) + } + } + } + + test("map_sort with null and empty maps") { + withTempDir { dir => + withTempView("t1") { + val path = new Path(dir.toURI.toString, "test.parquet") + withSQLConf(CometConf.COMET_ENABLED.key -> "false") { + val df = spark + .range(5) + .select( + when(col("id") === 0, lit(null)) + .when(col("id") === 1, map()) + .when(col("id") === 2, map(lit(1), lit("a"))) + .otherwise(map(lit(3), lit("c"), lit(2), lit("b"))) + .alias("m")) + df.write.parquet(path.toString) + } + spark.read.parquet(path.toString).createOrReplaceTempView("t1") + checkSparkAnswerAndOperator(sql("SELECT map_sort(m) FROM t1")) + } + } + } + + test("map_sort with struct keys") { + withTempDir { dir => + withTempView("t1") { + val path = new Path(dir.toURI.toString, "test.parquet") + withSQLConf(CometConf.COMET_ENABLED.key -> "false") { + val df = spark + .range(3) + .select( + map( + struct(lit(2), lit("b")), + lit("second"), + struct(lit(1), lit("a")), + lit("first"), + struct(lit(3), lit("c")), + lit("third")).alias("m")) + df.write.parquet(path.toString) + } + spark.read.parquet(path.toString).createOrReplaceTempView("t1") + checkSparkAnswerAndOperator(sql("SELECT map_sort(m) FROM t1")) + } + } + } + + test("map_sort with array keys") { + withTempDir { dir => + withTempView("t1") { + val path = new Path(dir.toURI.toString, "test.parquet") + withSQLConf(CometConf.COMET_ENABLED.key -> "false") { + val df = spark + .range(3) + .select( + map( + array(lit(2), lit(3)), + lit("array2"), + array(lit(1), lit(2)), + lit("array1"), + array(lit(3), lit(4)), + lit("array3")).alias("m")) + df.write.parquet(path.toString) + } + spark.read.parquet(path.toString).createOrReplaceTempView("t1") + checkSparkAnswerAndOperator(sql("SELECT map_sort(m) FROM t1")) + } + } + } + + test("map_sort with complex values") { + withTempDir { dir => + withTempView("t1") { + val path = new Path(dir.toURI.toString, "test.parquet") + withSQLConf(CometConf.COMET_ENABLED.key -> "false") { + val df = spark + .range(3) + .select( + map( + lit(3), + struct(lit("c"), array(lit(30), lit(31))), + lit(1), + struct(lit("a"), array(lit(10), lit(11))), + lit(2), + struct(lit("b"), array(lit(20), lit(21)))).alias("m")) + df.write.parquet(path.toString) + } + spark.read.parquet(path.toString).createOrReplaceTempView("t1") + checkSparkAnswerAndOperator(sql("SELECT map_sort(m) FROM t1")) + } + } + } + + test("map_sort fallback for non-orderable keys") { + withTempDir { dir => + withTempView("t1") { + val path = new Path(dir.toURI.toString, "test.parquet") + withSQLConf(CometConf.COMET_ENABLED.key -> "false") { + val df = spark + .range(3) + .select( + map( + map(lit(1), lit("inner1")), + lit("outer1"), + map(lit(2), lit("inner2")), + lit("outer2")).alias("m")) + df.write.parquet(path.toString) + } + spark.read.parquet(path.toString).createOrReplaceTempView("t1") + checkSparkAnswerAndFallbackReason( + sql("SELECT map_sort(m) FROM t1"), + "map_sort requires orderable key type") + } + } + } + } From 07d1013679d9a43fe270792c77d3c2184092afcb Mon Sep 17 00:00:00 2001 From: shekharrajak Date: Wed, 21 Jan 2026 01:22:24 +0530 Subject: [PATCH 04/11] refactor: remove MapSort from shared code MapSort only exists in Spark 4.0+, not in 3.4/3.5 --- .../apache/comet/serde/QueryPlanSerde.scala | 3 +-- .../scala/org/apache/comet/serde/maps.scala | 19 ------------------- 2 files changed, 1 insertion(+), 21 deletions(-) diff --git a/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala b/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala index 407f82589f..59fb0f9819 100644 --- a/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala +++ b/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala @@ -131,8 +131,7 @@ object QueryPlanSerde extends Logging with CometExprShim { classOf[MapValues] -> CometMapValues, classOf[MapFromArrays] -> CometMapFromArrays, classOf[MapContainsKey] -> CometMapContainsKey, - classOf[MapFromEntries] -> CometMapFromEntries, - classOf[MapSort] -> CometMapSort) + classOf[MapFromEntries] -> CometMapFromEntries) private val structExpressions: Map[Class[_ <: Expression], CometExpressionSerde[_]] = Map( classOf[CreateNamedStruct] -> CometCreateNamedStruct, diff --git a/spark/src/main/scala/org/apache/comet/serde/maps.scala b/spark/src/main/scala/org/apache/comet/serde/maps.scala index acb135b369..b038fcf954 100644 --- a/spark/src/main/scala/org/apache/comet/serde/maps.scala +++ b/spark/src/main/scala/org/apache/comet/serde/maps.scala @@ -20,7 +20,6 @@ package org.apache.comet.serde import org.apache.spark.sql.catalyst.expressions._ -import org.apache.spark.sql.catalyst.expressions.objects.RowOrdering import org.apache.spark.sql.types._ import org.apache.comet.CometSparkSessionExtensions.withInfo @@ -158,21 +157,3 @@ object CometMapFromEntries extends CometScalarFunction[MapFromEntries]("map_from Compatible(None) } } - -object CometMapSort extends CometExpressionSerde[MapSort] { - - override def convert( - expr: MapSort, - inputs: Seq[Attribute], - binding: Boolean): Option[ExprOuterClass.Expr] = { - val keyType = expr.base.dataType.asInstanceOf[MapType].keyType - if (!RowOrdering.isOrderable(keyType)) { - withInfo(expr, s"map_sort requires orderable key type, got: $keyType") - return None - } - - val childExpr = exprToProtoInternal(expr.base, inputs, binding) - val mapSortScalarExpr = scalarFunctionExprToProto("map_sort", childExpr) - optExprWithInfo(mapSortScalarExpr, expr, expr.children: _*) - } -} From e050c4f6e26cc3e57baea59a4eb56ba4bb4d5f83 Mon Sep 17 00:00:00 2001 From: shekharrajak Date: Wed, 21 Jan 2026 01:22:31 +0530 Subject: [PATCH 05/11] feat: add MapSort support for Spark 4.0 via shim Adds MapSort serialization to Spark 4.0 version shim --- .../org/apache/comet/shims/CometExprShim.scala | 17 ++++++++++++++--- 1 file changed, 14 insertions(+), 3 deletions(-) diff --git a/spark/src/main/spark-4.0/org/apache/comet/shims/CometExprShim.scala b/spark/src/main/spark-4.0/org/apache/comet/shims/CometExprShim.scala index 2c5cebd166..86e0a882b9 100644 --- a/spark/src/main/spark-4.0/org/apache/comet/shims/CometExprShim.scala +++ b/spark/src/main/spark-4.0/org/apache/comet/shims/CometExprShim.scala @@ -20,16 +20,16 @@ package org.apache.comet.shims import org.apache.spark.sql.catalyst.expressions._ -import org.apache.spark.sql.catalyst.expressions.objects.StaticInvoke +import org.apache.spark.sql.catalyst.expressions.objects.{RowOrdering, StaticInvoke} import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.internal.types.StringTypeWithCollation -import org.apache.spark.sql.types.{BinaryType, BooleanType, DataTypes, StringType} +import org.apache.spark.sql.types.{BinaryType, BooleanType, DataTypes, MapType, StringType} import org.apache.comet.CometSparkSessionExtensions.withInfo import org.apache.comet.expressions.{CometCast, CometEvalMode} import org.apache.comet.serde.{CommonStringExprs, Compatible, ExprOuterClass, Incompatible} import org.apache.comet.serde.ExprOuterClass.{BinaryOutputStyle, Expr} -import org.apache.comet.serde.QueryPlanSerde.exprToProtoInternal +import org.apache.comet.serde.QueryPlanSerde.{exprToProtoInternal, optExprWithInfo, scalarFunctionExprToProto} /** * `CometExprShim` acts as a shim for parsing expressions from different Spark versions. @@ -103,6 +103,17 @@ trait CometExprShim extends CommonStringExprs { None } + case expr: MapSort => + val keyType = expr.base.dataType.asInstanceOf[MapType].keyType + if (!RowOrdering.isOrderable(keyType)) { + withInfo(expr, s"map_sort requires orderable key type, got: $keyType") + return None + } + + val childExpr = exprToProtoInternal(expr.base, inputs, binding) + val mapSortScalarExpr = scalarFunctionExprToProto("map_sort", childExpr) + optExprWithInfo(mapSortScalarExpr, expr, expr.children: _*) + case wb: WidthBucket => withInfo( wb, From 25539cecdac53e33f365c5e83363a99c467b31f8 Mon Sep 17 00:00:00 2001 From: shekharrajak Date: Mon, 6 Apr 2026 16:46:58 +0530 Subject: [PATCH 06/11] fix: preserve field metadata and handle nulls in map_sort native impl --- native/spark-expr/src/map_funcs/map_sort.rs | 145 +++++++------------- 1 file changed, 49 insertions(+), 96 deletions(-) diff --git a/native/spark-expr/src/map_funcs/map_sort.rs b/native/spark-expr/src/map_funcs/map_sort.rs index 4841cf50a6..115bb0f6f1 100644 --- a/native/spark-expr/src/map_funcs/map_sort.rs +++ b/native/spark-expr/src/map_funcs/map_sort.rs @@ -15,13 +15,14 @@ // specific language governing permissions and limitations // under the License. -use arrow::array::{Array, ArrayRef, MapArray}; +use arrow::array::{Array, ArrayRef, MapArray, StructArray}; use arrow::compute::{lexsort_to_indices, take, SortColumn, SortOptions}; -use arrow::datatypes::{DataType, Field}; +use arrow::datatypes::DataType; use datafusion::common::{exec_err, DataFusionError, ScalarValue}; use datafusion::logical_expr::ColumnarValue; use std::sync::Arc; +/// Sorts map entries by key in ascending order, matching Spark's MapSort semantics. pub fn spark_map_sort(args: &[ColumnarValue]) -> Result { if args.len() != 1 { return exec_err!("map_sort function takes exactly one argument"); @@ -29,38 +30,37 @@ pub fn spark_map_sort(args: &[ColumnarValue]) -> Result { - let result = spark_map_sort_array(array)?; + let result = sort_map_array(array)?; Ok(ColumnarValue::Array(result)) } - ColumnarValue::Scalar(scalar) => { - let result = spark_map_sort_scalar(scalar)?; - Ok(ColumnarValue::Scalar(result)) + ColumnarValue::Scalar(ScalarValue::Null) => Ok(ColumnarValue::Scalar(ScalarValue::Null)), + ColumnarValue::Scalar(other) => { + exec_err!("map_sort expects a map type, got: {:?}", other) } } } -fn spark_map_sort_array(array: &ArrayRef) -> Result { +fn sort_map_array(array: &ArrayRef) -> Result { let map_array = array .as_any() .downcast_ref::() .ok_or_else(|| DataFusionError::Internal("Expected MapArray".to_string()))?; - let entries = map_array.entries(); - let struct_array = entries - .as_any() - .downcast_ref::() - .ok_or_else(|| DataFusionError::Internal("Expected StructArray for entries".to_string()))?; - - if struct_array.num_columns() != 2 { - return exec_err!("Map entries must have exactly 2 columns (keys and values)"); - } + // Extract the entries field from the input data type to preserve field names and metadata + let (entries_field, ordered) = match map_array.data_type() { + DataType::Map(field, ordered) => (Arc::clone(field), *ordered), + other => { + return exec_err!("Expected Map data type, got: {:?}", other); + } + }; - let keys = struct_array.column(0); - let values = struct_array.column(1); + let entries = map_array.entries(); + let keys = entries.column(0); + let values = entries.column(1); let offsets = map_array.offsets(); - let mut sorted_keys_arrays = Vec::new(); - let mut sorted_values_arrays = Vec::new(); + let mut sorted_keys_arrays: Vec = Vec::new(); + let mut sorted_values_arrays: Vec = Vec::new(); let mut new_offsets = Vec::with_capacity(map_array.len() + 1); new_offsets.push(0i32); @@ -69,8 +69,8 @@ fn spark_map_sort_array(array: &ArrayRef) -> Result { let end = offsets[row_idx + 1] as usize; let len = end - start; - if len == 0 { - new_offsets.push(new_offsets[row_idx]); + if len == 0 || map_array.is_null(row_idx) { + new_offsets.push(*new_offsets.last().unwrap()); continue; } @@ -80,7 +80,7 @@ fn spark_map_sort_array(array: &ArrayRef) -> Result { if len == 1 { sorted_keys_arrays.push(row_keys); sorted_values_arrays.push(row_values); - new_offsets.push(new_offsets[row_idx] + len as i32); + new_offsets.push(*new_offsets.last().unwrap() + len as i32); continue; } @@ -98,88 +98,41 @@ fn spark_map_sort_array(array: &ArrayRef) -> Result { sorted_keys_arrays.push(sorted_keys); sorted_values_arrays.push(sorted_values); - new_offsets.push(new_offsets[row_idx] + len as i32); + new_offsets.push(*new_offsets.last().unwrap() + len as i32); } - if sorted_keys_arrays.is_empty() { - let key_field = Arc::new(Field::new( - "key", - keys.data_type().clone(), - keys.is_nullable(), - )); - let value_field = Arc::new(Field::new( - "value", - values.data_type().clone(), - values.is_nullable(), - )); - let entries_field = Arc::new(Field::new( - "entries", - DataType::Struct(vec![Arc::clone(&key_field), Arc::clone(&value_field)].into()), - false, - )); - - let empty_keys = arrow::array::new_empty_array(keys.data_type()); - let empty_values = arrow::array::new_empty_array(values.data_type()); - let empty_entries = arrow::array::StructArray::new( - vec![key_field, value_field].into(), - vec![empty_keys, empty_values], - None, - ); - - return Ok(Arc::new(MapArray::new( - entries_field, - arrow::buffer::OffsetBuffer::new(vec![0i32; map_array.len() + 1].into()), - empty_entries, - map_array.nulls().cloned(), - false, - ))); - } - - let sorted_keys_refs: Vec<&dyn Array> = sorted_keys_arrays.iter().map(|a| a.as_ref()).collect(); - let sorted_values_refs: Vec<&dyn Array> = - sorted_values_arrays.iter().map(|a| a.as_ref()).collect(); - - let concatenated_keys = arrow::compute::concat(&sorted_keys_refs)?; - let concatenated_values = arrow::compute::concat(&sorted_values_refs)?; - - let key_field = Arc::new(Field::new( - "key", - keys.data_type().clone(), - keys.is_nullable(), - )); - let value_field = Arc::new(Field::new( - "value", - values.data_type().clone(), - values.is_nullable(), - )); - - let sorted_entries = arrow::array::StructArray::new( - vec![Arc::clone(&key_field), Arc::clone(&value_field)].into(), - vec![concatenated_keys, concatenated_values], + // Reconstruct using the original entries field to preserve field names and metadata + let (key_field, value_field) = map_array.entries_fields(); + let key_field = Arc::new(key_field.clone()); + let value_field = Arc::new(value_field.clone()); + + let (sorted_keys_col, sorted_values_col) = if sorted_keys_arrays.is_empty() { + ( + arrow::array::new_empty_array(keys.data_type()), + arrow::array::new_empty_array(values.data_type()), + ) + } else { + let keys_refs: Vec<&dyn Array> = + sorted_keys_arrays.iter().map(|a| a.as_ref()).collect(); + let values_refs: Vec<&dyn Array> = + sorted_values_arrays.iter().map(|a| a.as_ref()).collect(); + ( + arrow::compute::concat(&keys_refs)?, + arrow::compute::concat(&values_refs)?, + ) + }; + + let sorted_entries = StructArray::new( + vec![key_field, value_field].into(), + vec![sorted_keys_col, sorted_values_col], None, ); - let entries_field = Arc::new(Field::new( - "entries", - DataType::Struct(vec![key_field, value_field].into()), - false, - )); - Ok(Arc::new(MapArray::new( entries_field, arrow::buffer::OffsetBuffer::new(new_offsets.into()), sorted_entries, map_array.nulls().cloned(), - false, + ordered, ))) } - -fn spark_map_sort_scalar(scalar: &ScalarValue) -> Result { - match scalar { - ScalarValue::Null => Ok(ScalarValue::Null), - _ => exec_err!( - "map_sort scalar function only supports map types, got: {:?}", - scalar - ), - } -} From 5aca5ac7af2cec9178a5b9253d61c6fe6e844f9b Mon Sep 17 00:00:00 2001 From: shekharrajak Date: Mon, 6 Apr 2026 16:47:03 +0530 Subject: [PATCH 07/11] fix: add Spark 4.0 version guards to map_sort tests and remove unused import --- spark/src/main/scala/org/apache/comet/serde/maps.scala | 2 +- .../scala/org/apache/comet/CometMapExpressionSuite.scala | 9 +++++++++ 2 files changed, 10 insertions(+), 1 deletion(-) diff --git a/spark/src/main/scala/org/apache/comet/serde/maps.scala b/spark/src/main/scala/org/apache/comet/serde/maps.scala index b038fcf954..767f6c2ba1 100644 --- a/spark/src/main/scala/org/apache/comet/serde/maps.scala +++ b/spark/src/main/scala/org/apache/comet/serde/maps.scala @@ -23,7 +23,7 @@ import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.types._ import org.apache.comet.CometSparkSessionExtensions.withInfo -import org.apache.comet.serde.QueryPlanSerde.{createBinaryExpr, exprToProtoInternal, optExprWithInfo, scalarFunctionExprToProto, scalarFunctionExprToProtoWithReturnType} +import org.apache.comet.serde.QueryPlanSerde.{createBinaryExpr, exprToProtoInternal, optExprWithInfo, scalarFunctionExprToProto} object CometMapKeys extends CometExpressionSerde[MapKeys] { diff --git a/spark/src/test/scala/org/apache/comet/CometMapExpressionSuite.scala b/spark/src/test/scala/org/apache/comet/CometMapExpressionSuite.scala index afb1e96bbb..fc148f181e 100644 --- a/spark/src/test/scala/org/apache/comet/CometMapExpressionSuite.scala +++ b/spark/src/test/scala/org/apache/comet/CometMapExpressionSuite.scala @@ -27,6 +27,7 @@ import org.apache.spark.sql.functions._ import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types.BinaryType +import org.apache.comet.CometSparkSessionExtensions.isSpark40Plus import org.apache.comet.serde.CometMapFromEntries import org.apache.comet.testing.{DataGenOptions, ParquetGenerator, SchemaGenOptions} @@ -239,6 +240,7 @@ class CometMapExpressionSuite extends CometTestBase { test("map_sort with integer keys") { + assume(isSpark40Plus, "map_sort was added in Spark 4.0") withTempDir { dir => withTempView("t1") { val path = new Path(dir.toURI.toString, "test.parquet") @@ -255,6 +257,7 @@ class CometMapExpressionSuite extends CometTestBase { } test("map_sort with string keys") { + assume(isSpark40Plus, "map_sort was added in Spark 4.0") withTempDir { dir => withTempView("t1") { val path = new Path(dir.toURI.toString, "test.parquet") @@ -271,6 +274,7 @@ class CometMapExpressionSuite extends CometTestBase { } test("map_sort with double keys") { + assume(isSpark40Plus, "map_sort was added in Spark 4.0") withTempDir { dir => withTempView("t1") { val path = new Path(dir.toURI.toString, "test.parquet") @@ -287,6 +291,7 @@ class CometMapExpressionSuite extends CometTestBase { } test("map_sort with null and empty maps") { + assume(isSpark40Plus, "map_sort was added in Spark 4.0") withTempDir { dir => withTempView("t1") { val path = new Path(dir.toURI.toString, "test.parquet") @@ -308,6 +313,7 @@ class CometMapExpressionSuite extends CometTestBase { } test("map_sort with struct keys") { + assume(isSpark40Plus, "map_sort was added in Spark 4.0") withTempDir { dir => withTempView("t1") { val path = new Path(dir.toURI.toString, "test.parquet") @@ -331,6 +337,7 @@ class CometMapExpressionSuite extends CometTestBase { } test("map_sort with array keys") { + assume(isSpark40Plus, "map_sort was added in Spark 4.0") withTempDir { dir => withTempView("t1") { val path = new Path(dir.toURI.toString, "test.parquet") @@ -354,6 +361,7 @@ class CometMapExpressionSuite extends CometTestBase { } test("map_sort with complex values") { + assume(isSpark40Plus, "map_sort was added in Spark 4.0") withTempDir { dir => withTempView("t1") { val path = new Path(dir.toURI.toString, "test.parquet") @@ -377,6 +385,7 @@ class CometMapExpressionSuite extends CometTestBase { } test("map_sort fallback for non-orderable keys") { + assume(isSpark40Plus, "map_sort was added in Spark 4.0") withTempDir { dir => withTempView("t1") { val path = new Path(dir.toURI.toString, "test.parquet") From 60a7d4b35c18f65e7c52ae7bc671408adc9f92e9 Mon Sep 17 00:00:00 2001 From: shekharrajak Date: Mon, 6 Apr 2026 16:50:32 +0530 Subject: [PATCH 08/11] test: add map_sort tests for boolean, decimal, date, and timestamp keys --- .../comet/CometMapExpressionSuite.scala | 89 +++++++++++++++++++ 1 file changed, 89 insertions(+) diff --git a/spark/src/test/scala/org/apache/comet/CometMapExpressionSuite.scala b/spark/src/test/scala/org/apache/comet/CometMapExpressionSuite.scala index fc148f181e..45b522a89e 100644 --- a/spark/src/test/scala/org/apache/comet/CometMapExpressionSuite.scala +++ b/spark/src/test/scala/org/apache/comet/CometMapExpressionSuite.scala @@ -408,4 +408,93 @@ class CometMapExpressionSuite extends CometTestBase { } } + test("map_sort with boolean keys") { + assume(isSpark40Plus, "map_sort was added in Spark 4.0") + withTempDir { dir => + withTempView("t1") { + val path = new Path(dir.toURI.toString, "test.parquet") + withSQLConf(CometConf.COMET_ENABLED.key -> "false") { + val df = spark + .range(5) + .select(map(lit(true), lit("yes"), lit(false), lit("no")).alias("m")) + df.write.parquet(path.toString) + } + spark.read.parquet(path.toString).createOrReplaceTempView("t1") + checkSparkAnswerAndOperator(sql("SELECT map_sort(m) FROM t1")) + } + } + } + + test("map_sort with decimal keys") { + assume(isSpark40Plus, "map_sort was added in Spark 4.0") + withTempDir { dir => + withTempView("t1") { + val path = new Path(dir.toURI.toString, "test.parquet") + withSQLConf(CometConf.COMET_ENABLED.key -> "false") { + val df = spark + .range(5) + .select( + map( + lit(BigDecimal("3.14")), + lit("pi"), + lit(BigDecimal("1.41")), + lit("sqrt2"), + lit(BigDecimal("2.72")), + lit("e")).alias("m")) + df.write.parquet(path.toString) + } + spark.read.parquet(path.toString).createOrReplaceTempView("t1") + checkSparkAnswerAndOperator(sql("SELECT map_sort(m) FROM t1")) + } + } + } + + test("map_sort with date keys") { + assume(isSpark40Plus, "map_sort was added in Spark 4.0") + withTempDir { dir => + withTempView("t1") { + val path = new Path(dir.toURI.toString, "test.parquet") + withSQLConf(CometConf.COMET_ENABLED.key -> "false") { + val df = spark + .range(5) + .select( + map( + lit(java.sql.Date.valueOf("2024-03-15")), + lit("march"), + lit(java.sql.Date.valueOf("2024-01-10")), + lit("jan"), + lit(java.sql.Date.valueOf("2024-02-20")), + lit("feb")).alias("m")) + df.write.parquet(path.toString) + } + spark.read.parquet(path.toString).createOrReplaceTempView("t1") + checkSparkAnswerAndOperator(sql("SELECT map_sort(m) FROM t1")) + } + } + } + + test("map_sort with timestamp keys") { + assume(isSpark40Plus, "map_sort was added in Spark 4.0") + withTempDir { dir => + withTempView("t1") { + val path = new Path(dir.toURI.toString, "test.parquet") + withSQLConf(CometConf.COMET_ENABLED.key -> "false") { + val df = spark + .range(5) + .select( + map( + lit(java.sql.Timestamp.valueOf("2024-03-15 10:30:00")), + lit("third"), + lit(java.sql.Timestamp.valueOf("2024-01-10 08:00:00")), + lit("first"), + lit(java.sql.Timestamp.valueOf("2024-02-20 14:15:00")), + lit("second")).alias("m")) + df.write.parquet(path.toString) + } + spark.read.parquet(path.toString).createOrReplaceTempView("t1") + checkSparkAnswerAndOperator(sql("SELECT map_sort(m) FROM t1")) + } + } + } + } From ab68ca533ab7667d390fcf604ca80ca6acd8876b Mon Sep 17 00:00:00 2001 From: shekharrajak Date: Mon, 6 Apr 2026 16:50:37 +0530 Subject: [PATCH 09/11] test: add Rust unit tests for map_sort native implementation --- native/spark-expr/src/map_funcs/map_sort.rs | 247 +++++++++++++++++++- 1 file changed, 246 insertions(+), 1 deletion(-) diff --git a/native/spark-expr/src/map_funcs/map_sort.rs b/native/spark-expr/src/map_funcs/map_sort.rs index 115bb0f6f1..3fc8c8b568 100644 --- a/native/spark-expr/src/map_funcs/map_sort.rs +++ b/native/spark-expr/src/map_funcs/map_sort.rs @@ -22,7 +22,8 @@ use datafusion::common::{exec_err, DataFusionError, ScalarValue}; use datafusion::logical_expr::ColumnarValue; use std::sync::Arc; -/// Sorts map entries by key in ascending order, matching Spark's MapSort semantics. +/// Sorts map entries by key in ascending order, matching Spark's `MapSort` semantics. +/// Takes a single map argument and returns a new map with entries sorted by key. pub fn spark_map_sort(args: &[ColumnarValue]) -> Result { if args.len() != 1 { return exec_err!("map_sort function takes exactly one argument"); @@ -136,3 +137,247 @@ fn sort_map_array(array: &ArrayRef) -> Result { ordered, ))) } + +#[cfg(test)] +mod tests { + use super::*; + use arrow::array::{Int32Array, StringArray, StructArray}; + use arrow::array::NullBufferBuilder; + use arrow::buffer::OffsetBuffer; + use arrow::datatypes::{Field, Fields}; + use datafusion::logical_expr::ColumnarValue; + + fn make_map_array( + key_values: &[i32], + str_values: &[&str], + offsets: &[i32], + nulls: Option>, + ) -> ArrayRef { + let keys = Arc::new(Int32Array::from(key_values.to_vec())) as ArrayRef; + let values = Arc::new(StringArray::from( + str_values.iter().map(|s| Some(*s)).collect::>(), + )) as ArrayRef; + + let key_field = Arc::new(Field::new("key", DataType::Int32, false)); + let value_field = Arc::new(Field::new("value", DataType::Utf8, true)); + + let entries = StructArray::new( + Fields::from(vec![Arc::clone(&key_field), Arc::clone(&value_field)]), + vec![keys, values], + None, + ); + + let entries_field = Arc::new(Field::new( + "entries", + DataType::Struct(entries.fields().clone()), + false, + )); + + let null_buffer = nulls.map(|n| { + let mut builder = NullBufferBuilder::new(n.len()); + for v in n { + builder.append(v); + } + builder.finish().unwrap() + }); + + Arc::new(MapArray::new( + entries_field, + OffsetBuffer::new(offsets.to_vec().into()), + entries, + null_buffer, + false, + )) + } + + fn get_sorted_keys(result: &ArrayRef) -> Vec { + let map = result.as_any().downcast_ref::().unwrap(); + let entries = map.entries(); + let keys = entries.column(0).as_any().downcast_ref::().unwrap(); + keys.values().to_vec() + } + + fn get_sorted_values(result: &ArrayRef) -> Vec { + let map = result.as_any().downcast_ref::().unwrap(); + let entries = map.entries(); + let vals = entries.column(1).as_any().downcast_ref::().unwrap(); + (0..vals.len()).map(|i| vals.value(i).to_string()).collect() + } + + #[test] + fn test_sort_integer_keys() { + let array = make_map_array(&[3, 1, 2], &["c", "a", "b"], &[0, 3], None); + let args = vec![ColumnarValue::Array(array)]; + let result = spark_map_sort(&args).unwrap(); + + match result { + ColumnarValue::Array(ref arr) => { + assert_eq!(get_sorted_keys(arr), vec![1, 2, 3]); + assert_eq!(get_sorted_values(arr), vec!["a", "b", "c"]); + } + _ => panic!("Expected array result"), + } + } + + #[test] + fn test_sort_multiple_rows() { + // Row 0: {3->c, 1->a}, Row 1: {5->e, 4->d} + let array = make_map_array( + &[3, 1, 5, 4], + &["c", "a", "e", "d"], + &[0, 2, 4], + None, + ); + let args = vec![ColumnarValue::Array(array)]; + let result = spark_map_sort(&args).unwrap(); + + match result { + ColumnarValue::Array(ref arr) => { + assert_eq!(get_sorted_keys(arr), vec![1, 3, 4, 5]); + assert_eq!(get_sorted_values(arr), vec!["a", "c", "d", "e"]); + } + _ => panic!("Expected array result"), + } + } + + #[test] + fn test_sort_with_empty_map() { + // Row 0: {2->b, 1->a}, Row 1: empty, Row 2: {3->c} + let array = make_map_array( + &[2, 1, 3], + &["b", "a", "c"], + &[0, 2, 2, 3], + None, + ); + let args = vec![ColumnarValue::Array(array)]; + let result = spark_map_sort(&args).unwrap(); + + match result { + ColumnarValue::Array(ref arr) => { + let map = arr.as_any().downcast_ref::().unwrap(); + assert_eq!(map.len(), 3); + assert_eq!(get_sorted_keys(arr), vec![1, 2, 3]); + // Verify offsets: row 0 has 2 entries, row 1 has 0, row 2 has 1 + let offsets = map.offsets(); + assert_eq!(offsets.as_ref(), &[0, 2, 2, 3]); + } + _ => panic!("Expected array result"), + } + } + + #[test] + fn test_sort_with_null_map() { + // Row 0: {2->b, 1->a}, Row 1: null + let array = make_map_array( + &[2, 1], + &["b", "a"], + &[0, 2, 2], + Some(vec![true, false]), + ); + let args = vec![ColumnarValue::Array(array)]; + let result = spark_map_sort(&args).unwrap(); + + match result { + ColumnarValue::Array(ref arr) => { + let map = arr.as_any().downcast_ref::().unwrap(); + assert_eq!(map.len(), 2); + assert!(!map.is_null(0)); + assert!(map.is_null(1)); + assert_eq!(get_sorted_keys(arr), vec![1, 2]); + } + _ => panic!("Expected array result"), + } + } + + #[test] + fn test_sort_single_entry_map() { + let array = make_map_array(&[42], &["answer"], &[0, 1], None); + let args = vec![ColumnarValue::Array(array)]; + let result = spark_map_sort(&args).unwrap(); + + match result { + ColumnarValue::Array(ref arr) => { + assert_eq!(get_sorted_keys(arr), vec![42]); + assert_eq!(get_sorted_values(arr), vec!["answer"]); + } + _ => panic!("Expected array result"), + } + } + + #[test] + fn test_sort_all_empty_maps() { + // Two empty maps + let array = make_map_array(&[], &[], &[0, 0, 0], None); + let args = vec![ColumnarValue::Array(array)]; + let result = spark_map_sort(&args).unwrap(); + + match result { + ColumnarValue::Array(ref arr) => { + let map = arr.as_any().downcast_ref::().unwrap(); + assert_eq!(map.len(), 2); + let offsets = map.offsets(); + assert_eq!(offsets.as_ref(), &[0, 0, 0]); + } + _ => panic!("Expected array result"), + } + } + + #[test] + fn test_sort_preserves_field_names() { + // Use custom field names to verify preservation + let keys = Arc::new(Int32Array::from(vec![2, 1])) as ArrayRef; + let values = Arc::new(StringArray::from(vec!["b", "a"])) as ArrayRef; + let key_field = Arc::new(Field::new("my_key", DataType::Int32, false)); + let value_field = Arc::new(Field::new("my_value", DataType::Utf8, true)); + let entries = StructArray::new( + Fields::from(vec![Arc::clone(&key_field), Arc::clone(&value_field)]), + vec![keys, values], + None, + ); + let entries_field = Arc::new(Field::new( + "my_entries", + DataType::Struct(entries.fields().clone()), + false, + )); + let map = MapArray::new( + entries_field, + OffsetBuffer::new(vec![0, 2].into()), + entries, + None, + false, + ); + let array: ArrayRef = Arc::new(map); + let args = vec![ColumnarValue::Array(array)]; + let result = spark_map_sort(&args).unwrap(); + + match result { + ColumnarValue::Array(ref arr) => { + let map = arr.as_any().downcast_ref::().unwrap(); + let (kf, vf) = map.entries_fields(); + assert_eq!(kf.name(), "my_key"); + assert_eq!(vf.name(), "my_value"); + } + _ => panic!("Expected array result"), + } + } + + #[test] + fn test_sort_null_scalar() { + let args = vec![ColumnarValue::Scalar(ScalarValue::Null)]; + let result = spark_map_sort(&args).unwrap(); + match result { + ColumnarValue::Scalar(ScalarValue::Null) => {} + _ => panic!("Expected null scalar result"), + } + } + + #[test] + fn test_wrong_arg_count() { + let array = make_map_array(&[1], &["a"], &[0, 1], None); + let args = vec![ + ColumnarValue::Array(array.clone()), + ColumnarValue::Array(array), + ]; + assert!(spark_map_sort(&args).is_err()); + } +} From 0a7528dc7bdd3a8ce9e243a8f637e7c3563fa681 Mon Sep 17 00:00:00 2001 From: shekharrajak Date: Mon, 6 Apr 2026 16:52:33 +0530 Subject: [PATCH 10/11] style: apply cargo fmt and scalafmt formatting --- native/spark-expr/src/map_funcs/map_sort.rs | 38 ++++++++----------- .../comet/CometMapExpressionSuite.scala | 32 +++++++--------- 2 files changed, 29 insertions(+), 41 deletions(-) diff --git a/native/spark-expr/src/map_funcs/map_sort.rs b/native/spark-expr/src/map_funcs/map_sort.rs index 3fc8c8b568..ebc2903c5b 100644 --- a/native/spark-expr/src/map_funcs/map_sort.rs +++ b/native/spark-expr/src/map_funcs/map_sort.rs @@ -113,8 +113,7 @@ fn sort_map_array(array: &ArrayRef) -> Result { arrow::array::new_empty_array(values.data_type()), ) } else { - let keys_refs: Vec<&dyn Array> = - sorted_keys_arrays.iter().map(|a| a.as_ref()).collect(); + let keys_refs: Vec<&dyn Array> = sorted_keys_arrays.iter().map(|a| a.as_ref()).collect(); let values_refs: Vec<&dyn Array> = sorted_values_arrays.iter().map(|a| a.as_ref()).collect(); ( @@ -141,8 +140,8 @@ fn sort_map_array(array: &ArrayRef) -> Result { #[cfg(test)] mod tests { use super::*; - use arrow::array::{Int32Array, StringArray, StructArray}; use arrow::array::NullBufferBuilder; + use arrow::array::{Int32Array, StringArray, StructArray}; use arrow::buffer::OffsetBuffer; use arrow::datatypes::{Field, Fields}; use datafusion::logical_expr::ColumnarValue; @@ -193,14 +192,22 @@ mod tests { fn get_sorted_keys(result: &ArrayRef) -> Vec { let map = result.as_any().downcast_ref::().unwrap(); let entries = map.entries(); - let keys = entries.column(0).as_any().downcast_ref::().unwrap(); + let keys = entries + .column(0) + .as_any() + .downcast_ref::() + .unwrap(); keys.values().to_vec() } fn get_sorted_values(result: &ArrayRef) -> Vec { let map = result.as_any().downcast_ref::().unwrap(); let entries = map.entries(); - let vals = entries.column(1).as_any().downcast_ref::().unwrap(); + let vals = entries + .column(1) + .as_any() + .downcast_ref::() + .unwrap(); (0..vals.len()).map(|i| vals.value(i).to_string()).collect() } @@ -222,12 +229,7 @@ mod tests { #[test] fn test_sort_multiple_rows() { // Row 0: {3->c, 1->a}, Row 1: {5->e, 4->d} - let array = make_map_array( - &[3, 1, 5, 4], - &["c", "a", "e", "d"], - &[0, 2, 4], - None, - ); + let array = make_map_array(&[3, 1, 5, 4], &["c", "a", "e", "d"], &[0, 2, 4], None); let args = vec![ColumnarValue::Array(array)]; let result = spark_map_sort(&args).unwrap(); @@ -243,12 +245,7 @@ mod tests { #[test] fn test_sort_with_empty_map() { // Row 0: {2->b, 1->a}, Row 1: empty, Row 2: {3->c} - let array = make_map_array( - &[2, 1, 3], - &["b", "a", "c"], - &[0, 2, 2, 3], - None, - ); + let array = make_map_array(&[2, 1, 3], &["b", "a", "c"], &[0, 2, 2, 3], None); let args = vec![ColumnarValue::Array(array)]; let result = spark_map_sort(&args).unwrap(); @@ -268,12 +265,7 @@ mod tests { #[test] fn test_sort_with_null_map() { // Row 0: {2->b, 1->a}, Row 1: null - let array = make_map_array( - &[2, 1], - &["b", "a"], - &[0, 2, 2], - Some(vec![true, false]), - ); + let array = make_map_array(&[2, 1], &["b", "a"], &[0, 2, 2], Some(vec![true, false])); let args = vec![ColumnarValue::Array(array)]; let result = spark_map_sort(&args).unwrap(); diff --git a/spark/src/test/scala/org/apache/comet/CometMapExpressionSuite.scala b/spark/src/test/scala/org/apache/comet/CometMapExpressionSuite.scala index 45b522a89e..e2a276d8c4 100644 --- a/spark/src/test/scala/org/apache/comet/CometMapExpressionSuite.scala +++ b/spark/src/test/scala/org/apache/comet/CometMapExpressionSuite.scala @@ -237,8 +237,6 @@ class CometMapExpressionSuite extends CometTestBase { } } - - test("map_sort with integer keys") { assume(isSpark40Plus, "map_sort was added in Spark 4.0") withTempDir { dir => @@ -457,14 +455,13 @@ class CometMapExpressionSuite extends CometTestBase { withSQLConf(CometConf.COMET_ENABLED.key -> "false") { val df = spark .range(5) - .select( - map( - lit(java.sql.Date.valueOf("2024-03-15")), - lit("march"), - lit(java.sql.Date.valueOf("2024-01-10")), - lit("jan"), - lit(java.sql.Date.valueOf("2024-02-20")), - lit("feb")).alias("m")) + .select(map( + lit(java.sql.Date.valueOf("2024-03-15")), + lit("march"), + lit(java.sql.Date.valueOf("2024-01-10")), + lit("jan"), + lit(java.sql.Date.valueOf("2024-02-20")), + lit("feb")).alias("m")) df.write.parquet(path.toString) } spark.read.parquet(path.toString).createOrReplaceTempView("t1") @@ -481,14 +478,13 @@ class CometMapExpressionSuite extends CometTestBase { withSQLConf(CometConf.COMET_ENABLED.key -> "false") { val df = spark .range(5) - .select( - map( - lit(java.sql.Timestamp.valueOf("2024-03-15 10:30:00")), - lit("third"), - lit(java.sql.Timestamp.valueOf("2024-01-10 08:00:00")), - lit("first"), - lit(java.sql.Timestamp.valueOf("2024-02-20 14:15:00")), - lit("second")).alias("m")) + .select(map( + lit(java.sql.Timestamp.valueOf("2024-03-15 10:30:00")), + lit("third"), + lit(java.sql.Timestamp.valueOf("2024-01-10 08:00:00")), + lit("first"), + lit(java.sql.Timestamp.valueOf("2024-02-20 14:15:00")), + lit("second")).alias("m")) df.write.parquet(path.toString) } spark.read.parquet(path.toString).createOrReplaceTempView("t1") From c40332fc6dc690da6b45a8c3ba6eacb0dad87b37 Mon Sep 17 00:00:00 2001 From: shekharrajak Date: Mon, 6 Apr 2026 20:30:37 +0530 Subject: [PATCH 11/11] fix: correct RowOrdering import path for Spark 4.0 --- .../main/spark-4.0/org/apache/comet/shims/CometExprShim.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/spark/src/main/spark-4.0/org/apache/comet/shims/CometExprShim.scala b/spark/src/main/spark-4.0/org/apache/comet/shims/CometExprShim.scala index 86e0a882b9..1ccc023ec4 100644 --- a/spark/src/main/spark-4.0/org/apache/comet/shims/CometExprShim.scala +++ b/spark/src/main/spark-4.0/org/apache/comet/shims/CometExprShim.scala @@ -20,7 +20,7 @@ package org.apache.comet.shims import org.apache.spark.sql.catalyst.expressions._ -import org.apache.spark.sql.catalyst.expressions.objects.{RowOrdering, StaticInvoke} +import org.apache.spark.sql.catalyst.expressions.objects.StaticInvoke import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.internal.types.StringTypeWithCollation import org.apache.spark.sql.types.{BinaryType, BooleanType, DataTypes, MapType, StringType}