From ace19dd7230598350838aa60fc93b32a08642acd Mon Sep 17 00:00:00 2001 From: Takuya UESHIN Date: Thu, 2 Aug 2018 15:11:52 +0900 Subject: [PATCH 1/5] Add `ArrayFilter`. --- .../catalyst/analysis/FunctionRegistry.scala | 1 + .../expressions/higherOrderFunctions.scala | 58 +++++++++++ .../HigherOrderFunctionsSuite.scala | 37 +++++++ .../inputs/higher-order-functions.sql | 9 ++ .../results/higher-order-functions.sql.out | 30 +++++- .../spark/sql/DataFrameFunctionsSuite.scala | 96 +++++++++++++++++++ 6 files changed, 230 insertions(+), 1 deletion(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala index f7517486e5411..d0efe975f81ce 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala @@ -441,6 +441,7 @@ object FunctionRegistry { expression[ArrayRemove]("array_remove"), expression[ArrayDistinct]("array_distinct"), expression[ArrayTransform]("transform"), + expression[ArrayFilter]("filter"), CreateStruct.registryEntry, // misc functions diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/higherOrderFunctions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/higherOrderFunctions.scala index c5c3482afa134..ca3a274d7ad1d 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/higherOrderFunctions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/higherOrderFunctions.scala @@ -19,6 +19,8 @@ package org.apache.spark.sql.catalyst.expressions import java.util.concurrent.atomic.AtomicReference +import scala.collection.mutable + import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.codegen._ import org.apache.spark.sql.catalyst.expressions.codegen.Block._ @@ -210,3 +212,59 @@ case class ArrayTransform( override def prettyName: String = "transform" } + +/** + * Filters the input array using the given lambda function. + */ +@ExpressionDescription( + usage = "_FUNC_(expr, func) - Filters the input array using the given predicate.", + examples = """ + Examples: + > SELECT _FUNC_(array(1, 2, 3), x -> x % 2 == 1); + array(1, 3) + """, + since = "2.4.0") +case class ArrayFilter( + input: Expression, + function: Expression) + extends ArrayBasedHigherOrderFunction with CodegenFallback { + + override def nullable: Boolean = input.nullable + + override def dataType: DataType = input.dataType + + override def expectingFunctionType: AbstractDataType = BooleanType + + override def bind(f: (Expression, Seq[(DataType, Boolean)]) => LambdaFunction): ArrayFilter = { + val (elementType, containsNull) = input.dataType match { + case ArrayType(elementType, containsNull) => (elementType, containsNull) + case _ => + val ArrayType(elementType, containsNull) = ArrayType.defaultConcreteType + (elementType, containsNull) + } + copy(function = f(function, (elementType, containsNull) :: Nil)) + } + + @transient lazy val LambdaFunction(_, Seq(elementVar: NamedLambdaVariable), _) = function + + override def eval(input: InternalRow): Any = { + val arr = this.input.eval(input).asInstanceOf[ArrayData] + if (arr == null) { + null + } else { + val f = functionForEval + val buffer = new mutable.ArrayBuffer[Any] + var i = 0 + while (i < arr.numElements) { + elementVar.value.set(arr.get(i, elementVar.dataType)) + if (f.eval(input).asInstanceOf[Boolean]) { + buffer += elementVar.value.get + } + i += 1 + } + new GenericArrayData(buffer) + } + } + + override def prettyName: String = "filter" +} diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/HigherOrderFunctionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/HigherOrderFunctionsSuite.scala index e987ea5b8a4d1..d1330c7aad219 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/HigherOrderFunctionsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/HigherOrderFunctionsSuite.scala @@ -54,6 +54,11 @@ class HigherOrderFunctionsSuite extends SparkFunSuite with ExpressionEvalHelper ArrayTransform(expr, createLambda(at.elementType, at.containsNull, IntegerType, false, f)) } + def filter(expr: Expression, f: Expression => Expression): Expression = { + val at = expr.dataType.asInstanceOf[ArrayType] + ArrayFilter(expr, createLambda(at.elementType, at.containsNull, f)) + } + test("ArrayTransform") { val ai0 = Literal.create(Seq(1, 2, 3), ArrayType(IntegerType, containsNull = false)) val ai1 = Literal.create(Seq[Integer](1, null, 3), ArrayType(IntegerType, containsNull = true)) @@ -94,4 +99,36 @@ class HigherOrderFunctionsSuite extends SparkFunSuite with ExpressionEvalHelper checkEvaluation(transform(aai, array => Cast(transform(array, plusIndex), StringType)), Seq("[1, 3, 5]", null, "[4, 6]")) } + + test("ArrayFilter") { + val ai0 = Literal.create(Seq(1, 2, 3), ArrayType(IntegerType, containsNull = false)) + val ai1 = Literal.create(Seq[Integer](1, null, 3), ArrayType(IntegerType, containsNull = true)) + val ain = Literal.create(null, ArrayType(IntegerType, containsNull = false)) + + val isEven: Expression => Expression = x => x % 2 === 0 + val isNullOrOdd: Expression => Expression = x => x.isNull || x % 2 === 1 + + checkEvaluation(filter(ai0, isEven), Seq(2)) + checkEvaluation(filter(ai0, isNullOrOdd), Seq(1, 3)) + checkEvaluation(filter(ai1, isEven), Seq.empty) + checkEvaluation(filter(ai1, isNullOrOdd), Seq(1, null, 3)) + checkEvaluation(filter(ain, isEven), null) + checkEvaluation(filter(ain, isNullOrOdd), null) + + val as0 = + Literal.create(Seq("a0", "b1", "a2", "c3"), ArrayType(StringType, containsNull = false)) + val as1 = Literal.create(Seq("a", null, "c"), ArrayType(StringType, containsNull = true)) + val asn = Literal.create(null, ArrayType(StringType, containsNull = false)) + + val startsWithA: Expression => Expression = x => x.startsWith("a") + + checkEvaluation(filter(as0, startsWithA), Seq("a0", "a2")) + checkEvaluation(filter(as1, startsWithA), Seq("a")) + checkEvaluation(filter(asn, startsWithA), null) + + val aai = Literal.create(Seq(Seq(1, 2, 3), null, Seq(4, 5)), + ArrayType(ArrayType(IntegerType, containsNull = false), containsNull = true)) + checkEvaluation(transform(aai, ix => filter(ix, isNullOrOdd)), + Seq(Seq(1, 3), null, Seq(5))) + } } diff --git a/sql/core/src/test/resources/sql-tests/inputs/higher-order-functions.sql b/sql/core/src/test/resources/sql-tests/inputs/higher-order-functions.sql index 8e928a41f08e0..f833aa5818bc1 100644 --- a/sql/core/src/test/resources/sql-tests/inputs/higher-order-functions.sql +++ b/sql/core/src/test/resources/sql-tests/inputs/higher-order-functions.sql @@ -24,3 +24,12 @@ select transform(ys, 0) as v from nested; -- Transform a null array select transform(cast(null as array), x -> x + 1) as v; + +-- Filter. +select filter(ys, y -> y > 30) as v from nested; + +-- Filter a null array +select filter(cast(null as array), y -> true) as v; + +-- Filter nested arrays +select transform(zs, z -> filter(z, zz -> zz > 50)) as v from nested; diff --git a/sql/core/src/test/resources/sql-tests/results/higher-order-functions.sql.out b/sql/core/src/test/resources/sql-tests/results/higher-order-functions.sql.out index ca2c3c35333cc..4c5d972378b31 100644 --- a/sql/core/src/test/resources/sql-tests/results/higher-order-functions.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/higher-order-functions.sql.out @@ -1,5 +1,5 @@ -- Automatically generated by SQLQueryTestSuite --- Number of queries: 8 +-- Number of queries: 11 -- !query 0 @@ -79,3 +79,31 @@ select transform(cast(null as array), x -> x + 1) as v struct> -- !query 7 output NULL + + +-- !query 8 +select filter(ys, y -> y > 30) as v from nested +-- !query 8 schema +struct> +-- !query 8 output +[32,97] +[77] +[] + + +-- !query 9 +select filter(cast(null as array), y -> true) as v +-- !query 9 schema +struct> +-- !query 9 output +NULL + + +-- !query 10 +select transform(zs, z -> filter(z, zz -> zz > 50)) as v from nested +-- !query 10 schema +struct>> +-- !query 10 output +[[96,65],[]] +[[99],[123],[]] +[[]] diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala index 923482024b033..1d5707a2c7047 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala @@ -1800,6 +1800,102 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSQLContext { assert(ex2.getMessage.contains("data type mismatch: argument 1 requires array type")) } + test("filter function - array for primitive type not containing null") { + val df = Seq( + Seq(1, 9, 8, 7), + Seq(5, 8, 9, 7, 2), + Seq.empty, + null + ).toDF("i") + + def testArrayOfPrimitiveTypeNotContainsNull(): Unit = { + checkAnswer(df.selectExpr("filter(i, x -> x % 2 == 0)"), + Seq( + Row(Seq(8)), + Row(Seq(8, 2)), + Row(Seq.empty), + Row(null))) + } + + // Test with local relation, the Project will be evaluated without codegen + testArrayOfPrimitiveTypeNotContainsNull() + // Test with cached relation, the Project will be evaluated with codegen + df.cache() + testArrayOfPrimitiveTypeNotContainsNull() + } + + test("filter function - array for primitive type containing null") { + val df = Seq[Seq[Integer]]( + Seq(1, 9, 8, null, 7), + Seq(5, null, 8, 9, 7, 2), + Seq.empty, + null + ).toDF("i") + + def testArrayOfPrimitiveTypeContainsNull(): Unit = { + checkAnswer(df.selectExpr("filter(i, x -> x % 2 == 0)"), + Seq( + Row(Seq(8)), + Row(Seq(8, 2)), + Row(Seq.empty), + Row(null))) + } + + // Test with local relation, the Project will be evaluated without codegen + testArrayOfPrimitiveTypeContainsNull() + // Test with cached relation, the Project will be evaluated with codegen + df.cache() + testArrayOfPrimitiveTypeContainsNull() + } + + test("filter function - array for non-primitive type") { + val df = Seq( + Seq("c", "a", "b"), + Seq("b", null, "c", null), + Seq.empty, + null + ).toDF("s") + + def testNonPrimitiveType(): Unit = { + checkAnswer(df.selectExpr("filter(s, x -> x is not null)"), + Seq( + Row(Seq("c", "a", "b")), + Row(Seq("b", "c")), + Row(Seq.empty), + Row(null))) + } + + // Test with local relation, the Project will be evaluated without codegen + testNonPrimitiveType() + // Test with cached relation, the Project will be evaluated with codegen + df.cache() + testNonPrimitiveType() + } + + test("filter function - invalid") { + val df = Seq( + (Seq("c", "a", "b"), 1), + (Seq("b", null, "c", null), 2), + (Seq.empty, 3), + (null, 4) + ).toDF("s", "i") + + val ex1 = intercept[AnalysisException] { + df.selectExpr("filter(s, (x, y) -> x + y)") + } + assert(ex1.getMessage.contains("The number of lambda function arguments '2' does not match")) + + val ex2 = intercept[AnalysisException] { + df.selectExpr("filter(i, x -> x)") + } + assert(ex2.getMessage.contains("data type mismatch: argument 1 requires array type")) + + val ex3 = intercept[AnalysisException] { + df.selectExpr("filter(s, x -> x)") + } + assert(ex3.getMessage.contains("data type mismatch: argument 2 requires boolean type")) + } + private def assertValuesDoNotChangeAfterCoalesceOrUnion(v: Column): Unit = { import DataFrameFunctionsSuite.CodegenFallbackExpr for ((codegenFallback, wholeStage) <- Seq((true, false), (false, false), (false, true))) { From e9af0947044312fdd8da74498f9935784cdc44f2 Mon Sep 17 00:00:00 2001 From: Takuya UESHIN Date: Fri, 3 Aug 2018 17:11:39 +0900 Subject: [PATCH 2/5] Refactor. --- .../expressions/higherOrderFunctions.scala | 32 ++++++++++--------- 1 file changed, 17 insertions(+), 15 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/higherOrderFunctions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/higherOrderFunctions.scala index ca3a274d7ad1d..8f632eb373c29 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/higherOrderFunctions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/higherOrderFunctions.scala @@ -142,6 +142,18 @@ trait ArrayBasedHigherOrderFunction extends HigherOrderFunction with ExpectsInpu @transient lazy val functionForEval: Expression = functionsForEval.head } +object ArrayBasedHigherOrderFunction { + + def elementArgumentType(dt: DataType): (DataType, Boolean) = { + dt match { + case ArrayType(elementType, containsNull) => (elementType, containsNull) + case _ => + val ArrayType(elementType, containsNull) = ArrayType.defaultConcreteType + (elementType, containsNull) + } + } +} + /** * Transform elements in an array using the transform function. This is similar to * a `map` in functional programming. @@ -166,17 +178,12 @@ case class ArrayTransform( override def dataType: ArrayType = ArrayType(function.dataType, function.nullable) override def bind(f: (Expression, Seq[(DataType, Boolean)]) => LambdaFunction): ArrayTransform = { - val (elementType, containsNull) = input.dataType match { - case ArrayType(elementType, containsNull) => (elementType, containsNull) - case _ => - val ArrayType(elementType, containsNull) = ArrayType.defaultConcreteType - (elementType, containsNull) - } + val elem = ArrayBasedHigherOrderFunction.elementArgumentType(input.dataType) function match { case LambdaFunction(_, arguments, _) if arguments.size == 2 => - copy(function = f(function, (elementType, containsNull) :: (IntegerType, false) :: Nil)) + copy(function = f(function, elem :: (IntegerType, false) :: Nil)) case _ => - copy(function = f(function, (elementType, containsNull) :: Nil)) + copy(function = f(function, elem :: Nil)) } } @@ -236,13 +243,8 @@ case class ArrayFilter( override def expectingFunctionType: AbstractDataType = BooleanType override def bind(f: (Expression, Seq[(DataType, Boolean)]) => LambdaFunction): ArrayFilter = { - val (elementType, containsNull) = input.dataType match { - case ArrayType(elementType, containsNull) => (elementType, containsNull) - case _ => - val ArrayType(elementType, containsNull) = ArrayType.defaultConcreteType - (elementType, containsNull) - } - copy(function = f(function, (elementType, containsNull) :: Nil)) + val elem = ArrayBasedHigherOrderFunction.elementArgumentType(input.dataType) + copy(function = f(function, elem :: Nil)) } @transient lazy val LambdaFunction(_, Seq(elementVar: NamedLambdaVariable), _) = function From e79ebbdb71d851b7f95bbb5eb7ce02b6c6edb3eb Mon Sep 17 00:00:00 2001 From: Takuya UESHIN Date: Fri, 3 Aug 2018 18:06:46 +0900 Subject: [PATCH 3/5] Address comments. --- .../catalyst/analysis/FunctionRegistry.scala | 2 +- .../expressions/higherOrderFunctions.scala | 4 ++-- .../inputs/higher-order-functions.sql | 6 +++--- .../results/higher-order-functions.sql.out | 6 +++--- .../spark/sql/DataFrameFunctionsSuite.scala | 20 +++++++++---------- 5 files changed, 19 insertions(+), 19 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala index d0efe975f81ce..d40db5c4e708a 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala @@ -441,7 +441,7 @@ object FunctionRegistry { expression[ArrayRemove]("array_remove"), expression[ArrayDistinct]("array_distinct"), expression[ArrayTransform]("transform"), - expression[ArrayFilter]("filter"), + expression[ArrayFilter]("array_filter"), CreateStruct.registryEntry, // misc functions diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/higherOrderFunctions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/higherOrderFunctions.scala index 8f632eb373c29..3ceb531a8c9a0 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/higherOrderFunctions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/higherOrderFunctions.scala @@ -255,7 +255,7 @@ case class ArrayFilter( null } else { val f = functionForEval - val buffer = new mutable.ArrayBuffer[Any] + val buffer = new mutable.ArrayBuffer[Any](arr.numElements) var i = 0 while (i < arr.numElements) { elementVar.value.set(arr.get(i, elementVar.dataType)) @@ -268,5 +268,5 @@ case class ArrayFilter( } } - override def prettyName: String = "filter" + override def prettyName: String = "array_filter" } diff --git a/sql/core/src/test/resources/sql-tests/inputs/higher-order-functions.sql b/sql/core/src/test/resources/sql-tests/inputs/higher-order-functions.sql index f833aa5818bc1..a6b19e100906c 100644 --- a/sql/core/src/test/resources/sql-tests/inputs/higher-order-functions.sql +++ b/sql/core/src/test/resources/sql-tests/inputs/higher-order-functions.sql @@ -26,10 +26,10 @@ select transform(ys, 0) as v from nested; select transform(cast(null as array), x -> x + 1) as v; -- Filter. -select filter(ys, y -> y > 30) as v from nested; +select array_filter(ys, y -> y > 30) as v from nested; -- Filter a null array -select filter(cast(null as array), y -> true) as v; +select array_filter(cast(null as array), y -> true) as v; -- Filter nested arrays -select transform(zs, z -> filter(z, zz -> zz > 50)) as v from nested; +select transform(zs, z -> array_filter(z, zz -> zz > 50)) as v from nested; diff --git a/sql/core/src/test/resources/sql-tests/results/higher-order-functions.sql.out b/sql/core/src/test/resources/sql-tests/results/higher-order-functions.sql.out index 4c5d972378b31..d67075a482441 100644 --- a/sql/core/src/test/resources/sql-tests/results/higher-order-functions.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/higher-order-functions.sql.out @@ -82,7 +82,7 @@ NULL -- !query 8 -select filter(ys, y -> y > 30) as v from nested +select array_filter(ys, y -> y > 30) as v from nested -- !query 8 schema struct> -- !query 8 output @@ -92,7 +92,7 @@ struct> -- !query 9 -select filter(cast(null as array), y -> true) as v +select array_filter(cast(null as array), y -> true) as v -- !query 9 schema struct> -- !query 9 output @@ -100,7 +100,7 @@ NULL -- !query 10 -select transform(zs, z -> filter(z, zz -> zz > 50)) as v from nested +select transform(zs, z -> array_filter(z, zz -> zz > 50)) as v from nested -- !query 10 schema struct>> -- !query 10 output diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala index 1d5707a2c7047..526474b64b6ff 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala @@ -1800,7 +1800,7 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSQLContext { assert(ex2.getMessage.contains("data type mismatch: argument 1 requires array type")) } - test("filter function - array for primitive type not containing null") { + test("array_filter function - array for primitive type not containing null") { val df = Seq( Seq(1, 9, 8, 7), Seq(5, 8, 9, 7, 2), @@ -1809,7 +1809,7 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSQLContext { ).toDF("i") def testArrayOfPrimitiveTypeNotContainsNull(): Unit = { - checkAnswer(df.selectExpr("filter(i, x -> x % 2 == 0)"), + checkAnswer(df.selectExpr("array_filter(i, x -> x % 2 == 0)"), Seq( Row(Seq(8)), Row(Seq(8, 2)), @@ -1824,7 +1824,7 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSQLContext { testArrayOfPrimitiveTypeNotContainsNull() } - test("filter function - array for primitive type containing null") { + test("array_filter function - array for primitive type containing null") { val df = Seq[Seq[Integer]]( Seq(1, 9, 8, null, 7), Seq(5, null, 8, 9, 7, 2), @@ -1833,7 +1833,7 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSQLContext { ).toDF("i") def testArrayOfPrimitiveTypeContainsNull(): Unit = { - checkAnswer(df.selectExpr("filter(i, x -> x % 2 == 0)"), + checkAnswer(df.selectExpr("array_filter(i, x -> x % 2 == 0)"), Seq( Row(Seq(8)), Row(Seq(8, 2)), @@ -1848,7 +1848,7 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSQLContext { testArrayOfPrimitiveTypeContainsNull() } - test("filter function - array for non-primitive type") { + test("array_filter function - array for non-primitive type") { val df = Seq( Seq("c", "a", "b"), Seq("b", null, "c", null), @@ -1857,7 +1857,7 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSQLContext { ).toDF("s") def testNonPrimitiveType(): Unit = { - checkAnswer(df.selectExpr("filter(s, x -> x is not null)"), + checkAnswer(df.selectExpr("array_filter(s, x -> x is not null)"), Seq( Row(Seq("c", "a", "b")), Row(Seq("b", "c")), @@ -1872,7 +1872,7 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSQLContext { testNonPrimitiveType() } - test("filter function - invalid") { + test("array_filter function - invalid") { val df = Seq( (Seq("c", "a", "b"), 1), (Seq("b", null, "c", null), 2), @@ -1881,17 +1881,17 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSQLContext { ).toDF("s", "i") val ex1 = intercept[AnalysisException] { - df.selectExpr("filter(s, (x, y) -> x + y)") + df.selectExpr("array_filter(s, (x, y) -> x + y)") } assert(ex1.getMessage.contains("The number of lambda function arguments '2' does not match")) val ex2 = intercept[AnalysisException] { - df.selectExpr("filter(i, x -> x)") + df.selectExpr("array_filter(i, x -> x)") } assert(ex2.getMessage.contains("data type mismatch: argument 1 requires array type")) val ex3 = intercept[AnalysisException] { - df.selectExpr("filter(s, x -> x)") + df.selectExpr("array_filter(s, x -> x)") } assert(ex3.getMessage.contains("data type mismatch: argument 2 requires boolean type")) } From 7f4351f1528164b9edf312debba0876a37e43cec Mon Sep 17 00:00:00 2001 From: Takuya UESHIN Date: Fri, 3 Aug 2018 18:38:00 +0900 Subject: [PATCH 4/5] Revert "Address comments." This reverts commit e79ebbdb71d851b7f95bbb5eb7ce02b6c6edb3eb. --- .../catalyst/analysis/FunctionRegistry.scala | 2 +- .../expressions/higherOrderFunctions.scala | 4 ++-- .../inputs/higher-order-functions.sql | 6 +++--- .../results/higher-order-functions.sql.out | 6 +++--- .../spark/sql/DataFrameFunctionsSuite.scala | 20 +++++++++---------- 5 files changed, 19 insertions(+), 19 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala index d40db5c4e708a..d0efe975f81ce 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala @@ -441,7 +441,7 @@ object FunctionRegistry { expression[ArrayRemove]("array_remove"), expression[ArrayDistinct]("array_distinct"), expression[ArrayTransform]("transform"), - expression[ArrayFilter]("array_filter"), + expression[ArrayFilter]("filter"), CreateStruct.registryEntry, // misc functions diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/higherOrderFunctions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/higherOrderFunctions.scala index 3ceb531a8c9a0..8f632eb373c29 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/higherOrderFunctions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/higherOrderFunctions.scala @@ -255,7 +255,7 @@ case class ArrayFilter( null } else { val f = functionForEval - val buffer = new mutable.ArrayBuffer[Any](arr.numElements) + val buffer = new mutable.ArrayBuffer[Any] var i = 0 while (i < arr.numElements) { elementVar.value.set(arr.get(i, elementVar.dataType)) @@ -268,5 +268,5 @@ case class ArrayFilter( } } - override def prettyName: String = "array_filter" + override def prettyName: String = "filter" } diff --git a/sql/core/src/test/resources/sql-tests/inputs/higher-order-functions.sql b/sql/core/src/test/resources/sql-tests/inputs/higher-order-functions.sql index a6b19e100906c..f833aa5818bc1 100644 --- a/sql/core/src/test/resources/sql-tests/inputs/higher-order-functions.sql +++ b/sql/core/src/test/resources/sql-tests/inputs/higher-order-functions.sql @@ -26,10 +26,10 @@ select transform(ys, 0) as v from nested; select transform(cast(null as array), x -> x + 1) as v; -- Filter. -select array_filter(ys, y -> y > 30) as v from nested; +select filter(ys, y -> y > 30) as v from nested; -- Filter a null array -select array_filter(cast(null as array), y -> true) as v; +select filter(cast(null as array), y -> true) as v; -- Filter nested arrays -select transform(zs, z -> array_filter(z, zz -> zz > 50)) as v from nested; +select transform(zs, z -> filter(z, zz -> zz > 50)) as v from nested; diff --git a/sql/core/src/test/resources/sql-tests/results/higher-order-functions.sql.out b/sql/core/src/test/resources/sql-tests/results/higher-order-functions.sql.out index d67075a482441..4c5d972378b31 100644 --- a/sql/core/src/test/resources/sql-tests/results/higher-order-functions.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/higher-order-functions.sql.out @@ -82,7 +82,7 @@ NULL -- !query 8 -select array_filter(ys, y -> y > 30) as v from nested +select filter(ys, y -> y > 30) as v from nested -- !query 8 schema struct> -- !query 8 output @@ -92,7 +92,7 @@ struct> -- !query 9 -select array_filter(cast(null as array), y -> true) as v +select filter(cast(null as array), y -> true) as v -- !query 9 schema struct> -- !query 9 output @@ -100,7 +100,7 @@ NULL -- !query 10 -select transform(zs, z -> array_filter(z, zz -> zz > 50)) as v from nested +select transform(zs, z -> filter(z, zz -> zz > 50)) as v from nested -- !query 10 schema struct>> -- !query 10 output diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala index 526474b64b6ff..1d5707a2c7047 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala @@ -1800,7 +1800,7 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSQLContext { assert(ex2.getMessage.contains("data type mismatch: argument 1 requires array type")) } - test("array_filter function - array for primitive type not containing null") { + test("filter function - array for primitive type not containing null") { val df = Seq( Seq(1, 9, 8, 7), Seq(5, 8, 9, 7, 2), @@ -1809,7 +1809,7 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSQLContext { ).toDF("i") def testArrayOfPrimitiveTypeNotContainsNull(): Unit = { - checkAnswer(df.selectExpr("array_filter(i, x -> x % 2 == 0)"), + checkAnswer(df.selectExpr("filter(i, x -> x % 2 == 0)"), Seq( Row(Seq(8)), Row(Seq(8, 2)), @@ -1824,7 +1824,7 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSQLContext { testArrayOfPrimitiveTypeNotContainsNull() } - test("array_filter function - array for primitive type containing null") { + test("filter function - array for primitive type containing null") { val df = Seq[Seq[Integer]]( Seq(1, 9, 8, null, 7), Seq(5, null, 8, 9, 7, 2), @@ -1833,7 +1833,7 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSQLContext { ).toDF("i") def testArrayOfPrimitiveTypeContainsNull(): Unit = { - checkAnswer(df.selectExpr("array_filter(i, x -> x % 2 == 0)"), + checkAnswer(df.selectExpr("filter(i, x -> x % 2 == 0)"), Seq( Row(Seq(8)), Row(Seq(8, 2)), @@ -1848,7 +1848,7 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSQLContext { testArrayOfPrimitiveTypeContainsNull() } - test("array_filter function - array for non-primitive type") { + test("filter function - array for non-primitive type") { val df = Seq( Seq("c", "a", "b"), Seq("b", null, "c", null), @@ -1857,7 +1857,7 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSQLContext { ).toDF("s") def testNonPrimitiveType(): Unit = { - checkAnswer(df.selectExpr("array_filter(s, x -> x is not null)"), + checkAnswer(df.selectExpr("filter(s, x -> x is not null)"), Seq( Row(Seq("c", "a", "b")), Row(Seq("b", "c")), @@ -1872,7 +1872,7 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSQLContext { testNonPrimitiveType() } - test("array_filter function - invalid") { + test("filter function - invalid") { val df = Seq( (Seq("c", "a", "b"), 1), (Seq("b", null, "c", null), 2), @@ -1881,17 +1881,17 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSQLContext { ).toDF("s", "i") val ex1 = intercept[AnalysisException] { - df.selectExpr("array_filter(s, (x, y) -> x + y)") + df.selectExpr("filter(s, (x, y) -> x + y)") } assert(ex1.getMessage.contains("The number of lambda function arguments '2' does not match")) val ex2 = intercept[AnalysisException] { - df.selectExpr("array_filter(i, x -> x)") + df.selectExpr("filter(i, x -> x)") } assert(ex2.getMessage.contains("data type mismatch: argument 1 requires array type")) val ex3 = intercept[AnalysisException] { - df.selectExpr("array_filter(s, x -> x)") + df.selectExpr("filter(s, x -> x)") } assert(ex3.getMessage.contains("data type mismatch: argument 2 requires boolean type")) } From f6aaa902b687c1bddf233c65c3739ec4ac407774 Mon Sep 17 00:00:00 2001 From: Takuya UESHIN Date: Fri, 3 Aug 2018 18:41:22 +0900 Subject: [PATCH 5/5] Use a size hint. --- .../spark/sql/catalyst/expressions/higherOrderFunctions.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/higherOrderFunctions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/higherOrderFunctions.scala index 8f632eb373c29..e15225ffbd2d2 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/higherOrderFunctions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/higherOrderFunctions.scala @@ -255,7 +255,7 @@ case class ArrayFilter( null } else { val f = functionForEval - val buffer = new mutable.ArrayBuffer[Any] + val buffer = new mutable.ArrayBuffer[Any](arr.numElements) var i = 0 while (i < arr.numElements) { elementVar.value.set(arr.get(i, elementVar.dataType))