diff --git a/docs/source/user-guide/latest/compatibility.md b/docs/source/user-guide/latest/compatibility.md index e0bc5f06ee..0c65495cff 100644 --- a/docs/source/user-guide/latest/compatibility.md +++ b/docs/source/user-guide/latest/compatibility.md @@ -74,6 +74,7 @@ the [Comet Supported Expressions Guide](expressions.md) for more information on - **ArrayUnion**: Sorts input arrays before performing the union, while Spark preserves the order of the first array and appends unique elements from the second. [#3644](https://github.com/apache/datafusion-comet/issues/3644) +- **SortArray**: Nested arrays with `Struct` or `Null` child values are not supported natively and will fall back to Spark. ### Date/Time Expressions diff --git a/docs/spark_expressions_support.md b/docs/spark_expressions_support.md index 5474894108..10da3b4dc7 100644 --- a/docs/spark_expressions_support.md +++ b/docs/spark_expressions_support.md @@ -105,7 +105,7 @@ - [ ] sequence - [ ] shuffle - [ ] slice -- [ ] sort_array +- [x] sort_array ### bitwise_funcs diff --git a/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala b/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala index 8c39ba779d..425e74add9 100644 --- a/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala +++ b/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala @@ -60,6 +60,7 @@ object QueryPlanSerde extends Logging with CometExprShim { classOf[ArrayMin] -> CometArrayMin, classOf[ArrayRemove] -> CometArrayRemove, classOf[ArrayRepeat] -> CometArrayRepeat, + classOf[SortArray] -> CometSortArray, classOf[ArraysOverlap] -> CometArraysOverlap, classOf[ArrayUnion] -> CometArrayUnion, classOf[CreateArray] -> CometCreateArray, @@ -778,30 +779,23 @@ object QueryPlanSerde extends Logging with CometExprShim { * TODO: Include SparkSQL's [[YearMonthIntervalType]] and [[DayTimeIntervalType]] */ // scalastyle:on - def supportedSortType(op: SparkPlan, sortOrder: Seq[SortOrder]): Boolean = { - def canRank(dt: DataType): Boolean = { - dt match { - case _: ByteType | _: ShortType | _: IntegerType | _: LongType | _: FloatType | - _: DoubleType | _: DecimalType => - true - case _: DateType | _: TimestampType | _: TimestampNTZType => - true - case _: BooleanType | _: BinaryType | _: StringType => true - case _ => false - } + def supportedScalarSortElementType(dt: DataType): Boolean = { + dt match { + case _: ByteType | _: ShortType | _: IntegerType | _: LongType | _: FloatType | + _: DoubleType | _: DecimalType | _: DateType | _: TimestampType | _: TimestampNTZType | + _: BooleanType | _: BinaryType | _: StringType => + true + case _ => + false } + } + def supportedSortType(op: SparkPlan, sortOrder: Seq[SortOrder]): Boolean = { if (sortOrder.length == 1) { val canSort = sortOrder.head.dataType match { - case _: ByteType | _: ShortType | _: IntegerType | _: LongType | _: FloatType | - _: DoubleType | _: DecimalType => - true - case _: DateType | _: TimestampType | _: TimestampNTZType => - true - case _: BooleanType | _: BinaryType | _: StringType => true - case ArrayType(elementType, _) => canRank(elementType) - case MapType(_, valueType, _) => canRank(valueType) - case _ => false + case ArrayType(elementType, _) => supportedScalarSortElementType(elementType) + case MapType(_, valueType, _) => supportedScalarSortElementType(valueType) + case _ => supportedScalarSortElementType(sortOrder.head.dataType) } if (!canSort) { withInfo(op, s"Sort on single column of type ${sortOrder.head.dataType} is not supported") diff --git a/spark/src/main/scala/org/apache/comet/serde/arrays.scala b/spark/src/main/scala/org/apache/comet/serde/arrays.scala index c82018fe6d..1847eb6021 100644 --- a/spark/src/main/scala/org/apache/comet/serde/arrays.scala +++ b/spark/src/main/scala/org/apache/comet/serde/arrays.scala @@ -21,11 +21,12 @@ package org.apache.comet.serde import scala.annotation.tailrec -import org.apache.spark.sql.catalyst.expressions.{ArrayAppend, ArrayContains, ArrayDistinct, ArrayExcept, ArrayFilter, ArrayInsert, ArrayIntersect, ArrayJoin, ArrayMax, ArrayMin, ArrayRemove, ArrayRepeat, ArraysOverlap, ArrayUnion, Attribute, CreateArray, ElementAt, Expression, Flatten, GetArrayItem, IsNotNull, Literal, Reverse, Size} +import org.apache.spark.sql.catalyst.expressions.{ArrayAppend, ArrayContains, ArrayDistinct, ArrayExcept, ArrayFilter, ArrayInsert, ArrayIntersect, ArrayJoin, ArrayMax, ArrayMin, ArrayRemove, ArrayRepeat, ArraysOverlap, ArrayUnion, Attribute, CreateArray, ElementAt, Expression, Flatten, GetArrayItem, IsNotNull, Literal, Reverse, Size, SortArray} import org.apache.spark.sql.catalyst.util.GenericArrayData import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types._ +import org.apache.comet.CometConf import org.apache.comet.CometSparkSessionExtensions.withInfo import org.apache.comet.serde.QueryPlanSerde._ import org.apache.comet.shims.CometExprShim @@ -200,6 +201,80 @@ object CometArrayDistinct extends CometExpressionSerde[ArrayDistinct] { } } +object CometSortArray extends CometExpressionSerde[SortArray] { + private def containsFloatingPoint(dt: DataType): Boolean = { + dt match { + case FloatType | DoubleType => true + case ArrayType(elementType, _) => containsFloatingPoint(elementType) + case StructType(fields) => fields.exists(f => containsFloatingPoint(f.dataType)) + case MapType(keyType, valueType, _) => + containsFloatingPoint(keyType) || containsFloatingPoint(valueType) + case _ => false + } + } + + private def supportedSortArrayElementType( + dt: DataType, + nestedInArray: Boolean = false): Boolean = { + dt match { + // DataFusion's array_sort compares nested arrays through Arrow's rank kernel. + // That kernel does not support Struct or Null child values, + // so array>> and array> would fail at runtime. + case _: NullType if !nestedInArray => + true + case ArrayType(elementType, _) => + supportedSortArrayElementType(elementType, nestedInArray = true) + case StructType(fields) if !nestedInArray => + fields.forall(f => supportedSortArrayElementType(f.dataType)) + case _ => + supportedScalarSortElementType(dt) + } + } + + override def getSupportLevel(expr: SortArray): SupportLevel = { + val elementType = expr.base.dataType.asInstanceOf[ArrayType].elementType + + if (!supportedSortArrayElementType(elementType)) { + Unsupported(Some(s"Sort on array element type $elementType is not supported")) + } else if (CometConf.COMET_EXEC_STRICT_FLOATING_POINT.get() && + containsFloatingPoint(elementType)) { + Incompatible( + Some( + "Sorting on floating-point is not 100% compatible with Spark, and Comet is running " + + s"with ${CometConf.COMET_EXEC_STRICT_FLOATING_POINT.key}=true. " + + s"${CometConf.COMPAT_GUIDE}")) + } else { + Compatible() + } + } + + override def convert( + expr: SortArray, + inputs: Seq[Attribute], + binding: Boolean): Option[ExprOuterClass.Expr] = { + val arrayExprProto = exprToProtoInternal(expr.base, inputs, binding) + val (sortDirectionExprProto, nullOrderingExprProto) = expr.ascendingOrder match { + case Literal(value: Boolean, BooleanType) => + val direction = if (value) "ASC" else "DESC" + val nullOrdering = if (value) "NULLS FIRST" else "NULLS LAST" + ( + exprToProtoInternal(Literal(direction), inputs, binding), + exprToProtoInternal(Literal(nullOrdering), inputs, binding)) + case other => + withInfo(expr, s"ascendingOrder must be a boolean literal: $other") + (None, None) + } + + val sortArrayScalarExpr = + scalarFunctionExprToProto( + "array_sort", + arrayExprProto, + sortDirectionExprProto, + nullOrderingExprProto) + optExprWithInfo(sortArrayScalarExpr, expr, expr.children: _*) + } +} + object CometArrayIntersect extends CometExpressionSerde[ArrayIntersect] { override def getSupportLevel(expr: ArrayIntersect): SupportLevel = Incompatible(None) diff --git a/spark/src/test/resources/sql-tests/expressions/array/sort_array.sql b/spark/src/test/resources/sql-tests/expressions/array/sort_array.sql new file mode 100644 index 0000000000..1ced53394d --- /dev/null +++ b/spark/src/test/resources/sql-tests/expressions/array/sort_array.sql @@ -0,0 +1,401 @@ +-- Licensed to the Apache Software Foundation (ASF) under one +-- or more contributor license agreements. See the NOTICE file +-- distributed with this work for additional information +-- regarding copyright ownership. The ASF licenses this file +-- to you under the Apache License, Version 2.0 (the +-- "License"); you may not use this file except in compliance +-- with the License. You may obtain a copy of the License at +-- +-- http://www.apache.org/licenses/LICENSE-2.0 +-- +-- Unless required by applicable law or agreed to in writing, +-- software distributed under the License is distributed on an +-- "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +-- KIND, either express or implied. See the License for the +-- specific language governing permissions and limitations +-- under the License. + +-- ConfigMatrix: parquet.enable.dictionary=false,true + +statement +CREATE TABLE test_sort_array_int(arr array) USING parquet + +statement +INSERT INTO test_sort_array_int VALUES + (array(3, 1, 4, 1, 5)), + (array(3, NULL, 1, NULL, 2)), + (array()), + (NULL) + +query +SELECT sort_array(arr) FROM test_sort_array_int + +query +SELECT sort_array(arr, true) FROM test_sort_array_int + +query +SELECT sort_array(arr, false) FROM test_sort_array_int + +statement +CREATE TABLE test_sort_array_string(arr array) USING parquet + +statement +INSERT INTO test_sort_array_string VALUES + (array('d', 'c', 'b', 'a')), + (array('b', NULL, 'a')), + (array()), + (NULL) + +query +SELECT sort_array(arr) FROM test_sort_array_string + +query +SELECT sort_array(arr, true) FROM test_sort_array_string + +query +SELECT sort_array(arr, false) FROM test_sort_array_string + +statement +CREATE TABLE test_sort_array_double(arr array) USING parquet + +statement +INSERT INTO test_sort_array_double VALUES + (array( + CAST('Infinity' AS DOUBLE), + CAST('-Infinity' AS DOUBLE), + CAST('NaN' AS DOUBLE), + 3.0, + 1.0, + NULL, + -0.0, + 0.0)), + (array( + CAST('NaN' AS DOUBLE), + CAST('NaN' AS DOUBLE), + CAST('Infinity' AS DOUBLE), + CAST('-Infinity' AS DOUBLE), + -5.0, + 2.0)), + (array()), + (NULL) + +query +SELECT sort_array(arr) FROM test_sort_array_double + +query +SELECT sort_array(arr, true) FROM test_sort_array_double + +query +SELECT sort_array(arr, false) FROM test_sort_array_double + +statement +CREATE TABLE test_sort_array_float(arr array) USING parquet + +statement +INSERT INTO test_sort_array_float VALUES + (array( + CAST('Infinity' AS FLOAT), + CAST('-Infinity' AS FLOAT), + CAST('NaN' AS FLOAT), + CAST(3.0 AS FLOAT), + CAST(1.0 AS FLOAT), + CAST(NULL AS FLOAT), + CAST(-0.0 AS FLOAT), + CAST(0.0 AS FLOAT))), + (array( + CAST('NaN' AS FLOAT), + CAST('Infinity' AS FLOAT), + CAST('-Infinity' AS FLOAT), + CAST(-5.0 AS FLOAT), + CAST(2.0 AS FLOAT))), + (array()), + (NULL) + +query +SELECT sort_array(arr) FROM test_sort_array_float + +query +SELECT sort_array(arr, true) FROM test_sort_array_float + +query +SELECT sort_array(arr, false) FROM test_sort_array_float + +statement +CREATE TABLE test_sort_array_decimal(arr array) USING parquet + +statement +INSERT INTO test_sort_array_decimal VALUES + (CAST(array(CAST(100 AS DECIMAL(10, 0)), CAST(10 AS DECIMAL(10, 0))) AS array)), + (CAST(array( + CAST(1 AS DECIMAL(10, 0)), + CAST(1.0 AS DECIMAL(10, 1)), + CAST(1.00 AS DECIMAL(10, 2)), + CAST(1.000 AS DECIMAL(10, 3))) AS array)), + (array()), + (NULL) + +query +SELECT sort_array(arr) FROM test_sort_array_decimal + +query +SELECT sort_array(arr, true) FROM test_sort_array_decimal + +query +SELECT sort_array(arr, false) FROM test_sort_array_decimal + +statement +CREATE TABLE test_sort_array_boolean(arr array) USING parquet + +statement +INSERT INTO test_sort_array_boolean VALUES + (array(true, false, true, false)), + (array(true, false, true, NULL, false)), + (array()), + (NULL) + +query +SELECT sort_array(arr) FROM test_sort_array_boolean + +query +SELECT sort_array(arr, true) FROM test_sort_array_boolean + +query +SELECT sort_array(arr, false) FROM test_sort_array_boolean + +statement +CREATE TABLE test_sort_array_date(arr array) USING parquet + +statement +INSERT INTO test_sort_array_date VALUES + (array(DATE '2026-01-03', DATE '2026-01-01', DATE '2026-01-02')), + (array(DATE '2026-01-02', NULL, DATE '2026-01-01')), + (array()), + (NULL) + +query +SELECT sort_array(arr) FROM test_sort_array_date + +query +SELECT sort_array(arr, true) FROM test_sort_array_date + +query +SELECT sort_array(arr, false) FROM test_sort_array_date + +statement +CREATE TABLE test_sort_array_timestamp(arr array) USING parquet + +statement +INSERT INTO test_sort_array_timestamp VALUES + (array( + TIMESTAMP '2026-01-03 01:00:00', + TIMESTAMP '2026-01-01 02:00:00', + TIMESTAMP '2026-01-02 03:00:00')), + (array( + TIMESTAMP '2026-01-02 00:00:00', + NULL, + TIMESTAMP '2026-01-01 00:00:00')), + (array()), + (NULL) + +query +SELECT sort_array(arr) FROM test_sort_array_timestamp + +query +SELECT sort_array(arr, true) FROM test_sort_array_timestamp + +query +SELECT sort_array(arr, false) FROM test_sort_array_timestamp + +statement +CREATE TABLE test_sort_array_binary(arr array) USING parquet + +statement +INSERT INTO test_sort_array_binary VALUES + (array(unhex('FF'), unhex('00'), unhex('0A'))), + (array(unhex('0B'), NULL, unhex('01'))), + (array()), + (NULL) + +query +SELECT + hex(element_at(sorted_arr, 1)), + hex(element_at(sorted_arr, 2)), + hex(element_at(sorted_arr, 3)) +FROM ( + SELECT sort_array(arr) AS sorted_arr + FROM test_sort_array_binary +) + +query +SELECT + hex(element_at(sorted_arr, 1)), + hex(element_at(sorted_arr, 2)), + hex(element_at(sorted_arr, 3)) +FROM ( + SELECT sort_array(arr, true) AS sorted_arr + FROM test_sort_array_binary +) + +query +SELECT + hex(element_at(sorted_arr, 1)), + hex(element_at(sorted_arr, 2)), + hex(element_at(sorted_arr, 3)) +FROM ( + SELECT sort_array(arr, false) AS sorted_arr + FROM test_sort_array_binary +) + +statement +CREATE TABLE test_sort_array_struct(arr array>) USING parquet + +statement +INSERT INTO test_sort_array_struct VALUES + (array( + named_struct('a', 2, 'b', 'b'), + named_struct('a', 1, 'b', 'c'), + named_struct('a', 1, 'b', 'a'))), + (array( + named_struct('a', 2, 'b', NULL), + named_struct('a', 1, 'b', 'z'), + named_struct('a', 1, 'b', NULL))), + (array()), + (NULL) + +query +SELECT sort_array(arr) FROM test_sort_array_struct + +query +SELECT sort_array(arr, false) FROM test_sort_array_struct + +statement +CREATE TABLE test_sort_array_nested(arr array>) USING parquet + +statement +INSERT INTO test_sort_array_nested VALUES + (array(array(2, 3), array(1), array(2, 1))), + (array(array(1, NULL), array(1), NULL)), + (array()), + (NULL) + +query +SELECT sort_array(arr) FROM test_sort_array_nested + +query +SELECT sort_array(arr, false) FROM test_sort_array_nested + +statement +CREATE TABLE test_sort_array_nested_struct(arr array>>) USING parquet + +statement +INSERT INTO test_sort_array_nested_struct VALUES + (array( + array(named_struct('a', 2)), + array(named_struct('a', 1)))), + (array()), + (NULL) + +query expect_fallback(Sort on array element type ArrayType(StructType(StructField(a,IntegerType) +SELECT sort_array(arr) FROM test_sort_array_nested_struct + +query expect_fallback(Sort on array element type ArrayType(StructType(StructField(a,IntegerType) +SELECT sort_array(arr, false) FROM test_sort_array_nested_struct + +-- literal arguments +query +SELECT + sort_array(array(3, 1, 4, 1, 5)), + sort_array(array(3, 1, 4, 1, 5), true), + sort_array(array(3, NULL, 1, NULL, 2)), + sort_array(array(3, NULL, 1, NULL, 2), false), + sort_array( + array( + CAST('Infinity' AS DOUBLE), + CAST('-Infinity' AS DOUBLE), + CAST('NaN' AS DOUBLE), + 1.0, + NULL, + -0.0, + 0.0)), + sort_array( + array( + CAST('Infinity' AS DOUBLE), + CAST('-Infinity' AS DOUBLE), + CAST('NaN' AS DOUBLE), + 1.0, + NULL, + -0.0, + 0.0), + false), + sort_array( + array( + CAST('Infinity' AS FLOAT), + CAST('-Infinity' AS FLOAT), + CAST('NaN' AS FLOAT), + CAST(1.0 AS FLOAT), + CAST(NULL AS FLOAT), + CAST(-0.0 AS FLOAT), + CAST(0.0 AS FLOAT))), + sort_array( + array( + CAST('Infinity' AS FLOAT), + CAST('-Infinity' AS FLOAT), + CAST('NaN' AS FLOAT), + CAST(1.0 AS FLOAT), + CAST(NULL AS FLOAT), + CAST(-0.0 AS FLOAT), + CAST(0.0 AS FLOAT)), + false), + sort_array( + CAST(array( + CAST(100 AS DECIMAL(10, 0)), + CAST(10 AS DECIMAL(10, 0)), + CAST(1 AS DECIMAL(10, 0)), + CAST(1.00 AS DECIMAL(10, 2))) AS array)), + sort_array( + CAST(array( + CAST(100 AS DECIMAL(10, 0)), + CAST(10 AS DECIMAL(10, 0)), + CAST(1 AS DECIMAL(10, 0)), + CAST(1.00 AS DECIMAL(10, 2))) AS array), + false), + sort_array(array(true, false, true, false)), + sort_array(array(true, false, true, NULL, false)), + sort_array(array(true, false, true, NULL, false), false), + sort_array(array(DATE '2026-01-03', DATE '2026-01-01', DATE '2026-01-02')), + sort_array(array(DATE '2026-01-02', NULL, DATE '2026-01-01'), false), + sort_array( + array( + TIMESTAMP '2026-01-03 01:00:00', + TIMESTAMP '2026-01-01 02:00:00', + TIMESTAMP '2026-01-02 03:00:00')), + sort_array( + array( + TIMESTAMP '2026-01-02 00:00:00', + NULL, + TIMESTAMP '2026-01-01 00:00:00'), + false), + hex(element_at(sort_array(array(unhex('FF'), unhex('00'), unhex('0A'))), 1)), + hex(element_at(sort_array(array(unhex('FF'), unhex('00'), unhex('0A'))), 2)), + hex(element_at(sort_array(array(unhex('FF'), unhex('00'), unhex('0A'))), 3)), + hex(element_at(sort_array(array(unhex('0B'), NULL, unhex('01')), false), 1)), + hex(element_at(sort_array(array(unhex('0B'), NULL, unhex('01')), false), 2)), + hex(element_at(sort_array(array(unhex('0B'), NULL, unhex('01')), false), 3)), + sort_array( + array( + named_struct('a', 2, 'b', 'b'), + named_struct('a', 1, 'b', 'c'), + named_struct('a', 1, 'b', 'a'))), + sort_array(array(array(2, 3), array(1), array(2, 1))), + sort_array(array(array(1, NULL), array(1), NULL)), + sort_array(array(NULL, NULL)), + sort_array(cast(NULL as array)) + +query expect_fallback(Sort on array element type ArrayType(StructType(StructField(a,IntegerType) +SELECT sort_array( + array( + array(named_struct('a', 2)), + array(named_struct('a', 1)))) + +query expect_error(BOOLEAN) +SELECT sort_array(array(3, 1, 4, 1, 5), 1)