From a081649fc95d96e551e68707a22b2b008f972954 Mon Sep 17 00:00:00 2001 From: Stefan Kandic Date: Thu, 8 Feb 2024 14:56:18 +0100 Subject: [PATCH 01/27] initial working version --- .../catalyst/analysis/FunctionRegistry.scala | 1 + .../expressions/collectionOperations.scala | 165 ++++++++++++++++++ .../CollectionExpressionsSuite.scala | 40 +++++ 3 files changed, 206 insertions(+) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala index b165d20d0b4fa..18f4eadc0cec0 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala @@ -696,6 +696,7 @@ object FunctionRegistry { expression[MapEntries]("map_entries"), expression[MapFromEntries]("map_from_entries"), expression[MapConcat]("map_concat"), + expression[SortMap]("sort_map"), expression[Size]("size"), expression[Slice]("slice"), expression[Size]("cardinality", true), diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala index a090bdf2bebf6..8088c97ded3b7 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala @@ -888,6 +888,171 @@ case class MapFromEntries(child: Expression) copy(child = newChild) } +@ExpressionDescription( + usage = """ + _FUNC_(map[, ascendingOrder]) - Sorts the input map in ascending or descending order + according to the natural ordering of the map keys. + """, + examples = """ + Examples: + > SELECT _FUNC_(map(3, 'c', 1, 'a', 2, 'b'), true); + {1:"a",2:"b",3:"c"} + """, + group = "map_funcs", + since = "4.0.0") +case class SortMap(base: Expression, ascendingOrder: Expression) + extends BinaryExpression with NullIntolerant with QueryErrorsBase { + + def this(e: Expression) = this(e, Literal(true)) + + val keyType: DataType = base.dataType.asInstanceOf[MapType].keyType + val valueType: DataType = base.dataType.asInstanceOf[MapType].valueType + + override def left: Expression = base + override def right: Expression = ascendingOrder + override def dataType: DataType = base.dataType + + override def checkInputDataTypes(): TypeCheckResult = base.dataType match { + case MapType(kt, _, _) if RowOrdering.isOrderable(kt) => + ascendingOrder match { + case Literal(_: Boolean, BooleanType) => + TypeCheckResult.TypeCheckSuccess + case _ => + DataTypeMismatch( + errorSubClass = "UNEXPECTED_INPUT_TYPE", + messageParameters = Map( + "paramIndex" -> "2", + "requiredType" -> toSQLType(BooleanType), + "inputSql" -> toSQLExpr(ascendingOrder), + "inputType" -> toSQLType(ascendingOrder.dataType)) + ) + } + case MapType(_, _, _) => + DataTypeMismatch( + errorSubClass = "INVALID_ORDERING_TYPE", + messageParameters = Map( + "functionName" -> toSQLId(prettyName), + "dataType" -> toSQLType(base.dataType) + ) + ) + case _ => + DataTypeMismatch( + errorSubClass = "UNEXPECTED_INPUT_TYPE", + messageParameters = Map( + "paramIndex" -> "1", + "requiredType" -> toSQLType(ArrayType), + "inputSql" -> toSQLExpr(base), + "inputType" -> toSQLType(base.dataType)) + ) + } + + override def nullSafeEval(array: Any, ascending: Any): Any = { + // put keys in a tree map and then read them back to build new k/v arrays + + val mapData = array.asInstanceOf[MapData] + val numElements = mapData.numElements() + val keys = mapData.keyArray() + val values = mapData.valueArray() + + val ordering = if (ascending.asInstanceOf[Boolean]) { + PhysicalDataType.ordering(keyType) + } else { + PhysicalDataType.ordering(keyType).reverse + } + + val treeMap = mutable.TreeMap.empty[Any, Int](ordering) + for (i <- 0 until numElements) { + treeMap.put(keys.get(i, keyType), i) + } + + val newKeys = new Array[Any](numElements) + val newValues = new Array[Any](numElements) + + treeMap.zipWithIndex.foreach { case ((_, originalIndex), sortedIndex) => + newKeys(sortedIndex) = keys.get(originalIndex, keyType) + newValues(sortedIndex) = values.get(originalIndex, valueType) + } + + new ArrayBasedMapData(new GenericArrayData(newKeys), new GenericArrayData(newValues)) + } + + override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { + nullSafeCodeGen(ctx, ev, (b, order) => sortCodegen(ctx, ev, b, order)) + } + + private def sortCodegen(ctx: CodegenContext, ev: ExprCode, + base: String, order: String): String = { + + val arrayBasedMapData = classOf[ArrayBasedMapData].getName + val genericArrayData = classOf[GenericArrayData].getName + + val numElements = ctx.freshName("numElements") + val keys = ctx.freshName("keys") + val values = ctx.freshName("values") + val treeMap = ctx.freshName("treeMap") + val i = ctx.freshName("i") + val o1 = ctx.freshName("o1") + val o2 = ctx.freshName("o2") + val c = ctx.freshName("c") + val newKeys = ctx.freshName("newKeys") + val newValues = ctx.freshName("newValues") + val mapEntry = ctx.freshName("mapEntry") + val originalIndex = ctx.freshName("originalIndex") + + val boxedKeyType = CodeGenerator.boxedType(keyType) + val javaKeyType = CodeGenerator.javaType(keyType) + + val comp = if (CodeGenerator.isPrimitiveType(keyType)) { + val v1 = ctx.freshName("v1") + val v2 = ctx.freshName("v2") + s""" + |$javaKeyType $v1 = (($boxedKeyType) $o1).${javaKeyType}Value(); + |$javaKeyType $v2 = (($boxedKeyType) $o2).${javaKeyType}Value(); + |int $c = ${ctx.genComp(keyType, v1, v2)}; + """.stripMargin + } else { + s"int $c = ${ctx.genComp(keyType, s"(($javaKeyType) $o1)", s"(($javaKeyType) $o2)")};" + } + + s""" + |final int $numElements = $base.numElements(); + |ArrayData $keys = $base.keyArray(); + |ArrayData $values = $base.valueArray(); + | + |java.util.TreeMap<$boxedKeyType, Integer> $treeMap = new java.util.TreeMap<>( + | new java.util.Comparator() { + | @Override public int compare(Object $o1, Object $o2) { + | $comp; + | return $order ? $c : -$c; + | } + | } + |); + | + |for (int $i = 0; $i < $numElements; $i++) { + | $treeMap.put(${CodeGenerator.getValue(keys, keyType, i)}, $i); + |} + | + |Object[] $newKeys = new Object[$numElements]; + |Object[] $newValues = new Object[$numElements]; + | + |int $i = 0; + |for (java.util.Map.Entry<$boxedKeyType, Integer> $mapEntry : $treeMap.entrySet()) { + | int $originalIndex = (Integer) $mapEntry.getValue(); + | $newKeys[$i] = ${CodeGenerator.getValue(keys, keyType, originalIndex)}; + | $newValues[$i] = ${CodeGenerator.getValue(values, valueType, originalIndex)}; + | $i++; + |} + | + |${ev.value} = new $arrayBasedMapData( + | new $genericArrayData($newKeys), new $genericArrayData($newValues)); + |""".stripMargin + } + + override def prettyName: String = "sort_map" + + override protected def withNewChildrenInternal(newLeft: Expression, newRight: Expression) + : SortMap = copy(base = newLeft, ascendingOrder = newRight) +} /** * Common base class for [[SortArray]] and [[ArraySort]]. diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala index 133e27c5b0a66..302703d4497a2 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala @@ -421,6 +421,46 @@ class CollectionExpressionsSuite ) } + test("Sort Map") { + val intKey = Literal.create(Map(2 -> 2, 1 -> 1, 3 -> 3), MapType(IntegerType, IntegerType)) + val boolKey = Literal.create(Map(true -> 2, false -> 1), MapType(BooleanType, IntegerType)) + val stringKey = Literal.create(Map("2" -> 2, "1" -> 1, "3" -> 3), + MapType(StringType, IntegerType)) + val arrayKey = Literal.create(Map(Seq(2) -> 2, Seq(1) -> 1, Seq(3) -> 3), + MapType(ArrayType(IntegerType), IntegerType)) + val nestedArrayKey = Literal.create(Map(Seq(Seq(2)) -> 2, Seq(Seq(1)) -> 1, Seq(Seq(3)) -> 3), + MapType(ArrayType(ArrayType(IntegerType)), IntegerType)) + val structKey = Literal.create( + Map(create_row(2) -> 2, create_row(1) -> 1, create_row(3) -> 3), + MapType(StructType(Seq(StructField("a", IntegerType))), IntegerType)) + + checkEvaluation(new SortMap(intKey), Map(1 -> 1, 2 -> 2, 3 -> 3)) + checkEvaluation(SortMap(intKey, Literal.create(false, BooleanType)), + Map(3 -> 3, 2 -> 2, 1 -> 1)) + + checkEvaluation(new SortMap(boolKey), Map(false -> 1, true -> 2)) + checkEvaluation(SortMap(boolKey, Literal.create(false, BooleanType)), + Map(true -> 2, false -> 1)) + + checkEvaluation(new SortMap(stringKey), Map("1" -> 1, "2" -> 2, "3" -> 3)) + checkEvaluation(SortMap(stringKey, Literal.create(false, BooleanType)), + Map("3" -> 3, "2" -> 2, "1" -> 1)) + + checkEvaluation(new SortMap(arrayKey), Map(Seq(1) -> 1, Seq(2) -> 2, Seq(3) -> 3)) + checkEvaluation(SortMap(arrayKey, Literal.create(false, BooleanType)), + Map(Seq(3) -> 3, Seq(2) -> 2, Seq(1) -> 1)) + + checkEvaluation(new SortMap(nestedArrayKey), + Map(Seq(Seq(1)) -> 1, Seq(Seq(2)) -> 2, Seq(Seq(3)) -> 3)) + checkEvaluation(SortMap(nestedArrayKey, Literal.create(false, BooleanType)), + Map(Seq(Seq(3)) -> 3, Seq(Seq(2)) -> 2, Seq(Seq(1)) -> 1)) + + checkEvaluation(new SortMap(structKey), + Map(create_row(1) -> 1, create_row(2) -> 2, create_row(3) -> 3)) + checkEvaluation(SortMap(structKey, Literal.create(false, BooleanType)), + Map(create_row(3) -> 3, create_row(2) -> 2, create_row(1) -> 1)) + } + test("Sort Array") { val a0 = Literal.create(Seq(2, 1, 3), ArrayType(IntegerType)) val a1 = Literal.create(Seq[Integer](), ArrayType(IntegerType)) From 1441549ed1fbee0188c32a6f3c44cb05d2e3470d Mon Sep 17 00:00:00 2001 From: Stefan Kandic Date: Sun, 11 Feb 2024 01:20:34 +0100 Subject: [PATCH 02/27] add golden files --- .../src/test/resources/sql-functions/sql-expression-schema.md | 1 + 1 file changed, 1 insertion(+) diff --git a/sql/core/src/test/resources/sql-functions/sql-expression-schema.md b/sql/core/src/test/resources/sql-functions/sql-expression-schema.md index e20db3b49589c..4714e4f70668b 100644 --- a/sql/core/src/test/resources/sql-functions/sql-expression-schema.md +++ b/sql/core/src/test/resources/sql-functions/sql-expression-schema.md @@ -301,6 +301,7 @@ | org.apache.spark.sql.catalyst.expressions.Size | size | SELECT size(array('b', 'd', 'c', 'a')) | struct | | org.apache.spark.sql.catalyst.expressions.Slice | slice | SELECT slice(array(1, 2, 3, 4), 2, 2) | struct> | | org.apache.spark.sql.catalyst.expressions.SortArray | sort_array | SELECT sort_array(array('b', 'd', null, 'c', 'a'), true) | struct> | +| org.apache.spark.sql.catalyst.expressions.SortMap | sort_map | SELECT sort_map(map(3, 'c', 1, 'a', 2, 'b'), true) | struct> | | org.apache.spark.sql.catalyst.expressions.SoundEx | soundex | SELECT soundex('Miller') | struct | | org.apache.spark.sql.catalyst.expressions.SparkPartitionID | spark_partition_id | SELECT spark_partition_id() | struct | | org.apache.spark.sql.catalyst.expressions.SparkVersion | version | SELECT version() | struct | From 1be06e37e26b27e9f2e66bdc8260cb9d8abf9d81 Mon Sep 17 00:00:00 2001 From: Stefan Kandic Date: Wed, 14 Feb 2024 15:27:50 +0100 Subject: [PATCH 03/27] add map sort to other languages --- R/pkg/NAMESPACE | 1 + R/pkg/R/functions.R | 17 +++++++ R/pkg/R/generics.R | 4 ++ R/pkg/tests/fulltests/test_sparkSQL.R | 6 +++ .../org/apache/spark/sql/functions.scala | 19 ++++++++ .../spark/sql/PlanGenerationTestSuite.scala | 4 ++ .../reference/pyspark.sql/functions.rst | 1 + .../pyspark/sql/connect/functions/builtin.py | 7 +++ python/pyspark/sql/functions/builtin.py | 47 +++++++++++++++++++ python/pyspark/sql/tests/test_functions.py | 7 +++ .../catalyst/analysis/FunctionRegistry.scala | 2 +- .../expressions/collectionOperations.scala | 6 +-- .../CollectionExpressionsSuite.scala | 24 +++++----- .../org/apache/spark/sql/functions.scala | 7 +++ .../sql-functions/sql-expression-schema.md | 2 +- 15 files changed, 137 insertions(+), 17 deletions(-) diff --git a/R/pkg/NAMESPACE b/R/pkg/NAMESPACE index c5668d1739b17..bdbcfa552448b 100644 --- a/R/pkg/NAMESPACE +++ b/R/pkg/NAMESPACE @@ -361,6 +361,7 @@ exportMethods("%<=>%", "map_keys", "map_values", "map_zip_with", + "map_sort", "max", "max_by", "md5", diff --git a/R/pkg/R/functions.R b/R/pkg/R/functions.R index 5106a83bd0ec4..e3452d71682cc 100644 --- a/R/pkg/R/functions.R +++ b/R/pkg/R/functions.R @@ -4523,6 +4523,23 @@ setMethod("map_zip_with", ) }) +#' @details +#' \code{sort_array}: Sorts the input map in ascending or descending order according to +#' the natural ordering of the map keys. +#' +#' @rdname column_collection_functions +#' @param asc a logical flag indicating the sorting order. +#' TRUE, sorting is in ascending order. +#' FALSE, sorting is in descending order. +#' @aliases map_sort map_sort,Column-method +#' @note sort_array since 4.0.0 +setMethod("map_sort", + signature(x = "Column"), + function(x, asc = TRUE) { + jc <- callJStatic("org.apache.spark.sql.functions", "map_sort", x@jc, asc) + column(jc) + } + #' @details #' \code{element_at}: Returns element of array at given index in \code{extraction} if #' \code{x} is array. Returns value for the given key in \code{extraction} if \code{x} is map. diff --git a/R/pkg/R/generics.R b/R/pkg/R/generics.R index 26e81733055a6..2004530da88cb 100644 --- a/R/pkg/R/generics.R +++ b/R/pkg/R/generics.R @@ -1216,6 +1216,10 @@ setGeneric("map_values", function(x) { standardGeneric("map_values") }) #' @name NULL setGeneric("map_zip_with", function(x, y, f) { standardGeneric("map_zip_with") }) +#' @rdname column_collection_functions +#' @name NULL +setGeneric("map_sort", function(x, asc = TRUE) { standardGeneric("map_sort") }) + #' @rdname column_aggregate_functions #' @name NULL setGeneric("max_by", function(x, y) { standardGeneric("max_by") }) diff --git a/R/pkg/tests/fulltests/test_sparkSQL.R b/R/pkg/tests/fulltests/test_sparkSQL.R index 630781a57e444..652b81d7b7532 100644 --- a/R/pkg/tests/fulltests/test_sparkSQL.R +++ b/R/pkg/tests/fulltests/test_sparkSQL.R @@ -1646,6 +1646,12 @@ test_that("column functions", { expected_entries <- list(as.environment(list(x = 1, y = 2, a = 3, b = 4))) expect_equal(result, expected_entries) + # Test map_sort + df <- createDataFrame(list(map1 = as.environment(list(c = 3, a = 1, b = 2)))) + result <- collect(select(df, map_concat(df[[1]])))[[1]] + expected_entries <- list(as.environment(list(a = 1, b = 2, c = 3))) + expect_equal(result, expected_entries) + # Test map_entries(), map_keys(), map_values() and element_at() df <- createDataFrame(list(list(map = as.environment(list(x = 1, y = 2))))) result <- collect(select(df, map_entries(df$map)))[[1]] diff --git a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/functions.scala b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/functions.scala index 133b7e036cd7c..cad72d7da24aa 100644 --- a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/functions.scala +++ b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/functions.scala @@ -7081,6 +7081,25 @@ object functions { */ def sort_array(e: Column, asc: Boolean): Column = Column.fn("sort_array", e, lit(asc)) + /** + * Sorts the input map in ascending order according to the natural ordering + * of the map keys. + * + * @group map_funcs + * @since 4.0.0 + */ + def map_sort(e: Column): Column = map_sort(e, asc = true) + + + /** + * Sorts the input map in ascending or descending order according to the natural ordering + * of the map keys. + * + * @group map_funcs + * @since 4.0.0 + */ + def map_sort(e: Column, asc: Boolean): Column = Column.fn("map_sort", e, lit(asc)) + /** * Returns the minimum value in the array. NaN is greater than any non-NaN elements for * double/float type. NULL elements are skipped. diff --git a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/PlanGenerationTestSuite.scala b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/PlanGenerationTestSuite.scala index ee98a1aceea38..6fbee02997275 100644 --- a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/PlanGenerationTestSuite.scala +++ b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/PlanGenerationTestSuite.scala @@ -2525,6 +2525,10 @@ class PlanGenerationTestSuite fn.map_from_entries(fn.transform(fn.col("e"), (x, i) => fn.struct(i, x))) } + functionTest("map_sort") { + fn.map_sort(fn.col("f")) + } + functionTest("arrays_zip") { fn.arrays_zip(fn.col("e"), fn.sequence(lit(1), lit(20))) } diff --git a/python/docs/source/reference/pyspark.sql/functions.rst b/python/docs/source/reference/pyspark.sql/functions.rst index ca20ccfb73c56..438e1e7a9a88d 100644 --- a/python/docs/source/reference/pyspark.sql/functions.rst +++ b/python/docs/source/reference/pyspark.sql/functions.rst @@ -394,6 +394,7 @@ Map Functions map_from_entries map_keys map_values + map_sort str_to_map diff --git a/python/pyspark/sql/connect/functions/builtin.py b/python/pyspark/sql/connect/functions/builtin.py index 72adfec33b1d6..53f3c537cbc41 100644 --- a/python/pyspark/sql/connect/functions/builtin.py +++ b/python/pyspark/sql/connect/functions/builtin.py @@ -2004,6 +2004,13 @@ def map_values(col: "ColumnOrName") -> Column: map_values.__doc__ = pysparkfuncs.map_values.__doc__ +def map_sort(col: "ColumnOrName") -> Column: + return _invoke_function_over_columns("map_sort", col) + + +map_sort.__doc__ = pysparkfuncs.map_sort.__doc__ + + def map_zip_with( col1: "ColumnOrName", col2: "ColumnOrName", diff --git a/python/pyspark/sql/functions/builtin.py b/python/pyspark/sql/functions/builtin.py index 6320f9b922eef..226cca3f87f7b 100644 --- a/python/pyspark/sql/functions/builtin.py +++ b/python/pyspark/sql/functions/builtin.py @@ -16839,6 +16839,53 @@ def map_concat( cols = cols[0] # type: ignore[assignment] return _invoke_function_over_seq_of_columns("map_concat", cols) # type: ignore[arg-type] +@_try_remote_functions +def map_sort(col: "ColumnOrName", asc: bool = True) -> Column: + """ + Map function: Sorts the input map in ascending or descending order according + to the natural ordering of the map keys. + + .. versionadded:: 4.0.0 + + Parameters + ---------- + col : :class:`~pyspark.sql.Column` or str + Name of the column or expression. + asc : bool, optional + Whether to sort in ascending or descending order. If `asc` is True (default), + then the sorting is in ascending order. If False, then in descending order. + + Returns + ------- + :class:`~pyspark.sql.Column` + Sorted map. + + Examples + -------- + Example 1: Sorting a map in ascending order + + >>> import pyspark.sql.functions as sf + >>> df = spark.sql("SELECT map(3, 'c', 1, 'a', 2, 'b') as data") + >>> df.select(sf.map_sort(df.data)).show() + +------------------------+ + | map_sort(data, true)| + +------------------------+ + |{1 -> a, 2 -> b, 3 -> c}| + +------------------------+ + + Example 2: Sorting a map in descending order + + >>> import pyspark.sql.functions as sf + >>> df = spark.sql("SELECT map(3, 'c', 1, 'a', 2, 'b') as data") + >>> df.select(sf.map_sort(df.data, false)).show() + +------------------------+ + | map_sort(data, true)| + +------------------------+ + |{3 -> c, 2 -> b, 1 -> a}| + +------------------------+ + """ + return _invoke_function("map_sort", _to_java_column(col), asc) + @_try_remote_functions def sequence( diff --git a/python/pyspark/sql/tests/test_functions.py b/python/pyspark/sql/tests/test_functions.py index a736832c8ef99..74e8a5f2a90e1 100644 --- a/python/pyspark/sql/tests/test_functions.py +++ b/python/pyspark/sql/tests/test_functions.py @@ -1440,6 +1440,13 @@ def test_map_concat(self): {1: "a", 2: "b", 3: "c"}, ) + def test_map_sort(self): + df = self.spark.sql("SELECT map(3, 'c', 1, 'a', 2, 'b') as map1") + self.assertEqual( + df.select(F.map_sort("map1").alias("map2")).first()[0], + {1: "a", 2: "b", 3: "c"}, + ) + def test_version(self): self.assertIsInstance(self.spark.range(1).select(F.version()).first()[0], str) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala index 18f4eadc0cec0..f64f88cfd9b65 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala @@ -696,7 +696,7 @@ object FunctionRegistry { expression[MapEntries]("map_entries"), expression[MapFromEntries]("map_from_entries"), expression[MapConcat]("map_concat"), - expression[SortMap]("sort_map"), + expression[MapSort]("map_sort"), expression[Size]("size"), expression[Slice]("slice"), expression[Size]("cardinality", true), diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala index 8088c97ded3b7..fae74bb1580ac 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala @@ -900,7 +900,7 @@ case class MapFromEntries(child: Expression) """, group = "map_funcs", since = "4.0.0") -case class SortMap(base: Expression, ascendingOrder: Expression) +case class MapSort(base: Expression, ascendingOrder: Expression) extends BinaryExpression with NullIntolerant with QueryErrorsBase { def this(e: Expression) = this(e, Literal(true)) @@ -1048,10 +1048,10 @@ case class SortMap(base: Expression, ascendingOrder: Expression) |""".stripMargin } - override def prettyName: String = "sort_map" + override def prettyName: String = "map_sort" override protected def withNewChildrenInternal(newLeft: Expression, newRight: Expression) - : SortMap = copy(base = newLeft, ascendingOrder = newRight) + : MapSort = copy(base = newLeft, ascendingOrder = newRight) } /** diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala index 302703d4497a2..3063b83d4dca1 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala @@ -434,30 +434,30 @@ class CollectionExpressionsSuite Map(create_row(2) -> 2, create_row(1) -> 1, create_row(3) -> 3), MapType(StructType(Seq(StructField("a", IntegerType))), IntegerType)) - checkEvaluation(new SortMap(intKey), Map(1 -> 1, 2 -> 2, 3 -> 3)) - checkEvaluation(SortMap(intKey, Literal.create(false, BooleanType)), + checkEvaluation(new MapSort(intKey), Map(1 -> 1, 2 -> 2, 3 -> 3)) + checkEvaluation(MapSort(intKey, Literal.create(false, BooleanType)), Map(3 -> 3, 2 -> 2, 1 -> 1)) - checkEvaluation(new SortMap(boolKey), Map(false -> 1, true -> 2)) - checkEvaluation(SortMap(boolKey, Literal.create(false, BooleanType)), + checkEvaluation(new MapSort(boolKey), Map(false -> 1, true -> 2)) + checkEvaluation(MapSort(boolKey, Literal.create(false, BooleanType)), Map(true -> 2, false -> 1)) - checkEvaluation(new SortMap(stringKey), Map("1" -> 1, "2" -> 2, "3" -> 3)) - checkEvaluation(SortMap(stringKey, Literal.create(false, BooleanType)), + checkEvaluation(new MapSort(stringKey), Map("1" -> 1, "2" -> 2, "3" -> 3)) + checkEvaluation(MapSort(stringKey, Literal.create(false, BooleanType)), Map("3" -> 3, "2" -> 2, "1" -> 1)) - checkEvaluation(new SortMap(arrayKey), Map(Seq(1) -> 1, Seq(2) -> 2, Seq(3) -> 3)) - checkEvaluation(SortMap(arrayKey, Literal.create(false, BooleanType)), + checkEvaluation(new MapSort(arrayKey), Map(Seq(1) -> 1, Seq(2) -> 2, Seq(3) -> 3)) + checkEvaluation(MapSort(arrayKey, Literal.create(false, BooleanType)), Map(Seq(3) -> 3, Seq(2) -> 2, Seq(1) -> 1)) - checkEvaluation(new SortMap(nestedArrayKey), + checkEvaluation(new MapSort(nestedArrayKey), Map(Seq(Seq(1)) -> 1, Seq(Seq(2)) -> 2, Seq(Seq(3)) -> 3)) - checkEvaluation(SortMap(nestedArrayKey, Literal.create(false, BooleanType)), + checkEvaluation(MapSort(nestedArrayKey, Literal.create(false, BooleanType)), Map(Seq(Seq(3)) -> 3, Seq(Seq(2)) -> 2, Seq(Seq(1)) -> 1)) - checkEvaluation(new SortMap(structKey), + checkEvaluation(new MapSort(structKey), Map(create_row(1) -> 1, create_row(2) -> 2, create_row(3) -> 3)) - checkEvaluation(SortMap(structKey, Literal.create(false, BooleanType)), + checkEvaluation(MapSort(structKey, Literal.create(false, BooleanType)), Map(create_row(3) -> 3, create_row(2) -> 2, create_row(1) -> 1)) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala index 933d0b3f89a7e..cd3d182841278 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala @@ -6986,6 +6986,13 @@ object functions { @scala.annotation.varargs def map_concat(cols: Column*): Column = Column.fn("map_concat", cols: _*) + /** + * Sorts the input map in ascending order based on the natural order of map keys. + * @group map_funcs + * @since 4.0.0 + */ + def map_sort(e: Column): Column = Column.fn("map_sort", e) + // scalastyle:off line.size.limit /** * Parses a column containing a CSV string into a `StructType` with the specified schema. diff --git a/sql/core/src/test/resources/sql-functions/sql-expression-schema.md b/sql/core/src/test/resources/sql-functions/sql-expression-schema.md index 4714e4f70668b..999cd68738484 100644 --- a/sql/core/src/test/resources/sql-functions/sql-expression-schema.md +++ b/sql/core/src/test/resources/sql-functions/sql-expression-schema.md @@ -215,6 +215,7 @@ | org.apache.spark.sql.catalyst.expressions.MapFromArrays | map_from_arrays | SELECT map_from_arrays(array(1.0, 3.0), array('2', '4')) | struct> | | org.apache.spark.sql.catalyst.expressions.MapFromEntries | map_from_entries | SELECT map_from_entries(array(struct(1, 'a'), struct(2, 'b'))) | struct> | | org.apache.spark.sql.catalyst.expressions.MapKeys | map_keys | SELECT map_keys(map(1, 'a', 2, 'b')) | struct> | +| org.apache.spark.sql.catalyst.expressions.MapSort | map_sort | SELECT map_sort(map(3, 'c', 1, 'a', 2, 'b'), true) | struct> | | org.apache.spark.sql.catalyst.expressions.MapValues | map_values | SELECT map_values(map(1, 'a', 2, 'b')) | struct> | | org.apache.spark.sql.catalyst.expressions.MapZipWith | map_zip_with | SELECT map_zip_with(map(1, 'a', 2, 'b'), map(1, 'x', 2, 'y'), (k, v1, v2) -> concat(v1, v2)) | struct> | | org.apache.spark.sql.catalyst.expressions.MaskExpressionBuilder | mask | SELECT mask('abcd-EFGH-8765-4321') | struct | @@ -301,7 +302,6 @@ | org.apache.spark.sql.catalyst.expressions.Size | size | SELECT size(array('b', 'd', 'c', 'a')) | struct | | org.apache.spark.sql.catalyst.expressions.Slice | slice | SELECT slice(array(1, 2, 3, 4), 2, 2) | struct> | | org.apache.spark.sql.catalyst.expressions.SortArray | sort_array | SELECT sort_array(array('b', 'd', null, 'c', 'a'), true) | struct> | -| org.apache.spark.sql.catalyst.expressions.SortMap | sort_map | SELECT sort_map(map(3, 'c', 1, 'a', 2, 'b'), true) | struct> | | org.apache.spark.sql.catalyst.expressions.SoundEx | soundex | SELECT soundex('Miller') | struct | | org.apache.spark.sql.catalyst.expressions.SparkPartitionID | spark_partition_id | SELECT spark_partition_id() | struct | | org.apache.spark.sql.catalyst.expressions.SparkVersion | version | SELECT version() | struct | From 249e903d596d3803d4c8bfbbe9b6ecce7b29042c Mon Sep 17 00:00:00 2001 From: Stefan Kandic Date: Wed, 28 Feb 2024 15:35:51 +0100 Subject: [PATCH 04/27] fix typoes --- R/pkg/R/functions.R | 2 +- python/pyspark/sql/functions/builtin.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/R/pkg/R/functions.R b/R/pkg/R/functions.R index e3452d71682cc..69ea77b87b2e0 100644 --- a/R/pkg/R/functions.R +++ b/R/pkg/R/functions.R @@ -4538,7 +4538,7 @@ setMethod("map_sort", function(x, asc = TRUE) { jc <- callJStatic("org.apache.spark.sql.functions", "map_sort", x@jc, asc) column(jc) - } + }) #' @details #' \code{element_at}: Returns element of array at given index in \code{extraction} if diff --git a/python/pyspark/sql/functions/builtin.py b/python/pyspark/sql/functions/builtin.py index 226cca3f87f7b..cc6cf3f0126be 100644 --- a/python/pyspark/sql/functions/builtin.py +++ b/python/pyspark/sql/functions/builtin.py @@ -16877,7 +16877,7 @@ def map_sort(col: "ColumnOrName", asc: bool = True) -> Column: >>> import pyspark.sql.functions as sf >>> df = spark.sql("SELECT map(3, 'c', 1, 'a', 2, 'b') as data") - >>> df.select(sf.map_sort(df.data, false)).show() + >>> df.select(sf.map_sort(df.data, False)).show() +------------------------+ | map_sort(data, true)| +------------------------+ From aaae8835463d25bf3d04a014115ae6177ce2e0c6 Mon Sep 17 00:00:00 2001 From: Stefan Kandic Date: Wed, 28 Feb 2024 15:43:24 +0100 Subject: [PATCH 05/27] fix scalastyle issue --- .../src/main/scala/org/apache/spark/sql/functions.scala | 8 +++----- 1 file changed, 3 insertions(+), 5 deletions(-) diff --git a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/functions.scala b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/functions.scala index cad72d7da24aa..15d8f4253eb92 100644 --- a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/functions.scala +++ b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/functions.scala @@ -7082,18 +7082,16 @@ object functions { def sort_array(e: Column, asc: Boolean): Column = Column.fn("sort_array", e, lit(asc)) /** - * Sorts the input map in ascending order according to the natural ordering - * of the map keys. + * Sorts the input map in ascending order according to the natural ordering of the map keys. * * @group map_funcs * @since 4.0.0 */ def map_sort(e: Column): Column = map_sort(e, asc = true) - /** - * Sorts the input map in ascending or descending order according to the natural ordering - * of the map keys. + * Sorts the input map in ascending or descending order according to the natural ordering of the + * map keys. * * @group map_funcs * @since 4.0.0 From acaf95e3cdd767cd0708c6fe16e416c13ef6600c Mon Sep 17 00:00:00 2001 From: Stefan Kandic Date: Wed, 28 Feb 2024 15:53:00 +0100 Subject: [PATCH 06/27] add proto golden files --- .../queries/function_map_sort.json | 29 ++++++++++++++++++ .../queries/function_map_sort.proto.bin | Bin 0 -> 183 bytes 2 files changed, 29 insertions(+) create mode 100644 connector/connect/common/src/test/resources/query-tests/queries/function_map_sort.json create mode 100644 connector/connect/common/src/test/resources/query-tests/queries/function_map_sort.proto.bin diff --git a/connector/connect/common/src/test/resources/query-tests/queries/function_map_sort.json b/connector/connect/common/src/test/resources/query-tests/queries/function_map_sort.json new file mode 100644 index 0000000000000..81a9788d0fbae --- /dev/null +++ b/connector/connect/common/src/test/resources/query-tests/queries/function_map_sort.json @@ -0,0 +1,29 @@ +{ + "common": { + "planId": "1" + }, + "project": { + "input": { + "common": { + "planId": "0" + }, + "localRelation": { + "schema": "struct\u003cid:bigint,a:int,b:double,d:struct\u003cid:bigint,a:int,b:double\u003e,e:array\u003cint\u003e,f:map\u003cstring,struct\u003cid:bigint,a:int,b:double\u003e\u003e,g:string\u003e" + } + }, + "expressions": [{ + "unresolvedFunction": { + "functionName": "map_sort", + "arguments": [{ + "unresolvedAttribute": { + "unparsedIdentifier": "f" + } + }, { + "literal": { + "boolean": true + } + }] + } + }] + } +} \ No newline at end of file diff --git a/connector/connect/common/src/test/resources/query-tests/queries/function_map_sort.proto.bin b/connector/connect/common/src/test/resources/query-tests/queries/function_map_sort.proto.bin new file mode 100644 index 0000000000000000000000000000000000000000..57b823a5712988205bd9b2b37ce7f274fe5cdf62 GIT binary patch literal 183 zcmd;L5@3|tz{oX;k&8)yA*!2EsDrV%q^LBx#3nPvDk(EPGp|G^(F#N+S*7HcCgr5+ zq*xJ9VW*R7l~`1iSZM>)XQz{9m77>#1Jsk5m##xdtDR0d$atVqJ1I#iaV`#^-uUAD Tq7oriA!aVdG$9r)CJ9CWE( Date: Wed, 28 Feb 2024 17:26:05 +0100 Subject: [PATCH 07/27] fix python function call --- .../pyspark/sql/connect/functions/builtin.py | 4 +-- .../org/apache/spark/sql/functions.scala | 13 +++++++++- .../spark/sql/DataFrameFunctionsSuite.scala | 25 +++++++++++++++++++ 3 files changed, 39 insertions(+), 3 deletions(-) diff --git a/python/pyspark/sql/connect/functions/builtin.py b/python/pyspark/sql/connect/functions/builtin.py index 53f3c537cbc41..318e6f7887699 100644 --- a/python/pyspark/sql/connect/functions/builtin.py +++ b/python/pyspark/sql/connect/functions/builtin.py @@ -2004,8 +2004,8 @@ def map_values(col: "ColumnOrName") -> Column: map_values.__doc__ = pysparkfuncs.map_values.__doc__ -def map_sort(col: "ColumnOrName") -> Column: - return _invoke_function_over_columns("map_sort", col) +def map_sort(col: "ColumnOrName", asc: bool = True) -> Column: + return _invoke_function("map_sort", _to_col(col), lit(asc)) map_sort.__doc__ = pysparkfuncs.map_sort.__doc__ diff --git a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala index cd3d182841278..2bc6db58333e3 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala @@ -6988,10 +6988,21 @@ object functions { /** * Sorts the input map in ascending order based on the natural order of map keys. + * + * @group map_funcs + * @since 4.0.0 + */ + def map_sort(e: Column): Column = map_sort(e, asc = true) + // TODO: add test for this + + /** + * Sorts the input map in ascending or descending order according to the natural ordering + * of the map keys. + * * @group map_funcs * @since 4.0.0 */ - def map_sort(e: Column): Column = Column.fn("map_sort", e) + def map_sort(e: Column, asc: Boolean): Column = Column.fn("map_sort", e, lit(asc)) // scalastyle:off line.size.limit /** diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala index e42f397cbfc29..cac0107e2443b 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala @@ -780,6 +780,31 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSparkSession { ) } + test("map_sort function") { + val df1 = Seq( + Map[Int, Int](2 -> 2, 1 -> 1, 3 -> 3) + ).toDF("a") + + checkAnswer( + df1.selectExpr("map_sort(a)"), + Seq( + Row(Map(1 -> 1, 2 -> 2, 3 -> 3)) + ) + ) + checkAnswer( + df1.selectExpr("map_sort(a, true)"), + Seq( + Row(Map(1 -> 1, 2 -> 2, 3 -> 3)) + ) + ) + checkAnswer( + df1.select(map_sort($"a", asc = false)), + Seq( + Row(Map(3 -> 3, 2 -> 2, 1 -> 1)) + ) + ) + } + test("sort_array/array_sort functions") { val df = Seq( (Array[Int](2, 1, 3), Array("b", "c", "a")), From 7754c14c4deb54acba8e75f76371d5f56f8795f7 Mon Sep 17 00:00:00 2001 From: Stefan Kandic Date: Thu, 29 Feb 2024 09:48:23 +0100 Subject: [PATCH 08/27] fix ci errors --- R/pkg/tests/fulltests/test_sparkSQL.R | 2 +- .../query-tests/explain-results/function_map_sort.explain | 2 ++ python/pyspark/sql/functions/builtin.py | 7 ++++--- 3 files changed, 7 insertions(+), 4 deletions(-) create mode 100644 connector/connect/common/src/test/resources/query-tests/explain-results/function_map_sort.explain diff --git a/R/pkg/tests/fulltests/test_sparkSQL.R b/R/pkg/tests/fulltests/test_sparkSQL.R index 652b81d7b7532..fa87106c1f144 100644 --- a/R/pkg/tests/fulltests/test_sparkSQL.R +++ b/R/pkg/tests/fulltests/test_sparkSQL.R @@ -1648,7 +1648,7 @@ test_that("column functions", { # Test map_sort df <- createDataFrame(list(map1 = as.environment(list(c = 3, a = 1, b = 2)))) - result <- collect(select(df, map_concat(df[[1]])))[[1]] + result <- collect(select(df, map_sort(df[[1]])))[[1]] expected_entries <- list(as.environment(list(a = 1, b = 2, c = 3))) expect_equal(result, expected_entries) diff --git a/connector/connect/common/src/test/resources/query-tests/explain-results/function_map_sort.explain b/connector/connect/common/src/test/resources/query-tests/explain-results/function_map_sort.explain new file mode 100644 index 0000000000000..069b2ce65d187 --- /dev/null +++ b/connector/connect/common/src/test/resources/query-tests/explain-results/function_map_sort.explain @@ -0,0 +1,2 @@ +Project [map_sort(f#0, true) AS map_sort(f, true)#0] ++- LocalRelation , [id#0L, a#0, b#0, d#0, e#0, f#0, g#0] diff --git a/python/pyspark/sql/functions/builtin.py b/python/pyspark/sql/functions/builtin.py index cc6cf3f0126be..61bfa4db79d44 100644 --- a/python/pyspark/sql/functions/builtin.py +++ b/python/pyspark/sql/functions/builtin.py @@ -16839,6 +16839,7 @@ def map_concat( cols = cols[0] # type: ignore[assignment] return _invoke_function_over_seq_of_columns("map_concat", cols) # type: ignore[arg-type] + @_try_remote_functions def map_sort(col: "ColumnOrName", asc: bool = True) -> Column: """ @@ -16870,7 +16871,7 @@ def map_sort(col: "ColumnOrName", asc: bool = True) -> Column: +------------------------+ | map_sort(data, true)| +------------------------+ - |{1 -> a, 2 -> b, 3 -> c}| + | {1 -> a, 2 -> b, ...| +------------------------+ Example 2: Sorting a map in descending order @@ -16879,9 +16880,9 @@ def map_sort(col: "ColumnOrName", asc: bool = True) -> Column: >>> df = spark.sql("SELECT map(3, 'c', 1, 'a', 2, 'b') as data") >>> df.select(sf.map_sort(df.data, False)).show() +------------------------+ - | map_sort(data, true)| + | map_sort(data, false)| +------------------------+ - |{3 -> c, 2 -> b, 1 -> a}| + | {3 -> c, 2 -> b, ...| +------------------------+ """ return _invoke_function("map_sort", _to_java_column(col), asc) From f0ebf5dc5a4d9d118babb6865e7d7871d2f44d0b Mon Sep 17 00:00:00 2001 From: Stefan Kandic Date: Thu, 29 Feb 2024 14:09:26 +0100 Subject: [PATCH 09/27] fix ci checks --- R/pkg/R/functions.R | 4 ++-- python/pyspark/sql/functions/builtin.py | 20 ++++++++++---------- 2 files changed, 12 insertions(+), 12 deletions(-) diff --git a/R/pkg/R/functions.R b/R/pkg/R/functions.R index 69ea77b87b2e0..143277eab1417 100644 --- a/R/pkg/R/functions.R +++ b/R/pkg/R/functions.R @@ -4524,7 +4524,7 @@ setMethod("map_zip_with", }) #' @details -#' \code{sort_array}: Sorts the input map in ascending or descending order according to +#' \code{map_sort}: Sorts the input map in ascending or descending order according to #' the natural ordering of the map keys. #' #' @rdname column_collection_functions @@ -4532,7 +4532,7 @@ setMethod("map_zip_with", #' TRUE, sorting is in ascending order. #' FALSE, sorting is in descending order. #' @aliases map_sort map_sort,Column-method -#' @note sort_array since 4.0.0 +#' @note map_sort since 4.0.0 setMethod("map_sort", signature(x = "Column"), function(x, asc = TRUE) { diff --git a/python/pyspark/sql/functions/builtin.py b/python/pyspark/sql/functions/builtin.py index 61bfa4db79d44..0167f7fd2be93 100644 --- a/python/pyspark/sql/functions/builtin.py +++ b/python/pyspark/sql/functions/builtin.py @@ -16868,22 +16868,22 @@ def map_sort(col: "ColumnOrName", asc: bool = True) -> Column: >>> import pyspark.sql.functions as sf >>> df = spark.sql("SELECT map(3, 'c', 1, 'a', 2, 'b') as data") >>> df.select(sf.map_sort(df.data)).show() - +------------------------+ - | map_sort(data, true)| - +------------------------+ - | {1 -> a, 2 -> b, ...| - +------------------------+ + +--------------------+ + |map_sort(data, true)| + +--------------------+ + |{1 -> a, 2 -> b, ...| + +--------------------+ Example 2: Sorting a map in descending order >>> import pyspark.sql.functions as sf >>> df = spark.sql("SELECT map(3, 'c', 1, 'a', 2, 'b') as data") >>> df.select(sf.map_sort(df.data, False)).show() - +------------------------+ - | map_sort(data, false)| - +------------------------+ - | {3 -> c, 2 -> b, ...| - +------------------------+ + +---------------------+ + |map_sort(data, false)| + +---------------------+ + | {3 -> c, 2 -> b, ...| + +---------------------+ """ return _invoke_function("map_sort", _to_java_column(col), asc) From 1f78167886a4cc2dee132732a0678d6d503967a0 Mon Sep 17 00:00:00 2001 From: Stevo Mitric Date: Tue, 12 Mar 2024 17:44:38 +0100 Subject: [PATCH 10/27] Optimized map-sort by switching to array sorting --- .../expressions/collectionOperations.scala | 50 ++++++++++--------- 1 file changed, 27 insertions(+), 23 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala index fae74bb1580ac..b095fe483f25a 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala @@ -947,7 +947,8 @@ case class MapSort(base: Expression, ascendingOrder: Expression) } override def nullSafeEval(array: Any, ascending: Any): Any = { - // put keys in a tree map and then read them back to build new k/v arrays + // put keys and their respective indices inside a tuple + // and sort them to extract new order k/v pairs val mapData = array.asInstanceOf[MapData] val numElements = mapData.numElements() @@ -960,17 +961,16 @@ case class MapSort(base: Expression, ascendingOrder: Expression) PhysicalDataType.ordering(keyType).reverse } - val treeMap = mutable.TreeMap.empty[Any, Int](ordering) - for (i <- 0 until numElements) { - treeMap.put(keys.get(i, keyType), i) - } + val sortedKeys = Array + .tabulate(numElements)(i => (keys.get(i, keyType).asInstanceOf[Any], i)) + .sortBy(_._1)(ordering) val newKeys = new Array[Any](numElements) val newValues = new Array[Any](numElements) - treeMap.zipWithIndex.foreach { case ((_, originalIndex), sortedIndex) => - newKeys(sortedIndex) = keys.get(originalIndex, keyType) - newValues(sortedIndex) = values.get(originalIndex, valueType) + sortedKeys.zipWithIndex.foreach { case (elem, index) => + newKeys(index) = keys.get(elem._2, keyType) + newValues(index) = values.get(elem._2, valueType) } new ArrayBasedMapData(new GenericArrayData(newKeys), new GenericArrayData(newValues)) @@ -989,19 +989,22 @@ case class MapSort(base: Expression, ascendingOrder: Expression) val numElements = ctx.freshName("numElements") val keys = ctx.freshName("keys") val values = ctx.freshName("values") - val treeMap = ctx.freshName("treeMap") + val sortArray = ctx.freshName("sortArray") val i = ctx.freshName("i") val o1 = ctx.freshName("o1") + val o1entry = ctx.freshName("o1entry") val o2 = ctx.freshName("o2") + val o2entry = ctx.freshName("o2entry") val c = ctx.freshName("c") val newKeys = ctx.freshName("newKeys") val newValues = ctx.freshName("newValues") - val mapEntry = ctx.freshName("mapEntry") val originalIndex = ctx.freshName("originalIndex") val boxedKeyType = CodeGenerator.boxedType(keyType) val javaKeyType = CodeGenerator.javaType(keyType) + val simpleEntryType = s"java.util.AbstractMap.SimpleEntry<$boxedKeyType, Integer>" + val comp = if (CodeGenerator.isPrimitiveType(keyType)) { val v1 = ctx.freshName("v1") val v2 = ctx.freshName("v2") @@ -1019,28 +1022,29 @@ case class MapSort(base: Expression, ascendingOrder: Expression) |ArrayData $keys = $base.keyArray(); |ArrayData $values = $base.valueArray(); | - |java.util.TreeMap<$boxedKeyType, Integer> $treeMap = new java.util.TreeMap<>( - | new java.util.Comparator() { - | @Override public int compare(Object $o1, Object $o2) { - | $comp; - | return $order ? $c : -$c; - | } - | } - |); + |Object[] $sortArray = new Object[$numElements]; | |for (int $i = 0; $i < $numElements; $i++) { - | $treeMap.put(${CodeGenerator.getValue(keys, keyType, i)}, $i); + | $sortArray[$i] = new $simpleEntryType( + | ${CodeGenerator.getValue(keys, keyType, i)}, $i); |} | + |java.util.Arrays.sort($sortArray, new java.util.Comparator() { + | @Override public int compare(Object $o1entry, Object $o2entry) { + | Object $o1 = (($simpleEntryType) $o1entry).getKey(); + | Object $o2 = (($simpleEntryType) $o2entry).getKey(); + | $comp; + | return $order ? $c : -$c; + | } + |}); + | |Object[] $newKeys = new Object[$numElements]; |Object[] $newValues = new Object[$numElements]; | - |int $i = 0; - |for (java.util.Map.Entry<$boxedKeyType, Integer> $mapEntry : $treeMap.entrySet()) { - | int $originalIndex = (Integer) $mapEntry.getValue(); + |for (int $i = 0; $i < $numElements; $i++) { + | int $originalIndex = (Integer) ((($simpleEntryType) $sortArray[$i]).getValue()); | $newKeys[$i] = ${CodeGenerator.getValue(keys, keyType, originalIndex)}; | $newValues[$i] = ${CodeGenerator.getValue(values, valueType, originalIndex)}; - | $i++; |} | |${ev.value} = new $arrayBasedMapData( From a5eb4807903f6dc8fbbc972b11359956044bb761 Mon Sep 17 00:00:00 2001 From: Stevo Mitric Date: Wed, 13 Mar 2024 10:17:36 +0100 Subject: [PATCH 11/27] Potential tests fix --- R/pkg/tests/fulltests/test_sparkSQL.R | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/R/pkg/tests/fulltests/test_sparkSQL.R b/R/pkg/tests/fulltests/test_sparkSQL.R index fa87106c1f144..ca3353f4eb899 100644 --- a/R/pkg/tests/fulltests/test_sparkSQL.R +++ b/R/pkg/tests/fulltests/test_sparkSQL.R @@ -1647,7 +1647,7 @@ test_that("column functions", { expect_equal(result, expected_entries) # Test map_sort - df <- createDataFrame(list(map1 = as.environment(list(c = 3, a = 1, b = 2)))) + df <- createDataFrame(list(List(map1 = as.environment(list(c = 3, a = 1, b = 2))))) result <- collect(select(df, map_sort(df[[1]])))[[1]] expected_entries <- list(as.environment(list(a = 1, b = 2, c = 3))) expect_equal(result, expected_entries) From 9497f998b3ebb14a7cfc910100524af640530305 Mon Sep 17 00:00:00 2001 From: Stevo Mitric Date: Wed, 13 Mar 2024 10:53:15 +0100 Subject: [PATCH 12/27] Potential tests fix 2 --- R/pkg/tests/fulltests/test_sparkSQL.R | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/R/pkg/tests/fulltests/test_sparkSQL.R b/R/pkg/tests/fulltests/test_sparkSQL.R index ca3353f4eb899..75fe342d4d487 100644 --- a/R/pkg/tests/fulltests/test_sparkSQL.R +++ b/R/pkg/tests/fulltests/test_sparkSQL.R @@ -1647,7 +1647,7 @@ test_that("column functions", { expect_equal(result, expected_entries) # Test map_sort - df <- createDataFrame(list(List(map1 = as.environment(list(c = 3, a = 1, b = 2))))) + df <- createDataFrame(list(list(map1 = as.environment(list(c = 3, a = 1, b = 2))))) result <- collect(select(df, map_sort(df[[1]])))[[1]] expected_entries <- list(as.environment(list(a = 1, b = 2, c = 3))) expect_equal(result, expected_entries) From 5e7a033c0d02d4609f728f35bd6f2cc7245c82fe Mon Sep 17 00:00:00 2001 From: Stevo Mitric Date: Sun, 17 Mar 2024 12:25:56 +0100 Subject: [PATCH 13/27] Removed TODOs and changed parmIndex to ordinal --- .../sql/catalyst/expressions/collectionOperations.scala | 6 +++--- .../src/main/scala/org/apache/spark/sql/functions.scala | 1 - 2 files changed, 3 insertions(+), 4 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala index b095fe483f25a..f6e05ba624a67 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala @@ -921,7 +921,7 @@ case class MapSort(base: Expression, ascendingOrder: Expression) DataTypeMismatch( errorSubClass = "UNEXPECTED_INPUT_TYPE", messageParameters = Map( - "paramIndex" -> "2", + "paramIndex" -> ordinalNumber(1), "requiredType" -> toSQLType(BooleanType), "inputSql" -> toSQLExpr(ascendingOrder), "inputType" -> toSQLType(ascendingOrder.dataType)) @@ -939,8 +939,8 @@ case class MapSort(base: Expression, ascendingOrder: Expression) DataTypeMismatch( errorSubClass = "UNEXPECTED_INPUT_TYPE", messageParameters = Map( - "paramIndex" -> "1", - "requiredType" -> toSQLType(ArrayType), + "paramIndex" -> ordinalNumber(0), + "requiredType" -> toSQLType(MapType), "inputSql" -> toSQLExpr(base), "inputType" -> toSQLType(base.dataType)) ) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala index 2bc6db58333e3..dc987e1083dc8 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala @@ -6993,7 +6993,6 @@ object functions { * @since 4.0.0 */ def map_sort(e: Column): Column = map_sort(e, asc = true) - // TODO: add test for this /** * Sorts the input map in ascending or descending order according to the natural ordering From ab70f1e44e1672f08d5ae42cd5cd8c33c5ea1f7f Mon Sep 17 00:00:00 2001 From: Stevo Mitric Date: Mon, 18 Mar 2024 10:36:05 +0100 Subject: [PATCH 14/27] Shortened map sort function and added more docs --- .../expressions/collectionOperations.scala | 27 ++++++++++--------- 1 file changed, 15 insertions(+), 12 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala index f6e05ba624a67..f02f444ec7f7d 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala @@ -891,7 +891,16 @@ case class MapFromEntries(child: Expression) @ExpressionDescription( usage = """ _FUNC_(map[, ascendingOrder]) - Sorts the input map in ascending or descending order - according to the natural ordering of the map keys. + according to the natural ordering of the map keys. The sorting algorithm used is + an adaptive, stable and iterative merge sort algorithm. If the input map is empty, + function returns an empty map. + """, + arguments = + """ + Arguments: + * map - an expression. The map that will be sorted. + * ascendingOrder - an expression. The ordering in which the map will be sorted. + This can be either ascending or descending element order. """, examples = """ Examples: @@ -961,19 +970,13 @@ case class MapSort(base: Expression, ascendingOrder: Expression) PhysicalDataType.ordering(keyType).reverse } - val sortedKeys = Array - .tabulate(numElements)(i => (keys.get(i, keyType).asInstanceOf[Any], i)) + val sortedMap = Array + .tabulate(numElements)(i => (keys.get(i, keyType).asInstanceOf[Any], + values.get(i, valueType).asInstanceOf[Any])) .sortBy(_._1)(ordering) - val newKeys = new Array[Any](numElements) - val newValues = new Array[Any](numElements) - - sortedKeys.zipWithIndex.foreach { case (elem, index) => - newKeys(index) = keys.get(elem._2, keyType) - newValues(index) = values.get(elem._2, valueType) - } - - new ArrayBasedMapData(new GenericArrayData(newKeys), new GenericArrayData(newValues)) + new ArrayBasedMapData(new GenericArrayData(sortedMap.map(_._1)), + new GenericArrayData(sortedMap.map(_._2))) } override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { From e79d65cbba4087d166cc1ec859f7ba6dc01e9aef Mon Sep 17 00:00:00 2001 From: Stevo Mitric Date: Mon, 18 Mar 2024 13:38:03 +0100 Subject: [PATCH 15/27] updated map_sort test suite --- .../spark/sql/DataFrameFunctionsSuite.scala | 66 ++++++++++++++++++- 1 file changed, 65 insertions(+), 1 deletion(-) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala index cac0107e2443b..6034bd5cc9cb2 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala @@ -25,7 +25,7 @@ import java.sql.{Date, Timestamp} import scala.util.Random import org.apache.spark.{SPARK_DOC_ROOT, SparkException, SparkRuntimeException} -import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.{ExtendedAnalysisException, InternalRow} import org.apache.spark.sql.catalyst.analysis.{FunctionRegistry, UnresolvedAttribute} import org.apache.spark.sql.catalyst.expressions.{Alias, ArraysZip, AttributeReference, Expression, NamedExpression, UnaryExpression} import org.apache.spark.sql.catalyst.expressions.Cast._ @@ -803,6 +803,70 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSparkSession { Row(Map(3 -> 3, 2 -> 2, 1 -> 1)) ) ) + + val df2 = Seq(Map.empty[Int, Int]).toDF("a") + + checkAnswer( + df2.selectExpr("map_sort(a, true)"), + Seq(Row(Map())) + ) + + checkError( + exception = intercept[AnalysisException] { + df2.orderBy("a") + }, + errorClass = "DATATYPE_MISMATCH.INVALID_ORDERING_TYPE", + parameters = Map( + "functionName" -> "`sortorder`", + "dataType" -> "\"MAP\"", + "sqlExpr" -> "\"a ASC NULLS FIRST\"") + ) + + checkError( + exception = intercept[SparkRuntimeException] { + sql("SELECT map_sort(map(null, 1))").collect() + }, + errorClass = "NULL_MAP_KEY" + ) + + checkError( + exception = intercept[SparkRuntimeException] { + sql("SELECT map_sort(map(1, 1, 2, 2, 1, 1))").collect() + }, + errorClass = "DUPLICATED_MAP_KEY", + parameters = Map( + "key" -> "1", + "mapKeyDedupPolicy" -> "\"spark.sql.mapKeyDedupPolicy\"" + ) + ) + + checkError( + exception = intercept[ExtendedAnalysisException] { + sql("SELECT map_sort(map(1,1,2,2), \"asc\")").collect() + }, + errorClass = "DATATYPE_MISMATCH.UNEXPECTED_INPUT_TYPE", + parameters = Map( + "sqlExpr" -> "\"map_sort(map(1, 1, 2, 2), asc)\"", + "paramIndex" -> "second", "inputSql" -> "\"asc\"", + "inputType" -> "\"STRING\"", + "requiredType" -> "\"BOOLEAN\"" + ), + queryContext = Array(ExpectedContext("", "", 7, 35, "map_sort(map(1,1,2,2), \"asc\")")) + ) + + checkError( + exception = intercept[ExtendedAnalysisException] { + sql("SELECT map_sort(map(1,1,2,2), \"asc\")").collect() + }, + errorClass = "DATATYPE_MISMATCH.UNEXPECTED_INPUT_TYPE", + parameters = Map( + "sqlExpr" -> "\"map_sort(map(1, 1, 2, 2), asc)\"", + "paramIndex" -> "second", "inputSql" -> "\"asc\"", + "inputType" -> "\"STRING\"", + "requiredType" -> "\"BOOLEAN\"" + ), + queryContext = Array(ExpectedContext("", "", 7, 35, "map_sort(map(1,1,2,2), \"asc\")")) + ) } test("sort_array/array_sort functions") { From a43535539ba8b3f5a81aa6200ffc791964bd3f33 Mon Sep 17 00:00:00 2001 From: Stevo Mitric Date: Mon, 18 Mar 2024 16:08:19 +0100 Subject: [PATCH 16/27] Update sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala Co-authored-by: Maxim Gekk --- .../scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala index 6034bd5cc9cb2..67ebe7a89e124 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala @@ -847,7 +847,8 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSparkSession { errorClass = "DATATYPE_MISMATCH.UNEXPECTED_INPUT_TYPE", parameters = Map( "sqlExpr" -> "\"map_sort(map(1, 1, 2, 2), asc)\"", - "paramIndex" -> "second", "inputSql" -> "\"asc\"", + "paramIndex" -> "second", + "inputSql" -> "\"asc\"", "inputType" -> "\"STRING\"", "requiredType" -> "\"BOOLEAN\"" ), From c9901d08f83cc60961993fe3a64acebba168ea89 Mon Sep 17 00:00:00 2001 From: Stevo Mitric Date: Mon, 18 Mar 2024 16:08:35 +0100 Subject: [PATCH 17/27] Update sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala Co-authored-by: Maxim Gekk --- .../scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala index 67ebe7a89e124..54e1ec53c25c6 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala @@ -862,7 +862,8 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSparkSession { errorClass = "DATATYPE_MISMATCH.UNEXPECTED_INPUT_TYPE", parameters = Map( "sqlExpr" -> "\"map_sort(map(1, 1, 2, 2), asc)\"", - "paramIndex" -> "second", "inputSql" -> "\"asc\"", + "paramIndex" -> "second", + "inputSql" -> "\"asc\"", "inputType" -> "\"STRING\"", "requiredType" -> "\"BOOLEAN\"" ), From da6a710b7ddb56068562ddccfddc6329ec16f4d9 Mon Sep 17 00:00:00 2001 From: Stevo Mitric Date: Mon, 18 Mar 2024 16:27:48 +0100 Subject: [PATCH 18/27] docs fix --- .../catalyst/expressions/collectionOperations.scala | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala index f02f444ec7f7d..6d1e72b5970e3 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala @@ -891,16 +891,16 @@ case class MapFromEntries(child: Expression) @ExpressionDescription( usage = """ _FUNC_(map[, ascendingOrder]) - Sorts the input map in ascending or descending order - according to the natural ordering of the map keys. The sorting algorithm used is - an adaptive, stable and iterative merge sort algorithm. If the input map is empty, - function returns an empty map. + according to the natural ordering of the map keys. The algorithm used for sorting is + an adaptive, stable and iterative algorithm. If the input map is empty, function + returns an empty map. """, arguments = """ Arguments: - * map - an expression. The map that will be sorted. - * ascendingOrder - an expression. The ordering in which the map will be sorted. - This can be either ascending or descending element order. + * map - The map that will be sorted. + * ascendingOrder - A boolean value describing the order in which the map will be sorted. + This can be either be ascending (true) or descending (false). """, examples = """ Examples: From 81008c218917e2b590111b132ed0b7008bedf305 Mon Sep 17 00:00:00 2001 From: Stevo Mitric Date: Tue, 19 Mar 2024 09:26:16 +0100 Subject: [PATCH 19/27] Updated codegen and removed once test-case --- .../expressions/collectionOperations.scala | 15 ++++++++------- .../spark/sql/DataFrameFunctionsSuite.scala | 11 ----------- 2 files changed, 8 insertions(+), 18 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala index 6d1e72b5970e3..896620eec9ac8 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala @@ -956,8 +956,8 @@ case class MapSort(base: Expression, ascendingOrder: Expression) } override def nullSafeEval(array: Any, ascending: Any): Any = { - // put keys and their respective indices inside a tuple - // and sort them to extract new order k/v pairs + // put keys and their respective values inside a tuple and sort them + // according to the key ordering. Extract the new sorted k/v pairs to form a sorted map val mapData = array.asInstanceOf[MapData] val numElements = mapData.numElements() @@ -1004,9 +1004,10 @@ case class MapSort(base: Expression, ascendingOrder: Expression) val originalIndex = ctx.freshName("originalIndex") val boxedKeyType = CodeGenerator.boxedType(keyType) + val boxedValueType = CodeGenerator.boxedType(valueType) val javaKeyType = CodeGenerator.javaType(keyType) - val simpleEntryType = s"java.util.AbstractMap.SimpleEntry<$boxedKeyType, Integer>" + val simpleEntryType = s"java.util.AbstractMap.SimpleEntry<$boxedKeyType, $boxedValueType>" val comp = if (CodeGenerator.isPrimitiveType(keyType)) { val v1 = ctx.freshName("v1") @@ -1029,7 +1030,8 @@ case class MapSort(base: Expression, ascendingOrder: Expression) | |for (int $i = 0; $i < $numElements; $i++) { | $sortArray[$i] = new $simpleEntryType( - | ${CodeGenerator.getValue(keys, keyType, i)}, $i); + | ${CodeGenerator.getValue(keys, keyType, i)}, + | ${CodeGenerator.getValue(values, valueType, i)}); |} | |java.util.Arrays.sort($sortArray, new java.util.Comparator() { @@ -1045,9 +1047,8 @@ case class MapSort(base: Expression, ascendingOrder: Expression) |Object[] $newValues = new Object[$numElements]; | |for (int $i = 0; $i < $numElements; $i++) { - | int $originalIndex = (Integer) ((($simpleEntryType) $sortArray[$i]).getValue()); - | $newKeys[$i] = ${CodeGenerator.getValue(keys, keyType, originalIndex)}; - | $newValues[$i] = ${CodeGenerator.getValue(values, valueType, originalIndex)}; + | $newKeys[$i] = (($simpleEntryType) $sortArray[$i]).getKey(); + | $newValues[$i] = (($simpleEntryType) $sortArray[$i]).getValue(); |} | |${ev.value} = new $arrayBasedMapData( diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala index 54e1ec53c25c6..e5953e59a51b1 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala @@ -829,17 +829,6 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSparkSession { errorClass = "NULL_MAP_KEY" ) - checkError( - exception = intercept[SparkRuntimeException] { - sql("SELECT map_sort(map(1, 1, 2, 2, 1, 1))").collect() - }, - errorClass = "DUPLICATED_MAP_KEY", - parameters = Map( - "key" -> "1", - "mapKeyDedupPolicy" -> "\"spark.sql.mapKeyDedupPolicy\"" - ) - ) - checkError( exception = intercept[ExtendedAnalysisException] { sql("SELECT map_sort(map(1,1,2,2), \"asc\")").collect() From 86b29c5ca6171529f18e611cda899f301c5aee0c Mon Sep 17 00:00:00 2001 From: Stevo Mitric Date: Tue, 19 Mar 2024 09:33:26 +0100 Subject: [PATCH 20/27] Update python/pyspark/sql/functions/builtin.py Co-authored-by: Ruifeng Zheng --- python/pyspark/sql/functions/builtin.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/pyspark/sql/functions/builtin.py b/python/pyspark/sql/functions/builtin.py index 0167f7fd2be93..0832f73785cd3 100644 --- a/python/pyspark/sql/functions/builtin.py +++ b/python/pyspark/sql/functions/builtin.py @@ -16867,7 +16867,7 @@ def map_sort(col: "ColumnOrName", asc: bool = True) -> Column: >>> import pyspark.sql.functions as sf >>> df = spark.sql("SELECT map(3, 'c', 1, 'a', 2, 'b') as data") - >>> df.select(sf.map_sort(df.data)).show() + >>> df.select(sf.map_sort(df.data)).show(truncate=False) +--------------------+ |map_sort(data, true)| +--------------------+ From c08ab6c027f3eaf4e6240d92c6d946909d0e570c Mon Sep 17 00:00:00 2001 From: Stevo Mitric Date: Tue, 19 Mar 2024 09:35:21 +0100 Subject: [PATCH 21/27] Updated 'select.show' to give more info in map_sort desc --- python/pyspark/sql/functions/builtin.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/pyspark/sql/functions/builtin.py b/python/pyspark/sql/functions/builtin.py index 0832f73785cd3..d206197996a94 100644 --- a/python/pyspark/sql/functions/builtin.py +++ b/python/pyspark/sql/functions/builtin.py @@ -16878,7 +16878,7 @@ def map_sort(col: "ColumnOrName", asc: bool = True) -> Column: >>> import pyspark.sql.functions as sf >>> df = spark.sql("SELECT map(3, 'c', 1, 'a', 2, 'b') as data") - >>> df.select(sf.map_sort(df.data, False)).show() + >>> df.select(sf.map_sort(df.data, False)).show(truncate=False) +---------------------+ |map_sort(data, false)| +---------------------+ From 31a797c34925fa651d0d911ce169335dcd75f4c2 Mon Sep 17 00:00:00 2001 From: Stevo Mitric Date: Tue, 19 Mar 2024 21:07:56 +0100 Subject: [PATCH 22/27] Restructured docs, removed unused variable and refactored code --- python/pyspark/sql/functions/builtin.py | 20 +++++++++---------- .../expressions/collectionOperations.scala | 5 ++--- 2 files changed, 12 insertions(+), 13 deletions(-) diff --git a/python/pyspark/sql/functions/builtin.py b/python/pyspark/sql/functions/builtin.py index d206197996a94..8710a7c6bb306 100644 --- a/python/pyspark/sql/functions/builtin.py +++ b/python/pyspark/sql/functions/builtin.py @@ -16868,22 +16868,22 @@ def map_sort(col: "ColumnOrName", asc: bool = True) -> Column: >>> import pyspark.sql.functions as sf >>> df = spark.sql("SELECT map(3, 'c', 1, 'a', 2, 'b') as data") >>> df.select(sf.map_sort(df.data)).show(truncate=False) - +--------------------+ - |map_sort(data, true)| - +--------------------+ - |{1 -> a, 2 -> b, ...| - +--------------------+ + +------------------------+ + |map_sort(data, true) | + +------------------------+ + |{1 -> a, 2 -> b, 3 -> c}| + +------------------------+ Example 2: Sorting a map in descending order >>> import pyspark.sql.functions as sf >>> df = spark.sql("SELECT map(3, 'c', 1, 'a', 2, 'b') as data") >>> df.select(sf.map_sort(df.data, False)).show(truncate=False) - +---------------------+ - |map_sort(data, false)| - +---------------------+ - | {3 -> c, 2 -> b, ...| - +---------------------+ + +------------------------+ + |map_sort(data, false) | + +------------------------+ + |{3 -> c, 2 -> b, 1 -> a}| + +------------------------+ """ return _invoke_function("map_sort", _to_java_column(col), asc) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala index 896620eec9ac8..3ed711d477621 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala @@ -922,7 +922,7 @@ case class MapSort(base: Expression, ascendingOrder: Expression) override def dataType: DataType = base.dataType override def checkInputDataTypes(): TypeCheckResult = base.dataType match { - case MapType(kt, _, _) if RowOrdering.isOrderable(kt) => + case m: MapType if RowOrdering.isOrderable(m.keyType) => ascendingOrder match { case Literal(_: Boolean, BooleanType) => TypeCheckResult.TypeCheckSuccess @@ -936,7 +936,7 @@ case class MapSort(base: Expression, ascendingOrder: Expression) "inputType" -> toSQLType(ascendingOrder.dataType)) ) } - case MapType(_, _, _) => + case _: MapType => DataTypeMismatch( errorSubClass = "INVALID_ORDERING_TYPE", messageParameters = Map( @@ -1001,7 +1001,6 @@ case class MapSort(base: Expression, ascendingOrder: Expression) val c = ctx.freshName("c") val newKeys = ctx.freshName("newKeys") val newValues = ctx.freshName("newValues") - val originalIndex = ctx.freshName("originalIndex") val boxedKeyType = CodeGenerator.boxedType(keyType) val boxedValueType = CodeGenerator.boxedType(valueType) From 69e3b48f7a8a539a3d9a1c968d5ff2e33e4b367c Mon Sep 17 00:00:00 2001 From: Stevo Mitric Date: Thu, 21 Mar 2024 12:20:48 +0100 Subject: [PATCH 23/27] Removed map_sort function but left the MapSort expression --- R/pkg/NAMESPACE | 1 - R/pkg/R/functions.R | 17 ------- R/pkg/tests/fulltests/test_sparkSQL.R | 6 --- .../org/apache/spark/sql/functions.scala | 17 ------- .../spark/sql/PlanGenerationTestSuite.scala | 4 -- .../explain-results/function_map_sort.explain | 2 - .../queries/function_map_sort.json | 29 ----------- .../queries/function_map_sort.proto.bin | Bin 183 -> 0 bytes .../reference/pyspark.sql/functions.rst | 1 - .../pyspark/sql/connect/functions/builtin.py | 7 --- python/pyspark/sql/functions/builtin.py | 48 ------------------ python/pyspark/sql/tests/test_functions.py | 7 --- .../catalyst/analysis/FunctionRegistry.scala | 1 - .../org/apache/spark/sql/functions.scala | 17 ------- .../sql-functions/sql-expression-schema.md | 1 - 15 files changed, 158 deletions(-) delete mode 100644 connector/connect/common/src/test/resources/query-tests/explain-results/function_map_sort.explain delete mode 100644 connector/connect/common/src/test/resources/query-tests/queries/function_map_sort.json delete mode 100644 connector/connect/common/src/test/resources/query-tests/queries/function_map_sort.proto.bin diff --git a/R/pkg/NAMESPACE b/R/pkg/NAMESPACE index bdbcfa552448b..c5668d1739b17 100644 --- a/R/pkg/NAMESPACE +++ b/R/pkg/NAMESPACE @@ -361,7 +361,6 @@ exportMethods("%<=>%", "map_keys", "map_values", "map_zip_with", - "map_sort", "max", "max_by", "md5", diff --git a/R/pkg/R/functions.R b/R/pkg/R/functions.R index 143277eab1417..5106a83bd0ec4 100644 --- a/R/pkg/R/functions.R +++ b/R/pkg/R/functions.R @@ -4523,23 +4523,6 @@ setMethod("map_zip_with", ) }) -#' @details -#' \code{map_sort}: Sorts the input map in ascending or descending order according to -#' the natural ordering of the map keys. -#' -#' @rdname column_collection_functions -#' @param asc a logical flag indicating the sorting order. -#' TRUE, sorting is in ascending order. -#' FALSE, sorting is in descending order. -#' @aliases map_sort map_sort,Column-method -#' @note map_sort since 4.0.0 -setMethod("map_sort", - signature(x = "Column"), - function(x, asc = TRUE) { - jc <- callJStatic("org.apache.spark.sql.functions", "map_sort", x@jc, asc) - column(jc) - }) - #' @details #' \code{element_at}: Returns element of array at given index in \code{extraction} if #' \code{x} is array. Returns value for the given key in \code{extraction} if \code{x} is map. diff --git a/R/pkg/tests/fulltests/test_sparkSQL.R b/R/pkg/tests/fulltests/test_sparkSQL.R index 75fe342d4d487..630781a57e444 100644 --- a/R/pkg/tests/fulltests/test_sparkSQL.R +++ b/R/pkg/tests/fulltests/test_sparkSQL.R @@ -1646,12 +1646,6 @@ test_that("column functions", { expected_entries <- list(as.environment(list(x = 1, y = 2, a = 3, b = 4))) expect_equal(result, expected_entries) - # Test map_sort - df <- createDataFrame(list(list(map1 = as.environment(list(c = 3, a = 1, b = 2))))) - result <- collect(select(df, map_sort(df[[1]])))[[1]] - expected_entries <- list(as.environment(list(a = 1, b = 2, c = 3))) - expect_equal(result, expected_entries) - # Test map_entries(), map_keys(), map_values() and element_at() df <- createDataFrame(list(list(map = as.environment(list(x = 1, y = 2))))) result <- collect(select(df, map_entries(df$map)))[[1]] diff --git a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/functions.scala b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/functions.scala index 15d8f4253eb92..133b7e036cd7c 100644 --- a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/functions.scala +++ b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/functions.scala @@ -7081,23 +7081,6 @@ object functions { */ def sort_array(e: Column, asc: Boolean): Column = Column.fn("sort_array", e, lit(asc)) - /** - * Sorts the input map in ascending order according to the natural ordering of the map keys. - * - * @group map_funcs - * @since 4.0.0 - */ - def map_sort(e: Column): Column = map_sort(e, asc = true) - - /** - * Sorts the input map in ascending or descending order according to the natural ordering of the - * map keys. - * - * @group map_funcs - * @since 4.0.0 - */ - def map_sort(e: Column, asc: Boolean): Column = Column.fn("map_sort", e, lit(asc)) - /** * Returns the minimum value in the array. NaN is greater than any non-NaN elements for * double/float type. NULL elements are skipped. diff --git a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/PlanGenerationTestSuite.scala b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/PlanGenerationTestSuite.scala index 6fbee02997275..ee98a1aceea38 100644 --- a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/PlanGenerationTestSuite.scala +++ b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/PlanGenerationTestSuite.scala @@ -2525,10 +2525,6 @@ class PlanGenerationTestSuite fn.map_from_entries(fn.transform(fn.col("e"), (x, i) => fn.struct(i, x))) } - functionTest("map_sort") { - fn.map_sort(fn.col("f")) - } - functionTest("arrays_zip") { fn.arrays_zip(fn.col("e"), fn.sequence(lit(1), lit(20))) } diff --git a/connector/connect/common/src/test/resources/query-tests/explain-results/function_map_sort.explain b/connector/connect/common/src/test/resources/query-tests/explain-results/function_map_sort.explain deleted file mode 100644 index 069b2ce65d187..0000000000000 --- a/connector/connect/common/src/test/resources/query-tests/explain-results/function_map_sort.explain +++ /dev/null @@ -1,2 +0,0 @@ -Project [map_sort(f#0, true) AS map_sort(f, true)#0] -+- LocalRelation , [id#0L, a#0, b#0, d#0, e#0, f#0, g#0] diff --git a/connector/connect/common/src/test/resources/query-tests/queries/function_map_sort.json b/connector/connect/common/src/test/resources/query-tests/queries/function_map_sort.json deleted file mode 100644 index 81a9788d0fbae..0000000000000 --- a/connector/connect/common/src/test/resources/query-tests/queries/function_map_sort.json +++ /dev/null @@ -1,29 +0,0 @@ -{ - "common": { - "planId": "1" - }, - "project": { - "input": { - "common": { - "planId": "0" - }, - "localRelation": { - "schema": "struct\u003cid:bigint,a:int,b:double,d:struct\u003cid:bigint,a:int,b:double\u003e,e:array\u003cint\u003e,f:map\u003cstring,struct\u003cid:bigint,a:int,b:double\u003e\u003e,g:string\u003e" - } - }, - "expressions": [{ - "unresolvedFunction": { - "functionName": "map_sort", - "arguments": [{ - "unresolvedAttribute": { - "unparsedIdentifier": "f" - } - }, { - "literal": { - "boolean": true - } - }] - } - }] - } -} \ No newline at end of file diff --git a/connector/connect/common/src/test/resources/query-tests/queries/function_map_sort.proto.bin b/connector/connect/common/src/test/resources/query-tests/queries/function_map_sort.proto.bin deleted file mode 100644 index 57b823a5712988205bd9b2b37ce7f274fe5cdf62..0000000000000000000000000000000000000000 GIT binary patch literal 0 HcmV?d00001 literal 183 zcmd;L5@3|tz{oX;k&8)yA*!2EsDrV%q^LBx#3nPvDk(EPGp|G^(F#N+S*7HcCgr5+ zq*xJ9VW*R7l~`1iSZM>)XQz{9m77>#1Jsk5m##xdtDR0d$atVqJ1I#iaV`#^-uUAD Tq7oriA!aVdG$9r)CJ9CWE( Column: map_values.__doc__ = pysparkfuncs.map_values.__doc__ -def map_sort(col: "ColumnOrName", asc: bool = True) -> Column: - return _invoke_function("map_sort", _to_col(col), lit(asc)) - - -map_sort.__doc__ = pysparkfuncs.map_sort.__doc__ - - def map_zip_with( col1: "ColumnOrName", col2: "ColumnOrName", diff --git a/python/pyspark/sql/functions/builtin.py b/python/pyspark/sql/functions/builtin.py index 8710a7c6bb306..6320f9b922eef 100644 --- a/python/pyspark/sql/functions/builtin.py +++ b/python/pyspark/sql/functions/builtin.py @@ -16840,54 +16840,6 @@ def map_concat( return _invoke_function_over_seq_of_columns("map_concat", cols) # type: ignore[arg-type] -@_try_remote_functions -def map_sort(col: "ColumnOrName", asc: bool = True) -> Column: - """ - Map function: Sorts the input map in ascending or descending order according - to the natural ordering of the map keys. - - .. versionadded:: 4.0.0 - - Parameters - ---------- - col : :class:`~pyspark.sql.Column` or str - Name of the column or expression. - asc : bool, optional - Whether to sort in ascending or descending order. If `asc` is True (default), - then the sorting is in ascending order. If False, then in descending order. - - Returns - ------- - :class:`~pyspark.sql.Column` - Sorted map. - - Examples - -------- - Example 1: Sorting a map in ascending order - - >>> import pyspark.sql.functions as sf - >>> df = spark.sql("SELECT map(3, 'c', 1, 'a', 2, 'b') as data") - >>> df.select(sf.map_sort(df.data)).show(truncate=False) - +------------------------+ - |map_sort(data, true) | - +------------------------+ - |{1 -> a, 2 -> b, 3 -> c}| - +------------------------+ - - Example 2: Sorting a map in descending order - - >>> import pyspark.sql.functions as sf - >>> df = spark.sql("SELECT map(3, 'c', 1, 'a', 2, 'b') as data") - >>> df.select(sf.map_sort(df.data, False)).show(truncate=False) - +------------------------+ - |map_sort(data, false) | - +------------------------+ - |{3 -> c, 2 -> b, 1 -> a}| - +------------------------+ - """ - return _invoke_function("map_sort", _to_java_column(col), asc) - - @_try_remote_functions def sequence( start: "ColumnOrName", stop: "ColumnOrName", step: Optional["ColumnOrName"] = None diff --git a/python/pyspark/sql/tests/test_functions.py b/python/pyspark/sql/tests/test_functions.py index 74e8a5f2a90e1..a736832c8ef99 100644 --- a/python/pyspark/sql/tests/test_functions.py +++ b/python/pyspark/sql/tests/test_functions.py @@ -1440,13 +1440,6 @@ def test_map_concat(self): {1: "a", 2: "b", 3: "c"}, ) - def test_map_sort(self): - df = self.spark.sql("SELECT map(3, 'c', 1, 'a', 2, 'b') as map1") - self.assertEqual( - df.select(F.map_sort("map1").alias("map2")).first()[0], - {1: "a", 2: "b", 3: "c"}, - ) - def test_version(self): self.assertIsInstance(self.spark.range(1).select(F.version()).first()[0], str) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala index f64f88cfd9b65..b165d20d0b4fa 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala @@ -696,7 +696,6 @@ object FunctionRegistry { expression[MapEntries]("map_entries"), expression[MapFromEntries]("map_from_entries"), expression[MapConcat]("map_concat"), - expression[MapSort]("map_sort"), expression[Size]("size"), expression[Slice]("slice"), expression[Size]("cardinality", true), diff --git a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala index dc987e1083dc8..933d0b3f89a7e 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala @@ -6986,23 +6986,6 @@ object functions { @scala.annotation.varargs def map_concat(cols: Column*): Column = Column.fn("map_concat", cols: _*) - /** - * Sorts the input map in ascending order based on the natural order of map keys. - * - * @group map_funcs - * @since 4.0.0 - */ - def map_sort(e: Column): Column = map_sort(e, asc = true) - - /** - * Sorts the input map in ascending or descending order according to the natural ordering - * of the map keys. - * - * @group map_funcs - * @since 4.0.0 - */ - def map_sort(e: Column, asc: Boolean): Column = Column.fn("map_sort", e, lit(asc)) - // scalastyle:off line.size.limit /** * Parses a column containing a CSV string into a `StructType` with the specified schema. diff --git a/sql/core/src/test/resources/sql-functions/sql-expression-schema.md b/sql/core/src/test/resources/sql-functions/sql-expression-schema.md index 999cd68738484..e20db3b49589c 100644 --- a/sql/core/src/test/resources/sql-functions/sql-expression-schema.md +++ b/sql/core/src/test/resources/sql-functions/sql-expression-schema.md @@ -215,7 +215,6 @@ | org.apache.spark.sql.catalyst.expressions.MapFromArrays | map_from_arrays | SELECT map_from_arrays(array(1.0, 3.0), array('2', '4')) | struct> | | org.apache.spark.sql.catalyst.expressions.MapFromEntries | map_from_entries | SELECT map_from_entries(array(struct(1, 'a'), struct(2, 'b'))) | struct> | | org.apache.spark.sql.catalyst.expressions.MapKeys | map_keys | SELECT map_keys(map(1, 'a', 2, 'b')) | struct> | -| org.apache.spark.sql.catalyst.expressions.MapSort | map_sort | SELECT map_sort(map(3, 'c', 1, 'a', 2, 'b'), true) | struct> | | org.apache.spark.sql.catalyst.expressions.MapValues | map_values | SELECT map_values(map(1, 'a', 2, 'b')) | struct> | | org.apache.spark.sql.catalyst.expressions.MapZipWith | map_zip_with | SELECT map_zip_with(map(1, 'a', 2, 'b'), map(1, 'x', 2, 'y'), (k, v1, v2) -> concat(v1, v2)) | struct> | | org.apache.spark.sql.catalyst.expressions.MaskExpressionBuilder | mask | SELECT mask('abcd-EFGH-8765-4321') | struct | From 8d9ac51d95669efe6a4253e0071fb2ec665ae2d6 Mon Sep 17 00:00:00 2001 From: Stevo Mitric Date: Thu, 21 Mar 2024 13:37:35 +0100 Subject: [PATCH 24/27] aditional erasions --- R/pkg/R/generics.R | 4 - .../spark/sql/DataFrameFunctionsSuite.scala | 82 +------------------ 2 files changed, 1 insertion(+), 85 deletions(-) diff --git a/R/pkg/R/generics.R b/R/pkg/R/generics.R index 2004530da88cb..26e81733055a6 100644 --- a/R/pkg/R/generics.R +++ b/R/pkg/R/generics.R @@ -1216,10 +1216,6 @@ setGeneric("map_values", function(x) { standardGeneric("map_values") }) #' @name NULL setGeneric("map_zip_with", function(x, y, f) { standardGeneric("map_zip_with") }) -#' @rdname column_collection_functions -#' @name NULL -setGeneric("map_sort", function(x, asc = TRUE) { standardGeneric("map_sort") }) - #' @rdname column_aggregate_functions #' @name NULL setGeneric("max_by", function(x, y) { standardGeneric("max_by") }) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala index e5953e59a51b1..e42f397cbfc29 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala @@ -25,7 +25,7 @@ import java.sql.{Date, Timestamp} import scala.util.Random import org.apache.spark.{SPARK_DOC_ROOT, SparkException, SparkRuntimeException} -import org.apache.spark.sql.catalyst.{ExtendedAnalysisException, InternalRow} +import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.analysis.{FunctionRegistry, UnresolvedAttribute} import org.apache.spark.sql.catalyst.expressions.{Alias, ArraysZip, AttributeReference, Expression, NamedExpression, UnaryExpression} import org.apache.spark.sql.catalyst.expressions.Cast._ @@ -780,86 +780,6 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSparkSession { ) } - test("map_sort function") { - val df1 = Seq( - Map[Int, Int](2 -> 2, 1 -> 1, 3 -> 3) - ).toDF("a") - - checkAnswer( - df1.selectExpr("map_sort(a)"), - Seq( - Row(Map(1 -> 1, 2 -> 2, 3 -> 3)) - ) - ) - checkAnswer( - df1.selectExpr("map_sort(a, true)"), - Seq( - Row(Map(1 -> 1, 2 -> 2, 3 -> 3)) - ) - ) - checkAnswer( - df1.select(map_sort($"a", asc = false)), - Seq( - Row(Map(3 -> 3, 2 -> 2, 1 -> 1)) - ) - ) - - val df2 = Seq(Map.empty[Int, Int]).toDF("a") - - checkAnswer( - df2.selectExpr("map_sort(a, true)"), - Seq(Row(Map())) - ) - - checkError( - exception = intercept[AnalysisException] { - df2.orderBy("a") - }, - errorClass = "DATATYPE_MISMATCH.INVALID_ORDERING_TYPE", - parameters = Map( - "functionName" -> "`sortorder`", - "dataType" -> "\"MAP\"", - "sqlExpr" -> "\"a ASC NULLS FIRST\"") - ) - - checkError( - exception = intercept[SparkRuntimeException] { - sql("SELECT map_sort(map(null, 1))").collect() - }, - errorClass = "NULL_MAP_KEY" - ) - - checkError( - exception = intercept[ExtendedAnalysisException] { - sql("SELECT map_sort(map(1,1,2,2), \"asc\")").collect() - }, - errorClass = "DATATYPE_MISMATCH.UNEXPECTED_INPUT_TYPE", - parameters = Map( - "sqlExpr" -> "\"map_sort(map(1, 1, 2, 2), asc)\"", - "paramIndex" -> "second", - "inputSql" -> "\"asc\"", - "inputType" -> "\"STRING\"", - "requiredType" -> "\"BOOLEAN\"" - ), - queryContext = Array(ExpectedContext("", "", 7, 35, "map_sort(map(1,1,2,2), \"asc\")")) - ) - - checkError( - exception = intercept[ExtendedAnalysisException] { - sql("SELECT map_sort(map(1,1,2,2), \"asc\")").collect() - }, - errorClass = "DATATYPE_MISMATCH.UNEXPECTED_INPUT_TYPE", - parameters = Map( - "sqlExpr" -> "\"map_sort(map(1, 1, 2, 2), asc)\"", - "paramIndex" -> "second", - "inputSql" -> "\"asc\"", - "inputType" -> "\"STRING\"", - "requiredType" -> "\"BOOLEAN\"" - ), - queryContext = Array(ExpectedContext("", "", 7, 35, "map_sort(map(1,1,2,2), \"asc\")")) - ) - } - test("sort_array/array_sort functions") { val df = Seq( (Array[Int](2, 1, 3), Array("b", "c", "a")), From 2951bcc189ceee08526b3d119586aaa72028bb00 Mon Sep 17 00:00:00 2001 From: Stevo Mitric Date: Thu, 21 Mar 2024 14:14:52 +0100 Subject: [PATCH 25/27] removed ExpressionDescription --- .../expressions/collectionOperations.scala | 21 ------------------- 1 file changed, 21 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala index 3ed711d477621..98ba1cad68309 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala @@ -888,27 +888,6 @@ case class MapFromEntries(child: Expression) copy(child = newChild) } -@ExpressionDescription( - usage = """ - _FUNC_(map[, ascendingOrder]) - Sorts the input map in ascending or descending order - according to the natural ordering of the map keys. The algorithm used for sorting is - an adaptive, stable and iterative algorithm. If the input map is empty, function - returns an empty map. - """, - arguments = - """ - Arguments: - * map - The map that will be sorted. - * ascendingOrder - A boolean value describing the order in which the map will be sorted. - This can be either be ascending (true) or descending (false). - """, - examples = """ - Examples: - > SELECT _FUNC_(map(3, 'c', 1, 'a', 2, 'b'), true); - {1:"a",2:"b",3:"c"} - """, - group = "map_funcs", - since = "4.0.0") case class MapSort(base: Expression, ascendingOrder: Expression) extends BinaryExpression with NullIntolerant with QueryErrorsBase { From 0fc3c6a63b8257356fa4d782a62e158c5b6de914 Mon Sep 17 00:00:00 2001 From: Stevo Mitric Date: Thu, 21 Mar 2024 14:43:15 +0100 Subject: [PATCH 26/27] Moved ordering outside of comapre function --- .../sql/catalyst/expressions/collectionOperations.scala | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala index 98ba1cad68309..7de5eb755ebec 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala @@ -980,6 +980,7 @@ case class MapSort(base: Expression, ascendingOrder: Expression) val c = ctx.freshName("c") val newKeys = ctx.freshName("newKeys") val newValues = ctx.freshName("newValues") + val sortOrder = ctx.freshName("sortOrder") val boxedKeyType = CodeGenerator.boxedType(keyType) val boxedValueType = CodeGenerator.boxedType(valueType) @@ -1011,13 +1012,13 @@ case class MapSort(base: Expression, ascendingOrder: Expression) | ${CodeGenerator.getValue(keys, keyType, i)}, | ${CodeGenerator.getValue(values, valueType, i)}); |} - | + |final int $sortOrder = $order ? 1 : -1; |java.util.Arrays.sort($sortArray, new java.util.Comparator() { | @Override public int compare(Object $o1entry, Object $o2entry) { | Object $o1 = (($simpleEntryType) $o1entry).getKey(); | Object $o2 = (($simpleEntryType) $o2entry).getKey(); | $comp; - | return $order ? $c : -$c; + | return $sortOrder * $c; | } |}); | From 0c7d21a36e4e2eaee5ea4db67a59d69901c4d31e Mon Sep 17 00:00:00 2001 From: Stevo Mitric Date: Thu, 21 Mar 2024 22:56:11 +0100 Subject: [PATCH 27/27] Removed oredering type --- .../expressions/collectionOperations.scala | 47 +++++-------------- .../CollectionExpressionsSuite.scala | 29 +++--------- 2 files changed, 19 insertions(+), 57 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala index 7de5eb755ebec..27225b4ac74a8 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala @@ -888,33 +888,19 @@ case class MapFromEntries(child: Expression) copy(child = newChild) } -case class MapSort(base: Expression, ascendingOrder: Expression) - extends BinaryExpression with NullIntolerant with QueryErrorsBase { - - def this(e: Expression) = this(e, Literal(true)) +case class MapSort(base: Expression) + extends UnaryExpression with NullIntolerant with QueryErrorsBase { val keyType: DataType = base.dataType.asInstanceOf[MapType].keyType val valueType: DataType = base.dataType.asInstanceOf[MapType].valueType - override def left: Expression = base - override def right: Expression = ascendingOrder + override def child: Expression = base + override def dataType: DataType = base.dataType override def checkInputDataTypes(): TypeCheckResult = base.dataType match { case m: MapType if RowOrdering.isOrderable(m.keyType) => - ascendingOrder match { - case Literal(_: Boolean, BooleanType) => - TypeCheckResult.TypeCheckSuccess - case _ => - DataTypeMismatch( - errorSubClass = "UNEXPECTED_INPUT_TYPE", - messageParameters = Map( - "paramIndex" -> ordinalNumber(1), - "requiredType" -> toSQLType(BooleanType), - "inputSql" -> toSQLExpr(ascendingOrder), - "inputType" -> toSQLType(ascendingOrder.dataType)) - ) - } + TypeCheckResult.TypeCheckSuccess case _: MapType => DataTypeMismatch( errorSubClass = "INVALID_ORDERING_TYPE", @@ -934,7 +920,7 @@ case class MapSort(base: Expression, ascendingOrder: Expression) ) } - override def nullSafeEval(array: Any, ascending: Any): Any = { + override def nullSafeEval(array: Any): Any = { // put keys and their respective values inside a tuple and sort them // according to the key ordering. Extract the new sorted k/v pairs to form a sorted map @@ -943,11 +929,7 @@ case class MapSort(base: Expression, ascendingOrder: Expression) val keys = mapData.keyArray() val values = mapData.valueArray() - val ordering = if (ascending.asInstanceOf[Boolean]) { - PhysicalDataType.ordering(keyType) - } else { - PhysicalDataType.ordering(keyType).reverse - } + val ordering = PhysicalDataType.ordering(keyType) val sortedMap = Array .tabulate(numElements)(i => (keys.get(i, keyType).asInstanceOf[Any], @@ -959,11 +941,11 @@ case class MapSort(base: Expression, ascendingOrder: Expression) } override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { - nullSafeCodeGen(ctx, ev, (b, order) => sortCodegen(ctx, ev, b, order)) + nullSafeCodeGen(ctx, ev, b => sortCodegen(ctx, ev, b)) } private def sortCodegen(ctx: CodegenContext, ev: ExprCode, - base: String, order: String): String = { + base: String): String = { val arrayBasedMapData = classOf[ArrayBasedMapData].getName val genericArrayData = classOf[GenericArrayData].getName @@ -980,7 +962,6 @@ case class MapSort(base: Expression, ascendingOrder: Expression) val c = ctx.freshName("c") val newKeys = ctx.freshName("newKeys") val newValues = ctx.freshName("newValues") - val sortOrder = ctx.freshName("sortOrder") val boxedKeyType = CodeGenerator.boxedType(keyType) val boxedValueType = CodeGenerator.boxedType(valueType) @@ -1012,13 +993,13 @@ case class MapSort(base: Expression, ascendingOrder: Expression) | ${CodeGenerator.getValue(keys, keyType, i)}, | ${CodeGenerator.getValue(values, valueType, i)}); |} - |final int $sortOrder = $order ? 1 : -1; + | |java.util.Arrays.sort($sortArray, new java.util.Comparator() { | @Override public int compare(Object $o1entry, Object $o2entry) { | Object $o1 = (($simpleEntryType) $o1entry).getKey(); | Object $o2 = (($simpleEntryType) $o2entry).getKey(); | $comp; - | return $sortOrder * $c; + | return $c; | } |}); | @@ -1035,10 +1016,8 @@ case class MapSort(base: Expression, ascendingOrder: Expression) |""".stripMargin } - override def prettyName: String = "map_sort" - - override protected def withNewChildrenInternal(newLeft: Expression, newRight: Expression) - : MapSort = copy(base = newLeft, ascendingOrder = newRight) + override protected def withNewChildInternal(newChild: Expression) + : MapSort = copy(base = newChild) } /** diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala index 3063b83d4dca1..d14118eb3f1d2 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala @@ -434,31 +434,14 @@ class CollectionExpressionsSuite Map(create_row(2) -> 2, create_row(1) -> 1, create_row(3) -> 3), MapType(StructType(Seq(StructField("a", IntegerType))), IntegerType)) - checkEvaluation(new MapSort(intKey), Map(1 -> 1, 2 -> 2, 3 -> 3)) - checkEvaluation(MapSort(intKey, Literal.create(false, BooleanType)), - Map(3 -> 3, 2 -> 2, 1 -> 1)) - - checkEvaluation(new MapSort(boolKey), Map(false -> 1, true -> 2)) - checkEvaluation(MapSort(boolKey, Literal.create(false, BooleanType)), - Map(true -> 2, false -> 1)) - - checkEvaluation(new MapSort(stringKey), Map("1" -> 1, "2" -> 2, "3" -> 3)) - checkEvaluation(MapSort(stringKey, Literal.create(false, BooleanType)), - Map("3" -> 3, "2" -> 2, "1" -> 1)) - - checkEvaluation(new MapSort(arrayKey), Map(Seq(1) -> 1, Seq(2) -> 2, Seq(3) -> 3)) - checkEvaluation(MapSort(arrayKey, Literal.create(false, BooleanType)), - Map(Seq(3) -> 3, Seq(2) -> 2, Seq(1) -> 1)) - - checkEvaluation(new MapSort(nestedArrayKey), + checkEvaluation(MapSort(intKey), Map(1 -> 1, 2 -> 2, 3 -> 3)) + checkEvaluation(MapSort(boolKey), Map(false -> 1, true -> 2)) + checkEvaluation(MapSort(stringKey), Map("1" -> 1, "2" -> 2, "3" -> 3)) + checkEvaluation(MapSort(arrayKey), Map(Seq(1) -> 1, Seq(2) -> 2, Seq(3) -> 3)) + checkEvaluation(MapSort(nestedArrayKey), Map(Seq(Seq(1)) -> 1, Seq(Seq(2)) -> 2, Seq(Seq(3)) -> 3)) - checkEvaluation(MapSort(nestedArrayKey, Literal.create(false, BooleanType)), - Map(Seq(Seq(3)) -> 3, Seq(Seq(2)) -> 2, Seq(Seq(1)) -> 1)) - - checkEvaluation(new MapSort(structKey), + checkEvaluation(MapSort(structKey), Map(create_row(1) -> 1, create_row(2) -> 2, create_row(3) -> 3)) - checkEvaluation(MapSort(structKey, Literal.create(false, BooleanType)), - Map(create_row(3) -> 3, create_row(2) -> 2, create_row(1) -> 1)) } test("Sort Array") {