From 78f960bfb5465b1ca85a90deb18447f223156ae1 Mon Sep 17 00:00:00 2001 From: Navin Viswanath Date: Fri, 27 Jan 2023 20:34:49 -0800 Subject: [PATCH 01/68] Adds a array_prepend expression to catalyst --- .../reference/pyspark.sql/functions.rst | 1 + python/pyspark/sql/functions.py | 27 ++++- .../catalyst/analysis/FunctionRegistry.scala | 12 ++ .../expressions/collectionOperations.scala | 113 ++++++++++++++++++ .../CollectionExpressionsSuite.scala | 41 +++++++ .../org/apache/spark/sql/functions.scala | 10 ++ .../spark/sql/DataFrameFunctionsSuite.scala | 52 ++++++++ 7 files changed, 255 insertions(+), 1 deletion(-) diff --git a/python/docs/source/reference/pyspark.sql/functions.rst b/python/docs/source/reference/pyspark.sql/functions.rst index ddc8eab90f77a..13a64b7215423 100644 --- a/python/docs/source/reference/pyspark.sql/functions.rst +++ b/python/docs/source/reference/pyspark.sql/functions.rst @@ -158,6 +158,7 @@ Collection Functions array_append array_sort array_remove + array_prepend array_distinct array_intersect array_union diff --git a/python/pyspark/sql/functions.py b/python/pyspark/sql/functions.py index 3426f2bdaf6c1..edd9f24b92ecc 100644 --- a/python/pyspark/sql/functions.py +++ b/python/pyspark/sql/functions.py @@ -7618,6 +7618,32 @@ def get(col: "ColumnOrName", index: Union["ColumnOrName", int]) -> Column: return _invoke_function_over_columns("get", col, index) +def array_prepend(col: "ColumnOrName", element: Any) -> Column: + """ + Collection function: Returns an array containing value as well as all elements from array. + The new element is positioned at the beginning of the array. + + .. versionadded:: 3.4.0 + + Parameters + ---------- + col : :class:`~pyspark.sql.Column` or str + name of column containing array + element : + element to be prepended to the array + + Returns + ------- + :class:`~pyspark.sql.Column` + an array excluding given value. + + Examples + -------- + >>> df = spark.createDataFrame([([2, 3, 4],), ([],)], ['data']) + >>> df.select(array_prepend(df.data, 1)).collect() + [Row(array_prepend(data, 1)=[1, 2, 3, 4]), Row(array_prepend(data, 1)=[1])] + """ + return _invoke_function("array_prepend", _to_java_column(col), element) @try_remote_functions def array_remove(col: "ColumnOrName", element: Any) -> Column: @@ -7649,7 +7675,6 @@ def array_remove(col: "ColumnOrName", element: Any) -> Column: """ return _invoke_function("array_remove", _to_java_column(col), element) - @try_remote_functions def array_distinct(col: "ColumnOrName") -> Column: """ diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala index 99bab1003767d..a1750d6d45d8e 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala @@ -694,6 +694,7 @@ object FunctionRegistry { expression[Sequence]("sequence"), expression[ArrayRepeat]("array_repeat"), expression[ArrayRemove]("array_remove"), + expression[ArrayPrepend]("array_prepend"), expression[ArrayDistinct]("array_distinct"), expression[ArrayTransform]("transform"), expression[MapFilter]("map_filter"), @@ -967,6 +968,7 @@ object TableFunctionRegistry { : (String, (ExpressionInfo, TableFunctionBuilder)) = { val (info, builder) = FunctionRegistryBase.build[T](name, since = None) val newBuilder = (expressions: Seq[Expression]) => { +<<<<<<< HEAD val generator = builder(expressions) assert(generator.isInstanceOf[Generator]) Generate( @@ -976,6 +978,16 @@ object TableFunctionRegistry { qualifier = None, generatorOutput = Nil, child = OneRowRelation()) +======= + try { + builder(expressions) + } catch { + case e: AnalysisException => + val argTypes = expressions.map(_.dataType.typeName).mkString(", ") + throw QueryCompilationErrors.cannotApplyTableValuedFunctionError( + name, argTypes, info.getUsage, e.getMessage) + } +>>>>>>> Revert "SPARK-41231: Adds an array_prepend function to catalyst" } (name, (info, newBuilder)) } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala index ca3982f54c8bb..ff0bce1687871 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala @@ -1399,6 +1399,119 @@ case class ArrayContains(left: Expression, right: Expression) copy(left = newLeft, right = newRight) } +// scalastyle:off line.size.limit +@ExpressionDescription( + usage = + "_FUNC_(array, value) - Returns an array containing value as well as all elements from array. The new element is positioned at the beginning of the array.", + examples = """ + Examples: + > SELECT _FUNC_(array(1, 2, 3), 4); + [4, 1, 2, 3] + """, + group = "array_funcs", + since = "3.4.0") +case class ArrayPrepend(left: Expression, right: Expression) + extends BinaryExpression + with ImplicitCastInputTypes + with NullIntolerant + with QueryErrorsBase { + + override def nullSafeEval(arr: Any, value: Any): Any = { + val numberOfElements = arr.asInstanceOf[ArrayData].numElements() + if (numberOfElements + 1 > ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH) { + throw QueryExecutionErrors.concatArraysWithElementsExceedLimitError(numberOfElements) + } + val newArray = new Array[Any](numberOfElements + 1) + newArray(0) = value + var pos = 1 + arr + .asInstanceOf[ArrayData] + .foreach( + right.dataType, + (i, v) => { + newArray(pos) = v + pos += 1 + }) + new GenericArrayData(newArray) + } + override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { + nullSafeCodeGen( + ctx, + ev, + (arr, value) => { + val newArraySize = ctx.freshName("newArraySize") + val newArray = ctx.freshName("newArray") + val i = ctx.freshName("i") + val pos = ctx.freshName("pos") + val allocation = CodeGenerator.createArrayData( + newArray, + right.dataType, + newArraySize, + s" $prettyName failed.") + val assignment = + CodeGenerator.createArrayAssignment(newArray, right.dataType, arr, pos, i, false) + val newElemAssignment = + CodeGenerator.setArrayElement(newArray, right.dataType, pos, value) + s""" + |int $pos = 0; + |int $newArraySize = $arr.numElements() + 1; + |$allocation + |$newElemAssignment + |$pos = $pos + 1; + |for (int $i = 0; $i < $arr.numElements(); $i ++) { + | $assignment + | $pos = $pos + 1; + |} + |${ev.value} = $newArray; + |""".stripMargin + }) + } + + override def prettyName: String = "array_prepend" + + override protected def withNewChildrenInternal( + newLeft: Expression, newRight: Expression): ArrayPrepend = + copy(left = newLeft, right = newRight) + override def dataType: DataType = left.dataType + override def checkInputDataTypes(): TypeCheckResult = { + (left.dataType, right.dataType) match { + case (_, NullType) | (NullType, _) => + DataTypeMismatch( + errorSubClass = "NULL_TYPE", + messageParameters = Map("functionName" -> toSQLId(prettyName))) + case (l, _) if !ArrayType.acceptsType(l) => + DataTypeMismatch( + errorSubClass = "UNEXPECTED_INPUT_TYPE", + messageParameters = Map( + "paramIndex" -> "1", + "requiredType" -> toSQLType(ArrayType), + "inputSql" -> toSQLExpr(left), + "inputType" -> toSQLType(left.dataType))) + case (ArrayType(e1, _), e2) if e1.sameType(e2) => + TypeUtils.checkForOrderingExpr(e2, prettyName) + case _ => + DataTypeMismatch( + errorSubClass = "ARRAY_FUNCTION_DIFF_TYPES", + messageParameters = Map( + "functionName" -> toSQLId(prettyName), + "dataType" -> toSQLType(ArrayType), + "leftType" -> toSQLType(left.dataType), + "rightType" -> toSQLType(right.dataType))) + } + } + override def inputTypes: Seq[AbstractDataType] = { + (left.dataType, right.dataType) match { + case (_, NullType) => Seq.empty + case (ArrayType(e1, hasNull), e2) => + TypeCoercion.findWiderTypeWithoutStringPromotionForTwo(e1, e2) match { + case Some(dt) => Seq(ArrayType(dt, hasNull), dt) + case _ => Seq.empty + } + case _ => Seq.empty + } + } +} + /** * Checks if the two arrays contain at least one common element. */ diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala index d83739df38d14..a4c3ffeeb0d71 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala @@ -1840,6 +1840,47 @@ class CollectionExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper checkEvaluation(ArrayRepeat(Literal("hi"), Literal(null, IntegerType)), null) } + test("SPARK-41233: ArrayPrepend") { + val a0 = Literal.create(Seq(1, 2, 3, 4), ArrayType(IntegerType)) + val a1 = Literal.create(Seq("a", "b", "c"), ArrayType(StringType)) + val a2 = Literal.create(Seq.empty[Integer], ArrayType(IntegerType)) + val a3 = Literal.create(null, ArrayType(StringType)) + + checkEvaluation(ArrayPrepend(a0, Literal(0)), Seq(0, 1, 2, 3, 4)) + checkEvaluation(ArrayPrepend(a1, Literal("a")), Seq("a", "a", "b", "c")) + checkEvaluation(ArrayPrepend(a2, Literal(1)), Seq(1)) + checkEvaluation(ArrayPrepend(a2, Literal(null, IntegerType)), null) + checkEvaluation(ArrayPrepend(a3, Literal("a")), null) + checkEvaluation(ArrayPrepend(a3, Literal(null, StringType)), null) + + // complex data types + val b0 = Literal.create( + Seq[Array[Byte]]( + Array[Byte](5, 6), + Array[Byte](1, 2), + Array[Byte](1, 2), + Array[Byte](5, 6)), + ArrayType(BinaryType)) + val b1 = Literal.create(Seq[Array[Byte]](Array[Byte](2, 1), null), ArrayType(BinaryType)) + val nullBinary = Literal.create(null, BinaryType) + checkEvaluation(ArrayPrepend(b0, nullBinary), null) + val dataToPrepend1 = Literal.create(Array[Byte](5, 6), BinaryType) + checkEvaluation( + ArrayPrepend(b1, dataToPrepend1), + Seq[Array[Byte]](Array[Byte](5, 6), Array[Byte](2, 1), null)) + + val c0 = Literal.create( + Seq[Seq[Int]](Seq[Int](1, 2), Seq[Int](3, 4)), + ArrayType(ArrayType(IntegerType))) + val dataToPrepend2 = Literal.create(Seq[Int](5, 6), ArrayType(IntegerType)) + checkEvaluation( + ArrayPrepend(c0, dataToPrepend2), + Seq(Seq[Int](5, 6), Seq[Int](1, 2), Seq[Int](3, 4))) + checkEvaluation( + ArrayPrepend(c0, Literal.create(Seq.empty[Int], ArrayType(IntegerType))), + Seq(Seq.empty[Int], Seq[Int](1, 2), Seq[Int](3, 4))) + } + test("Array remove") { val a0 = Literal.create(Seq(1, 2, 3, 2, 2, 5), ArrayType(IntegerType)) val a1 = Literal.create(Seq("b", "a", "a", "c", "b"), ArrayType(StringType)) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala index 3d5547ead8310..be0c99aeec716 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala @@ -4042,6 +4042,16 @@ object functions { */ def array_compact(column: Column): Column = withExpr { ArrayCompact(column.expr) + + /** + * Returns an array containing value as well as all elements from array. The new element is + * positioned at the beginning of the array. + * + * @group collection_funcs + * @since 3.4.0 + */ + def array_prepend(column: Column, element: Any): Column = withExpr { + ArrayPrepend(column.expr, lit(element).expr) } /** diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala index 231c9562511a2..d8b54dbc76eef 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala @@ -2651,6 +2651,58 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSparkSession { ) } + test("SPARK-41233: array prepend") { + val df = Seq( + (Array[Int](2, 3, 4), Array("b", "c", "d"), Array("", ""), 2), + (Array.empty[Int], Array.empty[String], Array.empty[String], 2), + (null, null, null, 2)).toDF("a", "b", "c", "d") + checkAnswer( + df.select(array_prepend($"a", 1), array_prepend($"b", "a"), array_prepend($"c", "")), + Seq( + Row(Seq(1, 2, 3, 4), Seq("a", "b", "c", "d"), Seq("", "", "")), + Row(Seq(1), Seq("a"), Seq("")), + Row(null, null, null))) + checkAnswer( + df.select(array_prepend($"a", $"d")), + Seq(Row(Seq(2, 2, 3, 4)), Row(Seq(2)), Row(null))) + checkAnswer( + df.selectExpr("array_prepend(a, d)"), + Seq(Row(Seq(2, 2, 3, 4)), Row(Seq(2)), Row(null))) + checkAnswer( + OneRowRelation().selectExpr("array_prepend(array(1, 2), 1.23D)"), + Seq(Row(Seq(1.23, 1.0, 2.0)))) + checkAnswer( + df.selectExpr("array_prepend(a, 1)", "array_prepend(b, \"a\")", "array_prepend(c, \"\")"), + Seq( + Row(Seq(1, 2, 3, 4), Seq("a", "b", "c", "d"), Seq("", "", "")), + Row(Seq(1), Seq("a"), Seq("")), + Row(null, null, null))) + checkError( + exception = intercept[AnalysisException] { + Seq(("a string element", "a")).toDF().selectExpr("array_prepend(_1, _2)") + }, + errorClass = "DATATYPE_MISMATCH.UNEXPECTED_INPUT_TYPE", + parameters = Map( + "paramIndex" -> "1", + "sqlExpr" -> "\"array_prepend(_1, _2)\"", + "inputSql" -> "\"_1\"", + "inputType" -> "\"STRING\"", + "requiredType" -> "\"ARRAY\""), + queryContext = Array(ExpectedContext("", "", 0, 20, "array_prepend(_1, _2)"))) + checkError( + exception = intercept[AnalysisException] { + OneRowRelation().selectExpr("array_prepend(array(1, 2), '1')") + }, + errorClass = "DATATYPE_MISMATCH.ARRAY_FUNCTION_DIFF_TYPES", + parameters = Map( + "sqlExpr" -> "\"array_prepend(array(1, 2), 1)\"", + "functionName" -> "`array_prepend`", + "dataType" -> "\"ARRAY\"", + "leftType" -> "\"ARRAY\"", + "rightType" -> "\"STRING\""), + queryContext = Array(ExpectedContext("", "", 0, 30, "array_prepend(array(1, 2), '1')"))) + } + test("array remove") { val df = Seq( (Array[Int](2, 1, 2, 3), Array("a", "b", "c", "a"), Array("", ""), 2), From f2d4f68ab2db1ba7c1c896f4474d63caab585297 Mon Sep 17 00:00:00 2001 From: Navin Viswanath Date: Fri, 27 Jan 2023 20:36:47 -0800 Subject: [PATCH 02/68] Fix null handling --- .../expressions/collectionOperations.scala | 122 +++++++++++------- .../CollectionExpressionsSuite.scala | 19 +-- .../org/apache/spark/sql/functions.scala | 15 +++ 3 files changed, 101 insertions(+), 55 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala index ff0bce1687871..000e61bb6a21b 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala @@ -1413,9 +1413,19 @@ case class ArrayContains(left: Expression, right: Expression) case class ArrayPrepend(left: Expression, right: Expression) extends BinaryExpression with ImplicitCastInputTypes - with NullIntolerant with QueryErrorsBase { + override def nullable: Boolean = left.nullable + + override def eval(input: InternalRow): Any = { + val value1 = left.eval(input) + if (value1 == null) { + null + } else { + val value2 = right.eval(input) + nullSafeEval(value1, value2) + } + } override def nullSafeEval(arr: Any, value: Any): Any = { val numberOfElements = arr.asInstanceOf[ArrayData].numElements() if (numberOfElements + 1 > ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH) { @@ -1435,36 +1445,57 @@ case class ArrayPrepend(left: Expression, right: Expression) new GenericArrayData(newArray) } override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { - nullSafeCodeGen( - ctx, - ev, - (arr, value) => { - val newArraySize = ctx.freshName("newArraySize") - val newArray = ctx.freshName("newArray") - val i = ctx.freshName("i") - val pos = ctx.freshName("pos") - val allocation = CodeGenerator.createArrayData( - newArray, - right.dataType, - newArraySize, - s" $prettyName failed.") - val assignment = - CodeGenerator.createArrayAssignment(newArray, right.dataType, arr, pos, i, false) - val newElemAssignment = - CodeGenerator.setArrayElement(newArray, right.dataType, pos, value) + val leftGen = left.genCode(ctx) + val rightGen = right.genCode(ctx) + val f = (arr: String, value: String) => { + val newArraySize = ctx.freshName("newArraySize") + val newArray = ctx.freshName("newArray") + val i = ctx.freshName("i") + val pos = ctx.freshName("pos") + val allocation = CodeGenerator.createArrayData( + newArray, + right.dataType, + newArraySize, + s" $prettyName failed.") + val assignment = + CodeGenerator.createArrayAssignment(newArray, right.dataType, arr, pos, i, false) + val newElemAssignment = + CodeGenerator.setArrayElement(newArray, right.dataType, pos, value, Some(rightGen.isNull)) + s""" + |int $pos = 0; + |int $newArraySize = $arr.numElements() + 1; + |$allocation + |$newElemAssignment + |$pos = $pos + 1; + |for (int $i = 0; $i < $arr.numElements(); $i ++) { + | $assignment + | $pos = $pos + 1; + |} + |${ev.value} = $newArray; + |""".stripMargin + } + val resultCode = f(leftGen.value, rightGen.value) + if(nullable) { + val nullSafeEval = leftGen.code + rightGen.code + ctx.nullSafeExec(nullable, leftGen.isNull) { s""" - |int $pos = 0; - |int $newArraySize = $arr.numElements() + 1; - |$allocation - |$newElemAssignment - |$pos = $pos + 1; - |for (int $i = 0; $i < $arr.numElements(); $i ++) { - | $assignment - | $pos = $pos + 1; - |} - |${ev.value} = $newArray; + |${ev.isNull} = false; + |${resultCode} |""".stripMargin - }) + } + ev.copy(code = + code""" + boolean ${ev.isNull} = true; + ${CodeGenerator.javaType(dataType)} ${ev.value} = ${CodeGenerator.defaultValue(dataType)}; + $nullSafeEval + """) + } else { + ev.copy(code = + code""" + ${leftGen.code} + ${rightGen.code} + ${CodeGenerator.javaType(dataType)} ${ev.value} = ${CodeGenerator.defaultValue(dataType)}; + $resultCode""", isNull = FalseLiteral) + } } override def prettyName: String = "array_prepend" @@ -1472,31 +1503,30 @@ case class ArrayPrepend(left: Expression, right: Expression) override protected def withNewChildrenInternal( newLeft: Expression, newRight: Expression): ArrayPrepend = copy(left = newLeft, right = newRight) + override def dataType: DataType = left.dataType + override def checkInputDataTypes(): TypeCheckResult = { (left.dataType, right.dataType) match { - case (_, NullType) | (NullType, _) => - DataTypeMismatch( - errorSubClass = "NULL_TYPE", - messageParameters = Map("functionName" -> toSQLId(prettyName))) - case (l, _) if !ArrayType.acceptsType(l) => + case (ArrayType(e1, _), e2) if e1.sameType(e2) => TypeCheckResult.TypeCheckSuccess + case (ArrayType(e1, _), e2) => DataTypeMismatch( + errorSubClass = "ARRAY_FUNCTION_DIFF_TYPES", + messageParameters = Map( + "functionName" -> toSQLId(prettyName), + "leftType" -> toSQLType(left.dataType), + "rightType" -> toSQLType(right.dataType), + "dataType" -> toSQLType(ArrayType) + )) + case _ => DataTypeMismatch( errorSubClass = "UNEXPECTED_INPUT_TYPE", messageParameters = Map( - "paramIndex" -> "1", + "paramIndex" -> "0", "requiredType" -> toSQLType(ArrayType), "inputSql" -> toSQLExpr(left), - "inputType" -> toSQLType(left.dataType))) - case (ArrayType(e1, _), e2) if e1.sameType(e2) => - TypeUtils.checkForOrderingExpr(e2, prettyName) - case _ => - DataTypeMismatch( - errorSubClass = "ARRAY_FUNCTION_DIFF_TYPES", - messageParameters = Map( - "functionName" -> toSQLId(prettyName), - "dataType" -> toSQLType(ArrayType), - "leftType" -> toSQLType(left.dataType), - "rightType" -> toSQLType(right.dataType))) + "inputType" -> toSQLType(left.dataType) + ) + ) } } override def inputTypes: Seq[AbstractDataType] = { diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala index a4c3ffeeb0d71..b443a1c7f5aa8 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala @@ -20,10 +20,8 @@ package org.apache.spark.sql.catalyst.expressions import java.sql.{Date, Timestamp} import java.time.{Duration, LocalDateTime, Period} import java.util.TimeZone - import scala.language.implicitConversions import scala.util.Random - import org.apache.spark.{SparkFunSuite, SparkRuntimeException} import org.apache.spark.sql.Row import org.apache.spark.sql.catalyst.InternalRow @@ -1849,21 +1847,24 @@ class CollectionExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper checkEvaluation(ArrayPrepend(a0, Literal(0)), Seq(0, 1, 2, 3, 4)) checkEvaluation(ArrayPrepend(a1, Literal("a")), Seq("a", "a", "b", "c")) checkEvaluation(ArrayPrepend(a2, Literal(1)), Seq(1)) - checkEvaluation(ArrayPrepend(a2, Literal(null, IntegerType)), null) + checkEvaluation(ArrayPrepend(a2, Literal(null, IntegerType)), Seq(null)) checkEvaluation(ArrayPrepend(a3, Literal("a")), null) checkEvaluation(ArrayPrepend(a3, Literal(null, StringType)), null) // complex data types + val data = Seq[Array[Byte]]( + Array[Byte](5, 6), + Array[Byte](1, 2), + Array[Byte](1, 2), + Array[Byte](5, 6)) val b0 = Literal.create( - Seq[Array[Byte]]( - Array[Byte](5, 6), - Array[Byte](1, 2), - Array[Byte](1, 2), - Array[Byte](5, 6)), + data, ArrayType(BinaryType)) val b1 = Literal.create(Seq[Array[Byte]](Array[Byte](2, 1), null), ArrayType(BinaryType)) val nullBinary = Literal.create(null, BinaryType) - checkEvaluation(ArrayPrepend(b0, nullBinary), null) + // Calling ArrayPrepend with a null element should result in NULL being prepended to the array + val dataWithNullPrepended = null +: data + checkEvaluation(ArrayPrepend(b0, nullBinary), dataWithNullPrepended) val dataToPrepend1 = Literal.create(Array[Byte](5, 6), BinaryType) checkEvaluation( ArrayPrepend(b1, dataToPrepend1), diff --git a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala index be0c99aeec716..23f78e532ec0e 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala @@ -4050,6 +4050,21 @@ object functions { * @group collection_funcs * @since 3.4.0 */ + def array_prepend(column: Column, element: Any): Column = withExpr { + ArrayPrepend(column.expr, lit(element).expr) + + /** + * Remove all null elements from the given array. + * + * @group collection_funcs + * @since 3.4.0 + */ + def array_compact(column: Column): Column = withExpr { + ArrayCompact(column.expr) + /** + * Returns an array containing value as well as all elements from array.The + * new element is positioned at the beginning of the array. + */ def array_prepend(column: Column, element: Any): Column = withExpr { ArrayPrepend(column.expr, lit(element).expr) } From 85a8b4ce9b9bed366df5de073eb8d8f2c0c37b99 Mon Sep 17 00:00:00 2001 From: Navin Viswanath Date: Fri, 27 Jan 2023 20:41:01 -0800 Subject: [PATCH 03/68] Fix --- .../scala/org/apache/spark/sql/functions.scala | 17 +---------------- 1 file changed, 1 insertion(+), 16 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala index 23f78e532ec0e..79e42db22776a 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala @@ -4042,7 +4042,7 @@ object functions { */ def array_compact(column: Column): Column = withExpr { ArrayCompact(column.expr) - + } /** * Returns an array containing value as well as all elements from array. The new element is * positioned at the beginning of the array. @@ -4050,21 +4050,6 @@ object functions { * @group collection_funcs * @since 3.4.0 */ - def array_prepend(column: Column, element: Any): Column = withExpr { - ArrayPrepend(column.expr, lit(element).expr) - - /** - * Remove all null elements from the given array. - * - * @group collection_funcs - * @since 3.4.0 - */ - def array_compact(column: Column): Column = withExpr { - ArrayCompact(column.expr) - /** - * Returns an array containing value as well as all elements from array.The - * new element is positioned at the beginning of the array. - */ def array_prepend(column: Column, element: Any): Column = withExpr { ArrayPrepend(column.expr, lit(element).expr) } From d7b601cedb9c8eb945e196d0e168d68f64bf0af0 Mon Sep 17 00:00:00 2001 From: Navin Viswanath Date: Fri, 27 Jan 2023 20:41:01 -0800 Subject: [PATCH 04/68] Fix --- .../catalyst/analysis/FunctionRegistry.scala | 11 ----------- .../scala/org/apache/spark/sql/functions.scala | 17 +---------------- 2 files changed, 1 insertion(+), 27 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala index a1750d6d45d8e..d9765a20a80bb 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala @@ -968,7 +968,6 @@ object TableFunctionRegistry { : (String, (ExpressionInfo, TableFunctionBuilder)) = { val (info, builder) = FunctionRegistryBase.build[T](name, since = None) val newBuilder = (expressions: Seq[Expression]) => { -<<<<<<< HEAD val generator = builder(expressions) assert(generator.isInstanceOf[Generator]) Generate( @@ -978,16 +977,6 @@ object TableFunctionRegistry { qualifier = None, generatorOutput = Nil, child = OneRowRelation()) -======= - try { - builder(expressions) - } catch { - case e: AnalysisException => - val argTypes = expressions.map(_.dataType.typeName).mkString(", ") - throw QueryCompilationErrors.cannotApplyTableValuedFunctionError( - name, argTypes, info.getUsage, e.getMessage) - } ->>>>>>> Revert "SPARK-41231: Adds an array_prepend function to catalyst" } (name, (info, newBuilder)) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala index 23f78e532ec0e..79e42db22776a 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala @@ -4042,7 +4042,7 @@ object functions { */ def array_compact(column: Column): Column = withExpr { ArrayCompact(column.expr) - + } /** * Returns an array containing value as well as all elements from array. The new element is * positioned at the beginning of the array. @@ -4050,21 +4050,6 @@ object functions { * @group collection_funcs * @since 3.4.0 */ - def array_prepend(column: Column, element: Any): Column = withExpr { - ArrayPrepend(column.expr, lit(element).expr) - - /** - * Remove all null elements from the given array. - * - * @group collection_funcs - * @since 3.4.0 - */ - def array_compact(column: Column): Column = withExpr { - ArrayCompact(column.expr) - /** - * Returns an array containing value as well as all elements from array.The - * new element is positioned at the beginning of the array. - */ def array_prepend(column: Column, element: Any): Column = withExpr { ArrayPrepend(column.expr, lit(element).expr) } From d827d21e33e2c9dc99fa653c8550ad6e28e96330 Mon Sep 17 00:00:00 2001 From: Navin Viswanath Date: Tue, 7 Feb 2023 20:45:20 -0800 Subject: [PATCH 05/68] Lint --- .../expressions/CollectionExpressionsSuite.scala | 1 + .../spark/sql/DataFrameFunctionsSuite.scala | 15 ++++++++++++--- 2 files changed, 13 insertions(+), 3 deletions(-) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala index b443a1c7f5aa8..fc769b3f1773d 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala @@ -20,6 +20,7 @@ package org.apache.spark.sql.catalyst.expressions import java.sql.{Date, Timestamp} import java.time.{Duration, LocalDateTime, Period} import java.util.TimeZone + import scala.language.implicitConversions import scala.util.Random import org.apache.spark.{SparkFunSuite, SparkRuntimeException} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala index d8b54dbc76eef..a4641ee646d48 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala @@ -2664,13 +2664,22 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSparkSession { Row(null, null, null))) checkAnswer( df.select(array_prepend($"a", $"d")), - Seq(Row(Seq(2, 2, 3, 4)), Row(Seq(2)), Row(null))) + Seq( + Row(Seq(2, 2, 3, 4)), + Row(Seq(2)), + Row(null))) checkAnswer( df.selectExpr("array_prepend(a, d)"), - Seq(Row(Seq(2, 2, 3, 4)), Row(Seq(2)), Row(null))) + Seq( + Row(Seq(2, 2, 3, 4)), + Row(Seq(2)), + Row(null))) checkAnswer( OneRowRelation().selectExpr("array_prepend(array(1, 2), 1.23D)"), - Seq(Row(Seq(1.23, 1.0, 2.0)))) + Seq( + Row(Seq(1.23, 1.0, 2.0)) + ) + ) checkAnswer( df.selectExpr("array_prepend(a, 1)", "array_prepend(b, \"a\")", "array_prepend(c, \"\")"), Seq( From 4e7aa1e4504b014f7b0a380a97faa4a61b6f34ce Mon Sep 17 00:00:00 2001 From: Navin Viswanath Date: Tue, 7 Feb 2023 20:47:55 -0800 Subject: [PATCH 06/68] Lint --- .../sql/catalyst/expressions/CollectionExpressionsSuite.scala | 1 + 1 file changed, 1 insertion(+) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala index fc769b3f1773d..f94d216beed69 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala @@ -23,6 +23,7 @@ import java.util.TimeZone import scala.language.implicitConversions import scala.util.Random + import org.apache.spark.{SparkFunSuite, SparkRuntimeException} import org.apache.spark.sql.Row import org.apache.spark.sql.catalyst.InternalRow From ff8c19e4cf40641fe072d00118b393065f8e5416 Mon Sep 17 00:00:00 2001 From: Navin Viswanath Date: Wed, 8 Feb 2023 20:46:44 -0800 Subject: [PATCH 07/68] Add examples of usage and fix test --- python/pyspark/sql/functions.py | 6 +- .../sql-functions/sql-expression-schema.md | 3 +- .../test/resources/sql-tests/inputs/array.sql | 11 +++ .../sql-tests/results/ansi/array.sql.out | 72 +++++++++++++++++++ 4 files changed, 89 insertions(+), 3 deletions(-) diff --git a/python/pyspark/sql/functions.py b/python/pyspark/sql/functions.py index 572465ff83467..548b0266d4efe 100644 --- a/python/pyspark/sql/functions.py +++ b/python/pyspark/sql/functions.py @@ -7618,10 +7618,12 @@ def get(col: "ColumnOrName", index: Union["ColumnOrName", int]) -> Column: return _invoke_function_over_columns("get", col, index) +@try_remote_functions def array_prepend(col: "ColumnOrName", element: Any) -> Column: """ - Collection function: Returns an array containing value as well as all elements from array. - The new element is positioned at the beginning of the array. + Collection function: Returns an array containing element as + well as all elements from array. The new element is positioned + at the beginning of the array. .. versionadded:: 3.4.0 diff --git a/sql/core/src/test/resources/sql-functions/sql-expression-schema.md b/sql/core/src/test/resources/sql-functions/sql-expression-schema.md index 03ec4bce54b44..cf355e11fc4ea 100644 --- a/sql/core/src/test/resources/sql-functions/sql-expression-schema.md +++ b/sql/core/src/test/resources/sql-functions/sql-expression-schema.md @@ -13,6 +13,7 @@ | org.apache.spark.sql.catalyst.expressions.ArrayAggregate | aggregate | SELECT aggregate(array(1, 2, 3), 0, (acc, x) -> acc + x) | struct | | org.apache.spark.sql.catalyst.expressions.ArrayAggregate | reduce | SELECT reduce(array(1, 2, 3), 0, (acc, x) -> acc + x) | struct | | org.apache.spark.sql.catalyst.expressions.ArrayAppend | array_append | SELECT array_append(array('b', 'd', 'c', 'a'), 'd') | struct> | +| org.apache.spark.sql.catalyst.expressions.ArrayPrepend | array_prepend | SELECT array_prepend(array('b', 'd', 'c', 'a'), 'd') | struct> | | org.apache.spark.sql.catalyst.expressions.ArrayCompact | array_compact | SELECT array_compact(array(1, 2, 3, null)) | struct> | | org.apache.spark.sql.catalyst.expressions.ArrayContains | array_contains | SELECT array_contains(array(1, 2, 3), 2) | struct | | org.apache.spark.sql.catalyst.expressions.ArrayDistinct | array_distinct | SELECT array_distinct(array(1, 2, 3, null, 3)) | struct> | @@ -420,4 +421,4 @@ | org.apache.spark.sql.catalyst.expressions.xml.XPathList | xpath | SELECT xpath('b1b2b3c1c2','a/b/text()') | structb1b2b3c1c2, a/b/text()):array> | | org.apache.spark.sql.catalyst.expressions.xml.XPathLong | xpath_long | SELECT xpath_long('12', 'sum(a/b)') | struct12, sum(a/b)):bigint> | | org.apache.spark.sql.catalyst.expressions.xml.XPathShort | xpath_short | SELECT xpath_short('12', 'sum(a/b)') | struct12, sum(a/b)):smallint> | -| org.apache.spark.sql.catalyst.expressions.xml.XPathString | xpath_string | SELECT xpath_string('bcc','a/c') | structbcc, a/c):string> | \ No newline at end of file +| org.apache.spark.sql.catalyst.expressions.xml.XPathString | xpath_string | SELECT xpath_string('bcc','a/c') | structbcc, a/c):string> | diff --git a/sql/core/src/test/resources/sql-tests/inputs/array.sql b/sql/core/src/test/resources/sql-tests/inputs/array.sql index 3d107cb6dfc07..d3c36b79d1f3a 100644 --- a/sql/core/src/test/resources/sql-tests/inputs/array.sql +++ b/sql/core/src/test/resources/sql-tests/inputs/array.sql @@ -160,3 +160,14 @@ select array_append(CAST(null AS ARRAY), CAST(null as String)); select array_append(array(), 1); select array_append(CAST(array() AS ARRAY), CAST(NULL AS String)); select array_append(array(CAST(NULL AS String)), CAST(NULL AS String)); + +-- function array_prepend +select array_prepend(array(1, 2, 3), 4); +select array_prepend(array('a', 'b', 'c'), 'd'); +select array_prepend(array(1, 2, 3, NULL), NULL); +select array_prepend(array('a', 'b', 'c', NULL), NULL); +select array_prepend(CAST(null AS ARRAY), 'a'); +select array_prepend(CAST(null AS ARRAY), CAST(null as String)); +select array_prepend(array(), 1); +select array_prepend(CAST(array() AS ARRAY), CAST(NULL AS String)); +select array_prepend(array(CAST(NULL AS String)), CAST(NULL AS String)); diff --git a/sql/core/src/test/resources/sql-tests/results/ansi/array.sql.out b/sql/core/src/test/resources/sql-tests/results/ansi/array.sql.out index 0d8ef39ed60c6..d228c605705d6 100644 --- a/sql/core/src/test/resources/sql-tests/results/ansi/array.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/ansi/array.sql.out @@ -784,3 +784,75 @@ select array_append(array(CAST(NULL AS String)), CAST(NULL AS String)) struct> -- !query output [null,null] + + +-- !query +select array_prepend(array(1, 2, 3), 4) +-- !query schema +struct> +-- !query output +[4,1,2,3] + + +-- !query +select array_prepend(array('a', 'b', 'c'), 'd') +-- !query schema +struct> +-- !query output +["d","a","b","c"] + + +-- !query +select array_prepend(array(1, 2, 3, NULL), NULL) +-- !query schema +struct> +-- !query output +[null,1,2,3,null] + + +-- !query +select array_prepend(array('a', 'b', 'c', NULL), NULL) +-- !query schema +struct> +-- !query output +[null,"a","b","c",null] + + +-- !query +select array_prepend(CAST(null AS ARRAY), 'a') +-- !query schema +struct> +-- !query output +NULL + + +-- !query +select array_prepend(CAST(null AS ARRAY), CAST(null as String)) +-- !query schema +struct> +-- !query output +NULL + + +-- !query +select array_prepend(array(), 1) +-- !query schema +struct> +-- !query output +[1] + + +-- !query +select array_prepend(CAST(array() AS ARRAY), CAST(NULL AS String)) +-- !query schema +struct> +-- !query output +[null] + + +-- !query +select array_prepend(array(CAST(NULL AS String)), CAST(NULL AS String)) +-- !query schema +struct> +-- !query output +[null,null] From 4f3a9685303bcaf57dcdeb4ac017b25c357064e6 Mon Sep 17 00:00:00 2001 From: Navin Viswanath Date: Thu, 9 Feb 2023 20:44:09 -0800 Subject: [PATCH 08/68] Fix tests --- python/pyspark/sql/functions.py | 9 +-- .../expressions/collectionOperations.scala | 3 +- .../sql-functions/sql-expression-schema.md | 2 +- .../resources/sql-tests/results/array.sql.out | 72 +++++++++++++++++++ 4 files changed, 80 insertions(+), 6 deletions(-) diff --git a/python/pyspark/sql/functions.py b/python/pyspark/sql/functions.py index 548b0266d4efe..c8a709d27c7c1 100644 --- a/python/pyspark/sql/functions.py +++ b/python/pyspark/sql/functions.py @@ -7621,10 +7621,10 @@ def get(col: "ColumnOrName", index: Union["ColumnOrName", int]) -> Column: @try_remote_functions def array_prepend(col: "ColumnOrName", element: Any) -> Column: """ - Collection function: Returns an array containing element as - well as all elements from array. The new element is positioned + Collection function: Returns an array containing element as + well as all elements from array. The new element is positioned at the beginning of the array. - + .. versionadded:: 3.4.0 Parameters @@ -7636,6 +7636,7 @@ def array_prepend(col: "ColumnOrName", element: Any) -> Column: Returns ------- + :class:`~pyspark.sql.Column` an array excluding given value. @@ -7644,7 +7645,7 @@ def array_prepend(col: "ColumnOrName", element: Any) -> Column: >>> df = spark.createDataFrame([([2, 3, 4],), ([],)], ['data']) >>> df.select(array_prepend(df.data, 1)).collect() [Row(array_prepend(data, 1)=[1, 2, 3, 4]), Row(array_prepend(data, 1)=[1])] - """ + """ return _invoke_function("array_prepend", _to_java_column(col), element) @try_remote_functions diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala index b122a629585b2..6e2beda4bccd8 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala @@ -1413,6 +1413,7 @@ case class ArrayContains(left: Expression, right: Expression) case class ArrayPrepend(left: Expression, right: Expression) extends BinaryExpression with ImplicitCastInputTypes + with ComplexTypeMergingExpression with QueryErrorsBase { override def nullable: Boolean = left.nullable @@ -1533,7 +1534,7 @@ case class ArrayPrepend(left: Expression, right: Expression) (left.dataType, right.dataType) match { case (_, NullType) => Seq.empty case (ArrayType(e1, hasNull), e2) => - TypeCoercion.findWiderTypeWithoutStringPromotionForTwo(e1, e2) match { + TypeCoercion.findTightestCommonType(e1, e2) match { case Some(dt) => Seq(ArrayType(dt, hasNull), dt) case _ => Seq.empty } diff --git a/sql/core/src/test/resources/sql-functions/sql-expression-schema.md b/sql/core/src/test/resources/sql-functions/sql-expression-schema.md index cf355e11fc4ea..6146b7fcb9c06 100644 --- a/sql/core/src/test/resources/sql-functions/sql-expression-schema.md +++ b/sql/core/src/test/resources/sql-functions/sql-expression-schema.md @@ -13,7 +13,6 @@ | org.apache.spark.sql.catalyst.expressions.ArrayAggregate | aggregate | SELECT aggregate(array(1, 2, 3), 0, (acc, x) -> acc + x) | struct | | org.apache.spark.sql.catalyst.expressions.ArrayAggregate | reduce | SELECT reduce(array(1, 2, 3), 0, (acc, x) -> acc + x) | struct | | org.apache.spark.sql.catalyst.expressions.ArrayAppend | array_append | SELECT array_append(array('b', 'd', 'c', 'a'), 'd') | struct> | -| org.apache.spark.sql.catalyst.expressions.ArrayPrepend | array_prepend | SELECT array_prepend(array('b', 'd', 'c', 'a'), 'd') | struct> | | org.apache.spark.sql.catalyst.expressions.ArrayCompact | array_compact | SELECT array_compact(array(1, 2, 3, null)) | struct> | | org.apache.spark.sql.catalyst.expressions.ArrayContains | array_contains | SELECT array_contains(array(1, 2, 3), 2) | struct | | org.apache.spark.sql.catalyst.expressions.ArrayDistinct | array_distinct | SELECT array_distinct(array(1, 2, 3, null, 3)) | struct> | @@ -27,6 +26,7 @@ | org.apache.spark.sql.catalyst.expressions.ArrayMax | array_max | SELECT array_max(array(1, 20, null, 3)) | struct | | org.apache.spark.sql.catalyst.expressions.ArrayMin | array_min | SELECT array_min(array(1, 20, null, 3)) | struct | | org.apache.spark.sql.catalyst.expressions.ArrayPosition | array_position | SELECT array_position(array(3, 2, 1), 1) | struct | +| org.apache.spark.sql.catalyst.expressions.ArrayPrepend | array_prepend | SELECT array_prepend(array('b', 'd', 'c', 'a'), 'd') | struct> | | org.apache.spark.sql.catalyst.expressions.ArrayRemove | array_remove | SELECT array_remove(array(1, 2, 3, null, 3), 3) | struct> | | org.apache.spark.sql.catalyst.expressions.ArrayRepeat | array_repeat | SELECT array_repeat('123', 2) | struct> | | org.apache.spark.sql.catalyst.expressions.ArraySize | array_size | SELECT array_size(array('b', 'd', 'c', 'a')) | struct | diff --git a/sql/core/src/test/resources/sql-tests/results/array.sql.out b/sql/core/src/test/resources/sql-tests/results/array.sql.out index 609122a23d316..029bd767f54c4 100644 --- a/sql/core/src/test/resources/sql-tests/results/array.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/array.sql.out @@ -665,3 +665,75 @@ select array_append(array(CAST(NULL AS String)), CAST(NULL AS String)) struct> -- !query output [null,null] + + +-- !query +select array_prepend(array(1, 2, 3), 4) +-- !query schema +struct> +-- !query output +[4,1,2,3] + + +-- !query +select array_prepend(array('a', 'b', 'c'), 'd') +-- !query schema +struct> +-- !query output +["d","a","b","c"] + + +-- !query +select array_prepend(array(1, 2, 3, NULL), NULL) +-- !query schema +struct> +-- !query output +[null,1,2,3,null] + + +-- !query +select array_prepend(array('a', 'b', 'c', NULL), NULL) +-- !query schema +struct> +-- !query output +[null,"a","b","c",null] + + +-- !query +select array_prepend(CAST(null AS ARRAY), 'a') +-- !query schema +struct> +-- !query output +NULL + + +-- !query +select array_prepend(CAST(null AS ARRAY), CAST(null as String)) +-- !query schema +struct> +-- !query output +NULL + + +-- !query +select array_prepend(array(), 1) +-- !query schema +struct> +-- !query output +[1] + + +-- !query +select array_prepend(CAST(array() AS ARRAY), CAST(NULL AS String)) +-- !query schema +struct> +-- !query output +[null] + + +-- !query +select array_prepend(array(CAST(NULL AS String)), CAST(NULL AS String)) +-- !query schema +struct> +-- !query output +[null,null] From ec9ea76e9cfa7bf78bc277f447fdd6c9cb95f3e4 Mon Sep 17 00:00:00 2001 From: Navin Viswanath Date: Fri, 10 Feb 2023 22:03:34 -0800 Subject: [PATCH 09/68] Fix types --- python/pyspark/sql/functions.py | 2 +- .../spark/sql/catalyst/expressions/collectionOperations.scala | 1 - 2 files changed, 1 insertion(+), 2 deletions(-) diff --git a/python/pyspark/sql/functions.py b/python/pyspark/sql/functions.py index c8a709d27c7c1..f9230f5478eb5 100644 --- a/python/pyspark/sql/functions.py +++ b/python/pyspark/sql/functions.py @@ -7624,7 +7624,7 @@ def array_prepend(col: "ColumnOrName", element: Any) -> Column: Collection function: Returns an array containing element as well as all elements from array. The new element is positioned at the beginning of the array. - + .. versionadded:: 3.4.0 Parameters diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala index 6e2beda4bccd8..73be6327bca15 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala @@ -1532,7 +1532,6 @@ case class ArrayPrepend(left: Expression, right: Expression) } override def inputTypes: Seq[AbstractDataType] = { (left.dataType, right.dataType) match { - case (_, NullType) => Seq.empty case (ArrayType(e1, hasNull), e2) => TypeCoercion.findTightestCommonType(e1, e2) match { case Some(dt) => Seq(ArrayType(dt, hasNull), dt) From c728581bfad8e36b29d6bf74fc40a8ea8a3f5c6c Mon Sep 17 00:00:00 2001 From: Navin Viswanath Date: Sun, 26 Feb 2023 16:46:34 -0800 Subject: [PATCH 10/68] Fix tests --- python/pyspark/sql/functions.py | 1 - .../spark/sql/catalyst/expressions/collectionOperations.scala | 4 ++-- .../scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala | 2 +- 3 files changed, 3 insertions(+), 4 deletions(-) diff --git a/python/pyspark/sql/functions.py b/python/pyspark/sql/functions.py index f9230f5478eb5..294ec3669a98b 100644 --- a/python/pyspark/sql/functions.py +++ b/python/pyspark/sql/functions.py @@ -7636,7 +7636,6 @@ def array_prepend(col: "ColumnOrName", element: Any) -> Column: Returns ------- - :class:`~pyspark.sql.Column` an array excluding given value. diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala index 73be6327bca15..068d18f1727b3 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala @@ -1405,8 +1405,8 @@ case class ArrayContains(left: Expression, right: Expression) "_FUNC_(array, value) - Returns an array containing value as well as all elements from array. The new element is positioned at the beginning of the array.", examples = """ Examples: - > SELECT _FUNC_(array(1, 2, 3), 4); - [4, 1, 2, 3] + > SELECT _FUNC_(array('b', 'd', 'c', 'a'), 'd'); + ["d","b","d","c","a"] """, group = "array_funcs", since = "3.4.0") diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala index e342903962127..4fd350d8db265 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala @@ -2692,7 +2692,7 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSparkSession { }, errorClass = "DATATYPE_MISMATCH.UNEXPECTED_INPUT_TYPE", parameters = Map( - "paramIndex" -> "1", + "paramIndex" -> "0", "sqlExpr" -> "\"array_prepend(_1, _2)\"", "inputSql" -> "\"_1\"", "inputType" -> "\"STRING\"", From 6eba188c53abe689df84af448f0626679ed73708 Mon Sep 17 00:00:00 2001 From: Navin Viswanath Date: Sun, 26 Feb 2023 19:18:02 -0800 Subject: [PATCH 11/68] Fix python linter --- python/pyspark/sql/functions.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/python/pyspark/sql/functions.py b/python/pyspark/sql/functions.py index 294ec3669a98b..915470b06ca3e 100644 --- a/python/pyspark/sql/functions.py +++ b/python/pyspark/sql/functions.py @@ -7618,6 +7618,7 @@ def get(col: "ColumnOrName", index: Union["ColumnOrName", int]) -> Column: return _invoke_function_over_columns("get", col, index) + @try_remote_functions def array_prepend(col: "ColumnOrName", element: Any) -> Column: """ @@ -7644,9 +7645,10 @@ def array_prepend(col: "ColumnOrName", element: Any) -> Column: >>> df = spark.createDataFrame([([2, 3, 4],), ([],)], ['data']) >>> df.select(array_prepend(df.data, 1)).collect() [Row(array_prepend(data, 1)=[1, 2, 3, 4]), Row(array_prepend(data, 1)=[1])] - """ + """ return _invoke_function("array_prepend", _to_java_column(col), element) + @try_remote_functions def array_remove(col: "ColumnOrName", element: Any) -> Column: """ @@ -7677,6 +7679,7 @@ def array_remove(col: "ColumnOrName", element: Any) -> Column: """ return _invoke_function("array_remove", _to_java_column(col), element) + @try_remote_functions def array_distinct(col: "ColumnOrName") -> Column: """ From 63ff6cc605719bcbde4b9f30bd484b7c9e3ed575 Mon Sep 17 00:00:00 2001 From: Navin Viswanath Date: Sun, 26 Feb 2023 20:30:18 -0800 Subject: [PATCH 12/68] Add test for null cases --- .../expressions/collectionOperations.scala | 28 ++++++++----------- .../spark/sql/DataFrameFunctionsSuite.scala | 7 +++++ 2 files changed, 19 insertions(+), 16 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala index 068d18f1727b3..0a4680193e014 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala @@ -1418,6 +1418,9 @@ case class ArrayPrepend(left: Expression, right: Expression) override def nullable: Boolean = left.nullable + @transient protected lazy val elementType: DataType = + inputTypes.head.asInstanceOf[ArrayType].elementType + override def eval(input: InternalRow): Any = { val value1 = left.eval(input) if (value1 == null) { @@ -1427,23 +1430,16 @@ case class ArrayPrepend(left: Expression, right: Expression) nullSafeEval(value1, value2) } } - override def nullSafeEval(arr: Any, value: Any): Any = { - val numberOfElements = arr.asInstanceOf[ArrayData].numElements() - if (numberOfElements + 1 > ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH) { + override def nullSafeEval(arr: Any, elementData: Any): Any = { + val arrayData = arr.asInstanceOf[ArrayData] + val numberOfElements = arrayData.numElements() + 1 + if (numberOfElements > ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH) { throw QueryExecutionErrors.concatArraysWithElementsExceedLimitError(numberOfElements) } - val newArray = new Array[Any](numberOfElements + 1) - newArray(0) = value - var pos = 1 - arr - .asInstanceOf[ArrayData] - .foreach( - right.dataType, - (i, v) => { - newArray(pos) = v - pos += 1 - }) - new GenericArrayData(newArray) + val finalData = new Array[Any](numberOfElements) + finalData.update(0, elementData) + arrayData.foreach(elementType, (i: Int, v: Any) => finalData.update(i + 1, v)) + new GenericArrayData(finalData) } override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { val leftGen = left.genCode(ctx) @@ -1505,7 +1501,7 @@ case class ArrayPrepend(left: Expression, right: Expression) newLeft: Expression, newRight: Expression): ArrayPrepend = copy(left = newLeft, right = newRight) - override def dataType: DataType = left.dataType + override def dataType: DataType = if (right.nullable) left.dataType.asNullable else left.dataType override def checkInputDataTypes(): TypeCheckResult = { (left.dataType, right.dataType) match { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala index 4fd350d8db265..bc096f923fa2e 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala @@ -2710,6 +2710,13 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSparkSession { "leftType" -> "\"ARRAY\"", "rightType" -> "\"STRING\""), queryContext = Array(ExpectedContext("", "", 0, 30, "array_prepend(array(1, 2), '1')"))) + val df2 = Seq((Array[String]("a", "b", "c"), "d"), + (null, "d"), + (Array[String]("x", "y", "z"), null), + (null, null) + ).toDF("a", "b") + checkAnswer(df2.selectExpr("array_prepend(a, b)"), + Seq(Row(Seq("d", "a", "b", "c")), Row(null), Row(Seq(null, "x", "y", "z")), Row(null))) } test("array remove") { From 73a7dd78550b20321a0c6313c3b6e651848ae176 Mon Sep 17 00:00:00 2001 From: Navin Viswanath Date: Sun, 26 Feb 2023 20:35:43 -0800 Subject: [PATCH 13/68] Fix type of array --- .../sql/catalyst/expressions/collectionOperations.scala | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala index 0a4680193e014..d27f3d3f78517 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala @@ -1451,13 +1451,13 @@ case class ArrayPrepend(left: Expression, right: Expression) val pos = ctx.freshName("pos") val allocation = CodeGenerator.createArrayData( newArray, - right.dataType, + elementType, newArraySize, s" $prettyName failed.") val assignment = - CodeGenerator.createArrayAssignment(newArray, right.dataType, arr, pos, i, false) + CodeGenerator.createArrayAssignment(newArray, elementType, arr, pos, i, false) val newElemAssignment = - CodeGenerator.setArrayElement(newArray, right.dataType, pos, value, Some(rightGen.isNull)) + CodeGenerator.setArrayElement(newArray, elementType, pos, value, Some(rightGen.isNull)) s""" |int $pos = 0; |int $newArraySize = $arr.numElements() + 1; From 6f97761ed7596c66f943737bee22cacd5200fd95 Mon Sep 17 00:00:00 2001 From: Navin Viswanath Date: Fri, 27 Jan 2023 20:34:49 -0800 Subject: [PATCH 14/68] Adds a array_prepend expression to catalyst --- .../reference/pyspark.sql/functions.rst | 1 + python/pyspark/sql/functions.py | 27 ++++- .../catalyst/analysis/FunctionRegistry.scala | 12 ++ .../expressions/collectionOperations.scala | 113 ++++++++++++++++++ .../CollectionExpressionsSuite.scala | 41 +++++++ .../org/apache/spark/sql/functions.scala | 10 ++ .../spark/sql/DataFrameFunctionsSuite.scala | 52 ++++++++ 7 files changed, 255 insertions(+), 1 deletion(-) diff --git a/python/docs/source/reference/pyspark.sql/functions.rst b/python/docs/source/reference/pyspark.sql/functions.rst index 70fc04ef9cf23..cbc46e1fae18c 100644 --- a/python/docs/source/reference/pyspark.sql/functions.rst +++ b/python/docs/source/reference/pyspark.sql/functions.rst @@ -159,6 +159,7 @@ Collection Functions array_sort array_insert array_remove + array_prepend array_distinct array_intersect array_union diff --git a/python/pyspark/sql/functions.py b/python/pyspark/sql/functions.py index 8bee517de6af7..572465ff83467 100644 --- a/python/pyspark/sql/functions.py +++ b/python/pyspark/sql/functions.py @@ -7618,6 +7618,32 @@ def get(col: "ColumnOrName", index: Union["ColumnOrName", int]) -> Column: return _invoke_function_over_columns("get", col, index) +def array_prepend(col: "ColumnOrName", element: Any) -> Column: + """ + Collection function: Returns an array containing value as well as all elements from array. + The new element is positioned at the beginning of the array. + + .. versionadded:: 3.4.0 + + Parameters + ---------- + col : :class:`~pyspark.sql.Column` or str + name of column containing array + element : + element to be prepended to the array + + Returns + ------- + :class:`~pyspark.sql.Column` + an array excluding given value. + + Examples + -------- + >>> df = spark.createDataFrame([([2, 3, 4],), ([],)], ['data']) + >>> df.select(array_prepend(df.data, 1)).collect() + [Row(array_prepend(data, 1)=[1, 2, 3, 4]), Row(array_prepend(data, 1)=[1])] + """ + return _invoke_function("array_prepend", _to_java_column(col), element) @try_remote_functions def array_remove(col: "ColumnOrName", element: Any) -> Column: @@ -7649,7 +7675,6 @@ def array_remove(col: "ColumnOrName", element: Any) -> Column: """ return _invoke_function("array_remove", _to_java_column(col), element) - @try_remote_functions def array_distinct(col: "ColumnOrName") -> Column: """ diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala index d87cc0126cfa3..ce9e58722a2c1 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala @@ -696,6 +696,7 @@ object FunctionRegistry { expression[Sequence]("sequence"), expression[ArrayRepeat]("array_repeat"), expression[ArrayRemove]("array_remove"), + expression[ArrayPrepend]("array_prepend"), expression[ArrayDistinct]("array_distinct"), expression[ArrayTransform]("transform"), expression[MapFilter]("map_filter"), @@ -969,6 +970,7 @@ object TableFunctionRegistry { : (String, (ExpressionInfo, TableFunctionBuilder)) = { val (info, builder) = FunctionRegistryBase.build[T](name, since = None) val newBuilder = (expressions: Seq[Expression]) => { +<<<<<<< HEAD val generator = builder(expressions) assert(generator.isInstanceOf[Generator]) Generate( @@ -978,6 +980,16 @@ object TableFunctionRegistry { qualifier = None, generatorOutput = Nil, child = OneRowRelation()) +======= + try { + builder(expressions) + } catch { + case e: AnalysisException => + val argTypes = expressions.map(_.dataType.typeName).mkString(", ") + throw QueryCompilationErrors.cannotApplyTableValuedFunctionError( + name, argTypes, info.getUsage, e.getMessage) + } +>>>>>>> Revert "SPARK-41231: Adds an array_prepend function to catalyst" } (name, (info, newBuilder)) } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala index 92a3127d438ac..226b8fcdddd66 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala @@ -1399,6 +1399,119 @@ case class ArrayContains(left: Expression, right: Expression) copy(left = newLeft, right = newRight) } +// scalastyle:off line.size.limit +@ExpressionDescription( + usage = + "_FUNC_(array, value) - Returns an array containing value as well as all elements from array. The new element is positioned at the beginning of the array.", + examples = """ + Examples: + > SELECT _FUNC_(array(1, 2, 3), 4); + [4, 1, 2, 3] + """, + group = "array_funcs", + since = "3.4.0") +case class ArrayPrepend(left: Expression, right: Expression) + extends BinaryExpression + with ImplicitCastInputTypes + with NullIntolerant + with QueryErrorsBase { + + override def nullSafeEval(arr: Any, value: Any): Any = { + val numberOfElements = arr.asInstanceOf[ArrayData].numElements() + if (numberOfElements + 1 > ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH) { + throw QueryExecutionErrors.concatArraysWithElementsExceedLimitError(numberOfElements) + } + val newArray = new Array[Any](numberOfElements + 1) + newArray(0) = value + var pos = 1 + arr + .asInstanceOf[ArrayData] + .foreach( + right.dataType, + (i, v) => { + newArray(pos) = v + pos += 1 + }) + new GenericArrayData(newArray) + } + override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { + nullSafeCodeGen( + ctx, + ev, + (arr, value) => { + val newArraySize = ctx.freshName("newArraySize") + val newArray = ctx.freshName("newArray") + val i = ctx.freshName("i") + val pos = ctx.freshName("pos") + val allocation = CodeGenerator.createArrayData( + newArray, + right.dataType, + newArraySize, + s" $prettyName failed.") + val assignment = + CodeGenerator.createArrayAssignment(newArray, right.dataType, arr, pos, i, false) + val newElemAssignment = + CodeGenerator.setArrayElement(newArray, right.dataType, pos, value) + s""" + |int $pos = 0; + |int $newArraySize = $arr.numElements() + 1; + |$allocation + |$newElemAssignment + |$pos = $pos + 1; + |for (int $i = 0; $i < $arr.numElements(); $i ++) { + | $assignment + | $pos = $pos + 1; + |} + |${ev.value} = $newArray; + |""".stripMargin + }) + } + + override def prettyName: String = "array_prepend" + + override protected def withNewChildrenInternal( + newLeft: Expression, newRight: Expression): ArrayPrepend = + copy(left = newLeft, right = newRight) + override def dataType: DataType = left.dataType + override def checkInputDataTypes(): TypeCheckResult = { + (left.dataType, right.dataType) match { + case (_, NullType) | (NullType, _) => + DataTypeMismatch( + errorSubClass = "NULL_TYPE", + messageParameters = Map("functionName" -> toSQLId(prettyName))) + case (l, _) if !ArrayType.acceptsType(l) => + DataTypeMismatch( + errorSubClass = "UNEXPECTED_INPUT_TYPE", + messageParameters = Map( + "paramIndex" -> "1", + "requiredType" -> toSQLType(ArrayType), + "inputSql" -> toSQLExpr(left), + "inputType" -> toSQLType(left.dataType))) + case (ArrayType(e1, _), e2) if e1.sameType(e2) => + TypeUtils.checkForOrderingExpr(e2, prettyName) + case _ => + DataTypeMismatch( + errorSubClass = "ARRAY_FUNCTION_DIFF_TYPES", + messageParameters = Map( + "functionName" -> toSQLId(prettyName), + "dataType" -> toSQLType(ArrayType), + "leftType" -> toSQLType(left.dataType), + "rightType" -> toSQLType(right.dataType))) + } + } + override def inputTypes: Seq[AbstractDataType] = { + (left.dataType, right.dataType) match { + case (_, NullType) => Seq.empty + case (ArrayType(e1, hasNull), e2) => + TypeCoercion.findWiderTypeWithoutStringPromotionForTwo(e1, e2) match { + case Some(dt) => Seq(ArrayType(dt, hasNull), dt) + case _ => Seq.empty + } + case _ => Seq.empty + } + } +} + /** * Checks if the two arrays contain at least one common element. */ diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala index 9b97430594d04..56472a553af2f 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala @@ -1840,6 +1840,47 @@ class CollectionExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper checkEvaluation(ArrayRepeat(Literal("hi"), Literal(null, IntegerType)), null) } + test("SPARK-41233: ArrayPrepend") { + val a0 = Literal.create(Seq(1, 2, 3, 4), ArrayType(IntegerType)) + val a1 = Literal.create(Seq("a", "b", "c"), ArrayType(StringType)) + val a2 = Literal.create(Seq.empty[Integer], ArrayType(IntegerType)) + val a3 = Literal.create(null, ArrayType(StringType)) + + checkEvaluation(ArrayPrepend(a0, Literal(0)), Seq(0, 1, 2, 3, 4)) + checkEvaluation(ArrayPrepend(a1, Literal("a")), Seq("a", "a", "b", "c")) + checkEvaluation(ArrayPrepend(a2, Literal(1)), Seq(1)) + checkEvaluation(ArrayPrepend(a2, Literal(null, IntegerType)), null) + checkEvaluation(ArrayPrepend(a3, Literal("a")), null) + checkEvaluation(ArrayPrepend(a3, Literal(null, StringType)), null) + + // complex data types + val b0 = Literal.create( + Seq[Array[Byte]]( + Array[Byte](5, 6), + Array[Byte](1, 2), + Array[Byte](1, 2), + Array[Byte](5, 6)), + ArrayType(BinaryType)) + val b1 = Literal.create(Seq[Array[Byte]](Array[Byte](2, 1), null), ArrayType(BinaryType)) + val nullBinary = Literal.create(null, BinaryType) + checkEvaluation(ArrayPrepend(b0, nullBinary), null) + val dataToPrepend1 = Literal.create(Array[Byte](5, 6), BinaryType) + checkEvaluation( + ArrayPrepend(b1, dataToPrepend1), + Seq[Array[Byte]](Array[Byte](5, 6), Array[Byte](2, 1), null)) + + val c0 = Literal.create( + Seq[Seq[Int]](Seq[Int](1, 2), Seq[Int](3, 4)), + ArrayType(ArrayType(IntegerType))) + val dataToPrepend2 = Literal.create(Seq[Int](5, 6), ArrayType(IntegerType)) + checkEvaluation( + ArrayPrepend(c0, dataToPrepend2), + Seq(Seq[Int](5, 6), Seq[Int](1, 2), Seq[Int](3, 4))) + checkEvaluation( + ArrayPrepend(c0, Literal.create(Seq.empty[Int], ArrayType(IntegerType))), + Seq(Seq.empty[Int], Seq[Int](1, 2), Seq[Int](3, 4))) + } + test("Array remove") { val a0 = Literal.create(Seq(1, 2, 3, 2, 2, 5), ArrayType(IntegerType)) val a1 = Literal.create(Seq("b", "a", "a", "c", "b"), ArrayType(StringType)) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala index cb5c1ad5c4954..d2f1df8780b3f 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala @@ -4042,6 +4042,16 @@ object functions { */ def array_compact(column: Column): Column = withExpr { ArrayCompact(column.expr) + + /** + * Returns an array containing value as well as all elements from array. The new element is + * positioned at the beginning of the array. + * + * @group collection_funcs + * @since 3.4.0 + */ + def array_prepend(column: Column, element: Any): Column = withExpr { + ArrayPrepend(column.expr, lit(element).expr) } /** diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala index 6ed8299976c0a..f31278bae006a 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala @@ -2651,6 +2651,58 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSparkSession { ) } + test("SPARK-41233: array prepend") { + val df = Seq( + (Array[Int](2, 3, 4), Array("b", "c", "d"), Array("", ""), 2), + (Array.empty[Int], Array.empty[String], Array.empty[String], 2), + (null, null, null, 2)).toDF("a", "b", "c", "d") + checkAnswer( + df.select(array_prepend($"a", 1), array_prepend($"b", "a"), array_prepend($"c", "")), + Seq( + Row(Seq(1, 2, 3, 4), Seq("a", "b", "c", "d"), Seq("", "", "")), + Row(Seq(1), Seq("a"), Seq("")), + Row(null, null, null))) + checkAnswer( + df.select(array_prepend($"a", $"d")), + Seq(Row(Seq(2, 2, 3, 4)), Row(Seq(2)), Row(null))) + checkAnswer( + df.selectExpr("array_prepend(a, d)"), + Seq(Row(Seq(2, 2, 3, 4)), Row(Seq(2)), Row(null))) + checkAnswer( + OneRowRelation().selectExpr("array_prepend(array(1, 2), 1.23D)"), + Seq(Row(Seq(1.23, 1.0, 2.0)))) + checkAnswer( + df.selectExpr("array_prepend(a, 1)", "array_prepend(b, \"a\")", "array_prepend(c, \"\")"), + Seq( + Row(Seq(1, 2, 3, 4), Seq("a", "b", "c", "d"), Seq("", "", "")), + Row(Seq(1), Seq("a"), Seq("")), + Row(null, null, null))) + checkError( + exception = intercept[AnalysisException] { + Seq(("a string element", "a")).toDF().selectExpr("array_prepend(_1, _2)") + }, + errorClass = "DATATYPE_MISMATCH.UNEXPECTED_INPUT_TYPE", + parameters = Map( + "paramIndex" -> "1", + "sqlExpr" -> "\"array_prepend(_1, _2)\"", + "inputSql" -> "\"_1\"", + "inputType" -> "\"STRING\"", + "requiredType" -> "\"ARRAY\""), + queryContext = Array(ExpectedContext("", "", 0, 20, "array_prepend(_1, _2)"))) + checkError( + exception = intercept[AnalysisException] { + OneRowRelation().selectExpr("array_prepend(array(1, 2), '1')") + }, + errorClass = "DATATYPE_MISMATCH.ARRAY_FUNCTION_DIFF_TYPES", + parameters = Map( + "sqlExpr" -> "\"array_prepend(array(1, 2), 1)\"", + "functionName" -> "`array_prepend`", + "dataType" -> "\"ARRAY\"", + "leftType" -> "\"ARRAY\"", + "rightType" -> "\"STRING\""), + queryContext = Array(ExpectedContext("", "", 0, 30, "array_prepend(array(1, 2), '1')"))) + } + test("array remove") { val df = Seq( (Array[Int](2, 1, 2, 3), Array("a", "b", "c", "a"), Array("", ""), 2), From fec08e9fd1ae7aa2ce08e6ae4a054fadfb1143ad Mon Sep 17 00:00:00 2001 From: Navin Viswanath Date: Fri, 27 Jan 2023 20:36:47 -0800 Subject: [PATCH 15/68] Fix null handling --- .../expressions/collectionOperations.scala | 122 +++++++++++------- .../CollectionExpressionsSuite.scala | 19 +-- .../org/apache/spark/sql/functions.scala | 15 +++ 3 files changed, 101 insertions(+), 55 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala index 226b8fcdddd66..b122a629585b2 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala @@ -1413,9 +1413,19 @@ case class ArrayContains(left: Expression, right: Expression) case class ArrayPrepend(left: Expression, right: Expression) extends BinaryExpression with ImplicitCastInputTypes - with NullIntolerant with QueryErrorsBase { + override def nullable: Boolean = left.nullable + + override def eval(input: InternalRow): Any = { + val value1 = left.eval(input) + if (value1 == null) { + null + } else { + val value2 = right.eval(input) + nullSafeEval(value1, value2) + } + } override def nullSafeEval(arr: Any, value: Any): Any = { val numberOfElements = arr.asInstanceOf[ArrayData].numElements() if (numberOfElements + 1 > ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH) { @@ -1435,36 +1445,57 @@ case class ArrayPrepend(left: Expression, right: Expression) new GenericArrayData(newArray) } override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { - nullSafeCodeGen( - ctx, - ev, - (arr, value) => { - val newArraySize = ctx.freshName("newArraySize") - val newArray = ctx.freshName("newArray") - val i = ctx.freshName("i") - val pos = ctx.freshName("pos") - val allocation = CodeGenerator.createArrayData( - newArray, - right.dataType, - newArraySize, - s" $prettyName failed.") - val assignment = - CodeGenerator.createArrayAssignment(newArray, right.dataType, arr, pos, i, false) - val newElemAssignment = - CodeGenerator.setArrayElement(newArray, right.dataType, pos, value) + val leftGen = left.genCode(ctx) + val rightGen = right.genCode(ctx) + val f = (arr: String, value: String) => { + val newArraySize = ctx.freshName("newArraySize") + val newArray = ctx.freshName("newArray") + val i = ctx.freshName("i") + val pos = ctx.freshName("pos") + val allocation = CodeGenerator.createArrayData( + newArray, + right.dataType, + newArraySize, + s" $prettyName failed.") + val assignment = + CodeGenerator.createArrayAssignment(newArray, right.dataType, arr, pos, i, false) + val newElemAssignment = + CodeGenerator.setArrayElement(newArray, right.dataType, pos, value, Some(rightGen.isNull)) + s""" + |int $pos = 0; + |int $newArraySize = $arr.numElements() + 1; + |$allocation + |$newElemAssignment + |$pos = $pos + 1; + |for (int $i = 0; $i < $arr.numElements(); $i ++) { + | $assignment + | $pos = $pos + 1; + |} + |${ev.value} = $newArray; + |""".stripMargin + } + val resultCode = f(leftGen.value, rightGen.value) + if(nullable) { + val nullSafeEval = leftGen.code + rightGen.code + ctx.nullSafeExec(nullable, leftGen.isNull) { s""" - |int $pos = 0; - |int $newArraySize = $arr.numElements() + 1; - |$allocation - |$newElemAssignment - |$pos = $pos + 1; - |for (int $i = 0; $i < $arr.numElements(); $i ++) { - | $assignment - | $pos = $pos + 1; - |} - |${ev.value} = $newArray; + |${ev.isNull} = false; + |${resultCode} |""".stripMargin - }) + } + ev.copy(code = + code""" + boolean ${ev.isNull} = true; + ${CodeGenerator.javaType(dataType)} ${ev.value} = ${CodeGenerator.defaultValue(dataType)}; + $nullSafeEval + """) + } else { + ev.copy(code = + code""" + ${leftGen.code} + ${rightGen.code} + ${CodeGenerator.javaType(dataType)} ${ev.value} = ${CodeGenerator.defaultValue(dataType)}; + $resultCode""", isNull = FalseLiteral) + } } override def prettyName: String = "array_prepend" @@ -1472,31 +1503,30 @@ case class ArrayPrepend(left: Expression, right: Expression) override protected def withNewChildrenInternal( newLeft: Expression, newRight: Expression): ArrayPrepend = copy(left = newLeft, right = newRight) + override def dataType: DataType = left.dataType + override def checkInputDataTypes(): TypeCheckResult = { (left.dataType, right.dataType) match { - case (_, NullType) | (NullType, _) => - DataTypeMismatch( - errorSubClass = "NULL_TYPE", - messageParameters = Map("functionName" -> toSQLId(prettyName))) - case (l, _) if !ArrayType.acceptsType(l) => + case (ArrayType(e1, _), e2) if e1.sameType(e2) => TypeCheckResult.TypeCheckSuccess + case (ArrayType(e1, _), e2) => DataTypeMismatch( + errorSubClass = "ARRAY_FUNCTION_DIFF_TYPES", + messageParameters = Map( + "functionName" -> toSQLId(prettyName), + "leftType" -> toSQLType(left.dataType), + "rightType" -> toSQLType(right.dataType), + "dataType" -> toSQLType(ArrayType) + )) + case _ => DataTypeMismatch( errorSubClass = "UNEXPECTED_INPUT_TYPE", messageParameters = Map( - "paramIndex" -> "1", + "paramIndex" -> "0", "requiredType" -> toSQLType(ArrayType), "inputSql" -> toSQLExpr(left), - "inputType" -> toSQLType(left.dataType))) - case (ArrayType(e1, _), e2) if e1.sameType(e2) => - TypeUtils.checkForOrderingExpr(e2, prettyName) - case _ => - DataTypeMismatch( - errorSubClass = "ARRAY_FUNCTION_DIFF_TYPES", - messageParameters = Map( - "functionName" -> toSQLId(prettyName), - "dataType" -> toSQLType(ArrayType), - "leftType" -> toSQLType(left.dataType), - "rightType" -> toSQLType(right.dataType))) + "inputType" -> toSQLType(left.dataType) + ) + ) } } override def inputTypes: Seq[AbstractDataType] = { diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala index 56472a553af2f..dc8cc44a65356 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala @@ -20,10 +20,8 @@ package org.apache.spark.sql.catalyst.expressions import java.sql.{Date, Timestamp} import java.time.{Duration, LocalDateTime, Period} import java.util.TimeZone - import scala.language.implicitConversions import scala.util.Random - import org.apache.spark.{SparkFunSuite, SparkRuntimeException} import org.apache.spark.sql.Row import org.apache.spark.sql.catalyst.InternalRow @@ -1849,21 +1847,24 @@ class CollectionExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper checkEvaluation(ArrayPrepend(a0, Literal(0)), Seq(0, 1, 2, 3, 4)) checkEvaluation(ArrayPrepend(a1, Literal("a")), Seq("a", "a", "b", "c")) checkEvaluation(ArrayPrepend(a2, Literal(1)), Seq(1)) - checkEvaluation(ArrayPrepend(a2, Literal(null, IntegerType)), null) + checkEvaluation(ArrayPrepend(a2, Literal(null, IntegerType)), Seq(null)) checkEvaluation(ArrayPrepend(a3, Literal("a")), null) checkEvaluation(ArrayPrepend(a3, Literal(null, StringType)), null) // complex data types + val data = Seq[Array[Byte]]( + Array[Byte](5, 6), + Array[Byte](1, 2), + Array[Byte](1, 2), + Array[Byte](5, 6)) val b0 = Literal.create( - Seq[Array[Byte]]( - Array[Byte](5, 6), - Array[Byte](1, 2), - Array[Byte](1, 2), - Array[Byte](5, 6)), + data, ArrayType(BinaryType)) val b1 = Literal.create(Seq[Array[Byte]](Array[Byte](2, 1), null), ArrayType(BinaryType)) val nullBinary = Literal.create(null, BinaryType) - checkEvaluation(ArrayPrepend(b0, nullBinary), null) + // Calling ArrayPrepend with a null element should result in NULL being prepended to the array + val dataWithNullPrepended = null +: data + checkEvaluation(ArrayPrepend(b0, nullBinary), dataWithNullPrepended) val dataToPrepend1 = Literal.create(Array[Byte](5, 6), BinaryType) checkEvaluation( ArrayPrepend(b1, dataToPrepend1), diff --git a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala index d2f1df8780b3f..1f66a5daa2d26 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala @@ -4050,6 +4050,21 @@ object functions { * @group collection_funcs * @since 3.4.0 */ + def array_prepend(column: Column, element: Any): Column = withExpr { + ArrayPrepend(column.expr, lit(element).expr) + + /** + * Remove all null elements from the given array. + * + * @group collection_funcs + * @since 3.4.0 + */ + def array_compact(column: Column): Column = withExpr { + ArrayCompact(column.expr) + /** + * Returns an array containing value as well as all elements from array.The + * new element is positioned at the beginning of the array. + */ def array_prepend(column: Column, element: Any): Column = withExpr { ArrayPrepend(column.expr, lit(element).expr) } From a8da3455a17bf22cb1b6695c671117599521faa7 Mon Sep 17 00:00:00 2001 From: Navin Viswanath Date: Fri, 27 Jan 2023 20:41:01 -0800 Subject: [PATCH 16/68] Fix --- .../scala/org/apache/spark/sql/functions.scala | 17 +---------------- 1 file changed, 1 insertion(+), 16 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala index 1f66a5daa2d26..069e7b79fcb7a 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala @@ -4042,7 +4042,7 @@ object functions { */ def array_compact(column: Column): Column = withExpr { ArrayCompact(column.expr) - + } /** * Returns an array containing value as well as all elements from array. The new element is * positioned at the beginning of the array. @@ -4050,21 +4050,6 @@ object functions { * @group collection_funcs * @since 3.4.0 */ - def array_prepend(column: Column, element: Any): Column = withExpr { - ArrayPrepend(column.expr, lit(element).expr) - - /** - * Remove all null elements from the given array. - * - * @group collection_funcs - * @since 3.4.0 - */ - def array_compact(column: Column): Column = withExpr { - ArrayCompact(column.expr) - /** - * Returns an array containing value as well as all elements from array.The - * new element is positioned at the beginning of the array. - */ def array_prepend(column: Column, element: Any): Column = withExpr { ArrayPrepend(column.expr, lit(element).expr) } From 3af8fd96c156fe8c25686cedf250b55a7a07ef90 Mon Sep 17 00:00:00 2001 From: Navin Viswanath Date: Fri, 27 Jan 2023 20:41:01 -0800 Subject: [PATCH 17/68] Fix --- .../sql/catalyst/analysis/FunctionRegistry.scala | 11 ----------- 1 file changed, 11 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala index ce9e58722a2c1..472396bbef227 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala @@ -970,7 +970,6 @@ object TableFunctionRegistry { : (String, (ExpressionInfo, TableFunctionBuilder)) = { val (info, builder) = FunctionRegistryBase.build[T](name, since = None) val newBuilder = (expressions: Seq[Expression]) => { -<<<<<<< HEAD val generator = builder(expressions) assert(generator.isInstanceOf[Generator]) Generate( @@ -980,16 +979,6 @@ object TableFunctionRegistry { qualifier = None, generatorOutput = Nil, child = OneRowRelation()) -======= - try { - builder(expressions) - } catch { - case e: AnalysisException => - val argTypes = expressions.map(_.dataType.typeName).mkString(", ") - throw QueryCompilationErrors.cannotApplyTableValuedFunctionError( - name, argTypes, info.getUsage, e.getMessage) - } ->>>>>>> Revert "SPARK-41231: Adds an array_prepend function to catalyst" } (name, (info, newBuilder)) } From 2cd3e180e24239075b8469877924f2d8040e31e1 Mon Sep 17 00:00:00 2001 From: Navin Viswanath Date: Tue, 7 Feb 2023 20:45:20 -0800 Subject: [PATCH 18/68] Lint --- .../expressions/CollectionExpressionsSuite.scala | 1 + .../spark/sql/DataFrameFunctionsSuite.scala | 15 ++++++++++++--- 2 files changed, 13 insertions(+), 3 deletions(-) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala index dc8cc44a65356..9ace0cbf854b8 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala @@ -20,6 +20,7 @@ package org.apache.spark.sql.catalyst.expressions import java.sql.{Date, Timestamp} import java.time.{Duration, LocalDateTime, Period} import java.util.TimeZone + import scala.language.implicitConversions import scala.util.Random import org.apache.spark.{SparkFunSuite, SparkRuntimeException} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala index f31278bae006a..e342903962127 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala @@ -2664,13 +2664,22 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSparkSession { Row(null, null, null))) checkAnswer( df.select(array_prepend($"a", $"d")), - Seq(Row(Seq(2, 2, 3, 4)), Row(Seq(2)), Row(null))) + Seq( + Row(Seq(2, 2, 3, 4)), + Row(Seq(2)), + Row(null))) checkAnswer( df.selectExpr("array_prepend(a, d)"), - Seq(Row(Seq(2, 2, 3, 4)), Row(Seq(2)), Row(null))) + Seq( + Row(Seq(2, 2, 3, 4)), + Row(Seq(2)), + Row(null))) checkAnswer( OneRowRelation().selectExpr("array_prepend(array(1, 2), 1.23D)"), - Seq(Row(Seq(1.23, 1.0, 2.0)))) + Seq( + Row(Seq(1.23, 1.0, 2.0)) + ) + ) checkAnswer( df.selectExpr("array_prepend(a, 1)", "array_prepend(b, \"a\")", "array_prepend(c, \"\")"), Seq( From a307cb401887ac43fd78083ff035179552ab7e32 Mon Sep 17 00:00:00 2001 From: Navin Viswanath Date: Tue, 7 Feb 2023 20:47:55 -0800 Subject: [PATCH 19/68] Lint --- .../sql/catalyst/expressions/CollectionExpressionsSuite.scala | 1 + 1 file changed, 1 insertion(+) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala index 9ace0cbf854b8..667f717cce77c 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala @@ -23,6 +23,7 @@ import java.util.TimeZone import scala.language.implicitConversions import scala.util.Random + import org.apache.spark.{SparkFunSuite, SparkRuntimeException} import org.apache.spark.sql.Row import org.apache.spark.sql.catalyst.InternalRow From 7b2450011791c8037f320ae6c1341d7f52ecff44 Mon Sep 17 00:00:00 2001 From: Navin Viswanath Date: Wed, 8 Feb 2023 20:46:44 -0800 Subject: [PATCH 20/68] Add examples of usage and fix test --- python/pyspark/sql/functions.py | 6 +- .../sql-functions/sql-expression-schema.md | 3 +- .../test/resources/sql-tests/inputs/array.sql | 11 +++ .../sql-tests/results/ansi/array.sql.out | 72 +++++++++++++++++++ 4 files changed, 89 insertions(+), 3 deletions(-) diff --git a/python/pyspark/sql/functions.py b/python/pyspark/sql/functions.py index 572465ff83467..548b0266d4efe 100644 --- a/python/pyspark/sql/functions.py +++ b/python/pyspark/sql/functions.py @@ -7618,10 +7618,12 @@ def get(col: "ColumnOrName", index: Union["ColumnOrName", int]) -> Column: return _invoke_function_over_columns("get", col, index) +@try_remote_functions def array_prepend(col: "ColumnOrName", element: Any) -> Column: """ - Collection function: Returns an array containing value as well as all elements from array. - The new element is positioned at the beginning of the array. + Collection function: Returns an array containing element as + well as all elements from array. The new element is positioned + at the beginning of the array. .. versionadded:: 3.4.0 diff --git a/sql/core/src/test/resources/sql-functions/sql-expression-schema.md b/sql/core/src/test/resources/sql-functions/sql-expression-schema.md index 03ec4bce54b44..cf355e11fc4ea 100644 --- a/sql/core/src/test/resources/sql-functions/sql-expression-schema.md +++ b/sql/core/src/test/resources/sql-functions/sql-expression-schema.md @@ -13,6 +13,7 @@ | org.apache.spark.sql.catalyst.expressions.ArrayAggregate | aggregate | SELECT aggregate(array(1, 2, 3), 0, (acc, x) -> acc + x) | struct | | org.apache.spark.sql.catalyst.expressions.ArrayAggregate | reduce | SELECT reduce(array(1, 2, 3), 0, (acc, x) -> acc + x) | struct | | org.apache.spark.sql.catalyst.expressions.ArrayAppend | array_append | SELECT array_append(array('b', 'd', 'c', 'a'), 'd') | struct> | +| org.apache.spark.sql.catalyst.expressions.ArrayPrepend | array_prepend | SELECT array_prepend(array('b', 'd', 'c', 'a'), 'd') | struct> | | org.apache.spark.sql.catalyst.expressions.ArrayCompact | array_compact | SELECT array_compact(array(1, 2, 3, null)) | struct> | | org.apache.spark.sql.catalyst.expressions.ArrayContains | array_contains | SELECT array_contains(array(1, 2, 3), 2) | struct | | org.apache.spark.sql.catalyst.expressions.ArrayDistinct | array_distinct | SELECT array_distinct(array(1, 2, 3, null, 3)) | struct> | @@ -420,4 +421,4 @@ | org.apache.spark.sql.catalyst.expressions.xml.XPathList | xpath | SELECT xpath('b1b2b3c1c2','a/b/text()') | structb1b2b3c1c2, a/b/text()):array> | | org.apache.spark.sql.catalyst.expressions.xml.XPathLong | xpath_long | SELECT xpath_long('12', 'sum(a/b)') | struct12, sum(a/b)):bigint> | | org.apache.spark.sql.catalyst.expressions.xml.XPathShort | xpath_short | SELECT xpath_short('12', 'sum(a/b)') | struct12, sum(a/b)):smallint> | -| org.apache.spark.sql.catalyst.expressions.xml.XPathString | xpath_string | SELECT xpath_string('bcc','a/c') | structbcc, a/c):string> | \ No newline at end of file +| org.apache.spark.sql.catalyst.expressions.xml.XPathString | xpath_string | SELECT xpath_string('bcc','a/c') | structbcc, a/c):string> | diff --git a/sql/core/src/test/resources/sql-tests/inputs/array.sql b/sql/core/src/test/resources/sql-tests/inputs/array.sql index 3d107cb6dfc07..d3c36b79d1f3a 100644 --- a/sql/core/src/test/resources/sql-tests/inputs/array.sql +++ b/sql/core/src/test/resources/sql-tests/inputs/array.sql @@ -160,3 +160,14 @@ select array_append(CAST(null AS ARRAY), CAST(null as String)); select array_append(array(), 1); select array_append(CAST(array() AS ARRAY), CAST(NULL AS String)); select array_append(array(CAST(NULL AS String)), CAST(NULL AS String)); + +-- function array_prepend +select array_prepend(array(1, 2, 3), 4); +select array_prepend(array('a', 'b', 'c'), 'd'); +select array_prepend(array(1, 2, 3, NULL), NULL); +select array_prepend(array('a', 'b', 'c', NULL), NULL); +select array_prepend(CAST(null AS ARRAY), 'a'); +select array_prepend(CAST(null AS ARRAY), CAST(null as String)); +select array_prepend(array(), 1); +select array_prepend(CAST(array() AS ARRAY), CAST(NULL AS String)); +select array_prepend(array(CAST(NULL AS String)), CAST(NULL AS String)); diff --git a/sql/core/src/test/resources/sql-tests/results/ansi/array.sql.out b/sql/core/src/test/resources/sql-tests/results/ansi/array.sql.out index 0d8ef39ed60c6..d228c605705d6 100644 --- a/sql/core/src/test/resources/sql-tests/results/ansi/array.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/ansi/array.sql.out @@ -784,3 +784,75 @@ select array_append(array(CAST(NULL AS String)), CAST(NULL AS String)) struct> -- !query output [null,null] + + +-- !query +select array_prepend(array(1, 2, 3), 4) +-- !query schema +struct> +-- !query output +[4,1,2,3] + + +-- !query +select array_prepend(array('a', 'b', 'c'), 'd') +-- !query schema +struct> +-- !query output +["d","a","b","c"] + + +-- !query +select array_prepend(array(1, 2, 3, NULL), NULL) +-- !query schema +struct> +-- !query output +[null,1,2,3,null] + + +-- !query +select array_prepend(array('a', 'b', 'c', NULL), NULL) +-- !query schema +struct> +-- !query output +[null,"a","b","c",null] + + +-- !query +select array_prepend(CAST(null AS ARRAY), 'a') +-- !query schema +struct> +-- !query output +NULL + + +-- !query +select array_prepend(CAST(null AS ARRAY), CAST(null as String)) +-- !query schema +struct> +-- !query output +NULL + + +-- !query +select array_prepend(array(), 1) +-- !query schema +struct> +-- !query output +[1] + + +-- !query +select array_prepend(CAST(array() AS ARRAY), CAST(NULL AS String)) +-- !query schema +struct> +-- !query output +[null] + + +-- !query +select array_prepend(array(CAST(NULL AS String)), CAST(NULL AS String)) +-- !query schema +struct> +-- !query output +[null,null] From f29996292e022dcabd8ca0a306e82901b5476fd4 Mon Sep 17 00:00:00 2001 From: Navin Viswanath Date: Thu, 9 Feb 2023 20:44:09 -0800 Subject: [PATCH 21/68] Fix tests --- python/pyspark/sql/functions.py | 9 +-- .../expressions/collectionOperations.scala | 3 +- .../sql-functions/sql-expression-schema.md | 2 +- .../resources/sql-tests/results/array.sql.out | 72 +++++++++++++++++++ 4 files changed, 80 insertions(+), 6 deletions(-) diff --git a/python/pyspark/sql/functions.py b/python/pyspark/sql/functions.py index 548b0266d4efe..c8a709d27c7c1 100644 --- a/python/pyspark/sql/functions.py +++ b/python/pyspark/sql/functions.py @@ -7621,10 +7621,10 @@ def get(col: "ColumnOrName", index: Union["ColumnOrName", int]) -> Column: @try_remote_functions def array_prepend(col: "ColumnOrName", element: Any) -> Column: """ - Collection function: Returns an array containing element as - well as all elements from array. The new element is positioned + Collection function: Returns an array containing element as + well as all elements from array. The new element is positioned at the beginning of the array. - + .. versionadded:: 3.4.0 Parameters @@ -7636,6 +7636,7 @@ def array_prepend(col: "ColumnOrName", element: Any) -> Column: Returns ------- + :class:`~pyspark.sql.Column` an array excluding given value. @@ -7644,7 +7645,7 @@ def array_prepend(col: "ColumnOrName", element: Any) -> Column: >>> df = spark.createDataFrame([([2, 3, 4],), ([],)], ['data']) >>> df.select(array_prepend(df.data, 1)).collect() [Row(array_prepend(data, 1)=[1, 2, 3, 4]), Row(array_prepend(data, 1)=[1])] - """ + """ return _invoke_function("array_prepend", _to_java_column(col), element) @try_remote_functions diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala index b122a629585b2..6e2beda4bccd8 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala @@ -1413,6 +1413,7 @@ case class ArrayContains(left: Expression, right: Expression) case class ArrayPrepend(left: Expression, right: Expression) extends BinaryExpression with ImplicitCastInputTypes + with ComplexTypeMergingExpression with QueryErrorsBase { override def nullable: Boolean = left.nullable @@ -1533,7 +1534,7 @@ case class ArrayPrepend(left: Expression, right: Expression) (left.dataType, right.dataType) match { case (_, NullType) => Seq.empty case (ArrayType(e1, hasNull), e2) => - TypeCoercion.findWiderTypeWithoutStringPromotionForTwo(e1, e2) match { + TypeCoercion.findTightestCommonType(e1, e2) match { case Some(dt) => Seq(ArrayType(dt, hasNull), dt) case _ => Seq.empty } diff --git a/sql/core/src/test/resources/sql-functions/sql-expression-schema.md b/sql/core/src/test/resources/sql-functions/sql-expression-schema.md index cf355e11fc4ea..6146b7fcb9c06 100644 --- a/sql/core/src/test/resources/sql-functions/sql-expression-schema.md +++ b/sql/core/src/test/resources/sql-functions/sql-expression-schema.md @@ -13,7 +13,6 @@ | org.apache.spark.sql.catalyst.expressions.ArrayAggregate | aggregate | SELECT aggregate(array(1, 2, 3), 0, (acc, x) -> acc + x) | struct | | org.apache.spark.sql.catalyst.expressions.ArrayAggregate | reduce | SELECT reduce(array(1, 2, 3), 0, (acc, x) -> acc + x) | struct | | org.apache.spark.sql.catalyst.expressions.ArrayAppend | array_append | SELECT array_append(array('b', 'd', 'c', 'a'), 'd') | struct> | -| org.apache.spark.sql.catalyst.expressions.ArrayPrepend | array_prepend | SELECT array_prepend(array('b', 'd', 'c', 'a'), 'd') | struct> | | org.apache.spark.sql.catalyst.expressions.ArrayCompact | array_compact | SELECT array_compact(array(1, 2, 3, null)) | struct> | | org.apache.spark.sql.catalyst.expressions.ArrayContains | array_contains | SELECT array_contains(array(1, 2, 3), 2) | struct | | org.apache.spark.sql.catalyst.expressions.ArrayDistinct | array_distinct | SELECT array_distinct(array(1, 2, 3, null, 3)) | struct> | @@ -27,6 +26,7 @@ | org.apache.spark.sql.catalyst.expressions.ArrayMax | array_max | SELECT array_max(array(1, 20, null, 3)) | struct | | org.apache.spark.sql.catalyst.expressions.ArrayMin | array_min | SELECT array_min(array(1, 20, null, 3)) | struct | | org.apache.spark.sql.catalyst.expressions.ArrayPosition | array_position | SELECT array_position(array(3, 2, 1), 1) | struct | +| org.apache.spark.sql.catalyst.expressions.ArrayPrepend | array_prepend | SELECT array_prepend(array('b', 'd', 'c', 'a'), 'd') | struct> | | org.apache.spark.sql.catalyst.expressions.ArrayRemove | array_remove | SELECT array_remove(array(1, 2, 3, null, 3), 3) | struct> | | org.apache.spark.sql.catalyst.expressions.ArrayRepeat | array_repeat | SELECT array_repeat('123', 2) | struct> | | org.apache.spark.sql.catalyst.expressions.ArraySize | array_size | SELECT array_size(array('b', 'd', 'c', 'a')) | struct | diff --git a/sql/core/src/test/resources/sql-tests/results/array.sql.out b/sql/core/src/test/resources/sql-tests/results/array.sql.out index 609122a23d316..029bd767f54c4 100644 --- a/sql/core/src/test/resources/sql-tests/results/array.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/array.sql.out @@ -665,3 +665,75 @@ select array_append(array(CAST(NULL AS String)), CAST(NULL AS String)) struct> -- !query output [null,null] + + +-- !query +select array_prepend(array(1, 2, 3), 4) +-- !query schema +struct> +-- !query output +[4,1,2,3] + + +-- !query +select array_prepend(array('a', 'b', 'c'), 'd') +-- !query schema +struct> +-- !query output +["d","a","b","c"] + + +-- !query +select array_prepend(array(1, 2, 3, NULL), NULL) +-- !query schema +struct> +-- !query output +[null,1,2,3,null] + + +-- !query +select array_prepend(array('a', 'b', 'c', NULL), NULL) +-- !query schema +struct> +-- !query output +[null,"a","b","c",null] + + +-- !query +select array_prepend(CAST(null AS ARRAY), 'a') +-- !query schema +struct> +-- !query output +NULL + + +-- !query +select array_prepend(CAST(null AS ARRAY), CAST(null as String)) +-- !query schema +struct> +-- !query output +NULL + + +-- !query +select array_prepend(array(), 1) +-- !query schema +struct> +-- !query output +[1] + + +-- !query +select array_prepend(CAST(array() AS ARRAY), CAST(NULL AS String)) +-- !query schema +struct> +-- !query output +[null] + + +-- !query +select array_prepend(array(CAST(NULL AS String)), CAST(NULL AS String)) +-- !query schema +struct> +-- !query output +[null,null] From 7ce00b8bd561007bb00868f45748d9372ec35699 Mon Sep 17 00:00:00 2001 From: Navin Viswanath Date: Fri, 10 Feb 2023 22:03:34 -0800 Subject: [PATCH 22/68] Fix types --- python/pyspark/sql/functions.py | 2 +- .../spark/sql/catalyst/expressions/collectionOperations.scala | 1 - 2 files changed, 1 insertion(+), 2 deletions(-) diff --git a/python/pyspark/sql/functions.py b/python/pyspark/sql/functions.py index c8a709d27c7c1..f9230f5478eb5 100644 --- a/python/pyspark/sql/functions.py +++ b/python/pyspark/sql/functions.py @@ -7624,7 +7624,7 @@ def array_prepend(col: "ColumnOrName", element: Any) -> Column: Collection function: Returns an array containing element as well as all elements from array. The new element is positioned at the beginning of the array. - + .. versionadded:: 3.4.0 Parameters diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala index 6e2beda4bccd8..73be6327bca15 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala @@ -1532,7 +1532,6 @@ case class ArrayPrepend(left: Expression, right: Expression) } override def inputTypes: Seq[AbstractDataType] = { (left.dataType, right.dataType) match { - case (_, NullType) => Seq.empty case (ArrayType(e1, hasNull), e2) => TypeCoercion.findTightestCommonType(e1, e2) match { case Some(dt) => Seq(ArrayType(dt, hasNull), dt) From 505b8e23f51d5a89e4e623f3446b009e76c0a3c2 Mon Sep 17 00:00:00 2001 From: Navin Viswanath Date: Sun, 26 Feb 2023 16:46:34 -0800 Subject: [PATCH 23/68] Fix tests --- python/pyspark/sql/functions.py | 1 - .../spark/sql/catalyst/expressions/collectionOperations.scala | 4 ++-- .../scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala | 2 +- 3 files changed, 3 insertions(+), 4 deletions(-) diff --git a/python/pyspark/sql/functions.py b/python/pyspark/sql/functions.py index f9230f5478eb5..294ec3669a98b 100644 --- a/python/pyspark/sql/functions.py +++ b/python/pyspark/sql/functions.py @@ -7636,7 +7636,6 @@ def array_prepend(col: "ColumnOrName", element: Any) -> Column: Returns ------- - :class:`~pyspark.sql.Column` an array excluding given value. diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala index 73be6327bca15..068d18f1727b3 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala @@ -1405,8 +1405,8 @@ case class ArrayContains(left: Expression, right: Expression) "_FUNC_(array, value) - Returns an array containing value as well as all elements from array. The new element is positioned at the beginning of the array.", examples = """ Examples: - > SELECT _FUNC_(array(1, 2, 3), 4); - [4, 1, 2, 3] + > SELECT _FUNC_(array('b', 'd', 'c', 'a'), 'd'); + ["d","b","d","c","a"] """, group = "array_funcs", since = "3.4.0") diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala index e342903962127..4fd350d8db265 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala @@ -2692,7 +2692,7 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSparkSession { }, errorClass = "DATATYPE_MISMATCH.UNEXPECTED_INPUT_TYPE", parameters = Map( - "paramIndex" -> "1", + "paramIndex" -> "0", "sqlExpr" -> "\"array_prepend(_1, _2)\"", "inputSql" -> "\"_1\"", "inputType" -> "\"STRING\"", From 216ca4c32477fe438420705c6f9fe0b197aaf4fc Mon Sep 17 00:00:00 2001 From: Navin Viswanath Date: Sun, 26 Feb 2023 19:18:02 -0800 Subject: [PATCH 24/68] Fix python linter --- python/pyspark/sql/functions.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/python/pyspark/sql/functions.py b/python/pyspark/sql/functions.py index 294ec3669a98b..915470b06ca3e 100644 --- a/python/pyspark/sql/functions.py +++ b/python/pyspark/sql/functions.py @@ -7618,6 +7618,7 @@ def get(col: "ColumnOrName", index: Union["ColumnOrName", int]) -> Column: return _invoke_function_over_columns("get", col, index) + @try_remote_functions def array_prepend(col: "ColumnOrName", element: Any) -> Column: """ @@ -7644,9 +7645,10 @@ def array_prepend(col: "ColumnOrName", element: Any) -> Column: >>> df = spark.createDataFrame([([2, 3, 4],), ([],)], ['data']) >>> df.select(array_prepend(df.data, 1)).collect() [Row(array_prepend(data, 1)=[1, 2, 3, 4]), Row(array_prepend(data, 1)=[1])] - """ + """ return _invoke_function("array_prepend", _to_java_column(col), element) + @try_remote_functions def array_remove(col: "ColumnOrName", element: Any) -> Column: """ @@ -7677,6 +7679,7 @@ def array_remove(col: "ColumnOrName", element: Any) -> Column: """ return _invoke_function("array_remove", _to_java_column(col), element) + @try_remote_functions def array_distinct(col: "ColumnOrName") -> Column: """ From c121168e60f821fd013c38c0af6897949ef16319 Mon Sep 17 00:00:00 2001 From: Navin Viswanath Date: Sun, 26 Feb 2023 20:30:18 -0800 Subject: [PATCH 25/68] Add test for null cases --- .../expressions/collectionOperations.scala | 28 ++++++++----------- .../spark/sql/DataFrameFunctionsSuite.scala | 7 +++++ 2 files changed, 19 insertions(+), 16 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala index 068d18f1727b3..0a4680193e014 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala @@ -1418,6 +1418,9 @@ case class ArrayPrepend(left: Expression, right: Expression) override def nullable: Boolean = left.nullable + @transient protected lazy val elementType: DataType = + inputTypes.head.asInstanceOf[ArrayType].elementType + override def eval(input: InternalRow): Any = { val value1 = left.eval(input) if (value1 == null) { @@ -1427,23 +1430,16 @@ case class ArrayPrepend(left: Expression, right: Expression) nullSafeEval(value1, value2) } } - override def nullSafeEval(arr: Any, value: Any): Any = { - val numberOfElements = arr.asInstanceOf[ArrayData].numElements() - if (numberOfElements + 1 > ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH) { + override def nullSafeEval(arr: Any, elementData: Any): Any = { + val arrayData = arr.asInstanceOf[ArrayData] + val numberOfElements = arrayData.numElements() + 1 + if (numberOfElements > ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH) { throw QueryExecutionErrors.concatArraysWithElementsExceedLimitError(numberOfElements) } - val newArray = new Array[Any](numberOfElements + 1) - newArray(0) = value - var pos = 1 - arr - .asInstanceOf[ArrayData] - .foreach( - right.dataType, - (i, v) => { - newArray(pos) = v - pos += 1 - }) - new GenericArrayData(newArray) + val finalData = new Array[Any](numberOfElements) + finalData.update(0, elementData) + arrayData.foreach(elementType, (i: Int, v: Any) => finalData.update(i + 1, v)) + new GenericArrayData(finalData) } override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { val leftGen = left.genCode(ctx) @@ -1505,7 +1501,7 @@ case class ArrayPrepend(left: Expression, right: Expression) newLeft: Expression, newRight: Expression): ArrayPrepend = copy(left = newLeft, right = newRight) - override def dataType: DataType = left.dataType + override def dataType: DataType = if (right.nullable) left.dataType.asNullable else left.dataType override def checkInputDataTypes(): TypeCheckResult = { (left.dataType, right.dataType) match { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala index 4fd350d8db265..bc096f923fa2e 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala @@ -2710,6 +2710,13 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSparkSession { "leftType" -> "\"ARRAY\"", "rightType" -> "\"STRING\""), queryContext = Array(ExpectedContext("", "", 0, 30, "array_prepend(array(1, 2), '1')"))) + val df2 = Seq((Array[String]("a", "b", "c"), "d"), + (null, "d"), + (Array[String]("x", "y", "z"), null), + (null, null) + ).toDF("a", "b") + checkAnswer(df2.selectExpr("array_prepend(a, b)"), + Seq(Row(Seq("d", "a", "b", "c")), Row(null), Row(Seq(null, "x", "y", "z")), Row(null))) } test("array remove") { From ec503f9f7b718b02475eacb1fb6a7ce0ed52362c Mon Sep 17 00:00:00 2001 From: Navin Viswanath Date: Sun, 26 Feb 2023 20:35:43 -0800 Subject: [PATCH 26/68] Fix type of array --- .../sql/catalyst/expressions/collectionOperations.scala | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala index 0a4680193e014..d27f3d3f78517 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala @@ -1451,13 +1451,13 @@ case class ArrayPrepend(left: Expression, right: Expression) val pos = ctx.freshName("pos") val allocation = CodeGenerator.createArrayData( newArray, - right.dataType, + elementType, newArraySize, s" $prettyName failed.") val assignment = - CodeGenerator.createArrayAssignment(newArray, right.dataType, arr, pos, i, false) + CodeGenerator.createArrayAssignment(newArray, elementType, arr, pos, i, false) val newElemAssignment = - CodeGenerator.setArrayElement(newArray, right.dataType, pos, value, Some(rightGen.isNull)) + CodeGenerator.setArrayElement(newArray, elementType, pos, value, Some(rightGen.isNull)) s""" |int $pos = 0; |int $newArraySize = $arr.numElements() + 1; From f1c01860c087cacc6025a9be9e2900502e29e5a4 Mon Sep 17 00:00:00 2001 From: Navin Viswanath Date: Tue, 28 Feb 2023 21:06:54 -0800 Subject: [PATCH 27/68] Address comments --- python/pyspark/sql/functions.py | 8 +++---- .../expressions/collectionOperations.scala | 22 ++++++++++++------- 2 files changed, 18 insertions(+), 12 deletions(-) diff --git a/python/pyspark/sql/functions.py b/python/pyspark/sql/functions.py index 915470b06ca3e..12142b9e2ca0c 100644 --- a/python/pyspark/sql/functions.py +++ b/python/pyspark/sql/functions.py @@ -7620,7 +7620,7 @@ def get(col: "ColumnOrName", index: Union["ColumnOrName", int]) -> Column: @try_remote_functions -def array_prepend(col: "ColumnOrName", element: Any) -> Column: +def array_prepend(col: "ColumnOrName", value: Any) -> Column: """ Collection function: Returns an array containing element as well as all elements from array. The new element is positioned @@ -7632,8 +7632,8 @@ def array_prepend(col: "ColumnOrName", element: Any) -> Column: ---------- col : :class:`~pyspark.sql.Column` or str name of column containing array - element : - element to be prepended to the array + value : + a literal value, or a :class:`~pyspark.sql.Column` expression. Returns ------- @@ -7646,7 +7646,7 @@ def array_prepend(col: "ColumnOrName", element: Any) -> Column: >>> df.select(array_prepend(df.data, 1)).collect() [Row(array_prepend(data, 1)=[1, 2, 3, 4]), Row(array_prepend(data, 1)=[1])] """ - return _invoke_function("array_prepend", _to_java_column(col), element) + return _invoke_function_over_columns("array_prepend", col, lit(value)) @try_remote_functions diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala index d27f3d3f78517..62b1c5afaa083 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala @@ -1401,12 +1401,20 @@ case class ArrayContains(left: Expression, right: Expression) // scalastyle:off line.size.limit @ExpressionDescription( - usage = - "_FUNC_(array, value) - Returns an array containing value as well as all elements from array. The new element is positioned at the beginning of the array.", + usage = """ + _FUNC_(array, element) - Add the element at the beginning of the array passed as first + argument. Type of element should be similar to type of the elements of the array. + Null element is also prepended to the array. But if the array passed is NULL + output is NULL + """, examples = """ Examples: > SELECT _FUNC_(array('b', 'd', 'c', 'a'), 'd'); ["d","b","d","c","a"] + > SELECT _FUNC_(array(1, 2, 3, null), null); + [null,1,2,3,null] + > SELECT _FUNC_(CAST(null as Array), 2); + NULL """, group = "array_funcs", since = "3.4.0") @@ -1448,25 +1456,23 @@ case class ArrayPrepend(left: Expression, right: Expression) val newArraySize = ctx.freshName("newArraySize") val newArray = ctx.freshName("newArray") val i = ctx.freshName("i") - val pos = ctx.freshName("pos") + val iPlus1 = s"$i+1" + val zero = "0" val allocation = CodeGenerator.createArrayData( newArray, elementType, newArraySize, s" $prettyName failed.") val assignment = - CodeGenerator.createArrayAssignment(newArray, elementType, arr, pos, i, false) + CodeGenerator.createArrayAssignment(newArray, elementType, arr, iPlus1, i, false) val newElemAssignment = - CodeGenerator.setArrayElement(newArray, elementType, pos, value, Some(rightGen.isNull)) + CodeGenerator.setArrayElement(newArray, elementType, zero, value, Some(rightGen.isNull)) s""" - |int $pos = 0; |int $newArraySize = $arr.numElements() + 1; |$allocation |$newElemAssignment - |$pos = $pos + 1; |for (int $i = 0; $i < $arr.numElements(); $i ++) { | $assignment - | $pos = $pos + 1; |} |${ev.value} = $newArray; |""".stripMargin From 34cb724a3538717b412eb549038ab841e3185437 Mon Sep 17 00:00:00 2001 From: Navin Viswanath Date: Sun, 12 Mar 2023 18:54:50 -0700 Subject: [PATCH 28/68] Update version --- python/pyspark/sql/functions.py | 2 +- .../spark/sql/catalyst/expressions/collectionOperations.scala | 2 +- sql/core/src/main/scala/org/apache/spark/sql/functions.scala | 2 +- 3 files changed, 3 insertions(+), 3 deletions(-) diff --git a/python/pyspark/sql/functions.py b/python/pyspark/sql/functions.py index 12142b9e2ca0c..dac7cddb880aa 100644 --- a/python/pyspark/sql/functions.py +++ b/python/pyspark/sql/functions.py @@ -7626,7 +7626,7 @@ def array_prepend(col: "ColumnOrName", value: Any) -> Column: well as all elements from array. The new element is positioned at the beginning of the array. - .. versionadded:: 3.4.0 + .. versionadded:: 3.5.0 Parameters ---------- diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala index 62b1c5afaa083..66efec732fe3b 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala @@ -1417,7 +1417,7 @@ case class ArrayContains(left: Expression, right: Expression) NULL """, group = "array_funcs", - since = "3.4.0") + since = "3.5.0") case class ArrayPrepend(left: Expression, right: Expression) extends BinaryExpression with ImplicitCastInputTypes diff --git a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala index 069e7b79fcb7a..d771367f318ce 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala @@ -4048,7 +4048,7 @@ object functions { * positioned at the beginning of the array. * * @group collection_funcs - * @since 3.4.0 + * @since 3.5.0 */ def array_prepend(column: Column, element: Any): Column = withExpr { ArrayPrepend(column.expr, lit(element).expr) From baa6cc730af9389e4f695c6d3157642e1d238414 Mon Sep 17 00:00:00 2001 From: Navin Viswanath Date: Wed, 15 Mar 2023 22:34:27 -0700 Subject: [PATCH 29/68] Address review comments --- .../expressions/collectionOperations.scala | 23 ++++++++++--------- 1 file changed, 12 insertions(+), 11 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala index 66efec732fe3b..b621f3df100ca 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala @@ -1403,7 +1403,7 @@ case class ArrayContains(left: Expression, right: Expression) @ExpressionDescription( usage = """ _FUNC_(array, element) - Add the element at the beginning of the array passed as first - argument. Type of element should be similar to type of the elements of the array. + argument. Type of element should be the same as the type of the elements of the array. Null element is also prepended to the array. But if the array passed is NULL output is NULL """, @@ -1453,7 +1453,7 @@ case class ArrayPrepend(left: Expression, right: Expression) val leftGen = left.genCode(ctx) val rightGen = right.genCode(ctx) val f = (arr: String, value: String) => { - val newArraySize = ctx.freshName("newArraySize") + val newArraySize = s"$arr.numElements() + 1" val newArray = ctx.freshName("newArray") val i = ctx.freshName("i") val iPlus1 = s"$i+1" @@ -1468,7 +1468,6 @@ case class ArrayPrepend(left: Expression, right: Expression) val newElemAssignment = CodeGenerator.setArrayElement(newArray, elementType, zero, value, Some(rightGen.isNull)) s""" - |int $newArraySize = $arr.numElements() + 1; |$allocation |$newElemAssignment |for (int $i = 0; $i < $arr.numElements(); $i ++) { @@ -1487,17 +1486,19 @@ case class ArrayPrepend(left: Expression, right: Expression) } ev.copy(code = code""" - boolean ${ev.isNull} = true; - ${CodeGenerator.javaType(dataType)} ${ev.value} = ${CodeGenerator.defaultValue(dataType)}; - $nullSafeEval - """) + |boolean ${ev.isNull} = true; + |${CodeGenerator.javaType(dataType)} ${ev.value} = ${CodeGenerator.defaultValue(dataType)}; + |$nullSafeEval + """.stripMargin + ) } else { ev.copy(code = code""" - ${leftGen.code} - ${rightGen.code} - ${CodeGenerator.javaType(dataType)} ${ev.value} = ${CodeGenerator.defaultValue(dataType)}; - $resultCode""", isNull = FalseLiteral) + |${leftGen.code} + |${rightGen.code} + |${CodeGenerator.javaType(dataType)} ${ev.value} = ${CodeGenerator.defaultValue(dataType)}; + |$resultCode + """.stripMargin, isNull = FalseLiteral) } } From 8aa8ae525de7e83dfd8f0e855069f053068282ee Mon Sep 17 00:00:00 2001 From: Navin Viswanath Date: Fri, 27 Jan 2023 20:34:49 -0800 Subject: [PATCH 30/68] Adds a array_prepend expression to catalyst --- .../reference/pyspark.sql/functions.rst | 1 + python/pyspark/sql/functions.py | 27 ++++- .../catalyst/analysis/FunctionRegistry.scala | 12 ++ .../expressions/collectionOperations.scala | 113 ++++++++++++++++++ .../CollectionExpressionsSuite.scala | 41 +++++++ .../org/apache/spark/sql/functions.scala | 10 ++ .../spark/sql/DataFrameFunctionsSuite.scala | 52 ++++++++ 7 files changed, 255 insertions(+), 1 deletion(-) diff --git a/python/docs/source/reference/pyspark.sql/functions.rst b/python/docs/source/reference/pyspark.sql/functions.rst index 70fc04ef9cf23..cbc46e1fae18c 100644 --- a/python/docs/source/reference/pyspark.sql/functions.rst +++ b/python/docs/source/reference/pyspark.sql/functions.rst @@ -159,6 +159,7 @@ Collection Functions array_sort array_insert array_remove + array_prepend array_distinct array_intersect array_union diff --git a/python/pyspark/sql/functions.py b/python/pyspark/sql/functions.py index 051fd52a13c02..2ee7e44c670a4 100644 --- a/python/pyspark/sql/functions.py +++ b/python/pyspark/sql/functions.py @@ -7630,6 +7630,32 @@ def get(col: "ColumnOrName", index: Union["ColumnOrName", int]) -> Column: return _invoke_function_over_columns("get", col, index) +def array_prepend(col: "ColumnOrName", element: Any) -> Column: + """ + Collection function: Returns an array containing value as well as all elements from array. + The new element is positioned at the beginning of the array. + + .. versionadded:: 3.4.0 + + Parameters + ---------- + col : :class:`~pyspark.sql.Column` or str + name of column containing array + element : + element to be prepended to the array + + Returns + ------- + :class:`~pyspark.sql.Column` + an array excluding given value. + + Examples + -------- + >>> df = spark.createDataFrame([([2, 3, 4],), ([],)], ['data']) + >>> df.select(array_prepend(df.data, 1)).collect() + [Row(array_prepend(data, 1)=[1, 2, 3, 4]), Row(array_prepend(data, 1)=[1])] + """ + return _invoke_function("array_prepend", _to_java_column(col), element) @try_remote_functions def array_remove(col: "ColumnOrName", element: Any) -> Column: @@ -7661,7 +7687,6 @@ def array_remove(col: "ColumnOrName", element: Any) -> Column: """ return _invoke_function("array_remove", _to_java_column(col), element) - @try_remote_functions def array_distinct(col: "ColumnOrName") -> Column: """ diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala index ad82a83619931..7ff11b15c6eb1 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala @@ -697,6 +697,7 @@ object FunctionRegistry { expression[Sequence]("sequence"), expression[ArrayRepeat]("array_repeat"), expression[ArrayRemove]("array_remove"), + expression[ArrayPrepend]("array_prepend"), expression[ArrayDistinct]("array_distinct"), expression[ArrayTransform]("transform"), expression[MapFilter]("map_filter"), @@ -970,6 +971,7 @@ object TableFunctionRegistry { : (String, (ExpressionInfo, TableFunctionBuilder)) = { val (info, builder) = FunctionRegistryBase.build[T](name, since = None) val newBuilder = (expressions: Seq[Expression]) => { +<<<<<<< HEAD val generator = builder(expressions) assert(generator.isInstanceOf[Generator]) Generate( @@ -979,6 +981,16 @@ object TableFunctionRegistry { qualifier = None, generatorOutput = Nil, child = OneRowRelation()) +======= + try { + builder(expressions) + } catch { + case e: AnalysisException => + val argTypes = expressions.map(_.dataType.typeName).mkString(", ") + throw QueryCompilationErrors.cannotApplyTableValuedFunctionError( + name, argTypes, info.getUsage, e.getMessage) + } +>>>>>>> Revert "SPARK-41231: Adds an array_prepend function to catalyst" } (name, (info, newBuilder)) } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala index 289859d420bba..c003371c27a11 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala @@ -1399,6 +1399,119 @@ case class ArrayContains(left: Expression, right: Expression) copy(left = newLeft, right = newRight) } +// scalastyle:off line.size.limit +@ExpressionDescription( + usage = + "_FUNC_(array, value) - Returns an array containing value as well as all elements from array. The new element is positioned at the beginning of the array.", + examples = """ + Examples: + > SELECT _FUNC_(array(1, 2, 3), 4); + [4, 1, 2, 3] + """, + group = "array_funcs", + since = "3.4.0") +case class ArrayPrepend(left: Expression, right: Expression) + extends BinaryExpression + with ImplicitCastInputTypes + with NullIntolerant + with QueryErrorsBase { + + override def nullSafeEval(arr: Any, value: Any): Any = { + val numberOfElements = arr.asInstanceOf[ArrayData].numElements() + if (numberOfElements + 1 > ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH) { + throw QueryExecutionErrors.concatArraysWithElementsExceedLimitError(numberOfElements) + } + val newArray = new Array[Any](numberOfElements + 1) + newArray(0) = value + var pos = 1 + arr + .asInstanceOf[ArrayData] + .foreach( + right.dataType, + (i, v) => { + newArray(pos) = v + pos += 1 + }) + new GenericArrayData(newArray) + } + override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { + nullSafeCodeGen( + ctx, + ev, + (arr, value) => { + val newArraySize = ctx.freshName("newArraySize") + val newArray = ctx.freshName("newArray") + val i = ctx.freshName("i") + val pos = ctx.freshName("pos") + val allocation = CodeGenerator.createArrayData( + newArray, + right.dataType, + newArraySize, + s" $prettyName failed.") + val assignment = + CodeGenerator.createArrayAssignment(newArray, right.dataType, arr, pos, i, false) + val newElemAssignment = + CodeGenerator.setArrayElement(newArray, right.dataType, pos, value) + s""" + |int $pos = 0; + |int $newArraySize = $arr.numElements() + 1; + |$allocation + |$newElemAssignment + |$pos = $pos + 1; + |for (int $i = 0; $i < $arr.numElements(); $i ++) { + | $assignment + | $pos = $pos + 1; + |} + |${ev.value} = $newArray; + |""".stripMargin + }) + } + + override def prettyName: String = "array_prepend" + + override protected def withNewChildrenInternal( + newLeft: Expression, newRight: Expression): ArrayPrepend = + copy(left = newLeft, right = newRight) + override def dataType: DataType = left.dataType + override def checkInputDataTypes(): TypeCheckResult = { + (left.dataType, right.dataType) match { + case (_, NullType) | (NullType, _) => + DataTypeMismatch( + errorSubClass = "NULL_TYPE", + messageParameters = Map("functionName" -> toSQLId(prettyName))) + case (l, _) if !ArrayType.acceptsType(l) => + DataTypeMismatch( + errorSubClass = "UNEXPECTED_INPUT_TYPE", + messageParameters = Map( + "paramIndex" -> "1", + "requiredType" -> toSQLType(ArrayType), + "inputSql" -> toSQLExpr(left), + "inputType" -> toSQLType(left.dataType))) + case (ArrayType(e1, _), e2) if e1.sameType(e2) => + TypeUtils.checkForOrderingExpr(e2, prettyName) + case _ => + DataTypeMismatch( + errorSubClass = "ARRAY_FUNCTION_DIFF_TYPES", + messageParameters = Map( + "functionName" -> toSQLId(prettyName), + "dataType" -> toSQLType(ArrayType), + "leftType" -> toSQLType(left.dataType), + "rightType" -> toSQLType(right.dataType))) + } + } + override def inputTypes: Seq[AbstractDataType] = { + (left.dataType, right.dataType) match { + case (_, NullType) => Seq.empty + case (ArrayType(e1, hasNull), e2) => + TypeCoercion.findWiderTypeWithoutStringPromotionForTwo(e1, e2) match { + case Some(dt) => Seq(ArrayType(dt, hasNull), dt) + case _ => Seq.empty + } + case _ => Seq.empty + } + } +} + /** * Checks if the two arrays contain at least one common element. */ diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala index 60300ba62f2f5..63bfc76179f7f 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala @@ -1855,6 +1855,47 @@ class CollectionExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper checkEvaluation(ArrayRepeat(Literal("hi"), Literal(null, IntegerType)), null) } + test("SPARK-41233: ArrayPrepend") { + val a0 = Literal.create(Seq(1, 2, 3, 4), ArrayType(IntegerType)) + val a1 = Literal.create(Seq("a", "b", "c"), ArrayType(StringType)) + val a2 = Literal.create(Seq.empty[Integer], ArrayType(IntegerType)) + val a3 = Literal.create(null, ArrayType(StringType)) + + checkEvaluation(ArrayPrepend(a0, Literal(0)), Seq(0, 1, 2, 3, 4)) + checkEvaluation(ArrayPrepend(a1, Literal("a")), Seq("a", "a", "b", "c")) + checkEvaluation(ArrayPrepend(a2, Literal(1)), Seq(1)) + checkEvaluation(ArrayPrepend(a2, Literal(null, IntegerType)), null) + checkEvaluation(ArrayPrepend(a3, Literal("a")), null) + checkEvaluation(ArrayPrepend(a3, Literal(null, StringType)), null) + + // complex data types + val b0 = Literal.create( + Seq[Array[Byte]]( + Array[Byte](5, 6), + Array[Byte](1, 2), + Array[Byte](1, 2), + Array[Byte](5, 6)), + ArrayType(BinaryType)) + val b1 = Literal.create(Seq[Array[Byte]](Array[Byte](2, 1), null), ArrayType(BinaryType)) + val nullBinary = Literal.create(null, BinaryType) + checkEvaluation(ArrayPrepend(b0, nullBinary), null) + val dataToPrepend1 = Literal.create(Array[Byte](5, 6), BinaryType) + checkEvaluation( + ArrayPrepend(b1, dataToPrepend1), + Seq[Array[Byte]](Array[Byte](5, 6), Array[Byte](2, 1), null)) + + val c0 = Literal.create( + Seq[Seq[Int]](Seq[Int](1, 2), Seq[Int](3, 4)), + ArrayType(ArrayType(IntegerType))) + val dataToPrepend2 = Literal.create(Seq[Int](5, 6), ArrayType(IntegerType)) + checkEvaluation( + ArrayPrepend(c0, dataToPrepend2), + Seq(Seq[Int](5, 6), Seq[Int](1, 2), Seq[Int](3, 4))) + checkEvaluation( + ArrayPrepend(c0, Literal.create(Seq.empty[Int], ArrayType(IntegerType))), + Seq(Seq.empty[Int], Seq[Int](1, 2), Seq[Int](3, 4))) + } + test("Array remove") { val a0 = Literal.create(Seq(1, 2, 3, 2, 2, 5), ArrayType(IntegerType)) val a1 = Literal.create(Seq("b", "a", "a", "c", "b"), ArrayType(StringType)) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala index cb5c1ad5c4954..d2f1df8780b3f 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala @@ -4042,6 +4042,16 @@ object functions { */ def array_compact(column: Column): Column = withExpr { ArrayCompact(column.expr) + + /** + * Returns an array containing value as well as all elements from array. The new element is + * positioned at the beginning of the array. + * + * @group collection_funcs + * @since 3.4.0 + */ + def array_prepend(column: Column, element: Any): Column = withExpr { + ArrayPrepend(column.expr, lit(element).expr) } /** diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala index bd03d29282042..fcff7fb6adf2b 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala @@ -2651,6 +2651,58 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSparkSession { ) } + test("SPARK-41233: array prepend") { + val df = Seq( + (Array[Int](2, 3, 4), Array("b", "c", "d"), Array("", ""), 2), + (Array.empty[Int], Array.empty[String], Array.empty[String], 2), + (null, null, null, 2)).toDF("a", "b", "c", "d") + checkAnswer( + df.select(array_prepend($"a", 1), array_prepend($"b", "a"), array_prepend($"c", "")), + Seq( + Row(Seq(1, 2, 3, 4), Seq("a", "b", "c", "d"), Seq("", "", "")), + Row(Seq(1), Seq("a"), Seq("")), + Row(null, null, null))) + checkAnswer( + df.select(array_prepend($"a", $"d")), + Seq(Row(Seq(2, 2, 3, 4)), Row(Seq(2)), Row(null))) + checkAnswer( + df.selectExpr("array_prepend(a, d)"), + Seq(Row(Seq(2, 2, 3, 4)), Row(Seq(2)), Row(null))) + checkAnswer( + OneRowRelation().selectExpr("array_prepend(array(1, 2), 1.23D)"), + Seq(Row(Seq(1.23, 1.0, 2.0)))) + checkAnswer( + df.selectExpr("array_prepend(a, 1)", "array_prepend(b, \"a\")", "array_prepend(c, \"\")"), + Seq( + Row(Seq(1, 2, 3, 4), Seq("a", "b", "c", "d"), Seq("", "", "")), + Row(Seq(1), Seq("a"), Seq("")), + Row(null, null, null))) + checkError( + exception = intercept[AnalysisException] { + Seq(("a string element", "a")).toDF().selectExpr("array_prepend(_1, _2)") + }, + errorClass = "DATATYPE_MISMATCH.UNEXPECTED_INPUT_TYPE", + parameters = Map( + "paramIndex" -> "1", + "sqlExpr" -> "\"array_prepend(_1, _2)\"", + "inputSql" -> "\"_1\"", + "inputType" -> "\"STRING\"", + "requiredType" -> "\"ARRAY\""), + queryContext = Array(ExpectedContext("", "", 0, 20, "array_prepend(_1, _2)"))) + checkError( + exception = intercept[AnalysisException] { + OneRowRelation().selectExpr("array_prepend(array(1, 2), '1')") + }, + errorClass = "DATATYPE_MISMATCH.ARRAY_FUNCTION_DIFF_TYPES", + parameters = Map( + "sqlExpr" -> "\"array_prepend(array(1, 2), 1)\"", + "functionName" -> "`array_prepend`", + "dataType" -> "\"ARRAY\"", + "leftType" -> "\"ARRAY\"", + "rightType" -> "\"STRING\""), + queryContext = Array(ExpectedContext("", "", 0, 30, "array_prepend(array(1, 2), '1')"))) + } + test("array remove") { val df = Seq( (Array[Int](2, 1, 2, 3), Array("a", "b", "c", "a"), Array("", ""), 2), From 1ba91c7e8755720d79c65a50321dcd735c65942c Mon Sep 17 00:00:00 2001 From: Navin Viswanath Date: Fri, 27 Jan 2023 20:36:47 -0800 Subject: [PATCH 31/68] Fix null handling --- .../expressions/collectionOperations.scala | 122 +++++++++++------- .../CollectionExpressionsSuite.scala | 19 +-- .../org/apache/spark/sql/functions.scala | 15 +++ 3 files changed, 101 insertions(+), 55 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala index c003371c27a11..6443f342e56a2 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala @@ -1413,9 +1413,19 @@ case class ArrayContains(left: Expression, right: Expression) case class ArrayPrepend(left: Expression, right: Expression) extends BinaryExpression with ImplicitCastInputTypes - with NullIntolerant with QueryErrorsBase { + override def nullable: Boolean = left.nullable + + override def eval(input: InternalRow): Any = { + val value1 = left.eval(input) + if (value1 == null) { + null + } else { + val value2 = right.eval(input) + nullSafeEval(value1, value2) + } + } override def nullSafeEval(arr: Any, value: Any): Any = { val numberOfElements = arr.asInstanceOf[ArrayData].numElements() if (numberOfElements + 1 > ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH) { @@ -1435,36 +1445,57 @@ case class ArrayPrepend(left: Expression, right: Expression) new GenericArrayData(newArray) } override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { - nullSafeCodeGen( - ctx, - ev, - (arr, value) => { - val newArraySize = ctx.freshName("newArraySize") - val newArray = ctx.freshName("newArray") - val i = ctx.freshName("i") - val pos = ctx.freshName("pos") - val allocation = CodeGenerator.createArrayData( - newArray, - right.dataType, - newArraySize, - s" $prettyName failed.") - val assignment = - CodeGenerator.createArrayAssignment(newArray, right.dataType, arr, pos, i, false) - val newElemAssignment = - CodeGenerator.setArrayElement(newArray, right.dataType, pos, value) + val leftGen = left.genCode(ctx) + val rightGen = right.genCode(ctx) + val f = (arr: String, value: String) => { + val newArraySize = ctx.freshName("newArraySize") + val newArray = ctx.freshName("newArray") + val i = ctx.freshName("i") + val pos = ctx.freshName("pos") + val allocation = CodeGenerator.createArrayData( + newArray, + right.dataType, + newArraySize, + s" $prettyName failed.") + val assignment = + CodeGenerator.createArrayAssignment(newArray, right.dataType, arr, pos, i, false) + val newElemAssignment = + CodeGenerator.setArrayElement(newArray, right.dataType, pos, value, Some(rightGen.isNull)) + s""" + |int $pos = 0; + |int $newArraySize = $arr.numElements() + 1; + |$allocation + |$newElemAssignment + |$pos = $pos + 1; + |for (int $i = 0; $i < $arr.numElements(); $i ++) { + | $assignment + | $pos = $pos + 1; + |} + |${ev.value} = $newArray; + |""".stripMargin + } + val resultCode = f(leftGen.value, rightGen.value) + if(nullable) { + val nullSafeEval = leftGen.code + rightGen.code + ctx.nullSafeExec(nullable, leftGen.isNull) { s""" - |int $pos = 0; - |int $newArraySize = $arr.numElements() + 1; - |$allocation - |$newElemAssignment - |$pos = $pos + 1; - |for (int $i = 0; $i < $arr.numElements(); $i ++) { - | $assignment - | $pos = $pos + 1; - |} - |${ev.value} = $newArray; + |${ev.isNull} = false; + |${resultCode} |""".stripMargin - }) + } + ev.copy(code = + code""" + boolean ${ev.isNull} = true; + ${CodeGenerator.javaType(dataType)} ${ev.value} = ${CodeGenerator.defaultValue(dataType)}; + $nullSafeEval + """) + } else { + ev.copy(code = + code""" + ${leftGen.code} + ${rightGen.code} + ${CodeGenerator.javaType(dataType)} ${ev.value} = ${CodeGenerator.defaultValue(dataType)}; + $resultCode""", isNull = FalseLiteral) + } } override def prettyName: String = "array_prepend" @@ -1472,31 +1503,30 @@ case class ArrayPrepend(left: Expression, right: Expression) override protected def withNewChildrenInternal( newLeft: Expression, newRight: Expression): ArrayPrepend = copy(left = newLeft, right = newRight) + override def dataType: DataType = left.dataType + override def checkInputDataTypes(): TypeCheckResult = { (left.dataType, right.dataType) match { - case (_, NullType) | (NullType, _) => - DataTypeMismatch( - errorSubClass = "NULL_TYPE", - messageParameters = Map("functionName" -> toSQLId(prettyName))) - case (l, _) if !ArrayType.acceptsType(l) => + case (ArrayType(e1, _), e2) if e1.sameType(e2) => TypeCheckResult.TypeCheckSuccess + case (ArrayType(e1, _), e2) => DataTypeMismatch( + errorSubClass = "ARRAY_FUNCTION_DIFF_TYPES", + messageParameters = Map( + "functionName" -> toSQLId(prettyName), + "leftType" -> toSQLType(left.dataType), + "rightType" -> toSQLType(right.dataType), + "dataType" -> toSQLType(ArrayType) + )) + case _ => DataTypeMismatch( errorSubClass = "UNEXPECTED_INPUT_TYPE", messageParameters = Map( - "paramIndex" -> "1", + "paramIndex" -> "0", "requiredType" -> toSQLType(ArrayType), "inputSql" -> toSQLExpr(left), - "inputType" -> toSQLType(left.dataType))) - case (ArrayType(e1, _), e2) if e1.sameType(e2) => - TypeUtils.checkForOrderingExpr(e2, prettyName) - case _ => - DataTypeMismatch( - errorSubClass = "ARRAY_FUNCTION_DIFF_TYPES", - messageParameters = Map( - "functionName" -> toSQLId(prettyName), - "dataType" -> toSQLType(ArrayType), - "leftType" -> toSQLType(left.dataType), - "rightType" -> toSQLType(right.dataType))) + "inputType" -> toSQLType(left.dataType) + ) + ) } } override def inputTypes: Seq[AbstractDataType] = { diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala index 63bfc76179f7f..1d00ec0cd8d2b 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala @@ -20,10 +20,8 @@ package org.apache.spark.sql.catalyst.expressions import java.sql.{Date, Timestamp} import java.time.{Duration, LocalDateTime, Period} import java.util.TimeZone - import scala.language.implicitConversions import scala.util.Random - import org.apache.spark.{SparkFunSuite, SparkRuntimeException} import org.apache.spark.sql.Row import org.apache.spark.sql.catalyst.InternalRow @@ -1864,21 +1862,24 @@ class CollectionExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper checkEvaluation(ArrayPrepend(a0, Literal(0)), Seq(0, 1, 2, 3, 4)) checkEvaluation(ArrayPrepend(a1, Literal("a")), Seq("a", "a", "b", "c")) checkEvaluation(ArrayPrepend(a2, Literal(1)), Seq(1)) - checkEvaluation(ArrayPrepend(a2, Literal(null, IntegerType)), null) + checkEvaluation(ArrayPrepend(a2, Literal(null, IntegerType)), Seq(null)) checkEvaluation(ArrayPrepend(a3, Literal("a")), null) checkEvaluation(ArrayPrepend(a3, Literal(null, StringType)), null) // complex data types + val data = Seq[Array[Byte]]( + Array[Byte](5, 6), + Array[Byte](1, 2), + Array[Byte](1, 2), + Array[Byte](5, 6)) val b0 = Literal.create( - Seq[Array[Byte]]( - Array[Byte](5, 6), - Array[Byte](1, 2), - Array[Byte](1, 2), - Array[Byte](5, 6)), + data, ArrayType(BinaryType)) val b1 = Literal.create(Seq[Array[Byte]](Array[Byte](2, 1), null), ArrayType(BinaryType)) val nullBinary = Literal.create(null, BinaryType) - checkEvaluation(ArrayPrepend(b0, nullBinary), null) + // Calling ArrayPrepend with a null element should result in NULL being prepended to the array + val dataWithNullPrepended = null +: data + checkEvaluation(ArrayPrepend(b0, nullBinary), dataWithNullPrepended) val dataToPrepend1 = Literal.create(Array[Byte](5, 6), BinaryType) checkEvaluation( ArrayPrepend(b1, dataToPrepend1), diff --git a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala index d2f1df8780b3f..1f66a5daa2d26 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala @@ -4050,6 +4050,21 @@ object functions { * @group collection_funcs * @since 3.4.0 */ + def array_prepend(column: Column, element: Any): Column = withExpr { + ArrayPrepend(column.expr, lit(element).expr) + + /** + * Remove all null elements from the given array. + * + * @group collection_funcs + * @since 3.4.0 + */ + def array_compact(column: Column): Column = withExpr { + ArrayCompact(column.expr) + /** + * Returns an array containing value as well as all elements from array.The + * new element is positioned at the beginning of the array. + */ def array_prepend(column: Column, element: Any): Column = withExpr { ArrayPrepend(column.expr, lit(element).expr) } From 90c0c28bdc08f944c1cbfb5a151ec929decacb21 Mon Sep 17 00:00:00 2001 From: Navin Viswanath Date: Fri, 27 Jan 2023 20:41:01 -0800 Subject: [PATCH 32/68] Fix --- .../scala/org/apache/spark/sql/functions.scala | 17 +---------------- 1 file changed, 1 insertion(+), 16 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala index 1f66a5daa2d26..069e7b79fcb7a 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala @@ -4042,7 +4042,7 @@ object functions { */ def array_compact(column: Column): Column = withExpr { ArrayCompact(column.expr) - + } /** * Returns an array containing value as well as all elements from array. The new element is * positioned at the beginning of the array. @@ -4050,21 +4050,6 @@ object functions { * @group collection_funcs * @since 3.4.0 */ - def array_prepend(column: Column, element: Any): Column = withExpr { - ArrayPrepend(column.expr, lit(element).expr) - - /** - * Remove all null elements from the given array. - * - * @group collection_funcs - * @since 3.4.0 - */ - def array_compact(column: Column): Column = withExpr { - ArrayCompact(column.expr) - /** - * Returns an array containing value as well as all elements from array.The - * new element is positioned at the beginning of the array. - */ def array_prepend(column: Column, element: Any): Column = withExpr { ArrayPrepend(column.expr, lit(element).expr) } From 0a69172a9205a005f7e5ba7cb90dd30e2ea72b53 Mon Sep 17 00:00:00 2001 From: Navin Viswanath Date: Fri, 27 Jan 2023 20:41:01 -0800 Subject: [PATCH 33/68] Fix --- .../sql/catalyst/analysis/FunctionRegistry.scala | 11 ----------- 1 file changed, 11 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala index 7ff11b15c6eb1..aca73741c6396 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala @@ -971,7 +971,6 @@ object TableFunctionRegistry { : (String, (ExpressionInfo, TableFunctionBuilder)) = { val (info, builder) = FunctionRegistryBase.build[T](name, since = None) val newBuilder = (expressions: Seq[Expression]) => { -<<<<<<< HEAD val generator = builder(expressions) assert(generator.isInstanceOf[Generator]) Generate( @@ -981,16 +980,6 @@ object TableFunctionRegistry { qualifier = None, generatorOutput = Nil, child = OneRowRelation()) -======= - try { - builder(expressions) - } catch { - case e: AnalysisException => - val argTypes = expressions.map(_.dataType.typeName).mkString(", ") - throw QueryCompilationErrors.cannotApplyTableValuedFunctionError( - name, argTypes, info.getUsage, e.getMessage) - } ->>>>>>> Revert "SPARK-41231: Adds an array_prepend function to catalyst" } (name, (info, newBuilder)) } From db598804df9b6572940c7fd2637ef42c149d49d3 Mon Sep 17 00:00:00 2001 From: Navin Viswanath Date: Tue, 7 Feb 2023 20:45:20 -0800 Subject: [PATCH 34/68] Lint --- .../expressions/CollectionExpressionsSuite.scala | 1 + .../spark/sql/DataFrameFunctionsSuite.scala | 15 ++++++++++++--- 2 files changed, 13 insertions(+), 3 deletions(-) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala index 1d00ec0cd8d2b..fced26284885f 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala @@ -20,6 +20,7 @@ package org.apache.spark.sql.catalyst.expressions import java.sql.{Date, Timestamp} import java.time.{Duration, LocalDateTime, Period} import java.util.TimeZone + import scala.language.implicitConversions import scala.util.Random import org.apache.spark.{SparkFunSuite, SparkRuntimeException} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala index fcff7fb6adf2b..c238f56123ea0 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala @@ -2664,13 +2664,22 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSparkSession { Row(null, null, null))) checkAnswer( df.select(array_prepend($"a", $"d")), - Seq(Row(Seq(2, 2, 3, 4)), Row(Seq(2)), Row(null))) + Seq( + Row(Seq(2, 2, 3, 4)), + Row(Seq(2)), + Row(null))) checkAnswer( df.selectExpr("array_prepend(a, d)"), - Seq(Row(Seq(2, 2, 3, 4)), Row(Seq(2)), Row(null))) + Seq( + Row(Seq(2, 2, 3, 4)), + Row(Seq(2)), + Row(null))) checkAnswer( OneRowRelation().selectExpr("array_prepend(array(1, 2), 1.23D)"), - Seq(Row(Seq(1.23, 1.0, 2.0)))) + Seq( + Row(Seq(1.23, 1.0, 2.0)) + ) + ) checkAnswer( df.selectExpr("array_prepend(a, 1)", "array_prepend(b, \"a\")", "array_prepend(c, \"\")"), Seq( From f0d9329534d4af65eb8e312ae8a28be4fb435791 Mon Sep 17 00:00:00 2001 From: Navin Viswanath Date: Tue, 7 Feb 2023 20:47:55 -0800 Subject: [PATCH 35/68] Lint --- .../sql/catalyst/expressions/CollectionExpressionsSuite.scala | 1 + 1 file changed, 1 insertion(+) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala index fced26284885f..3abc70a3d5518 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala @@ -23,6 +23,7 @@ import java.util.TimeZone import scala.language.implicitConversions import scala.util.Random + import org.apache.spark.{SparkFunSuite, SparkRuntimeException} import org.apache.spark.sql.Row import org.apache.spark.sql.catalyst.InternalRow From ae5b65e5b7b851836617149d573e6e5bf2c3cc92 Mon Sep 17 00:00:00 2001 From: Navin Viswanath Date: Wed, 8 Feb 2023 20:46:44 -0800 Subject: [PATCH 36/68] Add examples of usage and fix test --- python/pyspark/sql/functions.py | 6 +- .../sql-functions/sql-expression-schema.md | 3 +- .../test/resources/sql-tests/inputs/array.sql | 11 +++ .../sql-tests/results/ansi/array.sql.out | 72 +++++++++++++++++++ 4 files changed, 89 insertions(+), 3 deletions(-) diff --git a/python/pyspark/sql/functions.py b/python/pyspark/sql/functions.py index 2ee7e44c670a4..259a9b7dd6019 100644 --- a/python/pyspark/sql/functions.py +++ b/python/pyspark/sql/functions.py @@ -7630,10 +7630,12 @@ def get(col: "ColumnOrName", index: Union["ColumnOrName", int]) -> Column: return _invoke_function_over_columns("get", col, index) +@try_remote_functions def array_prepend(col: "ColumnOrName", element: Any) -> Column: """ - Collection function: Returns an array containing value as well as all elements from array. - The new element is positioned at the beginning of the array. + Collection function: Returns an array containing element as + well as all elements from array. The new element is positioned + at the beginning of the array. .. versionadded:: 3.4.0 diff --git a/sql/core/src/test/resources/sql-functions/sql-expression-schema.md b/sql/core/src/test/resources/sql-functions/sql-expression-schema.md index 0894d03f9d412..529f4e044bbe8 100644 --- a/sql/core/src/test/resources/sql-functions/sql-expression-schema.md +++ b/sql/core/src/test/resources/sql-functions/sql-expression-schema.md @@ -13,6 +13,7 @@ | org.apache.spark.sql.catalyst.expressions.ArrayAggregate | aggregate | SELECT aggregate(array(1, 2, 3), 0, (acc, x) -> acc + x) | struct | | org.apache.spark.sql.catalyst.expressions.ArrayAggregate | reduce | SELECT reduce(array(1, 2, 3), 0, (acc, x) -> acc + x) | struct | | org.apache.spark.sql.catalyst.expressions.ArrayAppend | array_append | SELECT array_append(array('b', 'd', 'c', 'a'), 'd') | struct> | +| org.apache.spark.sql.catalyst.expressions.ArrayPrepend | array_prepend | SELECT array_prepend(array('b', 'd', 'c', 'a'), 'd') | struct> | | org.apache.spark.sql.catalyst.expressions.ArrayCompact | array_compact | SELECT array_compact(array(1, 2, 3, null)) | struct> | | org.apache.spark.sql.catalyst.expressions.ArrayContains | array_contains | SELECT array_contains(array(1, 2, 3), 2) | struct | | org.apache.spark.sql.catalyst.expressions.ArrayDistinct | array_distinct | SELECT array_distinct(array(1, 2, 3, null, 3)) | struct> | @@ -421,4 +422,4 @@ | org.apache.spark.sql.catalyst.expressions.xml.XPathList | xpath | SELECT xpath('b1b2b3c1c2','a/b/text()') | structb1b2b3c1c2, a/b/text()):array> | | org.apache.spark.sql.catalyst.expressions.xml.XPathLong | xpath_long | SELECT xpath_long('12', 'sum(a/b)') | struct12, sum(a/b)):bigint> | | org.apache.spark.sql.catalyst.expressions.xml.XPathShort | xpath_short | SELECT xpath_short('12', 'sum(a/b)') | struct12, sum(a/b)):smallint> | -| org.apache.spark.sql.catalyst.expressions.xml.XPathString | xpath_string | SELECT xpath_string('bcc','a/c') | structbcc, a/c):string> | \ No newline at end of file +| org.apache.spark.sql.catalyst.expressions.xml.XPathString | xpath_string | SELECT xpath_string('bcc','a/c') | structbcc, a/c):string> | diff --git a/sql/core/src/test/resources/sql-tests/inputs/array.sql b/sql/core/src/test/resources/sql-tests/inputs/array.sql index 3d107cb6dfc07..d3c36b79d1f3a 100644 --- a/sql/core/src/test/resources/sql-tests/inputs/array.sql +++ b/sql/core/src/test/resources/sql-tests/inputs/array.sql @@ -160,3 +160,14 @@ select array_append(CAST(null AS ARRAY), CAST(null as String)); select array_append(array(), 1); select array_append(CAST(array() AS ARRAY), CAST(NULL AS String)); select array_append(array(CAST(NULL AS String)), CAST(NULL AS String)); + +-- function array_prepend +select array_prepend(array(1, 2, 3), 4); +select array_prepend(array('a', 'b', 'c'), 'd'); +select array_prepend(array(1, 2, 3, NULL), NULL); +select array_prepend(array('a', 'b', 'c', NULL), NULL); +select array_prepend(CAST(null AS ARRAY), 'a'); +select array_prepend(CAST(null AS ARRAY), CAST(null as String)); +select array_prepend(array(), 1); +select array_prepend(CAST(array() AS ARRAY), CAST(NULL AS String)); +select array_prepend(array(CAST(NULL AS String)), CAST(NULL AS String)); diff --git a/sql/core/src/test/resources/sql-tests/results/ansi/array.sql.out b/sql/core/src/test/resources/sql-tests/results/ansi/array.sql.out index 0d8ef39ed60c6..d228c605705d6 100644 --- a/sql/core/src/test/resources/sql-tests/results/ansi/array.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/ansi/array.sql.out @@ -784,3 +784,75 @@ select array_append(array(CAST(NULL AS String)), CAST(NULL AS String)) struct> -- !query output [null,null] + + +-- !query +select array_prepend(array(1, 2, 3), 4) +-- !query schema +struct> +-- !query output +[4,1,2,3] + + +-- !query +select array_prepend(array('a', 'b', 'c'), 'd') +-- !query schema +struct> +-- !query output +["d","a","b","c"] + + +-- !query +select array_prepend(array(1, 2, 3, NULL), NULL) +-- !query schema +struct> +-- !query output +[null,1,2,3,null] + + +-- !query +select array_prepend(array('a', 'b', 'c', NULL), NULL) +-- !query schema +struct> +-- !query output +[null,"a","b","c",null] + + +-- !query +select array_prepend(CAST(null AS ARRAY), 'a') +-- !query schema +struct> +-- !query output +NULL + + +-- !query +select array_prepend(CAST(null AS ARRAY), CAST(null as String)) +-- !query schema +struct> +-- !query output +NULL + + +-- !query +select array_prepend(array(), 1) +-- !query schema +struct> +-- !query output +[1] + + +-- !query +select array_prepend(CAST(array() AS ARRAY), CAST(NULL AS String)) +-- !query schema +struct> +-- !query output +[null] + + +-- !query +select array_prepend(array(CAST(NULL AS String)), CAST(NULL AS String)) +-- !query schema +struct> +-- !query output +[null,null] From af3ee0abbab06cb03cc0e23d41cb04722b738a1b Mon Sep 17 00:00:00 2001 From: Navin Viswanath Date: Thu, 9 Feb 2023 20:44:09 -0800 Subject: [PATCH 37/68] Fix tests --- python/pyspark/sql/functions.py | 9 +-- .../expressions/collectionOperations.scala | 3 +- .../sql-functions/sql-expression-schema.md | 2 +- .../resources/sql-tests/results/array.sql.out | 72 +++++++++++++++++++ 4 files changed, 80 insertions(+), 6 deletions(-) diff --git a/python/pyspark/sql/functions.py b/python/pyspark/sql/functions.py index 259a9b7dd6019..80b806a02bdcc 100644 --- a/python/pyspark/sql/functions.py +++ b/python/pyspark/sql/functions.py @@ -7633,10 +7633,10 @@ def get(col: "ColumnOrName", index: Union["ColumnOrName", int]) -> Column: @try_remote_functions def array_prepend(col: "ColumnOrName", element: Any) -> Column: """ - Collection function: Returns an array containing element as - well as all elements from array. The new element is positioned + Collection function: Returns an array containing element as + well as all elements from array. The new element is positioned at the beginning of the array. - + .. versionadded:: 3.4.0 Parameters @@ -7648,6 +7648,7 @@ def array_prepend(col: "ColumnOrName", element: Any) -> Column: Returns ------- + :class:`~pyspark.sql.Column` an array excluding given value. @@ -7656,7 +7657,7 @@ def array_prepend(col: "ColumnOrName", element: Any) -> Column: >>> df = spark.createDataFrame([([2, 3, 4],), ([],)], ['data']) >>> df.select(array_prepend(df.data, 1)).collect() [Row(array_prepend(data, 1)=[1, 2, 3, 4]), Row(array_prepend(data, 1)=[1])] - """ + """ return _invoke_function("array_prepend", _to_java_column(col), element) @try_remote_functions diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala index 6443f342e56a2..737608790ef20 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala @@ -1413,6 +1413,7 @@ case class ArrayContains(left: Expression, right: Expression) case class ArrayPrepend(left: Expression, right: Expression) extends BinaryExpression with ImplicitCastInputTypes + with ComplexTypeMergingExpression with QueryErrorsBase { override def nullable: Boolean = left.nullable @@ -1533,7 +1534,7 @@ case class ArrayPrepend(left: Expression, right: Expression) (left.dataType, right.dataType) match { case (_, NullType) => Seq.empty case (ArrayType(e1, hasNull), e2) => - TypeCoercion.findWiderTypeWithoutStringPromotionForTwo(e1, e2) match { + TypeCoercion.findTightestCommonType(e1, e2) match { case Some(dt) => Seq(ArrayType(dt, hasNull), dt) case _ => Seq.empty } diff --git a/sql/core/src/test/resources/sql-functions/sql-expression-schema.md b/sql/core/src/test/resources/sql-functions/sql-expression-schema.md index 529f4e044bbe8..6b5b67f984916 100644 --- a/sql/core/src/test/resources/sql-functions/sql-expression-schema.md +++ b/sql/core/src/test/resources/sql-functions/sql-expression-schema.md @@ -13,7 +13,6 @@ | org.apache.spark.sql.catalyst.expressions.ArrayAggregate | aggregate | SELECT aggregate(array(1, 2, 3), 0, (acc, x) -> acc + x) | struct | | org.apache.spark.sql.catalyst.expressions.ArrayAggregate | reduce | SELECT reduce(array(1, 2, 3), 0, (acc, x) -> acc + x) | struct | | org.apache.spark.sql.catalyst.expressions.ArrayAppend | array_append | SELECT array_append(array('b', 'd', 'c', 'a'), 'd') | struct> | -| org.apache.spark.sql.catalyst.expressions.ArrayPrepend | array_prepend | SELECT array_prepend(array('b', 'd', 'c', 'a'), 'd') | struct> | | org.apache.spark.sql.catalyst.expressions.ArrayCompact | array_compact | SELECT array_compact(array(1, 2, 3, null)) | struct> | | org.apache.spark.sql.catalyst.expressions.ArrayContains | array_contains | SELECT array_contains(array(1, 2, 3), 2) | struct | | org.apache.spark.sql.catalyst.expressions.ArrayDistinct | array_distinct | SELECT array_distinct(array(1, 2, 3, null, 3)) | struct> | @@ -27,6 +26,7 @@ | org.apache.spark.sql.catalyst.expressions.ArrayMax | array_max | SELECT array_max(array(1, 20, null, 3)) | struct | | org.apache.spark.sql.catalyst.expressions.ArrayMin | array_min | SELECT array_min(array(1, 20, null, 3)) | struct | | org.apache.spark.sql.catalyst.expressions.ArrayPosition | array_position | SELECT array_position(array(3, 2, 1), 1) | struct | +| org.apache.spark.sql.catalyst.expressions.ArrayPrepend | array_prepend | SELECT array_prepend(array('b', 'd', 'c', 'a'), 'd') | struct> | | org.apache.spark.sql.catalyst.expressions.ArrayRemove | array_remove | SELECT array_remove(array(1, 2, 3, null, 3), 3) | struct> | | org.apache.spark.sql.catalyst.expressions.ArrayRepeat | array_repeat | SELECT array_repeat('123', 2) | struct> | | org.apache.spark.sql.catalyst.expressions.ArraySize | array_size | SELECT array_size(array('b', 'd', 'c', 'a')) | struct | diff --git a/sql/core/src/test/resources/sql-tests/results/array.sql.out b/sql/core/src/test/resources/sql-tests/results/array.sql.out index 609122a23d316..029bd767f54c4 100644 --- a/sql/core/src/test/resources/sql-tests/results/array.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/array.sql.out @@ -665,3 +665,75 @@ select array_append(array(CAST(NULL AS String)), CAST(NULL AS String)) struct> -- !query output [null,null] + + +-- !query +select array_prepend(array(1, 2, 3), 4) +-- !query schema +struct> +-- !query output +[4,1,2,3] + + +-- !query +select array_prepend(array('a', 'b', 'c'), 'd') +-- !query schema +struct> +-- !query output +["d","a","b","c"] + + +-- !query +select array_prepend(array(1, 2, 3, NULL), NULL) +-- !query schema +struct> +-- !query output +[null,1,2,3,null] + + +-- !query +select array_prepend(array('a', 'b', 'c', NULL), NULL) +-- !query schema +struct> +-- !query output +[null,"a","b","c",null] + + +-- !query +select array_prepend(CAST(null AS ARRAY), 'a') +-- !query schema +struct> +-- !query output +NULL + + +-- !query +select array_prepend(CAST(null AS ARRAY), CAST(null as String)) +-- !query schema +struct> +-- !query output +NULL + + +-- !query +select array_prepend(array(), 1) +-- !query schema +struct> +-- !query output +[1] + + +-- !query +select array_prepend(CAST(array() AS ARRAY), CAST(NULL AS String)) +-- !query schema +struct> +-- !query output +[null] + + +-- !query +select array_prepend(array(CAST(NULL AS String)), CAST(NULL AS String)) +-- !query schema +struct> +-- !query output +[null,null] From 3265717d555f55f234f398c43f541f17e9442043 Mon Sep 17 00:00:00 2001 From: Navin Viswanath Date: Fri, 10 Feb 2023 22:03:34 -0800 Subject: [PATCH 38/68] Fix types --- python/pyspark/sql/functions.py | 2 +- .../spark/sql/catalyst/expressions/collectionOperations.scala | 1 - 2 files changed, 1 insertion(+), 2 deletions(-) diff --git a/python/pyspark/sql/functions.py b/python/pyspark/sql/functions.py index 80b806a02bdcc..cfd0f8674378e 100644 --- a/python/pyspark/sql/functions.py +++ b/python/pyspark/sql/functions.py @@ -7636,7 +7636,7 @@ def array_prepend(col: "ColumnOrName", element: Any) -> Column: Collection function: Returns an array containing element as well as all elements from array. The new element is positioned at the beginning of the array. - + .. versionadded:: 3.4.0 Parameters diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala index 737608790ef20..1052f8053050b 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala @@ -1532,7 +1532,6 @@ case class ArrayPrepend(left: Expression, right: Expression) } override def inputTypes: Seq[AbstractDataType] = { (left.dataType, right.dataType) match { - case (_, NullType) => Seq.empty case (ArrayType(e1, hasNull), e2) => TypeCoercion.findTightestCommonType(e1, e2) match { case Some(dt) => Seq(ArrayType(dt, hasNull), dt) From 7df63ea1df13f7f33f23fbc182cd5cced95d49e4 Mon Sep 17 00:00:00 2001 From: Navin Viswanath Date: Sun, 26 Feb 2023 16:46:34 -0800 Subject: [PATCH 39/68] Fix tests --- python/pyspark/sql/functions.py | 1 - .../spark/sql/catalyst/expressions/collectionOperations.scala | 4 ++-- .../scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala | 2 +- 3 files changed, 3 insertions(+), 4 deletions(-) diff --git a/python/pyspark/sql/functions.py b/python/pyspark/sql/functions.py index cfd0f8674378e..554037eb0dff3 100644 --- a/python/pyspark/sql/functions.py +++ b/python/pyspark/sql/functions.py @@ -7648,7 +7648,6 @@ def array_prepend(col: "ColumnOrName", element: Any) -> Column: Returns ------- - :class:`~pyspark.sql.Column` an array excluding given value. diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala index 1052f8053050b..1c6e17a1bdc0c 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala @@ -1405,8 +1405,8 @@ case class ArrayContains(left: Expression, right: Expression) "_FUNC_(array, value) - Returns an array containing value as well as all elements from array. The new element is positioned at the beginning of the array.", examples = """ Examples: - > SELECT _FUNC_(array(1, 2, 3), 4); - [4, 1, 2, 3] + > SELECT _FUNC_(array('b', 'd', 'c', 'a'), 'd'); + ["d","b","d","c","a"] """, group = "array_funcs", since = "3.4.0") diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala index c238f56123ea0..a8929d5c8bdcf 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala @@ -2692,7 +2692,7 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSparkSession { }, errorClass = "DATATYPE_MISMATCH.UNEXPECTED_INPUT_TYPE", parameters = Map( - "paramIndex" -> "1", + "paramIndex" -> "0", "sqlExpr" -> "\"array_prepend(_1, _2)\"", "inputSql" -> "\"_1\"", "inputType" -> "\"STRING\"", From b4fbbd509102ca847e2c3cc0310481d898fea4d7 Mon Sep 17 00:00:00 2001 From: Navin Viswanath Date: Sun, 26 Feb 2023 19:18:02 -0800 Subject: [PATCH 40/68] Fix python linter --- python/pyspark/sql/functions.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/python/pyspark/sql/functions.py b/python/pyspark/sql/functions.py index 554037eb0dff3..6608f8d317e61 100644 --- a/python/pyspark/sql/functions.py +++ b/python/pyspark/sql/functions.py @@ -7630,6 +7630,7 @@ def get(col: "ColumnOrName", index: Union["ColumnOrName", int]) -> Column: return _invoke_function_over_columns("get", col, index) + @try_remote_functions def array_prepend(col: "ColumnOrName", element: Any) -> Column: """ @@ -7656,9 +7657,10 @@ def array_prepend(col: "ColumnOrName", element: Any) -> Column: >>> df = spark.createDataFrame([([2, 3, 4],), ([],)], ['data']) >>> df.select(array_prepend(df.data, 1)).collect() [Row(array_prepend(data, 1)=[1, 2, 3, 4]), Row(array_prepend(data, 1)=[1])] - """ + """ return _invoke_function("array_prepend", _to_java_column(col), element) + @try_remote_functions def array_remove(col: "ColumnOrName", element: Any) -> Column: """ @@ -7689,6 +7691,7 @@ def array_remove(col: "ColumnOrName", element: Any) -> Column: """ return _invoke_function("array_remove", _to_java_column(col), element) + @try_remote_functions def array_distinct(col: "ColumnOrName") -> Column: """ From 413af39a50279f8ca3492077cba9dc8666796d0a Mon Sep 17 00:00:00 2001 From: Navin Viswanath Date: Sun, 26 Feb 2023 20:30:18 -0800 Subject: [PATCH 41/68] Add test for null cases --- .../expressions/collectionOperations.scala | 28 ++++++++----------- .../spark/sql/DataFrameFunctionsSuite.scala | 7 +++++ 2 files changed, 19 insertions(+), 16 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala index 1c6e17a1bdc0c..b52eaf3cff619 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala @@ -1418,6 +1418,9 @@ case class ArrayPrepend(left: Expression, right: Expression) override def nullable: Boolean = left.nullable + @transient protected lazy val elementType: DataType = + inputTypes.head.asInstanceOf[ArrayType].elementType + override def eval(input: InternalRow): Any = { val value1 = left.eval(input) if (value1 == null) { @@ -1427,23 +1430,16 @@ case class ArrayPrepend(left: Expression, right: Expression) nullSafeEval(value1, value2) } } - override def nullSafeEval(arr: Any, value: Any): Any = { - val numberOfElements = arr.asInstanceOf[ArrayData].numElements() - if (numberOfElements + 1 > ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH) { + override def nullSafeEval(arr: Any, elementData: Any): Any = { + val arrayData = arr.asInstanceOf[ArrayData] + val numberOfElements = arrayData.numElements() + 1 + if (numberOfElements > ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH) { throw QueryExecutionErrors.concatArraysWithElementsExceedLimitError(numberOfElements) } - val newArray = new Array[Any](numberOfElements + 1) - newArray(0) = value - var pos = 1 - arr - .asInstanceOf[ArrayData] - .foreach( - right.dataType, - (i, v) => { - newArray(pos) = v - pos += 1 - }) - new GenericArrayData(newArray) + val finalData = new Array[Any](numberOfElements) + finalData.update(0, elementData) + arrayData.foreach(elementType, (i: Int, v: Any) => finalData.update(i + 1, v)) + new GenericArrayData(finalData) } override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { val leftGen = left.genCode(ctx) @@ -1505,7 +1501,7 @@ case class ArrayPrepend(left: Expression, right: Expression) newLeft: Expression, newRight: Expression): ArrayPrepend = copy(left = newLeft, right = newRight) - override def dataType: DataType = left.dataType + override def dataType: DataType = if (right.nullable) left.dataType.asNullable else left.dataType override def checkInputDataTypes(): TypeCheckResult = { (left.dataType, right.dataType) match { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala index a8929d5c8bdcf..355f2dfffb57f 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala @@ -2710,6 +2710,13 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSparkSession { "leftType" -> "\"ARRAY\"", "rightType" -> "\"STRING\""), queryContext = Array(ExpectedContext("", "", 0, 30, "array_prepend(array(1, 2), '1')"))) + val df2 = Seq((Array[String]("a", "b", "c"), "d"), + (null, "d"), + (Array[String]("x", "y", "z"), null), + (null, null) + ).toDF("a", "b") + checkAnswer(df2.selectExpr("array_prepend(a, b)"), + Seq(Row(Seq("d", "a", "b", "c")), Row(null), Row(Seq(null, "x", "y", "z")), Row(null))) } test("array remove") { From 9992e33836e6e9df26d24f86a88366c5a8a0e037 Mon Sep 17 00:00:00 2001 From: Navin Viswanath Date: Sun, 26 Feb 2023 20:35:43 -0800 Subject: [PATCH 42/68] Fix type of array --- .../sql/catalyst/expressions/collectionOperations.scala | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala index b52eaf3cff619..f35c1c15243e0 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala @@ -1451,13 +1451,13 @@ case class ArrayPrepend(left: Expression, right: Expression) val pos = ctx.freshName("pos") val allocation = CodeGenerator.createArrayData( newArray, - right.dataType, + elementType, newArraySize, s" $prettyName failed.") val assignment = - CodeGenerator.createArrayAssignment(newArray, right.dataType, arr, pos, i, false) + CodeGenerator.createArrayAssignment(newArray, elementType, arr, pos, i, false) val newElemAssignment = - CodeGenerator.setArrayElement(newArray, right.dataType, pos, value, Some(rightGen.isNull)) + CodeGenerator.setArrayElement(newArray, elementType, pos, value, Some(rightGen.isNull)) s""" |int $pos = 0; |int $newArraySize = $arr.numElements() + 1; From 684a7d96f658f3905ef5194b0e7aaa1639de0538 Mon Sep 17 00:00:00 2001 From: Navin Viswanath Date: Fri, 27 Jan 2023 20:34:49 -0800 Subject: [PATCH 43/68] Adds a array_prepend expression to catalyst --- python/pyspark/sql/functions.py | 27 ++++++++++++++++++- .../catalyst/analysis/FunctionRegistry.scala | 11 ++++++++ .../org/apache/spark/sql/functions.scala | 10 +++++++ 3 files changed, 47 insertions(+), 1 deletion(-) diff --git a/python/pyspark/sql/functions.py b/python/pyspark/sql/functions.py index 6608f8d317e61..22215b5958a98 100644 --- a/python/pyspark/sql/functions.py +++ b/python/pyspark/sql/functions.py @@ -7630,6 +7630,32 @@ def get(col: "ColumnOrName", index: Union["ColumnOrName", int]) -> Column: return _invoke_function_over_columns("get", col, index) +def array_prepend(col: "ColumnOrName", element: Any) -> Column: + """ + Collection function: Returns an array containing value as well as all elements from array. + The new element is positioned at the beginning of the array. + + .. versionadded:: 3.4.0 + + Parameters + ---------- + col : :class:`~pyspark.sql.Column` or str + name of column containing array + element : + element to be prepended to the array + + Returns + ------- + :class:`~pyspark.sql.Column` + an array excluding given value. + + Examples + -------- + >>> df = spark.createDataFrame([([2, 3, 4],), ([],)], ['data']) + >>> df.select(array_prepend(df.data, 1)).collect() + [Row(array_prepend(data, 1)=[1, 2, 3, 4]), Row(array_prepend(data, 1)=[1])] + """ + return _invoke_function("array_prepend", _to_java_column(col), element) @try_remote_functions def array_prepend(col: "ColumnOrName", element: Any) -> Column: @@ -7691,7 +7717,6 @@ def array_remove(col: "ColumnOrName", element: Any) -> Column: """ return _invoke_function("array_remove", _to_java_column(col), element) - @try_remote_functions def array_distinct(col: "ColumnOrName") -> Column: """ diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala index aca73741c6396..7ff11b15c6eb1 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala @@ -971,6 +971,7 @@ object TableFunctionRegistry { : (String, (ExpressionInfo, TableFunctionBuilder)) = { val (info, builder) = FunctionRegistryBase.build[T](name, since = None) val newBuilder = (expressions: Seq[Expression]) => { +<<<<<<< HEAD val generator = builder(expressions) assert(generator.isInstanceOf[Generator]) Generate( @@ -980,6 +981,16 @@ object TableFunctionRegistry { qualifier = None, generatorOutput = Nil, child = OneRowRelation()) +======= + try { + builder(expressions) + } catch { + case e: AnalysisException => + val argTypes = expressions.map(_.dataType.typeName).mkString(", ") + throw QueryCompilationErrors.cannotApplyTableValuedFunctionError( + name, argTypes, info.getUsage, e.getMessage) + } +>>>>>>> Revert "SPARK-41231: Adds an array_prepend function to catalyst" } (name, (info, newBuilder)) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala index 069e7b79fcb7a..b0538af2fd6c0 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala @@ -4042,6 +4042,16 @@ object functions { */ def array_compact(column: Column): Column = withExpr { ArrayCompact(column.expr) + + /** + * Returns an array containing value as well as all elements from array. The new element is + * positioned at the beginning of the array. + * + * @group collection_funcs + * @since 3.4.0 + */ + def array_prepend(column: Column, element: Any): Column = withExpr { + ArrayPrepend(column.expr, lit(element).expr) } /** * Returns an array containing value as well as all elements from array. The new element is From 93f181917f9b300b8eb38f11328335a2d7b70a56 Mon Sep 17 00:00:00 2001 From: Navin Viswanath Date: Fri, 27 Jan 2023 20:36:47 -0800 Subject: [PATCH 44/68] Fix null handling --- .../expressions/CollectionExpressionsSuite.scala | 2 -- .../scala/org/apache/spark/sql/functions.scala | 15 +++++++++++++++ 2 files changed, 15 insertions(+), 2 deletions(-) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala index 3abc70a3d5518..1d00ec0cd8d2b 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala @@ -20,10 +20,8 @@ package org.apache.spark.sql.catalyst.expressions import java.sql.{Date, Timestamp} import java.time.{Duration, LocalDateTime, Period} import java.util.TimeZone - import scala.language.implicitConversions import scala.util.Random - import org.apache.spark.{SparkFunSuite, SparkRuntimeException} import org.apache.spark.sql.Row import org.apache.spark.sql.catalyst.InternalRow diff --git a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala index b0538af2fd6c0..f99fbbe8eefaf 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala @@ -4050,6 +4050,21 @@ object functions { * @group collection_funcs * @since 3.4.0 */ + def array_prepend(column: Column, element: Any): Column = withExpr { + ArrayPrepend(column.expr, lit(element).expr) + + /** + * Remove all null elements from the given array. + * + * @group collection_funcs + * @since 3.4.0 + */ + def array_compact(column: Column): Column = withExpr { + ArrayCompact(column.expr) + /** + * Returns an array containing value as well as all elements from array.The + * new element is positioned at the beginning of the array. + */ def array_prepend(column: Column, element: Any): Column = withExpr { ArrayPrepend(column.expr, lit(element).expr) } From a8db7b3823778430d94a614861f7a4dbbe04ae7f Mon Sep 17 00:00:00 2001 From: Navin Viswanath Date: Fri, 27 Jan 2023 20:41:01 -0800 Subject: [PATCH 45/68] Fix --- .../scala/org/apache/spark/sql/functions.scala | 17 +---------------- 1 file changed, 1 insertion(+), 16 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala index f99fbbe8eefaf..9c102767c4224 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala @@ -4042,7 +4042,7 @@ object functions { */ def array_compact(column: Column): Column = withExpr { ArrayCompact(column.expr) - + } /** * Returns an array containing value as well as all elements from array. The new element is * positioned at the beginning of the array. @@ -4050,21 +4050,6 @@ object functions { * @group collection_funcs * @since 3.4.0 */ - def array_prepend(column: Column, element: Any): Column = withExpr { - ArrayPrepend(column.expr, lit(element).expr) - - /** - * Remove all null elements from the given array. - * - * @group collection_funcs - * @since 3.4.0 - */ - def array_compact(column: Column): Column = withExpr { - ArrayCompact(column.expr) - /** - * Returns an array containing value as well as all elements from array.The - * new element is positioned at the beginning of the array. - */ def array_prepend(column: Column, element: Any): Column = withExpr { ArrayPrepend(column.expr, lit(element).expr) } From 82186b9a4fd60f2886d630d951aeced9b6de36d3 Mon Sep 17 00:00:00 2001 From: Navin Viswanath Date: Fri, 27 Jan 2023 20:41:01 -0800 Subject: [PATCH 46/68] Fix --- .../sql/catalyst/analysis/FunctionRegistry.scala | 11 ----------- 1 file changed, 11 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala index 7ff11b15c6eb1..aca73741c6396 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala @@ -971,7 +971,6 @@ object TableFunctionRegistry { : (String, (ExpressionInfo, TableFunctionBuilder)) = { val (info, builder) = FunctionRegistryBase.build[T](name, since = None) val newBuilder = (expressions: Seq[Expression]) => { -<<<<<<< HEAD val generator = builder(expressions) assert(generator.isInstanceOf[Generator]) Generate( @@ -981,16 +980,6 @@ object TableFunctionRegistry { qualifier = None, generatorOutput = Nil, child = OneRowRelation()) -======= - try { - builder(expressions) - } catch { - case e: AnalysisException => - val argTypes = expressions.map(_.dataType.typeName).mkString(", ") - throw QueryCompilationErrors.cannotApplyTableValuedFunctionError( - name, argTypes, info.getUsage, e.getMessage) - } ->>>>>>> Revert "SPARK-41231: Adds an array_prepend function to catalyst" } (name, (info, newBuilder)) } From 30988b7ad179cdb6f811393f318add51c72d8f49 Mon Sep 17 00:00:00 2001 From: Navin Viswanath Date: Tue, 7 Feb 2023 20:45:20 -0800 Subject: [PATCH 47/68] Lint --- .../sql/catalyst/expressions/CollectionExpressionsSuite.scala | 1 + 1 file changed, 1 insertion(+) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala index 1d00ec0cd8d2b..fced26284885f 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala @@ -20,6 +20,7 @@ package org.apache.spark.sql.catalyst.expressions import java.sql.{Date, Timestamp} import java.time.{Duration, LocalDateTime, Period} import java.util.TimeZone + import scala.language.implicitConversions import scala.util.Random import org.apache.spark.{SparkFunSuite, SparkRuntimeException} From 09a61cad634bb1874a4b3ae761e6922e5cd44be7 Mon Sep 17 00:00:00 2001 From: Navin Viswanath Date: Tue, 7 Feb 2023 20:47:55 -0800 Subject: [PATCH 48/68] Lint --- .../sql/catalyst/expressions/CollectionExpressionsSuite.scala | 1 + 1 file changed, 1 insertion(+) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala index fced26284885f..3abc70a3d5518 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala @@ -23,6 +23,7 @@ import java.util.TimeZone import scala.language.implicitConversions import scala.util.Random + import org.apache.spark.{SparkFunSuite, SparkRuntimeException} import org.apache.spark.sql.Row import org.apache.spark.sql.catalyst.InternalRow From d188279dd096b3e2927da2c3de4c02445fd07760 Mon Sep 17 00:00:00 2001 From: Navin Viswanath Date: Wed, 8 Feb 2023 20:46:44 -0800 Subject: [PATCH 49/68] Add examples of usage and fix test --- python/pyspark/sql/functions.py | 6 ++++-- .../test/resources/sql-functions/sql-expression-schema.md | 1 + 2 files changed, 5 insertions(+), 2 deletions(-) diff --git a/python/pyspark/sql/functions.py b/python/pyspark/sql/functions.py index 22215b5958a98..b8b09182e9072 100644 --- a/python/pyspark/sql/functions.py +++ b/python/pyspark/sql/functions.py @@ -7630,10 +7630,12 @@ def get(col: "ColumnOrName", index: Union["ColumnOrName", int]) -> Column: return _invoke_function_over_columns("get", col, index) +@try_remote_functions def array_prepend(col: "ColumnOrName", element: Any) -> Column: """ - Collection function: Returns an array containing value as well as all elements from array. - The new element is positioned at the beginning of the array. + Collection function: Returns an array containing element as + well as all elements from array. The new element is positioned + at the beginning of the array. .. versionadded:: 3.4.0 diff --git a/sql/core/src/test/resources/sql-functions/sql-expression-schema.md b/sql/core/src/test/resources/sql-functions/sql-expression-schema.md index 6b5b67f984916..0cbb896fe03de 100644 --- a/sql/core/src/test/resources/sql-functions/sql-expression-schema.md +++ b/sql/core/src/test/resources/sql-functions/sql-expression-schema.md @@ -13,6 +13,7 @@ | org.apache.spark.sql.catalyst.expressions.ArrayAggregate | aggregate | SELECT aggregate(array(1, 2, 3), 0, (acc, x) -> acc + x) | struct | | org.apache.spark.sql.catalyst.expressions.ArrayAggregate | reduce | SELECT reduce(array(1, 2, 3), 0, (acc, x) -> acc + x) | struct | | org.apache.spark.sql.catalyst.expressions.ArrayAppend | array_append | SELECT array_append(array('b', 'd', 'c', 'a'), 'd') | struct> | +| org.apache.spark.sql.catalyst.expressions.ArrayPrepend | array_prepend | SELECT array_prepend(array('b', 'd', 'c', 'a'), 'd') | struct> | | org.apache.spark.sql.catalyst.expressions.ArrayCompact | array_compact | SELECT array_compact(array(1, 2, 3, null)) | struct> | | org.apache.spark.sql.catalyst.expressions.ArrayContains | array_contains | SELECT array_contains(array(1, 2, 3), 2) | struct | | org.apache.spark.sql.catalyst.expressions.ArrayDistinct | array_distinct | SELECT array_distinct(array(1, 2, 3, null, 3)) | struct> | From 15b713d81b5578ae6d77b816ace8900a3decc9cb Mon Sep 17 00:00:00 2001 From: Navin Viswanath Date: Thu, 9 Feb 2023 20:44:09 -0800 Subject: [PATCH 50/68] Fix tests --- python/pyspark/sql/functions.py | 9 +++++---- .../resources/sql-functions/sql-expression-schema.md | 1 - 2 files changed, 5 insertions(+), 5 deletions(-) diff --git a/python/pyspark/sql/functions.py b/python/pyspark/sql/functions.py index b8b09182e9072..f57b9ee58c798 100644 --- a/python/pyspark/sql/functions.py +++ b/python/pyspark/sql/functions.py @@ -7633,10 +7633,10 @@ def get(col: "ColumnOrName", index: Union["ColumnOrName", int]) -> Column: @try_remote_functions def array_prepend(col: "ColumnOrName", element: Any) -> Column: """ - Collection function: Returns an array containing element as - well as all elements from array. The new element is positioned + Collection function: Returns an array containing element as + well as all elements from array. The new element is positioned at the beginning of the array. - + .. versionadded:: 3.4.0 Parameters @@ -7648,6 +7648,7 @@ def array_prepend(col: "ColumnOrName", element: Any) -> Column: Returns ------- + :class:`~pyspark.sql.Column` an array excluding given value. @@ -7656,7 +7657,7 @@ def array_prepend(col: "ColumnOrName", element: Any) -> Column: >>> df = spark.createDataFrame([([2, 3, 4],), ([],)], ['data']) >>> df.select(array_prepend(df.data, 1)).collect() [Row(array_prepend(data, 1)=[1, 2, 3, 4]), Row(array_prepend(data, 1)=[1])] - """ + """ return _invoke_function("array_prepend", _to_java_column(col), element) @try_remote_functions diff --git a/sql/core/src/test/resources/sql-functions/sql-expression-schema.md b/sql/core/src/test/resources/sql-functions/sql-expression-schema.md index 0cbb896fe03de..6b5b67f984916 100644 --- a/sql/core/src/test/resources/sql-functions/sql-expression-schema.md +++ b/sql/core/src/test/resources/sql-functions/sql-expression-schema.md @@ -13,7 +13,6 @@ | org.apache.spark.sql.catalyst.expressions.ArrayAggregate | aggregate | SELECT aggregate(array(1, 2, 3), 0, (acc, x) -> acc + x) | struct | | org.apache.spark.sql.catalyst.expressions.ArrayAggregate | reduce | SELECT reduce(array(1, 2, 3), 0, (acc, x) -> acc + x) | struct | | org.apache.spark.sql.catalyst.expressions.ArrayAppend | array_append | SELECT array_append(array('b', 'd', 'c', 'a'), 'd') | struct> | -| org.apache.spark.sql.catalyst.expressions.ArrayPrepend | array_prepend | SELECT array_prepend(array('b', 'd', 'c', 'a'), 'd') | struct> | | org.apache.spark.sql.catalyst.expressions.ArrayCompact | array_compact | SELECT array_compact(array(1, 2, 3, null)) | struct> | | org.apache.spark.sql.catalyst.expressions.ArrayContains | array_contains | SELECT array_contains(array(1, 2, 3), 2) | struct | | org.apache.spark.sql.catalyst.expressions.ArrayDistinct | array_distinct | SELECT array_distinct(array(1, 2, 3, null, 3)) | struct> | From 380b15658decb3ba8e2b6a4aaaa5a417359ea965 Mon Sep 17 00:00:00 2001 From: Navin Viswanath Date: Fri, 10 Feb 2023 22:03:34 -0800 Subject: [PATCH 51/68] Fix types --- python/pyspark/sql/functions.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/pyspark/sql/functions.py b/python/pyspark/sql/functions.py index f57b9ee58c798..f08beff974345 100644 --- a/python/pyspark/sql/functions.py +++ b/python/pyspark/sql/functions.py @@ -7636,7 +7636,7 @@ def array_prepend(col: "ColumnOrName", element: Any) -> Column: Collection function: Returns an array containing element as well as all elements from array. The new element is positioned at the beginning of the array. - + .. versionadded:: 3.4.0 Parameters From 4ecfac854048e07002569bcb20d6e35567f85b97 Mon Sep 17 00:00:00 2001 From: Navin Viswanath Date: Sun, 26 Feb 2023 16:46:34 -0800 Subject: [PATCH 52/68] Fix tests --- python/pyspark/sql/functions.py | 1 - 1 file changed, 1 deletion(-) diff --git a/python/pyspark/sql/functions.py b/python/pyspark/sql/functions.py index f08beff974345..2128e1ba1d1e8 100644 --- a/python/pyspark/sql/functions.py +++ b/python/pyspark/sql/functions.py @@ -7648,7 +7648,6 @@ def array_prepend(col: "ColumnOrName", element: Any) -> Column: Returns ------- - :class:`~pyspark.sql.Column` an array excluding given value. From 160db20dc6b3b128b4bdacf6718e63b271a8fdb4 Mon Sep 17 00:00:00 2001 From: Navin Viswanath Date: Sun, 26 Feb 2023 19:18:02 -0800 Subject: [PATCH 53/68] Fix python linter --- python/pyspark/sql/functions.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/python/pyspark/sql/functions.py b/python/pyspark/sql/functions.py index 2128e1ba1d1e8..66f435f3614a8 100644 --- a/python/pyspark/sql/functions.py +++ b/python/pyspark/sql/functions.py @@ -7630,6 +7630,7 @@ def get(col: "ColumnOrName", index: Union["ColumnOrName", int]) -> Column: return _invoke_function_over_columns("get", col, index) + @try_remote_functions def array_prepend(col: "ColumnOrName", element: Any) -> Column: """ @@ -7656,9 +7657,10 @@ def array_prepend(col: "ColumnOrName", element: Any) -> Column: >>> df = spark.createDataFrame([([2, 3, 4],), ([],)], ['data']) >>> df.select(array_prepend(df.data, 1)).collect() [Row(array_prepend(data, 1)=[1, 2, 3, 4]), Row(array_prepend(data, 1)=[1])] - """ + """ return _invoke_function("array_prepend", _to_java_column(col), element) + @try_remote_functions def array_prepend(col: "ColumnOrName", element: Any) -> Column: """ @@ -7719,6 +7721,7 @@ def array_remove(col: "ColumnOrName", element: Any) -> Column: """ return _invoke_function("array_remove", _to_java_column(col), element) + @try_remote_functions def array_distinct(col: "ColumnOrName") -> Column: """ From 3aa673f5213106c66b060039baffe6825fd3832c Mon Sep 17 00:00:00 2001 From: Navin Viswanath Date: Tue, 28 Feb 2023 21:06:54 -0800 Subject: [PATCH 54/68] Address comments --- python/pyspark/sql/functions.py | 8 +++---- .../expressions/collectionOperations.scala | 22 ++++++++++++------- 2 files changed, 18 insertions(+), 12 deletions(-) diff --git a/python/pyspark/sql/functions.py b/python/pyspark/sql/functions.py index 66f435f3614a8..aec0455fc20b5 100644 --- a/python/pyspark/sql/functions.py +++ b/python/pyspark/sql/functions.py @@ -7632,7 +7632,7 @@ def get(col: "ColumnOrName", index: Union["ColumnOrName", int]) -> Column: @try_remote_functions -def array_prepend(col: "ColumnOrName", element: Any) -> Column: +def array_prepend(col: "ColumnOrName", value: Any) -> Column: """ Collection function: Returns an array containing element as well as all elements from array. The new element is positioned @@ -7644,8 +7644,8 @@ def array_prepend(col: "ColumnOrName", element: Any) -> Column: ---------- col : :class:`~pyspark.sql.Column` or str name of column containing array - element : - element to be prepended to the array + value : + a literal value, or a :class:`~pyspark.sql.Column` expression. Returns ------- @@ -7658,7 +7658,7 @@ def array_prepend(col: "ColumnOrName", element: Any) -> Column: >>> df.select(array_prepend(df.data, 1)).collect() [Row(array_prepend(data, 1)=[1, 2, 3, 4]), Row(array_prepend(data, 1)=[1])] """ - return _invoke_function("array_prepend", _to_java_column(col), element) + return _invoke_function_over_columns("array_prepend", col, lit(value)) @try_remote_functions diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala index f35c1c15243e0..366f035a5cd71 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala @@ -1401,12 +1401,20 @@ case class ArrayContains(left: Expression, right: Expression) // scalastyle:off line.size.limit @ExpressionDescription( - usage = - "_FUNC_(array, value) - Returns an array containing value as well as all elements from array. The new element is positioned at the beginning of the array.", + usage = """ + _FUNC_(array, element) - Add the element at the beginning of the array passed as first + argument. Type of element should be similar to type of the elements of the array. + Null element is also prepended to the array. But if the array passed is NULL + output is NULL + """, examples = """ Examples: > SELECT _FUNC_(array('b', 'd', 'c', 'a'), 'd'); ["d","b","d","c","a"] + > SELECT _FUNC_(array(1, 2, 3, null), null); + [null,1,2,3,null] + > SELECT _FUNC_(CAST(null as Array), 2); + NULL """, group = "array_funcs", since = "3.4.0") @@ -1448,25 +1456,23 @@ case class ArrayPrepend(left: Expression, right: Expression) val newArraySize = ctx.freshName("newArraySize") val newArray = ctx.freshName("newArray") val i = ctx.freshName("i") - val pos = ctx.freshName("pos") + val iPlus1 = s"$i+1" + val zero = "0" val allocation = CodeGenerator.createArrayData( newArray, elementType, newArraySize, s" $prettyName failed.") val assignment = - CodeGenerator.createArrayAssignment(newArray, elementType, arr, pos, i, false) + CodeGenerator.createArrayAssignment(newArray, elementType, arr, iPlus1, i, false) val newElemAssignment = - CodeGenerator.setArrayElement(newArray, elementType, pos, value, Some(rightGen.isNull)) + CodeGenerator.setArrayElement(newArray, elementType, zero, value, Some(rightGen.isNull)) s""" - |int $pos = 0; |int $newArraySize = $arr.numElements() + 1; |$allocation |$newElemAssignment - |$pos = $pos + 1; |for (int $i = 0; $i < $arr.numElements(); $i ++) { | $assignment - | $pos = $pos + 1; |} |${ev.value} = $newArray; |""".stripMargin From 8b480a5ae50640a2a6120aebdadcaa306c750753 Mon Sep 17 00:00:00 2001 From: Navin Viswanath Date: Sun, 12 Mar 2023 18:54:50 -0700 Subject: [PATCH 55/68] Update version --- python/pyspark/sql/functions.py | 2 +- .../spark/sql/catalyst/expressions/collectionOperations.scala | 2 +- sql/core/src/main/scala/org/apache/spark/sql/functions.scala | 2 +- 3 files changed, 3 insertions(+), 3 deletions(-) diff --git a/python/pyspark/sql/functions.py b/python/pyspark/sql/functions.py index aec0455fc20b5..624ff604cebee 100644 --- a/python/pyspark/sql/functions.py +++ b/python/pyspark/sql/functions.py @@ -7638,7 +7638,7 @@ def array_prepend(col: "ColumnOrName", value: Any) -> Column: well as all elements from array. The new element is positioned at the beginning of the array. - .. versionadded:: 3.4.0 + .. versionadded:: 3.5.0 Parameters ---------- diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala index 366f035a5cd71..eaadef7c43b2c 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala @@ -1417,7 +1417,7 @@ case class ArrayContains(left: Expression, right: Expression) NULL """, group = "array_funcs", - since = "3.4.0") + since = "3.5.0") case class ArrayPrepend(left: Expression, right: Expression) extends BinaryExpression with ImplicitCastInputTypes diff --git a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala index 9c102767c4224..9674eda7c2afa 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala @@ -4048,7 +4048,7 @@ object functions { * positioned at the beginning of the array. * * @group collection_funcs - * @since 3.4.0 + * @since 3.5.0 */ def array_prepend(column: Column, element: Any): Column = withExpr { ArrayPrepend(column.expr, lit(element).expr) From 19fe92435de7cb4380baed9409bfd8b433d07044 Mon Sep 17 00:00:00 2001 From: Navin Viswanath Date: Wed, 15 Mar 2023 22:34:27 -0700 Subject: [PATCH 56/68] Address review comments --- .../expressions/collectionOperations.scala | 23 ++++++++++--------- 1 file changed, 12 insertions(+), 11 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala index eaadef7c43b2c..2ccb3a6d0cd57 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala @@ -1403,7 +1403,7 @@ case class ArrayContains(left: Expression, right: Expression) @ExpressionDescription( usage = """ _FUNC_(array, element) - Add the element at the beginning of the array passed as first - argument. Type of element should be similar to type of the elements of the array. + argument. Type of element should be the same as the type of the elements of the array. Null element is also prepended to the array. But if the array passed is NULL output is NULL """, @@ -1453,7 +1453,7 @@ case class ArrayPrepend(left: Expression, right: Expression) val leftGen = left.genCode(ctx) val rightGen = right.genCode(ctx) val f = (arr: String, value: String) => { - val newArraySize = ctx.freshName("newArraySize") + val newArraySize = s"$arr.numElements() + 1" val newArray = ctx.freshName("newArray") val i = ctx.freshName("i") val iPlus1 = s"$i+1" @@ -1468,7 +1468,6 @@ case class ArrayPrepend(left: Expression, right: Expression) val newElemAssignment = CodeGenerator.setArrayElement(newArray, elementType, zero, value, Some(rightGen.isNull)) s""" - |int $newArraySize = $arr.numElements() + 1; |$allocation |$newElemAssignment |for (int $i = 0; $i < $arr.numElements(); $i ++) { @@ -1487,17 +1486,19 @@ case class ArrayPrepend(left: Expression, right: Expression) } ev.copy(code = code""" - boolean ${ev.isNull} = true; - ${CodeGenerator.javaType(dataType)} ${ev.value} = ${CodeGenerator.defaultValue(dataType)}; - $nullSafeEval - """) + |boolean ${ev.isNull} = true; + |${CodeGenerator.javaType(dataType)} ${ev.value} = ${CodeGenerator.defaultValue(dataType)}; + |$nullSafeEval + """.stripMargin + ) } else { ev.copy(code = code""" - ${leftGen.code} - ${rightGen.code} - ${CodeGenerator.javaType(dataType)} ${ev.value} = ${CodeGenerator.defaultValue(dataType)}; - $resultCode""", isNull = FalseLiteral) + |${leftGen.code} + |${rightGen.code} + |${CodeGenerator.javaType(dataType)} ${ev.value} = ${CodeGenerator.defaultValue(dataType)}; + |$resultCode + """.stripMargin, isNull = FalseLiteral) } } From b1cf31a6c799c2891e99eb76b1318ce2e11dc278 Mon Sep 17 00:00:00 2001 From: Navin Viswanath Date: Fri, 27 Jan 2023 20:34:49 -0800 Subject: [PATCH 57/68] Adds a array_prepend expression to catalyst --- python/pyspark/sql/functions.py | 27 ++++++++++++++++++- .../catalyst/analysis/FunctionRegistry.scala | 11 ++++++++ .../org/apache/spark/sql/functions.scala | 10 +++++++ 3 files changed, 47 insertions(+), 1 deletion(-) diff --git a/python/pyspark/sql/functions.py b/python/pyspark/sql/functions.py index 624ff604cebee..d8ee5074c44f4 100644 --- a/python/pyspark/sql/functions.py +++ b/python/pyspark/sql/functions.py @@ -7630,6 +7630,32 @@ def get(col: "ColumnOrName", index: Union["ColumnOrName", int]) -> Column: return _invoke_function_over_columns("get", col, index) +def array_prepend(col: "ColumnOrName", element: Any) -> Column: + """ + Collection function: Returns an array containing value as well as all elements from array. + The new element is positioned at the beginning of the array. + + .. versionadded:: 3.4.0 + + Parameters + ---------- + col : :class:`~pyspark.sql.Column` or str + name of column containing array + element : + element to be prepended to the array + + Returns + ------- + :class:`~pyspark.sql.Column` + an array excluding given value. + + Examples + -------- + >>> df = spark.createDataFrame([([2, 3, 4],), ([],)], ['data']) + >>> df.select(array_prepend(df.data, 1)).collect() + [Row(array_prepend(data, 1)=[1, 2, 3, 4]), Row(array_prepend(data, 1)=[1])] + """ + return _invoke_function("array_prepend", _to_java_column(col), element) @try_remote_functions def array_prepend(col: "ColumnOrName", value: Any) -> Column: @@ -7721,7 +7747,6 @@ def array_remove(col: "ColumnOrName", element: Any) -> Column: """ return _invoke_function("array_remove", _to_java_column(col), element) - @try_remote_functions def array_distinct(col: "ColumnOrName") -> Column: """ diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala index aca73741c6396..7ff11b15c6eb1 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala @@ -971,6 +971,7 @@ object TableFunctionRegistry { : (String, (ExpressionInfo, TableFunctionBuilder)) = { val (info, builder) = FunctionRegistryBase.build[T](name, since = None) val newBuilder = (expressions: Seq[Expression]) => { +<<<<<<< HEAD val generator = builder(expressions) assert(generator.isInstanceOf[Generator]) Generate( @@ -980,6 +981,16 @@ object TableFunctionRegistry { qualifier = None, generatorOutput = Nil, child = OneRowRelation()) +======= + try { + builder(expressions) + } catch { + case e: AnalysisException => + val argTypes = expressions.map(_.dataType.typeName).mkString(", ") + throw QueryCompilationErrors.cannotApplyTableValuedFunctionError( + name, argTypes, info.getUsage, e.getMessage) + } +>>>>>>> Revert "SPARK-41231: Adds an array_prepend function to catalyst" } (name, (info, newBuilder)) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala index 9674eda7c2afa..89aeccfe6ed4f 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala @@ -4042,6 +4042,16 @@ object functions { */ def array_compact(column: Column): Column = withExpr { ArrayCompact(column.expr) + + /** + * Returns an array containing value as well as all elements from array. The new element is + * positioned at the beginning of the array. + * + * @group collection_funcs + * @since 3.4.0 + */ + def array_prepend(column: Column, element: Any): Column = withExpr { + ArrayPrepend(column.expr, lit(element).expr) } /** * Returns an array containing value as well as all elements from array. The new element is From c0c6a512871651be2e4f834b41ad64c62b3abe09 Mon Sep 17 00:00:00 2001 From: Navin Viswanath Date: Fri, 27 Jan 2023 20:36:47 -0800 Subject: [PATCH 58/68] Fix null handling --- .../expressions/CollectionExpressionsSuite.scala | 2 -- .../scala/org/apache/spark/sql/functions.scala | 15 +++++++++++++++ 2 files changed, 15 insertions(+), 2 deletions(-) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala index 3abc70a3d5518..1d00ec0cd8d2b 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala @@ -20,10 +20,8 @@ package org.apache.spark.sql.catalyst.expressions import java.sql.{Date, Timestamp} import java.time.{Duration, LocalDateTime, Period} import java.util.TimeZone - import scala.language.implicitConversions import scala.util.Random - import org.apache.spark.{SparkFunSuite, SparkRuntimeException} import org.apache.spark.sql.Row import org.apache.spark.sql.catalyst.InternalRow diff --git a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala index 89aeccfe6ed4f..3aa778382d1c5 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala @@ -4050,6 +4050,21 @@ object functions { * @group collection_funcs * @since 3.4.0 */ + def array_prepend(column: Column, element: Any): Column = withExpr { + ArrayPrepend(column.expr, lit(element).expr) + + /** + * Remove all null elements from the given array. + * + * @group collection_funcs + * @since 3.4.0 + */ + def array_compact(column: Column): Column = withExpr { + ArrayCompact(column.expr) + /** + * Returns an array containing value as well as all elements from array.The + * new element is positioned at the beginning of the array. + */ def array_prepend(column: Column, element: Any): Column = withExpr { ArrayPrepend(column.expr, lit(element).expr) } From 422f393575fb82635169e67da0e5c2a5858815a0 Mon Sep 17 00:00:00 2001 From: Navin Viswanath Date: Fri, 27 Jan 2023 20:41:01 -0800 Subject: [PATCH 59/68] Fix --- .../scala/org/apache/spark/sql/functions.scala | 17 +---------------- 1 file changed, 1 insertion(+), 16 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala index 3aa778382d1c5..95a976d6ba325 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala @@ -4042,7 +4042,7 @@ object functions { */ def array_compact(column: Column): Column = withExpr { ArrayCompact(column.expr) - + } /** * Returns an array containing value as well as all elements from array. The new element is * positioned at the beginning of the array. @@ -4050,21 +4050,6 @@ object functions { * @group collection_funcs * @since 3.4.0 */ - def array_prepend(column: Column, element: Any): Column = withExpr { - ArrayPrepend(column.expr, lit(element).expr) - - /** - * Remove all null elements from the given array. - * - * @group collection_funcs - * @since 3.4.0 - */ - def array_compact(column: Column): Column = withExpr { - ArrayCompact(column.expr) - /** - * Returns an array containing value as well as all elements from array.The - * new element is positioned at the beginning of the array. - */ def array_prepend(column: Column, element: Any): Column = withExpr { ArrayPrepend(column.expr, lit(element).expr) } From 2e193e4cdac011233506e5b0271c1eceda5ac353 Mon Sep 17 00:00:00 2001 From: Navin Viswanath Date: Tue, 7 Feb 2023 20:45:20 -0800 Subject: [PATCH 60/68] Lint --- .../sql/catalyst/expressions/CollectionExpressionsSuite.scala | 1 + 1 file changed, 1 insertion(+) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala index 1d00ec0cd8d2b..fced26284885f 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala @@ -20,6 +20,7 @@ package org.apache.spark.sql.catalyst.expressions import java.sql.{Date, Timestamp} import java.time.{Duration, LocalDateTime, Period} import java.util.TimeZone + import scala.language.implicitConversions import scala.util.Random import org.apache.spark.{SparkFunSuite, SparkRuntimeException} From 95673b826bd58f126fde095e15c9338f663deaa8 Mon Sep 17 00:00:00 2001 From: Navin Viswanath Date: Wed, 8 Feb 2023 20:46:44 -0800 Subject: [PATCH 61/68] Add examples of usage and fix test --- python/pyspark/sql/functions.py | 6 ++++-- .../test/resources/sql-functions/sql-expression-schema.md | 1 + 2 files changed, 5 insertions(+), 2 deletions(-) diff --git a/python/pyspark/sql/functions.py b/python/pyspark/sql/functions.py index d8ee5074c44f4..f290f22e7b4b4 100644 --- a/python/pyspark/sql/functions.py +++ b/python/pyspark/sql/functions.py @@ -7630,10 +7630,12 @@ def get(col: "ColumnOrName", index: Union["ColumnOrName", int]) -> Column: return _invoke_function_over_columns("get", col, index) +@try_remote_functions def array_prepend(col: "ColumnOrName", element: Any) -> Column: """ - Collection function: Returns an array containing value as well as all elements from array. - The new element is positioned at the beginning of the array. + Collection function: Returns an array containing element as + well as all elements from array. The new element is positioned + at the beginning of the array. .. versionadded:: 3.4.0 diff --git a/sql/core/src/test/resources/sql-functions/sql-expression-schema.md b/sql/core/src/test/resources/sql-functions/sql-expression-schema.md index 6b5b67f984916..0cbb896fe03de 100644 --- a/sql/core/src/test/resources/sql-functions/sql-expression-schema.md +++ b/sql/core/src/test/resources/sql-functions/sql-expression-schema.md @@ -13,6 +13,7 @@ | org.apache.spark.sql.catalyst.expressions.ArrayAggregate | aggregate | SELECT aggregate(array(1, 2, 3), 0, (acc, x) -> acc + x) | struct | | org.apache.spark.sql.catalyst.expressions.ArrayAggregate | reduce | SELECT reduce(array(1, 2, 3), 0, (acc, x) -> acc + x) | struct | | org.apache.spark.sql.catalyst.expressions.ArrayAppend | array_append | SELECT array_append(array('b', 'd', 'c', 'a'), 'd') | struct> | +| org.apache.spark.sql.catalyst.expressions.ArrayPrepend | array_prepend | SELECT array_prepend(array('b', 'd', 'c', 'a'), 'd') | struct> | | org.apache.spark.sql.catalyst.expressions.ArrayCompact | array_compact | SELECT array_compact(array(1, 2, 3, null)) | struct> | | org.apache.spark.sql.catalyst.expressions.ArrayContains | array_contains | SELECT array_contains(array(1, 2, 3), 2) | struct | | org.apache.spark.sql.catalyst.expressions.ArrayDistinct | array_distinct | SELECT array_distinct(array(1, 2, 3, null, 3)) | struct> | From 46e6dd78177b98e20ba421c98ad35b999c285343 Mon Sep 17 00:00:00 2001 From: Navin Viswanath Date: Thu, 9 Feb 2023 20:44:09 -0800 Subject: [PATCH 62/68] Fix tests --- python/pyspark/sql/functions.py | 9 +++++---- .../resources/sql-functions/sql-expression-schema.md | 1 - 2 files changed, 5 insertions(+), 5 deletions(-) diff --git a/python/pyspark/sql/functions.py b/python/pyspark/sql/functions.py index f290f22e7b4b4..3b2e0e7ee7608 100644 --- a/python/pyspark/sql/functions.py +++ b/python/pyspark/sql/functions.py @@ -7633,10 +7633,10 @@ def get(col: "ColumnOrName", index: Union["ColumnOrName", int]) -> Column: @try_remote_functions def array_prepend(col: "ColumnOrName", element: Any) -> Column: """ - Collection function: Returns an array containing element as - well as all elements from array. The new element is positioned + Collection function: Returns an array containing element as + well as all elements from array. The new element is positioned at the beginning of the array. - + .. versionadded:: 3.4.0 Parameters @@ -7648,6 +7648,7 @@ def array_prepend(col: "ColumnOrName", element: Any) -> Column: Returns ------- + :class:`~pyspark.sql.Column` an array excluding given value. @@ -7656,7 +7657,7 @@ def array_prepend(col: "ColumnOrName", element: Any) -> Column: >>> df = spark.createDataFrame([([2, 3, 4],), ([],)], ['data']) >>> df.select(array_prepend(df.data, 1)).collect() [Row(array_prepend(data, 1)=[1, 2, 3, 4]), Row(array_prepend(data, 1)=[1])] - """ + """ return _invoke_function("array_prepend", _to_java_column(col), element) @try_remote_functions diff --git a/sql/core/src/test/resources/sql-functions/sql-expression-schema.md b/sql/core/src/test/resources/sql-functions/sql-expression-schema.md index 0cbb896fe03de..6b5b67f984916 100644 --- a/sql/core/src/test/resources/sql-functions/sql-expression-schema.md +++ b/sql/core/src/test/resources/sql-functions/sql-expression-schema.md @@ -13,7 +13,6 @@ | org.apache.spark.sql.catalyst.expressions.ArrayAggregate | aggregate | SELECT aggregate(array(1, 2, 3), 0, (acc, x) -> acc + x) | struct | | org.apache.spark.sql.catalyst.expressions.ArrayAggregate | reduce | SELECT reduce(array(1, 2, 3), 0, (acc, x) -> acc + x) | struct | | org.apache.spark.sql.catalyst.expressions.ArrayAppend | array_append | SELECT array_append(array('b', 'd', 'c', 'a'), 'd') | struct> | -| org.apache.spark.sql.catalyst.expressions.ArrayPrepend | array_prepend | SELECT array_prepend(array('b', 'd', 'c', 'a'), 'd') | struct> | | org.apache.spark.sql.catalyst.expressions.ArrayCompact | array_compact | SELECT array_compact(array(1, 2, 3, null)) | struct> | | org.apache.spark.sql.catalyst.expressions.ArrayContains | array_contains | SELECT array_contains(array(1, 2, 3), 2) | struct | | org.apache.spark.sql.catalyst.expressions.ArrayDistinct | array_distinct | SELECT array_distinct(array(1, 2, 3, null, 3)) | struct> | From 67a64daf795f87f3dc5ccd71e4c35c3365adb20b Mon Sep 17 00:00:00 2001 From: Navin Viswanath Date: Fri, 10 Feb 2023 22:03:34 -0800 Subject: [PATCH 63/68] Fix types --- python/pyspark/sql/functions.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/pyspark/sql/functions.py b/python/pyspark/sql/functions.py index 3b2e0e7ee7608..4e3354d8010db 100644 --- a/python/pyspark/sql/functions.py +++ b/python/pyspark/sql/functions.py @@ -7636,7 +7636,7 @@ def array_prepend(col: "ColumnOrName", element: Any) -> Column: Collection function: Returns an array containing element as well as all elements from array. The new element is positioned at the beginning of the array. - + .. versionadded:: 3.4.0 Parameters From 19505ff09da6bb7636c6d1d82f6c8a33727b4871 Mon Sep 17 00:00:00 2001 From: Navin Viswanath Date: Sun, 26 Feb 2023 16:46:34 -0800 Subject: [PATCH 64/68] Fix tests --- python/pyspark/sql/functions.py | 1 - 1 file changed, 1 deletion(-) diff --git a/python/pyspark/sql/functions.py b/python/pyspark/sql/functions.py index 4e3354d8010db..bcb88930cdcdc 100644 --- a/python/pyspark/sql/functions.py +++ b/python/pyspark/sql/functions.py @@ -7648,7 +7648,6 @@ def array_prepend(col: "ColumnOrName", element: Any) -> Column: Returns ------- - :class:`~pyspark.sql.Column` an array excluding given value. From 52078ff73a36893ab65f6ccc939e693b4328ba84 Mon Sep 17 00:00:00 2001 From: Navin Viswanath Date: Sun, 26 Feb 2023 19:18:02 -0800 Subject: [PATCH 65/68] Fix python linter --- python/pyspark/sql/functions.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/python/pyspark/sql/functions.py b/python/pyspark/sql/functions.py index bcb88930cdcdc..d920f0bf7c319 100644 --- a/python/pyspark/sql/functions.py +++ b/python/pyspark/sql/functions.py @@ -7630,6 +7630,7 @@ def get(col: "ColumnOrName", index: Union["ColumnOrName", int]) -> Column: return _invoke_function_over_columns("get", col, index) + @try_remote_functions def array_prepend(col: "ColumnOrName", element: Any) -> Column: """ @@ -7656,9 +7657,10 @@ def array_prepend(col: "ColumnOrName", element: Any) -> Column: >>> df = spark.createDataFrame([([2, 3, 4],), ([],)], ['data']) >>> df.select(array_prepend(df.data, 1)).collect() [Row(array_prepend(data, 1)=[1, 2, 3, 4]), Row(array_prepend(data, 1)=[1])] - """ + """ return _invoke_function("array_prepend", _to_java_column(col), element) + @try_remote_functions def array_prepend(col: "ColumnOrName", value: Any) -> Column: """ @@ -7749,6 +7751,7 @@ def array_remove(col: "ColumnOrName", element: Any) -> Column: """ return _invoke_function("array_remove", _to_java_column(col), element) + @try_remote_functions def array_distinct(col: "ColumnOrName") -> Column: """ From 8cd56bdc62306429f5a0dd310cfa671880787cab Mon Sep 17 00:00:00 2001 From: Navin Viswanath Date: Thu, 16 Mar 2023 21:12:58 -0700 Subject: [PATCH 66/68] Fix merge --- .../sql/catalyst/analysis/FunctionRegistry.scala | 11 ----------- 1 file changed, 11 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala index 7ff11b15c6eb1..aca73741c6396 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala @@ -971,7 +971,6 @@ object TableFunctionRegistry { : (String, (ExpressionInfo, TableFunctionBuilder)) = { val (info, builder) = FunctionRegistryBase.build[T](name, since = None) val newBuilder = (expressions: Seq[Expression]) => { -<<<<<<< HEAD val generator = builder(expressions) assert(generator.isInstanceOf[Generator]) Generate( @@ -981,16 +980,6 @@ object TableFunctionRegistry { qualifier = None, generatorOutput = Nil, child = OneRowRelation()) -======= - try { - builder(expressions) - } catch { - case e: AnalysisException => - val argTypes = expressions.map(_.dataType.typeName).mkString(", ") - throw QueryCompilationErrors.cannotApplyTableValuedFunctionError( - name, argTypes, info.getUsage, e.getMessage) - } ->>>>>>> Revert "SPARK-41231: Adds an array_prepend function to catalyst" } (name, (info, newBuilder)) } From 663473798517909bd0aee150703987bd62c176ce Mon Sep 17 00:00:00 2001 From: Navin Viswanath Date: Thu, 16 Mar 2023 22:36:30 -0700 Subject: [PATCH 67/68] Fix MiMa --- .../sql/connect/client/CheckConnectJvmClientCompatibility.scala | 1 + .../sql/catalyst/expressions/CollectionExpressionsSuite.scala | 1 + 2 files changed, 2 insertions(+) diff --git a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/client/CheckConnectJvmClientCompatibility.scala b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/client/CheckConnectJvmClientCompatibility.scala index 97d130421a242..f50520c1a54d9 100644 --- a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/client/CheckConnectJvmClientCompatibility.scala +++ b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/client/CheckConnectJvmClientCompatibility.scala @@ -177,6 +177,7 @@ object CheckConnectJvmClientCompatibility { ProblemFilters.exclude[Problem]("org.apache.spark.sql.functions.broadcast"), ProblemFilters.exclude[Problem]("org.apache.spark.sql.functions.typedlit"), ProblemFilters.exclude[Problem]("org.apache.spark.sql.functions.typedLit"), + ProblemFilters.exclude[Problem]("org.apache.spark.sql.functions.array_prepend"), // RelationalGroupedDataset ProblemFilters.exclude[Problem]("org.apache.spark.sql.RelationalGroupedDataset.apply"), diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala index fced26284885f..3abc70a3d5518 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala @@ -23,6 +23,7 @@ import java.util.TimeZone import scala.language.implicitConversions import scala.util.Random + import org.apache.spark.{SparkFunSuite, SparkRuntimeException} import org.apache.spark.sql.Row import org.apache.spark.sql.catalyst.InternalRow From 4dffbf70096df37a87a424c3fa77aff75f485386 Mon Sep 17 00:00:00 2001 From: Navin Viswanath Date: Fri, 17 Mar 2023 08:25:48 -0700 Subject: [PATCH 68/68] Fix indent --- .../scala/org/apache/spark/sql/functions.scala | 15 ++++++++------- 1 file changed, 8 insertions(+), 7 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala index d771367f318ce..5081f58220246 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala @@ -4043,13 +4043,14 @@ object functions { def array_compact(column: Column): Column = withExpr { ArrayCompact(column.expr) } - /** - * Returns an array containing value as well as all elements from array. The new element is - * positioned at the beginning of the array. - * - * @group collection_funcs - * @since 3.5.0 - */ + + /** + * Returns an array containing value as well as all elements from array. The new element is + * positioned at the beginning of the array. + * + * @group collection_funcs + * @since 3.5.0 + */ def array_prepend(column: Column, element: Any): Column = withExpr { ArrayPrepend(column.expr, lit(element).expr) }