diff --git a/native/spark-expr/src/comet_scalar_funcs.rs b/native/spark-expr/src/comet_scalar_funcs.rs index a3ea0e6816..0957868a60 100644 --- a/native/spark-expr/src/comet_scalar_funcs.rs +++ b/native/spark-expr/src/comet_scalar_funcs.rs @@ -16,6 +16,7 @@ // under the License. use crate::hash_funcs::*; +use crate::map_funcs::spark_map_sort; use crate::math_funcs::abs::abs; use crate::math_funcs::checked_arithmetic::{checked_add, checked_div, checked_mul, checked_sub}; use crate::math_funcs::log::spark_log; @@ -191,6 +192,10 @@ pub fn create_comet_physical_fun_with_eval_mode( let func = Arc::new(crate::string_funcs::spark_get_json_object); make_comet_scalar_udf!("get_json_object", func, without data_type) } + "map_sort" => { + let func = Arc::new(spark_map_sort); + make_comet_scalar_udf!("spark_map_sort", func, without data_type) + } _ => registry.udf(fun_name).map_err(|e| { DataFusionError::Execution(format!( "Function {fun_name} not found in the registry: {e}", diff --git a/native/spark-expr/src/lib.rs b/native/spark-expr/src/lib.rs index 963ce62e2b..42bfb80e86 100644 --- a/native/spark-expr/src/lib.rs +++ b/native/spark-expr/src/lib.rs @@ -57,6 +57,7 @@ pub use bloom_filter::{BloomFilterAgg, BloomFilterMightContain}; mod conditional_funcs; mod conversion_funcs; +mod map_funcs; mod math_funcs; mod nondetermenistic_funcs; diff --git a/native/spark-expr/src/map_funcs/map_sort.rs b/native/spark-expr/src/map_funcs/map_sort.rs new file mode 100644 index 0000000000..a5253aeefa --- /dev/null +++ b/native/spark-expr/src/map_funcs/map_sort.rs @@ -0,0 +1,707 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use arrow::array::{Array, ArrayRef, MapArray, StructArray}; +use arrow::compute::{concat, sort_to_indices, take, SortOptions}; +use arrow::datatypes::DataType; +use datafusion::common::{exec_err, DataFusionError}; +use datafusion::physical_plan::ColumnarValue; +use std::sync::Arc; + +/// Spark compatible `MapSort` implementation. +/// Sorts each entries of a MapArray by keys in ascending order without changing the ordering of the +/// maps in the array. +pub fn spark_map_sort(args: &[ColumnarValue]) -> Result { + if args.len() != 1 { + return exec_err!("spark_map_sort expects exactly one argument"); + } + + let arr_arg: ArrayRef = match &args[0] { + ColumnarValue::Array(array) => Arc::::clone(array), + ColumnarValue::Scalar(scalar) => scalar.to_array_of_size(1)?, + }; + + let (maps_arg, map_field, is_sorted) = match arr_arg.data_type() { + DataType::Map(map_field, is_sorted) => { + let maps_arg = arr_arg.as_any().downcast_ref::().unwrap(); + (maps_arg, map_field, is_sorted) + } + _ => return exec_err!("spark_map_sort expects Map type as argument"), + }; + + let maps_arg_entries = maps_arg.entries(); + let maps_arg_offsets = maps_arg.offsets(); + + let mut sorted_map_entries_vec: Vec = Vec::with_capacity(maps_arg.len()); + + for idx in 0..maps_arg.len() { + let map_start = maps_arg_offsets[idx] as usize; + let map_end = maps_arg_offsets[idx + 1] as usize; + let map_len = map_end - map_start; + + let map_entries = maps_arg_entries.slice(map_start, map_len); + + if map_len == 0 { + sorted_map_entries_vec.push(Arc::new(map_entries)); + continue; + } + + let map_keys = map_entries.column(0); + let sort_options = SortOptions { + descending: false, + nulls_first: true, + }; + let sorted_indices = sort_to_indices(&map_keys, Some(sort_options), None)?; + + let sorted_map_entries = take(&map_entries, &sorted_indices, None)?; + sorted_map_entries_vec.push(sorted_map_entries); + } + + let sorted_map_entries_arr: Vec<&dyn Array> = sorted_map_entries_vec + .iter() + .map(|arr| arr.as_ref()) + .collect(); + let combined_sorted_map_entries = concat(&sorted_map_entries_arr)?; + let sorted_map_struct = combined_sorted_map_entries + .as_any() + .downcast_ref::() + .unwrap(); + + // Preserve the original is_sorted flag to keep schema consistent + let sorted_map_arr = Arc::new(MapArray::try_new( + Arc::::clone(map_field), + maps_arg.offsets().clone(), + sorted_map_struct.clone(), + maps_arg.nulls().cloned(), + *is_sorted, + )?); + + Ok(ColumnarValue::Array(sorted_map_arr)) +} + +#[cfg(test)] +mod tests { + use super::*; + use arrow::array::builder::{Int32Builder, MapBuilder, StringBuilder}; + use arrow::array::{Int32Array, ListArray, ListBuilder, MapFieldNames, StringArray}; + use datafusion::common::ScalarValue; + use std::sync::Arc; + + macro_rules! build_map { + ( + $key_builder:expr, + $value_builder:expr, + $keys:expr, + $values:expr, + $validity:expr, + $entries_builder_fn:ident + ) => {{ + let mut map_builder = MapBuilder::new( + Some(MapFieldNames { + entry: "entries".into(), + key: "key".into(), + value: "value".into(), + }), + $key_builder, + $value_builder, + ); + + assert_eq!($keys.len(), $values.len()); + assert_eq!($keys.len(), $validity.len()); + + let total_maps = $keys.len(); + for map_idx in 0..total_maps { + let map_keys = &$keys[map_idx]; + let map_values = &$values[map_idx]; + assert_eq!(map_keys.len(), map_values.len()); + + let map_entries = map_keys.len(); + for entry_idx in 0..map_entries { + let map_key = &map_keys[entry_idx]; + let map_value = &map_values[entry_idx]; + $entries_builder_fn!(map_builder, map_key, map_value); + } + + let is_valid = $validity[map_idx]; + map_builder.append(is_valid).unwrap(); + } + + map_builder.finish() + }}; + } + + macro_rules! default_map_entries_builder { + ($map_builder:expr, $key:expr, $value:expr) => {{ + $map_builder.keys().append_value($key.clone()); + $map_builder.values().append_value($value.clone().unwrap()); + }}; + } + + macro_rules! nested_map_entries_builder { + ($map_builder:expr, $key:expr, $value:expr) => {{ + $map_builder.keys().append_value($key.clone()); + + let inner_map_builder = $map_builder.values(); + + let (inner_keys, inner_values, inner_valid) = $value; + assert_eq!(inner_keys.len(), inner_values.len()); + + let inner_entries = inner_keys.len(); + for inner_idx in 0..inner_entries { + let inner_key_val = &inner_keys[inner_idx]; + let inner_value = &inner_values[inner_idx]; + default_map_entries_builder!(inner_map_builder, inner_key_val, inner_value); + } + + inner_map_builder.append(*inner_valid).unwrap(); + }}; + } + + macro_rules! verify_result { + ( + $key_type:ty, + $value_type:ty, + $result:expr, + $expected_map_arr:expr, + $verify_entries_fn:ident + ) => {{ + match $result { + ColumnarValue::Array(actual_arr) => { + let actual_map_arr = actual_arr.as_any().downcast_ref::().unwrap(); + + assert_eq!(actual_map_arr.len(), $expected_map_arr.len()); + assert_eq!(actual_map_arr.offsets(), $expected_map_arr.offsets()); + assert_eq!(actual_map_arr.nulls(), $expected_map_arr.nulls()); + assert_eq!(actual_map_arr.data_type(), $expected_map_arr.data_type()); + + let actual_entries = actual_map_arr.entries(); + let actual_keys = actual_entries + .column(0) + .as_any() + .downcast_ref::<$key_type>() + .unwrap(); + let actual_values = actual_entries + .column(1) + .as_any() + .downcast_ref::<$value_type>() + .unwrap(); + + let expected_entries = $expected_map_arr.entries(); + let expected_keys = expected_entries + .column(0) + .as_any() + .downcast_ref::<$key_type>() + .unwrap(); + let expected_values = expected_entries + .column(1) + .as_any() + .downcast_ref::<$value_type>() + .unwrap(); + + assert_eq!(actual_keys.len(), expected_keys.len()); + assert_eq!(actual_values.len(), expected_values.len()); + + $verify_entries_fn!( + expected_entries.len(), + actual_keys, + expected_keys, + actual_values, + expected_values + ); + } + unexpected_arr => { + panic!("Actual result: {unexpected_arr:?} is not an Array ColumnarValue") + } + } + }}; + } + + macro_rules! default_entries_verifier { + ( + $entries_len:expr, + $actual_keys:expr, + $expected_keys:expr, + $actual_values:expr, + $expected_values:expr + ) => {{ + for idx in 0..$entries_len { + assert_eq!($actual_keys.value(idx), $expected_keys.value(idx)); + assert_eq!($actual_values.value(idx), $expected_values.value(idx)); + } + }}; + } + + macro_rules! list_entries_verifier { + ( + $entries_len:expr, + $actual_keys:expr, + $expected_keys:expr, + $actual_values:expr, + $expected_values:expr + ) => {{ + for idx in 0..$entries_len { + let actual_list = $actual_keys.value(idx); + let expected_list = $expected_keys.value(idx); + assert!(actual_list.eq(&expected_list)); + assert_eq!($actual_values.value(idx), $expected_values.value(idx)); + } + }}; + } + + #[test] + fn test_map_sort_with_string_keys() { + let keys_arg: [Vec; 4] = [ + vec!["c".into(), "a".into(), "b".into()], + vec!["z".into(), "y".into(), "x".into()], + vec!["a".into(), "b".into(), "c".into()], + vec!["fusion".into(), "comet".into(), "data".into()], + ]; + let values_arg = [ + vec![Some(3), Some(1), Some(2)], + vec![Some(30), Some(20), Some(10)], + vec![Some(1), Some(2), Some(3)], + vec![Some(300), Some(100), Some(200)], + ]; + let validity = [true, true, true, true]; + + let map_arr_arg = build_map!( + StringBuilder::new(), + Int32Builder::new(), + keys_arg, + values_arg, + validity, + default_map_entries_builder + ); + let args = vec![ColumnarValue::Array(Arc::new(map_arr_arg))]; + let result = spark_map_sort(&args).unwrap(); + + let expected_keys: [Vec; 4] = [ + vec!["a".into(), "b".into(), "c".into()], + vec!["x".into(), "y".into(), "z".into()], + vec!["a".into(), "b".into(), "c".into()], + vec!["comet".into(), "data".into(), "fusion".into()], + ]; + let expected_values = [ + vec![Some(1), Some(2), Some(3)], + vec![Some(10), Some(20), Some(30)], + vec![Some(1), Some(2), Some(3)], + vec![Some(100), Some(200), Some(300)], + ]; + let expected_validity = [true, true, true, true]; + + let expected_map_arr = build_map!( + StringBuilder::new(), + Int32Builder::new(), + expected_keys, + expected_values, + expected_validity, + default_map_entries_builder + ); + verify_result!( + StringArray, + Int32Array, + result, + expected_map_arr, + default_entries_verifier + ); + } + + #[test] + fn test_map_sort_with_int_keys() { + let keys_arg = [ + vec![3, 2, 1], + vec![100, 50, 20], + vec![20, 50, 100], + vec![-5, 0, -1], + ]; + let values_arg: [Vec>; 4] = [ + vec![Some("three".into()), Some("two".into()), Some("one".into())], + vec![ + Some("hundred".into()), + Some("fifty".into()), + Some("twenty".into()), + ], + vec![ + Some("twenty".into()), + Some("fifty".into()), + Some("hundred".into()), + ], + vec![ + Some("minus five".into()), + Some("zero".into()), + Some("minus one".into()), + ], + ]; + let validity = [true, true, true, true]; + + let map_arr_arg = build_map!( + Int32Builder::new(), + StringBuilder::new(), + keys_arg, + values_arg, + validity, + default_map_entries_builder + ); + let args = vec![ColumnarValue::Array(Arc::new(map_arr_arg))]; + let result = spark_map_sort(&args).unwrap(); + + let expected_keys = [ + vec![1, 2, 3], + vec![20, 50, 100], + vec![20, 50, 100], + vec![-5, -1, 0], + ]; + let expected_values: [Vec>; 4] = [ + vec![Some("one".into()), Some("two".into()), Some("three".into())], + vec![ + Some("twenty".into()), + Some("fifty".into()), + Some("hundred".into()), + ], + vec![ + Some("twenty".into()), + Some("fifty".into()), + Some("hundred".into()), + ], + vec![ + Some("minus five".into()), + Some("minus one".into()), + Some("zero".into()), + ], + ]; + let expected_validity = [true, true, true, true]; + + let expected_map_arr = build_map!( + Int32Builder::new(), + StringBuilder::new(), + expected_keys, + expected_values, + expected_validity, + default_map_entries_builder + ); + verify_result!( + Int32Array, + StringArray, + result, + expected_map_arr, + default_entries_verifier + ); + } + + #[test] + fn test_map_sort_with_nested_maps() { + let outer_keys: [String; 2] = ["outer_k2".into(), "outer_k1".into()]; + let inner_keys: [[String; 2]; 2] = [ + ["outer_k2->inner_k1".into(), "outer_k2->inner_k2".into()], + ["outer_k1->inner_k1".into(), "outer_k1->inner_k2".into()], + ]; + let inner_values: [[Option; 2]; 2] = [ + [ + Some("outer_k2->inner_k1->inner_v1".into()), + Some("outer_k2->inner_k2->inner_v2".into()), + ], + [ + Some("outer_k1->inner_k1->inner_v1".into()), + Some("outer_k1->inner_k2->inner_v2".into()), + ], + ]; + let outer_values = [ + (&inner_keys[0], &inner_values[0], true), + (&inner_keys[1], &inner_values[1], true), + ]; + + let keys_arg = [outer_keys]; + let values_arg = [outer_values]; + let validity = [true]; + + let map_arr_arg = build_map!( + StringBuilder::new(), + MapBuilder::new( + Some(MapFieldNames { + entry: "entries".into(), + key: "key".into(), + value: "value".into(), + }), + StringBuilder::new(), + StringBuilder::new(), + ), + keys_arg, + values_arg, + validity, + nested_map_entries_builder + ); + + let args = vec![ColumnarValue::Array(Arc::new(map_arr_arg))]; + let result = spark_map_sort(&args).unwrap(); + + let expected_outer_keys: [String; 2] = ["outer_k1".into(), "outer_k2".into()]; + let expected_inner_keys: [[String; 2]; 2] = [ + ["outer_k1->inner_k1".into(), "outer_k1->inner_k2".into()], + ["outer_k2->inner_k1".into(), "outer_k2->inner_k2".into()], + ]; + let expected_inner_values: [[Option; 2]; 2] = [ + [ + Some("outer_k1->inner_k1->inner_v1".into()), + Some("outer_k1->inner_k2->inner_v2".into()), + ], + [ + Some("outer_k2->inner_k1->inner_v1".into()), + Some("outer_k2->inner_k2->inner_v2".into()), + ], + ]; + let expected_outer_values = [ + (&expected_inner_keys[0], &expected_inner_values[0], true), + (&expected_inner_keys[1], &expected_inner_values[1], true), + ]; + + let expected_keys_arg = [expected_outer_keys]; + let expected_values_arg = [expected_outer_values]; + let expected_validity = [true]; + + let expected_map_arr = build_map!( + StringBuilder::new(), + MapBuilder::new( + Some(MapFieldNames { + entry: "entries".into(), + key: "key".into(), + value: "value".into(), + }), + StringBuilder::new(), + StringBuilder::new(), + ), + expected_keys_arg, + expected_values_arg, + expected_validity, + nested_map_entries_builder + ); + + verify_result!( + StringArray, + MapArray, + result, + expected_map_arr, + default_entries_verifier + ); + } + + #[test] + fn test_map_sort_with_list_int_keys() { + let keys_arg = [vec![ + vec![Some(3), Some(2)], + vec![Some(1), Some(2)], + vec![Some(2), Some(1)], + ]]; + let values_arg: [Vec>; 1] = [vec![ + Some("three_two".into()), + Some("one_two".into()), + Some("two_one".into()), + ]]; + let validity = [true]; + + let map_arr_arg = build_map!( + ListBuilder::new(Int32Builder::new()), + StringBuilder::new(), + keys_arg, + values_arg, + validity, + default_map_entries_builder + ); + + let args = vec![ColumnarValue::Array(Arc::new(map_arr_arg))]; + let result = spark_map_sort(&args).unwrap(); + + let expected_keys = [vec![ + vec![Some(1), Some(2)], + vec![Some(2), Some(1)], + vec![Some(3), Some(2)], + ]]; + let expected_values: [Vec>; 1] = [vec![ + Some("one_two".into()), + Some("two_one".into()), + Some("three_two".into()), + ]]; + let expected_validity = [true]; + + let expected_map_arr = build_map!( + ListBuilder::new(Int32Builder::new()), + StringBuilder::new(), + expected_keys, + expected_values, + expected_validity, + default_map_entries_builder + ); + + verify_result!( + ListArray, + StringArray, + result, + expected_map_arr, + list_entries_verifier + ); + } + + #[test] + fn test_map_sort_with_list_string_keys() { + let keys_arg: [Vec>>; 1] = [vec![ + vec![Some("c".into()), Some("b".into())], + vec![Some("a".into()), Some("b".into())], + vec![Some("b".into()), Some("a".into())], + ]]; + let values_arg: [Vec>; 1] = [vec![Some(32), Some(12), Some(21)]]; + let validity = [true]; + + let map_arr_arg = build_map!( + ListBuilder::new(StringBuilder::new()), + Int32Builder::new(), + keys_arg, + values_arg, + validity, + default_map_entries_builder + ); + + let args = vec![ColumnarValue::Array(Arc::new(map_arr_arg))]; + let result = spark_map_sort(&args).unwrap(); + + let expected_keys: [Vec>>; 1] = [vec![ + vec![Some("a".into()), Some("b".into())], + vec![Some("b".into()), Some("a".into())], + vec![Some("c".into()), Some("b".into())], + ]]; + let expected_values: [Vec>; 1] = [vec![Some(12), Some(21), Some(32)]]; + let expected_validity = [true]; + + let expected_map_arr = build_map!( + ListBuilder::new(StringBuilder::new()), + Int32Builder::new(), + expected_keys, + expected_values, + expected_validity, + default_map_entries_builder + ); + + verify_result!( + ListArray, + Int32Array, + result, + expected_map_arr, + list_entries_verifier + ); + } + + #[test] + fn test_map_sort_with_scalar_argument() { + let map_array = build_map!( + StringBuilder::new(), + Int32Builder::new(), + vec![vec!["b".to_string(), "a".to_string()]], + vec![vec![Some(2), Some(1)]], + vec![true], + default_map_entries_builder + ); + + let args = vec![ColumnarValue::Scalar( + ScalarValue::try_from_array(&map_array, 0).unwrap(), + )]; + let result = spark_map_sort(&args).unwrap(); + + let expected_map_arr = build_map!( + StringBuilder::new(), + Int32Builder::new(), + vec![vec!["a".to_string(), "b".to_string()]], + vec![vec![Some(1), Some(2)]], + vec![true], + default_map_entries_builder + ); + verify_result!( + StringArray, + Int32Array, + result, + expected_map_arr, + default_entries_verifier + ); + } + + #[test] + fn test_map_sort_with_empty_map() { + let map_arr_arg = build_map!( + StringBuilder::new(), + Int32Builder::new(), + vec![Vec::::new()], + vec![Vec::>::new()], + vec![false], + default_map_entries_builder + ); + + let args = vec![ColumnarValue::Array(Arc::new(map_arr_arg))]; + let result = spark_map_sort(&args).unwrap(); + let expected_map_arr = build_map!( + StringBuilder::new(), + Int32Builder::new(), + vec![Vec::::new()], + vec![Vec::>::new()], + vec![false], + default_map_entries_builder + ); + verify_result!( + StringArray, + Int32Array, + result, + expected_map_arr, + default_entries_verifier + ); + } + + #[test] + fn test_map_sort_with_invalid_arguments() { + let result = spark_map_sort(&[]); + assert!(result.is_err()); + assert!(result + .unwrap_err() + .to_string() + .contains("spark_map_sort expects exactly one argument")); + + let map_array = build_map!( + StringBuilder::new(), + Int32Builder::new(), + vec![vec!["a".to_string()]], + vec![vec![Some(1)]], + vec![true], + default_map_entries_builder + ); + + let args = vec![ + ColumnarValue::Array(Arc::new(map_array.clone())), + ColumnarValue::Array(Arc::new(map_array)), + ]; + let result = spark_map_sort(&args); + assert!(result.is_err()); + assert!(result + .unwrap_err() + .to_string() + .contains("spark_map_sort expects exactly one argument")); + + let int_array = Arc::new(Int32Array::from(vec![1, 2, 3])) as ArrayRef; + let args = vec![ColumnarValue::Array(int_array)]; + + let result = spark_map_sort(&args); + assert!(result.is_err()); + assert!(result + .unwrap_err() + .to_string() + .contains("spark_map_sort expects Map type as argument")); + } +} diff --git a/native/spark-expr/src/map_funcs/mod.rs b/native/spark-expr/src/map_funcs/mod.rs new file mode 100644 index 0000000000..7288b847a8 --- /dev/null +++ b/native/spark-expr/src/map_funcs/mod.rs @@ -0,0 +1,19 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +mod map_sort; +pub use map_sort::spark_map_sort; diff --git a/spark/src/main/spark-4.0/org/apache/comet/shims/CometExprShim.scala b/spark/src/main/spark-4.0/org/apache/comet/shims/CometExprShim.scala index 38a3b9d726..ea538cd3c8 100644 --- a/spark/src/main/spark-4.0/org/apache/comet/shims/CometExprShim.scala +++ b/spark/src/main/spark-4.0/org/apache/comet/shims/CometExprShim.scala @@ -24,13 +24,13 @@ import org.apache.spark.sql.catalyst.expressions.json.StructsToJsonEvaluator import org.apache.spark.sql.catalyst.expressions.objects.{Invoke, StaticInvoke} import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.internal.types.StringTypeWithCollation -import org.apache.spark.sql.types.{ArrayType, BinaryType, BooleanType, DataTypes, StringType} +import org.apache.spark.sql.types.{ArrayType, BinaryType, BooleanType, DataTypes, MapType, StringType} import org.apache.comet.CometSparkSessionExtensions.withInfo import org.apache.comet.expressions.{CometCast, CometEvalMode} import org.apache.comet.serde.{CommonStringExprs, Compatible, ExprOuterClass, Incompatible} import org.apache.comet.serde.ExprOuterClass.{BinaryOutputStyle, Expr} -import org.apache.comet.serde.QueryPlanSerde.{exprToProtoInternal, optExprWithInfo, scalarFunctionExprToProto, scalarFunctionExprToProtoWithReturnType} +import org.apache.comet.serde.QueryPlanSerde.{exprToProtoInternal, optExprWithInfo, scalarFunctionExprToProto, scalarFunctionExprToProtoWithReturnType, supportedScalarSortElementType} /** * `CometExprShim` acts as a shim for parsing expressions from different Spark versions. @@ -144,6 +144,21 @@ trait CometExprShim extends CommonStringExprs { case _ => None } + case ms: MapSort => + val keyType = ms.dataType.asInstanceOf[MapType].keyType + if (!supportedScalarSortElementType(keyType)) { + withInfo(ms, s"MapSort on map with key type $keyType is not supported") + None + } else { + val childExpr = exprToProtoInternal(ms.child, inputs, binding) + val mapSortExpr = scalarFunctionExprToProtoWithReturnType( + "map_sort", + ms.dataType, + failOnError = false, + childExpr) + optExprWithInfo(mapSortExpr, ms, ms.child) + } + case _ => None } } diff --git a/spark/src/test/scala/org/apache/comet/CometMapExpressionSuite.scala b/spark/src/test/scala/org/apache/comet/CometMapExpressionSuite.scala index 03db26e566..f3c7d9f23e 100644 --- a/spark/src/test/scala/org/apache/comet/CometMapExpressionSuite.scala +++ b/spark/src/test/scala/org/apache/comet/CometMapExpressionSuite.scala @@ -27,6 +27,7 @@ import org.apache.spark.sql.functions._ import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types.BinaryType +import org.apache.comet.CometSparkSessionExtensions.isSpark40Plus import org.apache.comet.serde.CometMapFromEntries import org.apache.comet.testing.{DataGenOptions, ParquetGenerator, SchemaGenOptions} @@ -221,6 +222,18 @@ class CometMapExpressionSuite extends CometTestBase { } } + test("group by map column with string values") { + assume(isSpark40Plus, "Spark 4.0 inserts MapSort for group-by on map keys") + withTable("t_map_group") { + sql(""" + |CREATE TABLE t_map_group USING parquet AS + |SELECT map(cast(id as string), cast(id + 100 as string)) as m + |FROM range(5) + """.stripMargin) + checkSparkAnswer(sql("SELECT m, count(*) FROM t_map_group GROUP BY m")) + } + } + test("map_from_entries - fallback for binary type") { def fallbackReason(reason: String) = reason val table = "t2" diff --git a/spark/src/test/scala/org/apache/comet/exec/CometColumnarShuffleSuite.scala b/spark/src/test/scala/org/apache/comet/exec/CometColumnarShuffleSuite.scala index 7d865829b6..fe2866da66 100644 --- a/spark/src/test/scala/org/apache/comet/exec/CometColumnarShuffleSuite.scala +++ b/spark/src/test/scala/org/apache/comet/exec/CometColumnarShuffleSuite.scala @@ -152,6 +152,10 @@ abstract class CometColumnarShuffleSuite extends CometTestBase with AdaptiveSpar } test("columnar shuffle on array/struct map key/value") { + // Spark 4.0 normalizes maps used as shuffle keys with mapsort(...). Comet's map_sort + // relies on Arrow's sort_to_indices, which only supports scalar key types, so a map + // with array or struct keys cannot be sorted natively and the shuffle falls back. + val complexKeyShuffles = if (isSpark40Plus) 0 else 1 Seq("false", "true").foreach { execEnabled => Seq(10, 201).foreach { numPartitions => Seq("1.0", "10.0").foreach { ratio => @@ -164,13 +168,7 @@ abstract class CometColumnarShuffleSuite extends CometTestBase with AdaptiveSpar .repartition(numPartitions, $"_1", $"_2") .sortWithinPartitions($"_2") - if (isSpark40Plus) { - // https://github.com/apache/datafusion-comet/issues/1941 - // Spark 4.0 introduces a mapsort which falls back - checkShuffleAnswer(df, 0) - } else { - checkShuffleAnswer(df, 1) - } + checkShuffleAnswer(df, complexKeyShuffles) } withParquetTable((0 until 50).map(i => (Map(i -> Seq(i, i + 1)), i + 1)), "tbl") { @@ -179,13 +177,7 @@ abstract class CometColumnarShuffleSuite extends CometTestBase with AdaptiveSpar .repartition(numPartitions, $"_1", $"_2") .sortWithinPartitions($"_2") - if (isSpark40Plus) { - // https://github.com/apache/datafusion-comet/issues/1941 - // Spark 4.0 introduces a mapsort which falls back - checkShuffleAnswer(df, 0) - } else { - checkShuffleAnswer(df, 1) - } + checkShuffleAnswer(df, 1) } withParquetTable((0 until 50).map(i => (Map((i, i.toString) -> i), i + 1)), "tbl") { @@ -194,13 +186,7 @@ abstract class CometColumnarShuffleSuite extends CometTestBase with AdaptiveSpar .repartition(numPartitions, $"_1", $"_2") .sortWithinPartitions($"_2") - if (isSpark40Plus) { - // https://github.com/apache/datafusion-comet/issues/1941 - // Spark 4.0 introduces a mapsort which falls back - checkShuffleAnswer(df, 0) - } else { - checkShuffleAnswer(df, 1) - } + checkShuffleAnswer(df, complexKeyShuffles) } withParquetTable((0 until 50).map(i => (Map(i -> (i, i.toString)), i + 1)), "tbl") { @@ -209,13 +195,7 @@ abstract class CometColumnarShuffleSuite extends CometTestBase with AdaptiveSpar .repartition(numPartitions, $"_1", $"_2") .sortWithinPartitions($"_2") - if (isSpark40Plus) { - // https://github.com/apache/datafusion-comet/issues/1941 - // Spark 4.0 introduces a mapsort which falls back - checkShuffleAnswer(df, 0) - } else { - checkShuffleAnswer(df, 1) - } + checkShuffleAnswer(df, 1) } } } @@ -238,13 +218,11 @@ abstract class CometColumnarShuffleSuite extends CometTestBase with AdaptiveSpar .repartition(numPartitions, $"_1", $"_2") .sortWithinPartitions($"_2") - if (isSpark40Plus) { - // https://github.com/apache/datafusion-comet/issues/1941 - // Spark 4.0 introduces a mapsort which falls back - checkShuffleAnswer(df, 0) - } else { - checkShuffleAnswer(df, 1) - } + // Spark 4.0 normalizes shuffle keys containing array via + // transform(arr, x -> mapsort(x)), which Comet doesn't yet + // support, so the shuffle falls back to Spark. + val expectedShuffles = if (isSpark40Plus) 0 else 1 + checkShuffleAnswer(df, expectedShuffles) } } } @@ -344,78 +322,54 @@ abstract class CometColumnarShuffleSuite extends CometTestBase with AdaptiveSpar } test("columnar shuffle on map [bool]") { - // https://github.com/apache/datafusion-comet/issues/1941 - assume(!isSpark40Plus) columnarShuffleOnMapTest(50, Seq(true, false)) } test("columnar shuffle on map [byte]") { - // https://github.com/apache/datafusion-comet/issues/1941 - assume(!isSpark40Plus) columnarShuffleOnMapTest(50, Seq(0.toByte, 1.toByte)) } test("columnar shuffle on map [short]") { - // https://github.com/apache/datafusion-comet/issues/1941 - assume(!isSpark40Plus) columnarShuffleOnMapTest(50, Seq(0.toShort, 1.toShort)) } test("columnar shuffle on map [int]") { - // https://github.com/apache/datafusion-comet/issues/1941 - assume(!isSpark40Plus) columnarShuffleOnMapTest(50, Seq(0, 1)) } test("columnar shuffle on map [long]") { - // https://github.com/apache/datafusion-comet/issues/1941 - assume(!isSpark40Plus) columnarShuffleOnMapTest(50, Seq(0.toLong, 1.toLong)) } test("columnar shuffle on map [float]") { - // https://github.com/apache/datafusion-comet/issues/1941 - assume(!isSpark40Plus) columnarShuffleOnMapTest(50, Seq(0.toFloat, 1.toFloat)) } test("columnar shuffle on map [double]") { - // https://github.com/apache/datafusion-comet/issues/1941 - assume(!isSpark40Plus) columnarShuffleOnMapTest(50, Seq(0.toDouble, 1.toDouble)) } test("columnar shuffle on map [date]") { - // https://github.com/apache/datafusion-comet/issues/1941 - assume(!isSpark40Plus) columnarShuffleOnMapTest(50, Seq(new java.sql.Date(0.toLong), new java.sql.Date(1.toLong))) } test("columnar shuffle on map [timestamp]") { - // https://github.com/apache/datafusion-comet/issues/1941 - assume(!isSpark40Plus) columnarShuffleOnMapTest( 50, Seq(new java.sql.Timestamp(0.toLong), new java.sql.Timestamp(1.toLong))) } test("columnar shuffle on map [decimal]") { - // https://github.com/apache/datafusion-comet/issues/1941 - assume(!isSpark40Plus) columnarShuffleOnMapTest( 50, Seq(new java.math.BigDecimal(0.toLong), new java.math.BigDecimal(1.toLong))) } test("columnar shuffle on map [string]") { - // https://github.com/apache/datafusion-comet/issues/1941 - assume(!isSpark40Plus) columnarShuffleOnMapTest(50, Seq(0.toString, 1.toString)) } test("columnar shuffle on map [binary]") { - // https://github.com/apache/datafusion-comet/issues/1941 - assume(!isSpark40Plus) columnarShuffleOnMapTest(50, Seq(0.toString.getBytes(), 1.toString.getBytes())) }