From ae6237f787be6c7d089d1be42153c951242f47d8 Mon Sep 17 00:00:00 2001 From: Andy Grove Date: Fri, 24 Apr 2026 17:44:36 -0600 Subject: [PATCH 1/5] feat: add MapSort expression support for Spark 4.0 Add native map_sort scalar function that sorts map entries by key in ascending order, and wire it up via the Spark 4.0 CometExprShim so that MapSort expressions are accelerated instead of falling back to Spark. Re-enable all CometColumnarShuffleSuite map tests that were skipped for Spark 4.0. Closes #1941 Co-Authored-By: Claude Opus 4.6 --- native/spark-expr/src/comet_scalar_funcs.rs | 5 + native/spark-expr/src/lib.rs | 1 + native/spark-expr/src/map_funcs/map_sort.rs | 707 ++++++++++++++++++ native/spark-expr/src/map_funcs/mod.rs | 19 + .../apache/comet/shims/CometExprShim.scala | 5 + .../exec/CometColumnarShuffleSuite.scala | 65 +- 6 files changed, 742 insertions(+), 60 deletions(-) create mode 100644 native/spark-expr/src/map_funcs/map_sort.rs create mode 100644 native/spark-expr/src/map_funcs/mod.rs diff --git a/native/spark-expr/src/comet_scalar_funcs.rs b/native/spark-expr/src/comet_scalar_funcs.rs index a3ea0e6816..0957868a60 100644 --- a/native/spark-expr/src/comet_scalar_funcs.rs +++ b/native/spark-expr/src/comet_scalar_funcs.rs @@ -16,6 +16,7 @@ // under the License. use crate::hash_funcs::*; +use crate::map_funcs::spark_map_sort; use crate::math_funcs::abs::abs; use crate::math_funcs::checked_arithmetic::{checked_add, checked_div, checked_mul, checked_sub}; use crate::math_funcs::log::spark_log; @@ -191,6 +192,10 @@ pub fn create_comet_physical_fun_with_eval_mode( let func = Arc::new(crate::string_funcs::spark_get_json_object); make_comet_scalar_udf!("get_json_object", func, without data_type) } + "map_sort" => { + let func = Arc::new(spark_map_sort); + make_comet_scalar_udf!("spark_map_sort", func, without data_type) + } _ => registry.udf(fun_name).map_err(|e| { DataFusionError::Execution(format!( "Function {fun_name} not found in the registry: {e}", diff --git a/native/spark-expr/src/lib.rs b/native/spark-expr/src/lib.rs index 963ce62e2b..42bfb80e86 100644 --- a/native/spark-expr/src/lib.rs +++ b/native/spark-expr/src/lib.rs @@ -57,6 +57,7 @@ pub use bloom_filter::{BloomFilterAgg, BloomFilterMightContain}; mod conditional_funcs; mod conversion_funcs; +mod map_funcs; mod math_funcs; mod nondetermenistic_funcs; diff --git a/native/spark-expr/src/map_funcs/map_sort.rs b/native/spark-expr/src/map_funcs/map_sort.rs new file mode 100644 index 0000000000..a5253aeefa --- /dev/null +++ b/native/spark-expr/src/map_funcs/map_sort.rs @@ -0,0 +1,707 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use arrow::array::{Array, ArrayRef, MapArray, StructArray}; +use arrow::compute::{concat, sort_to_indices, take, SortOptions}; +use arrow::datatypes::DataType; +use datafusion::common::{exec_err, DataFusionError}; +use datafusion::physical_plan::ColumnarValue; +use std::sync::Arc; + +/// Spark compatible `MapSort` implementation. +/// Sorts each entries of a MapArray by keys in ascending order without changing the ordering of the +/// maps in the array. +pub fn spark_map_sort(args: &[ColumnarValue]) -> Result { + if args.len() != 1 { + return exec_err!("spark_map_sort expects exactly one argument"); + } + + let arr_arg: ArrayRef = match &args[0] { + ColumnarValue::Array(array) => Arc::::clone(array), + ColumnarValue::Scalar(scalar) => scalar.to_array_of_size(1)?, + }; + + let (maps_arg, map_field, is_sorted) = match arr_arg.data_type() { + DataType::Map(map_field, is_sorted) => { + let maps_arg = arr_arg.as_any().downcast_ref::().unwrap(); + (maps_arg, map_field, is_sorted) + } + _ => return exec_err!("spark_map_sort expects Map type as argument"), + }; + + let maps_arg_entries = maps_arg.entries(); + let maps_arg_offsets = maps_arg.offsets(); + + let mut sorted_map_entries_vec: Vec = Vec::with_capacity(maps_arg.len()); + + for idx in 0..maps_arg.len() { + let map_start = maps_arg_offsets[idx] as usize; + let map_end = maps_arg_offsets[idx + 1] as usize; + let map_len = map_end - map_start; + + let map_entries = maps_arg_entries.slice(map_start, map_len); + + if map_len == 0 { + sorted_map_entries_vec.push(Arc::new(map_entries)); + continue; + } + + let map_keys = map_entries.column(0); + let sort_options = SortOptions { + descending: false, + nulls_first: true, + }; + let sorted_indices = sort_to_indices(&map_keys, Some(sort_options), None)?; + + let sorted_map_entries = take(&map_entries, &sorted_indices, None)?; + sorted_map_entries_vec.push(sorted_map_entries); + } + + let sorted_map_entries_arr: Vec<&dyn Array> = sorted_map_entries_vec + .iter() + .map(|arr| arr.as_ref()) + .collect(); + let combined_sorted_map_entries = concat(&sorted_map_entries_arr)?; + let sorted_map_struct = combined_sorted_map_entries + .as_any() + .downcast_ref::() + .unwrap(); + + // Preserve the original is_sorted flag to keep schema consistent + let sorted_map_arr = Arc::new(MapArray::try_new( + Arc::::clone(map_field), + maps_arg.offsets().clone(), + sorted_map_struct.clone(), + maps_arg.nulls().cloned(), + *is_sorted, + )?); + + Ok(ColumnarValue::Array(sorted_map_arr)) +} + +#[cfg(test)] +mod tests { + use super::*; + use arrow::array::builder::{Int32Builder, MapBuilder, StringBuilder}; + use arrow::array::{Int32Array, ListArray, ListBuilder, MapFieldNames, StringArray}; + use datafusion::common::ScalarValue; + use std::sync::Arc; + + macro_rules! build_map { + ( + $key_builder:expr, + $value_builder:expr, + $keys:expr, + $values:expr, + $validity:expr, + $entries_builder_fn:ident + ) => {{ + let mut map_builder = MapBuilder::new( + Some(MapFieldNames { + entry: "entries".into(), + key: "key".into(), + value: "value".into(), + }), + $key_builder, + $value_builder, + ); + + assert_eq!($keys.len(), $values.len()); + assert_eq!($keys.len(), $validity.len()); + + let total_maps = $keys.len(); + for map_idx in 0..total_maps { + let map_keys = &$keys[map_idx]; + let map_values = &$values[map_idx]; + assert_eq!(map_keys.len(), map_values.len()); + + let map_entries = map_keys.len(); + for entry_idx in 0..map_entries { + let map_key = &map_keys[entry_idx]; + let map_value = &map_values[entry_idx]; + $entries_builder_fn!(map_builder, map_key, map_value); + } + + let is_valid = $validity[map_idx]; + map_builder.append(is_valid).unwrap(); + } + + map_builder.finish() + }}; + } + + macro_rules! default_map_entries_builder { + ($map_builder:expr, $key:expr, $value:expr) => {{ + $map_builder.keys().append_value($key.clone()); + $map_builder.values().append_value($value.clone().unwrap()); + }}; + } + + macro_rules! nested_map_entries_builder { + ($map_builder:expr, $key:expr, $value:expr) => {{ + $map_builder.keys().append_value($key.clone()); + + let inner_map_builder = $map_builder.values(); + + let (inner_keys, inner_values, inner_valid) = $value; + assert_eq!(inner_keys.len(), inner_values.len()); + + let inner_entries = inner_keys.len(); + for inner_idx in 0..inner_entries { + let inner_key_val = &inner_keys[inner_idx]; + let inner_value = &inner_values[inner_idx]; + default_map_entries_builder!(inner_map_builder, inner_key_val, inner_value); + } + + inner_map_builder.append(*inner_valid).unwrap(); + }}; + } + + macro_rules! verify_result { + ( + $key_type:ty, + $value_type:ty, + $result:expr, + $expected_map_arr:expr, + $verify_entries_fn:ident + ) => {{ + match $result { + ColumnarValue::Array(actual_arr) => { + let actual_map_arr = actual_arr.as_any().downcast_ref::().unwrap(); + + assert_eq!(actual_map_arr.len(), $expected_map_arr.len()); + assert_eq!(actual_map_arr.offsets(), $expected_map_arr.offsets()); + assert_eq!(actual_map_arr.nulls(), $expected_map_arr.nulls()); + assert_eq!(actual_map_arr.data_type(), $expected_map_arr.data_type()); + + let actual_entries = actual_map_arr.entries(); + let actual_keys = actual_entries + .column(0) + .as_any() + .downcast_ref::<$key_type>() + .unwrap(); + let actual_values = actual_entries + .column(1) + .as_any() + .downcast_ref::<$value_type>() + .unwrap(); + + let expected_entries = $expected_map_arr.entries(); + let expected_keys = expected_entries + .column(0) + .as_any() + .downcast_ref::<$key_type>() + .unwrap(); + let expected_values = expected_entries + .column(1) + .as_any() + .downcast_ref::<$value_type>() + .unwrap(); + + assert_eq!(actual_keys.len(), expected_keys.len()); + assert_eq!(actual_values.len(), expected_values.len()); + + $verify_entries_fn!( + expected_entries.len(), + actual_keys, + expected_keys, + actual_values, + expected_values + ); + } + unexpected_arr => { + panic!("Actual result: {unexpected_arr:?} is not an Array ColumnarValue") + } + } + }}; + } + + macro_rules! default_entries_verifier { + ( + $entries_len:expr, + $actual_keys:expr, + $expected_keys:expr, + $actual_values:expr, + $expected_values:expr + ) => {{ + for idx in 0..$entries_len { + assert_eq!($actual_keys.value(idx), $expected_keys.value(idx)); + assert_eq!($actual_values.value(idx), $expected_values.value(idx)); + } + }}; + } + + macro_rules! list_entries_verifier { + ( + $entries_len:expr, + $actual_keys:expr, + $expected_keys:expr, + $actual_values:expr, + $expected_values:expr + ) => {{ + for idx in 0..$entries_len { + let actual_list = $actual_keys.value(idx); + let expected_list = $expected_keys.value(idx); + assert!(actual_list.eq(&expected_list)); + assert_eq!($actual_values.value(idx), $expected_values.value(idx)); + } + }}; + } + + #[test] + fn test_map_sort_with_string_keys() { + let keys_arg: [Vec; 4] = [ + vec!["c".into(), "a".into(), "b".into()], + vec!["z".into(), "y".into(), "x".into()], + vec!["a".into(), "b".into(), "c".into()], + vec!["fusion".into(), "comet".into(), "data".into()], + ]; + let values_arg = [ + vec![Some(3), Some(1), Some(2)], + vec![Some(30), Some(20), Some(10)], + vec![Some(1), Some(2), Some(3)], + vec![Some(300), Some(100), Some(200)], + ]; + let validity = [true, true, true, true]; + + let map_arr_arg = build_map!( + StringBuilder::new(), + Int32Builder::new(), + keys_arg, + values_arg, + validity, + default_map_entries_builder + ); + let args = vec![ColumnarValue::Array(Arc::new(map_arr_arg))]; + let result = spark_map_sort(&args).unwrap(); + + let expected_keys: [Vec; 4] = [ + vec!["a".into(), "b".into(), "c".into()], + vec!["x".into(), "y".into(), "z".into()], + vec!["a".into(), "b".into(), "c".into()], + vec!["comet".into(), "data".into(), "fusion".into()], + ]; + let expected_values = [ + vec![Some(1), Some(2), Some(3)], + vec![Some(10), Some(20), Some(30)], + vec![Some(1), Some(2), Some(3)], + vec![Some(100), Some(200), Some(300)], + ]; + let expected_validity = [true, true, true, true]; + + let expected_map_arr = build_map!( + StringBuilder::new(), + Int32Builder::new(), + expected_keys, + expected_values, + expected_validity, + default_map_entries_builder + ); + verify_result!( + StringArray, + Int32Array, + result, + expected_map_arr, + default_entries_verifier + ); + } + + #[test] + fn test_map_sort_with_int_keys() { + let keys_arg = [ + vec![3, 2, 1], + vec![100, 50, 20], + vec![20, 50, 100], + vec![-5, 0, -1], + ]; + let values_arg: [Vec>; 4] = [ + vec![Some("three".into()), Some("two".into()), Some("one".into())], + vec![ + Some("hundred".into()), + Some("fifty".into()), + Some("twenty".into()), + ], + vec![ + Some("twenty".into()), + Some("fifty".into()), + Some("hundred".into()), + ], + vec![ + Some("minus five".into()), + Some("zero".into()), + Some("minus one".into()), + ], + ]; + let validity = [true, true, true, true]; + + let map_arr_arg = build_map!( + Int32Builder::new(), + StringBuilder::new(), + keys_arg, + values_arg, + validity, + default_map_entries_builder + ); + let args = vec![ColumnarValue::Array(Arc::new(map_arr_arg))]; + let result = spark_map_sort(&args).unwrap(); + + let expected_keys = [ + vec![1, 2, 3], + vec![20, 50, 100], + vec![20, 50, 100], + vec![-5, -1, 0], + ]; + let expected_values: [Vec>; 4] = [ + vec![Some("one".into()), Some("two".into()), Some("three".into())], + vec![ + Some("twenty".into()), + Some("fifty".into()), + Some("hundred".into()), + ], + vec![ + Some("twenty".into()), + Some("fifty".into()), + Some("hundred".into()), + ], + vec![ + Some("minus five".into()), + Some("minus one".into()), + Some("zero".into()), + ], + ]; + let expected_validity = [true, true, true, true]; + + let expected_map_arr = build_map!( + Int32Builder::new(), + StringBuilder::new(), + expected_keys, + expected_values, + expected_validity, + default_map_entries_builder + ); + verify_result!( + Int32Array, + StringArray, + result, + expected_map_arr, + default_entries_verifier + ); + } + + #[test] + fn test_map_sort_with_nested_maps() { + let outer_keys: [String; 2] = ["outer_k2".into(), "outer_k1".into()]; + let inner_keys: [[String; 2]; 2] = [ + ["outer_k2->inner_k1".into(), "outer_k2->inner_k2".into()], + ["outer_k1->inner_k1".into(), "outer_k1->inner_k2".into()], + ]; + let inner_values: [[Option; 2]; 2] = [ + [ + Some("outer_k2->inner_k1->inner_v1".into()), + Some("outer_k2->inner_k2->inner_v2".into()), + ], + [ + Some("outer_k1->inner_k1->inner_v1".into()), + Some("outer_k1->inner_k2->inner_v2".into()), + ], + ]; + let outer_values = [ + (&inner_keys[0], &inner_values[0], true), + (&inner_keys[1], &inner_values[1], true), + ]; + + let keys_arg = [outer_keys]; + let values_arg = [outer_values]; + let validity = [true]; + + let map_arr_arg = build_map!( + StringBuilder::new(), + MapBuilder::new( + Some(MapFieldNames { + entry: "entries".into(), + key: "key".into(), + value: "value".into(), + }), + StringBuilder::new(), + StringBuilder::new(), + ), + keys_arg, + values_arg, + validity, + nested_map_entries_builder + ); + + let args = vec![ColumnarValue::Array(Arc::new(map_arr_arg))]; + let result = spark_map_sort(&args).unwrap(); + + let expected_outer_keys: [String; 2] = ["outer_k1".into(), "outer_k2".into()]; + let expected_inner_keys: [[String; 2]; 2] = [ + ["outer_k1->inner_k1".into(), "outer_k1->inner_k2".into()], + ["outer_k2->inner_k1".into(), "outer_k2->inner_k2".into()], + ]; + let expected_inner_values: [[Option; 2]; 2] = [ + [ + Some("outer_k1->inner_k1->inner_v1".into()), + Some("outer_k1->inner_k2->inner_v2".into()), + ], + [ + Some("outer_k2->inner_k1->inner_v1".into()), + Some("outer_k2->inner_k2->inner_v2".into()), + ], + ]; + let expected_outer_values = [ + (&expected_inner_keys[0], &expected_inner_values[0], true), + (&expected_inner_keys[1], &expected_inner_values[1], true), + ]; + + let expected_keys_arg = [expected_outer_keys]; + let expected_values_arg = [expected_outer_values]; + let expected_validity = [true]; + + let expected_map_arr = build_map!( + StringBuilder::new(), + MapBuilder::new( + Some(MapFieldNames { + entry: "entries".into(), + key: "key".into(), + value: "value".into(), + }), + StringBuilder::new(), + StringBuilder::new(), + ), + expected_keys_arg, + expected_values_arg, + expected_validity, + nested_map_entries_builder + ); + + verify_result!( + StringArray, + MapArray, + result, + expected_map_arr, + default_entries_verifier + ); + } + + #[test] + fn test_map_sort_with_list_int_keys() { + let keys_arg = [vec![ + vec![Some(3), Some(2)], + vec![Some(1), Some(2)], + vec![Some(2), Some(1)], + ]]; + let values_arg: [Vec>; 1] = [vec![ + Some("three_two".into()), + Some("one_two".into()), + Some("two_one".into()), + ]]; + let validity = [true]; + + let map_arr_arg = build_map!( + ListBuilder::new(Int32Builder::new()), + StringBuilder::new(), + keys_arg, + values_arg, + validity, + default_map_entries_builder + ); + + let args = vec![ColumnarValue::Array(Arc::new(map_arr_arg))]; + let result = spark_map_sort(&args).unwrap(); + + let expected_keys = [vec![ + vec![Some(1), Some(2)], + vec![Some(2), Some(1)], + vec![Some(3), Some(2)], + ]]; + let expected_values: [Vec>; 1] = [vec![ + Some("one_two".into()), + Some("two_one".into()), + Some("three_two".into()), + ]]; + let expected_validity = [true]; + + let expected_map_arr = build_map!( + ListBuilder::new(Int32Builder::new()), + StringBuilder::new(), + expected_keys, + expected_values, + expected_validity, + default_map_entries_builder + ); + + verify_result!( + ListArray, + StringArray, + result, + expected_map_arr, + list_entries_verifier + ); + } + + #[test] + fn test_map_sort_with_list_string_keys() { + let keys_arg: [Vec>>; 1] = [vec![ + vec![Some("c".into()), Some("b".into())], + vec![Some("a".into()), Some("b".into())], + vec![Some("b".into()), Some("a".into())], + ]]; + let values_arg: [Vec>; 1] = [vec![Some(32), Some(12), Some(21)]]; + let validity = [true]; + + let map_arr_arg = build_map!( + ListBuilder::new(StringBuilder::new()), + Int32Builder::new(), + keys_arg, + values_arg, + validity, + default_map_entries_builder + ); + + let args = vec![ColumnarValue::Array(Arc::new(map_arr_arg))]; + let result = spark_map_sort(&args).unwrap(); + + let expected_keys: [Vec>>; 1] = [vec![ + vec![Some("a".into()), Some("b".into())], + vec![Some("b".into()), Some("a".into())], + vec![Some("c".into()), Some("b".into())], + ]]; + let expected_values: [Vec>; 1] = [vec![Some(12), Some(21), Some(32)]]; + let expected_validity = [true]; + + let expected_map_arr = build_map!( + ListBuilder::new(StringBuilder::new()), + Int32Builder::new(), + expected_keys, + expected_values, + expected_validity, + default_map_entries_builder + ); + + verify_result!( + ListArray, + Int32Array, + result, + expected_map_arr, + list_entries_verifier + ); + } + + #[test] + fn test_map_sort_with_scalar_argument() { + let map_array = build_map!( + StringBuilder::new(), + Int32Builder::new(), + vec![vec!["b".to_string(), "a".to_string()]], + vec![vec![Some(2), Some(1)]], + vec![true], + default_map_entries_builder + ); + + let args = vec![ColumnarValue::Scalar( + ScalarValue::try_from_array(&map_array, 0).unwrap(), + )]; + let result = spark_map_sort(&args).unwrap(); + + let expected_map_arr = build_map!( + StringBuilder::new(), + Int32Builder::new(), + vec![vec!["a".to_string(), "b".to_string()]], + vec![vec![Some(1), Some(2)]], + vec![true], + default_map_entries_builder + ); + verify_result!( + StringArray, + Int32Array, + result, + expected_map_arr, + default_entries_verifier + ); + } + + #[test] + fn test_map_sort_with_empty_map() { + let map_arr_arg = build_map!( + StringBuilder::new(), + Int32Builder::new(), + vec![Vec::::new()], + vec![Vec::>::new()], + vec![false], + default_map_entries_builder + ); + + let args = vec![ColumnarValue::Array(Arc::new(map_arr_arg))]; + let result = spark_map_sort(&args).unwrap(); + let expected_map_arr = build_map!( + StringBuilder::new(), + Int32Builder::new(), + vec![Vec::::new()], + vec![Vec::>::new()], + vec![false], + default_map_entries_builder + ); + verify_result!( + StringArray, + Int32Array, + result, + expected_map_arr, + default_entries_verifier + ); + } + + #[test] + fn test_map_sort_with_invalid_arguments() { + let result = spark_map_sort(&[]); + assert!(result.is_err()); + assert!(result + .unwrap_err() + .to_string() + .contains("spark_map_sort expects exactly one argument")); + + let map_array = build_map!( + StringBuilder::new(), + Int32Builder::new(), + vec![vec!["a".to_string()]], + vec![vec![Some(1)]], + vec![true], + default_map_entries_builder + ); + + let args = vec![ + ColumnarValue::Array(Arc::new(map_array.clone())), + ColumnarValue::Array(Arc::new(map_array)), + ]; + let result = spark_map_sort(&args); + assert!(result.is_err()); + assert!(result + .unwrap_err() + .to_string() + .contains("spark_map_sort expects exactly one argument")); + + let int_array = Arc::new(Int32Array::from(vec![1, 2, 3])) as ArrayRef; + let args = vec![ColumnarValue::Array(int_array)]; + + let result = spark_map_sort(&args); + assert!(result.is_err()); + assert!(result + .unwrap_err() + .to_string() + .contains("spark_map_sort expects Map type as argument")); + } +} diff --git a/native/spark-expr/src/map_funcs/mod.rs b/native/spark-expr/src/map_funcs/mod.rs new file mode 100644 index 0000000000..7288b847a8 --- /dev/null +++ b/native/spark-expr/src/map_funcs/mod.rs @@ -0,0 +1,19 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +mod map_sort; +pub use map_sort::spark_map_sort; diff --git a/spark/src/main/spark-4.0/org/apache/comet/shims/CometExprShim.scala b/spark/src/main/spark-4.0/org/apache/comet/shims/CometExprShim.scala index 06e9d2278a..ca66635c03 100644 --- a/spark/src/main/spark-4.0/org/apache/comet/shims/CometExprShim.scala +++ b/spark/src/main/spark-4.0/org/apache/comet/shims/CometExprShim.scala @@ -128,6 +128,11 @@ trait CometExprShim extends CommonStringExprs { case _ => None } + case ms: MapSort => + val childExpr = exprToProtoInternal(ms.child, inputs, binding) + val mapSortExpr = scalarFunctionExprToProto("map_sort", childExpr) + optExprWithInfo(mapSortExpr, ms, ms.child) + case _ => None } } diff --git a/spark/src/test/scala/org/apache/comet/exec/CometColumnarShuffleSuite.scala b/spark/src/test/scala/org/apache/comet/exec/CometColumnarShuffleSuite.scala index 7d865829b6..a0c63aac92 100644 --- a/spark/src/test/scala/org/apache/comet/exec/CometColumnarShuffleSuite.scala +++ b/spark/src/test/scala/org/apache/comet/exec/CometColumnarShuffleSuite.scala @@ -39,7 +39,6 @@ import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types._ import org.apache.comet.CometConf -import org.apache.comet.CometSparkSessionExtensions.isSpark40Plus abstract class CometColumnarShuffleSuite extends CometTestBase with AdaptiveSparkPlanHelper { protected val adaptiveExecutionEnabled: Boolean @@ -164,13 +163,7 @@ abstract class CometColumnarShuffleSuite extends CometTestBase with AdaptiveSpar .repartition(numPartitions, $"_1", $"_2") .sortWithinPartitions($"_2") - if (isSpark40Plus) { - // https://github.com/apache/datafusion-comet/issues/1941 - // Spark 4.0 introduces a mapsort which falls back - checkShuffleAnswer(df, 0) - } else { - checkShuffleAnswer(df, 1) - } + checkShuffleAnswer(df, 1) } withParquetTable((0 until 50).map(i => (Map(i -> Seq(i, i + 1)), i + 1)), "tbl") { @@ -179,13 +172,7 @@ abstract class CometColumnarShuffleSuite extends CometTestBase with AdaptiveSpar .repartition(numPartitions, $"_1", $"_2") .sortWithinPartitions($"_2") - if (isSpark40Plus) { - // https://github.com/apache/datafusion-comet/issues/1941 - // Spark 4.0 introduces a mapsort which falls back - checkShuffleAnswer(df, 0) - } else { - checkShuffleAnswer(df, 1) - } + checkShuffleAnswer(df, 1) } withParquetTable((0 until 50).map(i => (Map((i, i.toString) -> i), i + 1)), "tbl") { @@ -194,13 +181,7 @@ abstract class CometColumnarShuffleSuite extends CometTestBase with AdaptiveSpar .repartition(numPartitions, $"_1", $"_2") .sortWithinPartitions($"_2") - if (isSpark40Plus) { - // https://github.com/apache/datafusion-comet/issues/1941 - // Spark 4.0 introduces a mapsort which falls back - checkShuffleAnswer(df, 0) - } else { - checkShuffleAnswer(df, 1) - } + checkShuffleAnswer(df, 1) } withParquetTable((0 until 50).map(i => (Map(i -> (i, i.toString)), i + 1)), "tbl") { @@ -209,13 +190,7 @@ abstract class CometColumnarShuffleSuite extends CometTestBase with AdaptiveSpar .repartition(numPartitions, $"_1", $"_2") .sortWithinPartitions($"_2") - if (isSpark40Plus) { - // https://github.com/apache/datafusion-comet/issues/1941 - // Spark 4.0 introduces a mapsort which falls back - checkShuffleAnswer(df, 0) - } else { - checkShuffleAnswer(df, 1) - } + checkShuffleAnswer(df, 1) } } } @@ -238,13 +213,7 @@ abstract class CometColumnarShuffleSuite extends CometTestBase with AdaptiveSpar .repartition(numPartitions, $"_1", $"_2") .sortWithinPartitions($"_2") - if (isSpark40Plus) { - // https://github.com/apache/datafusion-comet/issues/1941 - // Spark 4.0 introduces a mapsort which falls back - checkShuffleAnswer(df, 0) - } else { - checkShuffleAnswer(df, 1) - } + checkShuffleAnswer(df, 1) } } } @@ -344,78 +313,54 @@ abstract class CometColumnarShuffleSuite extends CometTestBase with AdaptiveSpar } test("columnar shuffle on map [bool]") { - // https://github.com/apache/datafusion-comet/issues/1941 - assume(!isSpark40Plus) columnarShuffleOnMapTest(50, Seq(true, false)) } test("columnar shuffle on map [byte]") { - // https://github.com/apache/datafusion-comet/issues/1941 - assume(!isSpark40Plus) columnarShuffleOnMapTest(50, Seq(0.toByte, 1.toByte)) } test("columnar shuffle on map [short]") { - // https://github.com/apache/datafusion-comet/issues/1941 - assume(!isSpark40Plus) columnarShuffleOnMapTest(50, Seq(0.toShort, 1.toShort)) } test("columnar shuffle on map [int]") { - // https://github.com/apache/datafusion-comet/issues/1941 - assume(!isSpark40Plus) columnarShuffleOnMapTest(50, Seq(0, 1)) } test("columnar shuffle on map [long]") { - // https://github.com/apache/datafusion-comet/issues/1941 - assume(!isSpark40Plus) columnarShuffleOnMapTest(50, Seq(0.toLong, 1.toLong)) } test("columnar shuffle on map [float]") { - // https://github.com/apache/datafusion-comet/issues/1941 - assume(!isSpark40Plus) columnarShuffleOnMapTest(50, Seq(0.toFloat, 1.toFloat)) } test("columnar shuffle on map [double]") { - // https://github.com/apache/datafusion-comet/issues/1941 - assume(!isSpark40Plus) columnarShuffleOnMapTest(50, Seq(0.toDouble, 1.toDouble)) } test("columnar shuffle on map [date]") { - // https://github.com/apache/datafusion-comet/issues/1941 - assume(!isSpark40Plus) columnarShuffleOnMapTest(50, Seq(new java.sql.Date(0.toLong), new java.sql.Date(1.toLong))) } test("columnar shuffle on map [timestamp]") { - // https://github.com/apache/datafusion-comet/issues/1941 - assume(!isSpark40Plus) columnarShuffleOnMapTest( 50, Seq(new java.sql.Timestamp(0.toLong), new java.sql.Timestamp(1.toLong))) } test("columnar shuffle on map [decimal]") { - // https://github.com/apache/datafusion-comet/issues/1941 - assume(!isSpark40Plus) columnarShuffleOnMapTest( 50, Seq(new java.math.BigDecimal(0.toLong), new java.math.BigDecimal(1.toLong))) } test("columnar shuffle on map [string]") { - // https://github.com/apache/datafusion-comet/issues/1941 - assume(!isSpark40Plus) columnarShuffleOnMapTest(50, Seq(0.toString, 1.toString)) } test("columnar shuffle on map [binary]") { - // https://github.com/apache/datafusion-comet/issues/1941 - assume(!isSpark40Plus) columnarShuffleOnMapTest(50, Seq(0.toString.getBytes(), 1.toString.getBytes())) } From f99ba23bc74264bbc0f3ed811cab5f39790b08ed Mon Sep 17 00:00:00 2001 From: Andy Grove Date: Fri, 24 Apr 2026 17:59:23 -0600 Subject: [PATCH 2/5] test: fall back to Spark for map array element shuffle on Spark 4.0 Spark 4.0 normalizes shuffle keys containing array via transform(arr, x -> mapsort(x)), which Comet does not yet support because ArrayTransform with a lambda body has no serde. Mark the columnar shuffle on map array element test as expecting the fallback on Spark 4.0+ while still verifying answer correctness. --- .../org/apache/comet/exec/CometColumnarShuffleSuite.scala | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/spark/src/test/scala/org/apache/comet/exec/CometColumnarShuffleSuite.scala b/spark/src/test/scala/org/apache/comet/exec/CometColumnarShuffleSuite.scala index a0c63aac92..26e9f81931 100644 --- a/spark/src/test/scala/org/apache/comet/exec/CometColumnarShuffleSuite.scala +++ b/spark/src/test/scala/org/apache/comet/exec/CometColumnarShuffleSuite.scala @@ -39,6 +39,7 @@ import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types._ import org.apache.comet.CometConf +import org.apache.comet.CometSparkSessionExtensions.isSpark40Plus abstract class CometColumnarShuffleSuite extends CometTestBase with AdaptiveSparkPlanHelper { protected val adaptiveExecutionEnabled: Boolean @@ -213,7 +214,11 @@ abstract class CometColumnarShuffleSuite extends CometTestBase with AdaptiveSpar .repartition(numPartitions, $"_1", $"_2") .sortWithinPartitions($"_2") - checkShuffleAnswer(df, 1) + // Spark 4.0 normalizes shuffle keys containing array via + // transform(arr, x -> mapsort(x)), which Comet doesn't yet + // support, so the shuffle falls back to Spark. + val expectedShuffles = if (isSpark40Plus) 0 else 1 + checkShuffleAnswer(df, expectedShuffles) } } } From 0480e82d72cad13236ca6699a829395fdf16d9d3 Mon Sep 17 00:00:00 2001 From: Andy Grove Date: Fri, 24 Apr 2026 20:26:17 -0600 Subject: [PATCH 3/5] fix: pass return type for MapSort serde The MapSort serde for Spark 4.0 called scalarFunctionExprToProto without a return type. The Rust planner then looked up "map_sort" in the session UDF registry to infer the type, but map_sort is only handled via the create_comet_physical_fun match dispatch, not registered as a UDF, causing "There is no UDF named 'map_sort' in the registry" at execution time (e.g., group-by on a map column in CollationSuite). Pass ms.dataType explicitly via scalarFunctionExprToProtoWithReturnType, matching the pattern used by ceil, floor, and other scalar functions. --- .../org/apache/comet/shims/CometExprShim.scala | 8 ++++++-- .../org/apache/comet/CometMapExpressionSuite.scala | 13 +++++++++++++ 2 files changed, 19 insertions(+), 2 deletions(-) diff --git a/spark/src/main/spark-4.0/org/apache/comet/shims/CometExprShim.scala b/spark/src/main/spark-4.0/org/apache/comet/shims/CometExprShim.scala index ca66635c03..b5b064615d 100644 --- a/spark/src/main/spark-4.0/org/apache/comet/shims/CometExprShim.scala +++ b/spark/src/main/spark-4.0/org/apache/comet/shims/CometExprShim.scala @@ -30,7 +30,7 @@ import org.apache.comet.CometSparkSessionExtensions.withInfo import org.apache.comet.expressions.{CometCast, CometEvalMode} import org.apache.comet.serde.{CommonStringExprs, Compatible, ExprOuterClass, Incompatible} import org.apache.comet.serde.ExprOuterClass.{BinaryOutputStyle, Expr} -import org.apache.comet.serde.QueryPlanSerde.{exprToProtoInternal, optExprWithInfo, scalarFunctionExprToProto} +import org.apache.comet.serde.QueryPlanSerde.{exprToProtoInternal, optExprWithInfo, scalarFunctionExprToProto, scalarFunctionExprToProtoWithReturnType} /** * `CometExprShim` acts as a shim for parsing expressions from different Spark versions. @@ -130,7 +130,11 @@ trait CometExprShim extends CommonStringExprs { case ms: MapSort => val childExpr = exprToProtoInternal(ms.child, inputs, binding) - val mapSortExpr = scalarFunctionExprToProto("map_sort", childExpr) + val mapSortExpr = scalarFunctionExprToProtoWithReturnType( + "map_sort", + ms.dataType, + failOnError = false, + childExpr) optExprWithInfo(mapSortExpr, ms, ms.child) case _ => None diff --git a/spark/src/test/scala/org/apache/comet/CometMapExpressionSuite.scala b/spark/src/test/scala/org/apache/comet/CometMapExpressionSuite.scala index 03db26e566..f3c7d9f23e 100644 --- a/spark/src/test/scala/org/apache/comet/CometMapExpressionSuite.scala +++ b/spark/src/test/scala/org/apache/comet/CometMapExpressionSuite.scala @@ -27,6 +27,7 @@ import org.apache.spark.sql.functions._ import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types.BinaryType +import org.apache.comet.CometSparkSessionExtensions.isSpark40Plus import org.apache.comet.serde.CometMapFromEntries import org.apache.comet.testing.{DataGenOptions, ParquetGenerator, SchemaGenOptions} @@ -221,6 +222,18 @@ class CometMapExpressionSuite extends CometTestBase { } } + test("group by map column with string values") { + assume(isSpark40Plus, "Spark 4.0 inserts MapSort for group-by on map keys") + withTable("t_map_group") { + sql(""" + |CREATE TABLE t_map_group USING parquet AS + |SELECT map(cast(id as string), cast(id + 100 as string)) as m + |FROM range(5) + """.stripMargin) + checkSparkAnswer(sql("SELECT m, count(*) FROM t_map_group GROUP BY m")) + } + } + test("map_from_entries - fallback for binary type") { def fallbackReason(reason: String) = reason val table = "t2" From 9ad865ab45abdbea736235cefc9c0254174e3e78 Mon Sep 17 00:00:00 2001 From: Andy Grove Date: Sat, 25 Apr 2026 07:27:47 -0600 Subject: [PATCH 4/5] fix: fall back to Spark for MapSort with unsupported key types Arrow's sort_to_indices does not support Struct (and other complex) key types, so map_sort fails at runtime when the map key is a struct. Check key type via supportedScalarSortElementType and fall back to Spark when the key type is not natively sortable. This fixes 4 CollationSuite failures in spark-sql-auto-sql_core-1 for Spark 4.0: 'Group by on map containing structs with ...'. --- .../apache/comet/shims/CometExprShim.scala | 24 ++++++++++++------- 1 file changed, 15 insertions(+), 9 deletions(-) diff --git a/spark/src/main/spark-4.0/org/apache/comet/shims/CometExprShim.scala b/spark/src/main/spark-4.0/org/apache/comet/shims/CometExprShim.scala index b5b064615d..ce5df088fe 100644 --- a/spark/src/main/spark-4.0/org/apache/comet/shims/CometExprShim.scala +++ b/spark/src/main/spark-4.0/org/apache/comet/shims/CometExprShim.scala @@ -24,13 +24,13 @@ import org.apache.spark.sql.catalyst.expressions.json.StructsToJsonEvaluator import org.apache.spark.sql.catalyst.expressions.objects.{Invoke, StaticInvoke} import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.internal.types.StringTypeWithCollation -import org.apache.spark.sql.types.{BinaryType, BooleanType, DataTypes, StringType} +import org.apache.spark.sql.types.{BinaryType, BooleanType, DataTypes, MapType, StringType} import org.apache.comet.CometSparkSessionExtensions.withInfo import org.apache.comet.expressions.{CometCast, CometEvalMode} import org.apache.comet.serde.{CommonStringExprs, Compatible, ExprOuterClass, Incompatible} import org.apache.comet.serde.ExprOuterClass.{BinaryOutputStyle, Expr} -import org.apache.comet.serde.QueryPlanSerde.{exprToProtoInternal, optExprWithInfo, scalarFunctionExprToProto, scalarFunctionExprToProtoWithReturnType} +import org.apache.comet.serde.QueryPlanSerde.{exprToProtoInternal, optExprWithInfo, scalarFunctionExprToProto, scalarFunctionExprToProtoWithReturnType, supportedScalarSortElementType} /** * `CometExprShim` acts as a shim for parsing expressions from different Spark versions. @@ -129,13 +129,19 @@ trait CometExprShim extends CommonStringExprs { } case ms: MapSort => - val childExpr = exprToProtoInternal(ms.child, inputs, binding) - val mapSortExpr = scalarFunctionExprToProtoWithReturnType( - "map_sort", - ms.dataType, - failOnError = false, - childExpr) - optExprWithInfo(mapSortExpr, ms, ms.child) + val keyType = ms.dataType.asInstanceOf[MapType].keyType + if (!supportedScalarSortElementType(keyType)) { + withInfo(ms, s"MapSort on map with key type $keyType is not supported") + None + } else { + val childExpr = exprToProtoInternal(ms.child, inputs, binding) + val mapSortExpr = scalarFunctionExprToProtoWithReturnType( + "map_sort", + ms.dataType, + failOnError = false, + childExpr) + optExprWithInfo(mapSortExpr, ms, ms.child) + } case _ => None } From bb0678fc0fbe63278bef0cefd6c93c1e3fb9d110 Mon Sep 17 00:00:00 2001 From: Andy Grove Date: Sat, 25 Apr 2026 08:44:52 -0600 Subject: [PATCH 5/5] test: expect Spark fallback for shuffle on map with array/struct keys Spark 4.0 wraps map shuffle keys in mapsort(...). Comet's map_sort relies on Arrow's sort_to_indices, which only supports scalar key types, so maps with array or struct keys fall back to Spark. Update the 'columnar shuffle on array/struct map key/value' test to expect 0 Comet shuffles for the array-key and struct-key cases on Spark 4.0+, while keeping the scalar-key cases at 1. --- .../org/apache/comet/exec/CometColumnarShuffleSuite.scala | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/spark/src/test/scala/org/apache/comet/exec/CometColumnarShuffleSuite.scala b/spark/src/test/scala/org/apache/comet/exec/CometColumnarShuffleSuite.scala index 26e9f81931..fe2866da66 100644 --- a/spark/src/test/scala/org/apache/comet/exec/CometColumnarShuffleSuite.scala +++ b/spark/src/test/scala/org/apache/comet/exec/CometColumnarShuffleSuite.scala @@ -152,6 +152,10 @@ abstract class CometColumnarShuffleSuite extends CometTestBase with AdaptiveSpar } test("columnar shuffle on array/struct map key/value") { + // Spark 4.0 normalizes maps used as shuffle keys with mapsort(...). Comet's map_sort + // relies on Arrow's sort_to_indices, which only supports scalar key types, so a map + // with array or struct keys cannot be sorted natively and the shuffle falls back. + val complexKeyShuffles = if (isSpark40Plus) 0 else 1 Seq("false", "true").foreach { execEnabled => Seq(10, 201).foreach { numPartitions => Seq("1.0", "10.0").foreach { ratio => @@ -164,7 +168,7 @@ abstract class CometColumnarShuffleSuite extends CometTestBase with AdaptiveSpar .repartition(numPartitions, $"_1", $"_2") .sortWithinPartitions($"_2") - checkShuffleAnswer(df, 1) + checkShuffleAnswer(df, complexKeyShuffles) } withParquetTable((0 until 50).map(i => (Map(i -> Seq(i, i + 1)), i + 1)), "tbl") { @@ -182,7 +186,7 @@ abstract class CometColumnarShuffleSuite extends CometTestBase with AdaptiveSpar .repartition(numPartitions, $"_1", $"_2") .sortWithinPartitions($"_2") - checkShuffleAnswer(df, 1) + checkShuffleAnswer(df, complexKeyShuffles) } withParquetTable((0 until 50).map(i => (Map(i -> (i, i.toString)), i + 1)), "tbl") {