From 9cf1ebfad9e9ab5af543027e245f8a70b90ab87b Mon Sep 17 00:00:00 2001 From: Nik Vanderhoof Date: Wed, 27 Mar 2019 23:34:48 -0400 Subject: [PATCH 01/38] Adds higher order functions to scala API --- .../expressions/higherOrderFunctions.scala | 133 +++++++ .../HigherOrderFunctionsSuite.scala | 125 +------ .../org/apache/spark/sql/functions.scala | 135 +++++++ .../spark/sql/DataFrameFunctionsSuite.scala | 350 +++++++++++++++++- 4 files changed, 617 insertions(+), 126 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/higherOrderFunctions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/higherOrderFunctions.scala index e6cc11d1ad280..39e5b63a1c218 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/higherOrderFunctions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/higherOrderFunctions.scala @@ -28,6 +28,139 @@ import org.apache.spark.sql.catalyst.util._ import org.apache.spark.sql.types._ import org.apache.spark.unsafe.array.ByteArrayMethods +/** + * Helper methods for constructing higher order functions. + */ +object HigherOrderUtils { + def createLambda( + dt: DataType, + nullable: Boolean, + f: Expression => Expression): Expression = { + val lv = NamedLambdaVariable("arg", dt, nullable) + val function = f(lv) + LambdaFunction(function, Seq(lv)) + } + + def createLambda( + dt1: DataType, + nullable1: Boolean, + dt2: DataType, + nullable2: Boolean, + f: (Expression, Expression) => Expression): Expression = { + val lv1 = NamedLambdaVariable("arg1", dt1, nullable1) + val lv2 = NamedLambdaVariable("arg2", dt2, nullable2) + val function = f(lv1, lv2) + LambdaFunction(function, Seq(lv1, lv2)) + } + + def createLambda( + dt1: DataType, + nullable1: Boolean, + dt2: DataType, + nullable2: Boolean, + dt3: DataType, + nullable3: Boolean, + f: (Expression, Expression, Expression) => Expression): Expression = { + val lv1 = NamedLambdaVariable("arg1", dt1, nullable1) + val lv2 = NamedLambdaVariable("arg2", dt2, nullable2) + val lv3 = NamedLambdaVariable("arg3", dt3, nullable3) + val function = f(lv1, lv2, lv3) + LambdaFunction(function, Seq(lv1, lv2, lv3)) + } + + def validateBinding( + e: Expression, + argInfo: Seq[(DataType, Boolean)]): LambdaFunction = e match { + case f: LambdaFunction => + assert(f.arguments.size == argInfo.size) + f.arguments.zip(argInfo).foreach { + case (arg, (dataType, nullable)) => + assert(arg.dataType == dataType) + assert(arg.nullable == nullable) + } + f + } + + // Array-based helpers + def filter(expr: Expression, f: Expression => Expression): Expression = { + val ArrayType(et, cn) = expr.dataType + ArrayFilter(expr, createLambda(et, cn, f)).bind(validateBinding) + } + + def exists(expr: Expression, f: Expression => Expression): Expression = { + val ArrayType(et, cn) = expr.dataType + ArrayExists(expr, createLambda(et, cn, f)).bind(validateBinding) + } + + def transform(expr: Expression, f: Expression => Expression): Expression = { + val ArrayType(et, cn) = expr.dataType + ArrayTransform(expr, createLambda(et, cn, f)).bind(validateBinding) + } + + def transform(expr: Expression, f: (Expression, Expression) => Expression): Expression = { + val ArrayType(et, cn) = expr.dataType + ArrayTransform(expr, createLambda(et, cn, IntegerType, false, f)).bind(validateBinding) + } + + def aggregate( + expr: Expression, + zero: Expression, + merge: (Expression, Expression) => Expression, + finish: Expression => Expression): Expression = { + val ArrayType(et, cn) = expr.dataType + val zeroType = zero.dataType + ArrayAggregate( + expr, + zero, + createLambda(zeroType, true, et, cn, merge), + createLambda(zeroType, true, finish)) + .bind(validateBinding) + } + + def aggregate( + expr: Expression, + zero: Expression, + merge: (Expression, Expression) => Expression): Expression = { + aggregate(expr, zero, merge, identity) + } + + def zip_with( + left: Expression, + right: Expression, + f: (Expression, Expression) => Expression): Expression = { + val ArrayType(leftT, _) = left.dataType + val ArrayType(rightT, _) = right.dataType + ZipWith(left, right, createLambda(leftT, true, rightT, true, f)).bind(validateBinding) + } + + // Map-based helpers + + def transformKeys(expr: Expression, f: (Expression, Expression) => Expression): Expression = { + val MapType(kt, vt, vcn) = expr.dataType + TransformKeys(expr, createLambda(kt, false, vt, vcn, f)).bind(validateBinding) + } + + def transformValues(expr: Expression, f: (Expression, Expression) => Expression): Expression = { + val MapType(kt, vt, vcn) = expr.dataType + TransformValues(expr, createLambda(kt, false, vt, vcn, f)).bind(validateBinding) + } + + def mapFilter(expr: Expression, f: (Expression, Expression) => Expression): Expression = { + val MapType(kt, vt, vcn) = expr.dataType + MapFilter(expr, createLambda(kt, false, vt, vcn, f)).bind(validateBinding) + } + + def map_zip_with( + left: Expression, + right: Expression, + f: (Expression, Expression, Expression) => Expression): Expression = { + val MapType(kt, vt1, _) = left.dataType + val MapType(_, vt2, _) = right.dataType + MapZipWith(left, right, createLambda(kt, false, vt1, true, vt2, true, f)) + .bind(validateBinding) + } +} + /** * A placeholder of lambda variables to prevent unexpected resolution of [[LambdaFunction]]. */ diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/HigherOrderFunctionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/HigherOrderFunctionsSuite.scala index 03fb75e330c66..1362572b0d2a7 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/HigherOrderFunctionsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/HigherOrderFunctionsSuite.scala @@ -24,102 +24,7 @@ import org.apache.spark.sql.types._ class HigherOrderFunctionsSuite extends SparkFunSuite with ExpressionEvalHelper { import org.apache.spark.sql.catalyst.dsl.expressions._ - - private def createLambda( - dt: DataType, - nullable: Boolean, - f: Expression => Expression): Expression = { - val lv = NamedLambdaVariable("arg", dt, nullable) - val function = f(lv) - LambdaFunction(function, Seq(lv)) - } - - private def createLambda( - dt1: DataType, - nullable1: Boolean, - dt2: DataType, - nullable2: Boolean, - f: (Expression, Expression) => Expression): Expression = { - val lv1 = NamedLambdaVariable("arg1", dt1, nullable1) - val lv2 = NamedLambdaVariable("arg2", dt2, nullable2) - val function = f(lv1, lv2) - LambdaFunction(function, Seq(lv1, lv2)) - } - - private def createLambda( - dt1: DataType, - nullable1: Boolean, - dt2: DataType, - nullable2: Boolean, - dt3: DataType, - nullable3: Boolean, - f: (Expression, Expression, Expression) => Expression): Expression = { - val lv1 = NamedLambdaVariable("arg1", dt1, nullable1) - val lv2 = NamedLambdaVariable("arg2", dt2, nullable2) - val lv3 = NamedLambdaVariable("arg3", dt3, nullable3) - val function = f(lv1, lv2, lv3) - LambdaFunction(function, Seq(lv1, lv2, lv3)) - } - - private def validateBinding( - e: Expression, - argInfo: Seq[(DataType, Boolean)]): LambdaFunction = e match { - case f: LambdaFunction => - assert(f.arguments.size === argInfo.size) - f.arguments.zip(argInfo).foreach { - case (arg, (dataType, nullable)) => - assert(arg.dataType === dataType) - assert(arg.nullable === nullable) - } - f - } - - def transform(expr: Expression, f: Expression => Expression): Expression = { - val ArrayType(et, cn) = expr.dataType - ArrayTransform(expr, createLambda(et, cn, f)).bind(validateBinding) - } - - def transform(expr: Expression, f: (Expression, Expression) => Expression): Expression = { - val ArrayType(et, cn) = expr.dataType - ArrayTransform(expr, createLambda(et, cn, IntegerType, false, f)).bind(validateBinding) - } - - def filter(expr: Expression, f: Expression => Expression): Expression = { - val ArrayType(et, cn) = expr.dataType - ArrayFilter(expr, createLambda(et, cn, f)).bind(validateBinding) - } - - def transformKeys(expr: Expression, f: (Expression, Expression) => Expression): Expression = { - val MapType(kt, vt, vcn) = expr.dataType - TransformKeys(expr, createLambda(kt, false, vt, vcn, f)).bind(validateBinding) - } - - def aggregate( - expr: Expression, - zero: Expression, - merge: (Expression, Expression) => Expression, - finish: Expression => Expression): Expression = { - val ArrayType(et, cn) = expr.dataType - val zeroType = zero.dataType - ArrayAggregate( - expr, - zero, - createLambda(zeroType, true, et, cn, merge), - createLambda(zeroType, true, finish)) - .bind(validateBinding) - } - - def aggregate( - expr: Expression, - zero: Expression, - merge: (Expression, Expression) => Expression): Expression = { - aggregate(expr, zero, merge, identity) - } - - def transformValues(expr: Expression, f: (Expression, Expression) => Expression): Expression = { - val MapType(kt, vt, vcn) = expr.dataType - TransformValues(expr, createLambda(kt, false, vt, vcn, f)).bind(validateBinding) - } + import org.apache.spark.sql.catalyst.expressions.HigherOrderUtils._ test("ArrayTransform") { val ai0 = Literal.create(Seq(1, 2, 3), ArrayType(IntegerType, containsNull = false)) @@ -163,10 +68,6 @@ class HigherOrderFunctionsSuite extends SparkFunSuite with ExpressionEvalHelper } test("MapFilter") { - def mapFilter(expr: Expression, f: (Expression, Expression) => Expression): Expression = { - val MapType(kt, vt, vcn) = expr.dataType - MapFilter(expr, createLambda(kt, false, vt, vcn, f)).bind(validateBinding) - } val mii0 = Literal.create(Map(1 -> 0, 2 -> 10, 3 -> -1), MapType(IntegerType, IntegerType, valueContainsNull = false)) val mii1 = Literal.create(Map(1 -> null, 2 -> 10, 3 -> null), @@ -244,11 +145,6 @@ class HigherOrderFunctionsSuite extends SparkFunSuite with ExpressionEvalHelper } test("ArrayExists") { - def exists(expr: Expression, f: Expression => Expression): Expression = { - val ArrayType(et, cn) = expr.dataType - ArrayExists(expr, createLambda(et, cn, f)).bind(validateBinding) - } - val ai0 = Literal.create(Seq(1, 2, 3), ArrayType(IntegerType, containsNull = false)) val ai1 = Literal.create(Seq[Integer](1, null, 3), ArrayType(IntegerType, containsNull = true)) val ain = Literal.create(null, ArrayType(IntegerType, containsNull = false)) @@ -457,16 +353,6 @@ class HigherOrderFunctionsSuite extends SparkFunSuite with ExpressionEvalHelper } test("MapZipWith") { - def map_zip_with( - left: Expression, - right: Expression, - f: (Expression, Expression, Expression) => Expression): Expression = { - val MapType(kt, vt1, _) = left.dataType - val MapType(_, vt2, _) = right.dataType - MapZipWith(left, right, createLambda(kt, false, vt1, true, vt2, true, f)) - .bind(validateBinding) - } - val mii0 = Literal.create(create_map(1 -> 10, 2 -> 20, 3 -> 30), MapType(IntegerType, IntegerType, valueContainsNull = false)) val mii1 = Literal.create(create_map(1 -> -1, 2 -> -2, 4 -> -4), @@ -549,15 +435,6 @@ class HigherOrderFunctionsSuite extends SparkFunSuite with ExpressionEvalHelper } test("ZipWith") { - def zip_with( - left: Expression, - right: Expression, - f: (Expression, Expression) => Expression): Expression = { - val ArrayType(leftT, _) = left.dataType - val ArrayType(rightT, _) = right.dataType - ZipWith(left, right, createLambda(leftT, true, rightT, true, f)).bind(validateBinding) - } - val ai0 = Literal.create(Seq(1, 2, 3), ArrayType(IntegerType, containsNull = false)) val ai1 = Literal.create(Seq(1, 2, 3, 4), ArrayType(IntegerType, containsNull = false)) val ai2 = Literal.create(Seq[Integer](1, null, 3), ArrayType(IntegerType, containsNull = true)) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala index f99186cabc26d..e4ff7e460b38d 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala @@ -3316,6 +3316,141 @@ object functions { ArrayExcept(col1.expr, col2.expr) } + private def expressionFunction(f: Column => Column) + : Expression => Expression = + x => f(Column(x)).expr + + private def expressionFunction(f: (Column, Column) => Column) + : (Expression, Expression) => Expression = + (x, y) => f(Column(x), Column(y)).expr + + private def expressionFunction(f: (Column, Column, Column) => Column) + : (Expression, Expression, Expression) => Expression = + (x, y, z) => f(Column(x), Column(y), Column(z)).expr + + /** + * Returns an array of elements after applying a tranformation to each element + * in the input array. + * + * @group collection_funcs + */ + def transform(column: Column, f: Column => Column): Column = withExpr { + HigherOrderUtils.transform(column.expr, expressionFunction(f)) + } + + /** + * Returns an array of elements after applying a tranformation to each element + * in the input array. + * + * @group collection_funcs + */ + def transform(column: Column, f: (Column, Column) => Column): Column = withExpr { + HigherOrderUtils.transform(column.expr, expressionFunction(f)) + } + + /** + * Returns whether a predicate holds for one or more elements in the array. + * + * @group collection_funcs + */ + def exists(column: Column, f: Column => Column): Column = withExpr { + HigherOrderUtils.exists(column.expr, expressionFunction(f)) + } + + /** + * Returns an array of elements for which a predicate holds in a given array. + * + * @group collection_funcs + */ + def filter(column: Column, f: Column => Column): Column = withExpr { + HigherOrderUtils.filter(column.expr, expressionFunction(f)) + } + + /** + * Applies a binary operator to an initial state and all elements in the array, + * and reduces this to a single state. The final state is converted into the final result + * by applying a finish function. + * + * @group collection_funcs + */ + def aggregate( + expr: Column, + zero: Column, + merge: (Column, Column) => Column, + finish: Column => Column): Column = withExpr { + HigherOrderUtils.aggregate( + expr.expr, + zero.expr, + expressionFunction(merge), + expressionFunction(finish) + ) + } + + /** + * Applies a binary operator to an initial state and all elements in the array, + * and reduces this to a single state. + * + * @group collection_funcs + */ + def aggregate( + expr: Column, + zero: Column, + merge: (Column, Column) => Column): Column = + aggregate(expr, zero, merge, identity) + + /** + * Merge two given arrays, element-wise, into a signle array using a function. + * If one array is shorter, nulls are appended at the end to match the length of the longer + * array, before applying the function. + * + * @group collection_funcs + */ + def zip_with( + left: Column, + right: Column, + f: (Column, Column) => Column): Column = withExpr { + HigherOrderUtils.zip_with(left.expr, right.expr, expressionFunction(f)) + } + + /** + * Applies a function to every key-value pair in a map and returns + * a map with the results of those applications as the new keys for the pairs. + * + * @group collection_funcs + */ + def transform_keys(expr: Column, f: (Column, Column) => Column): Column = withExpr { + HigherOrderUtils.transformKeys(expr.expr, expressionFunction(f)) + } + + /** + * Applies a function to every key-value pair in a map and returns + * a map with the results of those applications as the new values for the pairs. + * + * @group collection_funcs + */ + def transform_values(expr: Column, f: (Column, Column) => Column): Column = withExpr { + HigherOrderUtils.transformValues(expr.expr, expressionFunction(f)) + } + + /** + * Returns a map whose key-value pairs satisfy a predicate. + * + * @group collection_funcs + */ + def map_filter(expr: Column, f: (Column, Column) => Column): Column = withExpr { + HigherOrderUtils.mapFilter(expr.expr, expressionFunction(f)) + } + + /** + * Merge two given maps, key-wise into a single map using a function. + * + * @group collection_funcs + */ + def map_zip_with(left: Column, right: Column, f: (Column, Column, Column) => Column): Column = + withExpr { + HigherOrderUtils.map_zip_with(left.expr, right.expr, expressionFunction(f)) + } + /** * Creates a new row for each element in the given array or map column. * diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala index e5c2de950c2c0..6a9ab0413c455 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala @@ -1930,6 +1930,18 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSQLContext { Row(Seq(5, 9, 11, 10, 6)), Row(Seq.empty), Row(null))) + checkAnswer(df.select(transform(df("i"), x => x + 1)), + Seq( + Row(Seq(2, 10, 9, 8)), + Row(Seq(6, 9, 10, 8, 3)), + Row(Seq.empty), + Row(null))) + checkAnswer(df.select(transform(df("i"), (x, i) => x + i)), + Seq( + Row(Seq(1, 10, 10, 10)), + Row(Seq(5, 9, 11, 10, 6)), + Row(Seq.empty), + Row(null))) } // Test with local relation, the Project will be evaluated without codegen @@ -1960,6 +1972,18 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSQLContext { Row(Seq(5, null, 10, 12, 11, 7)), Row(Seq.empty), Row(null))) + checkAnswer(df.select(transform(df("i"), x => x + 1)), + Seq( + Row(Seq(2, 10, 9, null, 8)), + Row(Seq(6, null, 9, 10, 8, 3)), + Row(Seq.empty), + Row(null))) + checkAnswer(df.select(transform(df("i"), (x, i) => x + i)), + Seq( + Row(Seq(1, 10, 10, null, 11)), + Row(Seq(5, null, 10, 12, 11, 7)), + Row(Seq.empty), + Row(null))) } // Test with local relation, the Project will be evaluated without codegen @@ -1990,6 +2014,18 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSQLContext { Row(Seq("b0", null, "c2", null)), Row(Seq.empty), Row(null))) + checkAnswer(df.select(transform(df("s"), x => concat(x, x))), + Seq( + Row(Seq("cc", "aa", "bb")), + Row(Seq("bb", null, "cc", null)), + Row(Seq.empty), + Row(null))) + checkAnswer(df.select(transform(df("s"), (x, i) => concat(x, i))), + Seq( + Row(Seq("c0", "a1", "b2")), + Row(Seq("b0", null, "c2", null)), + Row(Seq.empty), + Row(null))) } // Test with local relation, the Project will be evaluated without codegen @@ -2034,6 +2070,32 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSQLContext { Seq("b", null, "c", null, null))), Row(Seq.empty), Row(null))) + checkAnswer(df.select(transform(df("arg"), arg => arg)), + Seq( + Row(Seq("c", "a", "b")), + Row(Seq("b", null, "c", null)), + Row(Seq.empty), + Row(null))) + checkAnswer(df.select(transform(df("arg"), _ => df("arg"))), + Seq( + Row(Seq(Seq("c", "a", "b"), Seq("c", "a", "b"), Seq("c", "a", "b"))), + Row(Seq( + Seq("b", null, "c", null), + Seq("b", null, "c", null), + Seq("b", null, "c", null), + Seq("b", null, "c", null))), + Row(Seq.empty), + Row(null))) + checkAnswer(df.select(transform(df("arg"), x => concat(df("arg"), array(x)))), + Seq( + Row(Seq(Seq("c", "a", "b", "c"), Seq("c", "a", "b", "a"), Seq("c", "a", "b", "b"))), + Row(Seq( + Seq("b", null, "c", null, "b"), + Seq("b", null, "c", null, null), + Seq("b", null, "c", null, "c"), + Seq("b", null, "c", null, null))), + Row(Seq.empty), + Row(null))) } // Test with local relation, the Project will be evaluated without codegen @@ -2080,6 +2142,14 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSQLContext { Row(Map(), Map(1 -> -1, 2 -> -2, 3 -> -3)), Row(Map(1 -> 10), Map(3 -> -3)))) + checkAnswer(dfInts.select( + map_filter(dfInts("m"), (k, v) => k * 10 === v), + map_filter(dfInts("m"), (k, v) => k === (v * -1))), + Seq( + Row(Map(1 -> 10, 2 -> 20, 3 -> 30), Map()), + Row(Map(), Map(1 -> -1, 2 -> -2, 3 -> -3)), + Row(Map(1 -> 10), Map(3 -> -3)))) + val dfComplex = Seq( Map(1 -> Seq(Some(1)), 2 -> Seq(Some(1), Some(2)), 3 -> Seq(Some(1), Some(2), Some(3))), Map(1 -> null, 2 -> Seq(Some(-2), Some(-2)), 3 -> Seq[Option[Int]](None))).toDF("m") @@ -2090,6 +2160,13 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSQLContext { Row(Map(1 -> Seq(1)), Map(1 -> Seq(1), 2 -> Seq(1, 2), 3 -> Seq(1, 2, 3))), Row(Map(), Map(2 -> Seq(-2, -2))))) + checkAnswer(dfComplex.select( + map_filter(dfComplex("m"), (k, v) => k === element_at(v, 1)), + map_filter(dfComplex("m"), (k, v) => k === size(v))), + Seq( + Row(Map(1 -> Seq(1)), Map(1 -> Seq(1), 2 -> Seq(1, 2), 3 -> Seq(1, 2, 3))), + Row(Map(), Map(2 -> Seq(-2, -2))))) + // Invalid use cases val df = Seq( (Map(1 -> "a"), 1), @@ -2112,6 +2189,11 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSQLContext { } assert(ex3.getMessage.contains("data type mismatch: argument 1 requires map type")) + val ex3a = intercept[MatchError] { + df.select(map_filter(df("i"), (k, v) => k > v)) + } + assert(ex3a.getMessage.contains("IntegerType")) + val ex4 = intercept[AnalysisException] { df.selectExpr("map_filter(a, (k, v) -> k > v)") } @@ -2133,6 +2215,12 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSQLContext { Row(Seq(8, 2)), Row(Seq.empty), Row(null))) + checkAnswer(df.select(filter(df("i"), _ % 2 === 0)), + Seq( + Row(Seq(8)), + Row(Seq(8, 2)), + Row(Seq.empty), + Row(null))) } // Test with local relation, the Project will be evaluated without codegen @@ -2157,6 +2245,12 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSQLContext { Row(Seq(8, 2)), Row(Seq.empty), Row(null))) + checkAnswer(df.select(filter(df("i"), _ % 2 === 0)), + Seq( + Row(Seq(8)), + Row(Seq(8, 2)), + Row(Seq.empty), + Row(null))) } // Test with local relation, the Project will be evaluated without codegen @@ -2181,6 +2275,12 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSQLContext { Row(Seq("b", "c")), Row(Seq.empty), Row(null))) + checkAnswer(df.select(filter(df("s"), x => x.isNotNull)), + Seq( + Row(Seq("c", "a", "b")), + Row(Seq("b", "c")), + Row(Seq.empty), + Row(null))) } // Test with local relation, the Project will be evaluated without codegen @@ -2208,11 +2308,21 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSQLContext { } assert(ex2.getMessage.contains("data type mismatch: argument 1 requires array type")) + val ex2a = intercept[MatchError] { + df.select(filter(df("i"), x => x)) + } + assert(ex2a.getMessage.contains("IntegerType")) + val ex3 = intercept[AnalysisException] { df.selectExpr("filter(s, x -> x)") } assert(ex3.getMessage.contains("data type mismatch: argument 2 requires boolean type")) + val ex3a = intercept[AnalysisException] { + df.select(filter(df("s"), x => x)) + } + assert(ex3a.getMessage.contains("data type mismatch: argument 2 requires boolean type")) + val ex4 = intercept[AnalysisException] { df.selectExpr("filter(a, x -> x)") } @@ -2234,6 +2344,12 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSQLContext { Row(false), Row(false), Row(null))) + checkAnswer(df.select(exists(df("i"), _ % 2 === 0)), + Seq( + Row(true), + Row(false), + Row(false), + Row(null))) } // Test with local relation, the Project will be evaluated without codegen @@ -2258,6 +2374,12 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSQLContext { Row(false), Row(false), Row(null))) + checkAnswer(df.select(exists(df("i"), _ % 2 === 0)), + Seq( + Row(true), + Row(false), + Row(false), + Row(null))) } // Test with local relation, the Project will be evaluated without codegen @@ -2282,6 +2404,12 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSQLContext { Row(true), Row(false), Row(null))) + checkAnswer(df.select(exists(df("s"), x => x.isNull)), + Seq( + Row(false), + Row(true), + Row(false), + Row(null))) } // Test with local relation, the Project will be evaluated without codegen @@ -2309,11 +2437,21 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSQLContext { } assert(ex2.getMessage.contains("data type mismatch: argument 1 requires array type")) + val ex2a = intercept[MatchError] { + df.select(exists(df("i"), x => x)) + } + assert(ex2a.getMessage.contains("IntegerType")) + val ex3 = intercept[AnalysisException] { df.selectExpr("exists(s, x -> x)") } assert(ex3.getMessage.contains("data type mismatch: argument 2 requires boolean type")) + val ex3a = intercept[AnalysisException] { + df.select(exists(df("s"), x => x)) + } + assert(ex3a.getMessage.contains("data type mismatch: argument 2 requires boolean type")) + val ex4 = intercept[AnalysisException] { df.selectExpr("exists(a, x -> x)") } @@ -2341,6 +2479,18 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSQLContext { Row(310), Row(0), Row(null))) + checkAnswer(df.select(aggregate(df("i"), lit(0), (acc, x) => acc + x)), + Seq( + Row(25), + Row(31), + Row(0), + Row(null))) + checkAnswer(df.select(aggregate(df("i"), lit(0), (acc, x) => acc + x, _ * 10)), + Seq( + Row(250), + Row(310), + Row(0), + Row(null))) } // Test with local relation, the Project will be evaluated without codegen @@ -2372,6 +2522,20 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSQLContext { Row(0), Row(0), Row(null))) + checkAnswer(df.select(aggregate(df("i"), lit(0), (acc, x) => acc + x)), + Seq( + Row(25), + Row(null), + Row(0), + Row(null))) + checkAnswer( + df.select( + aggregate(df("i"), lit(0), (acc, x) => acc + x, acc => coalesce(acc, lit(0)) * 10)), + Seq( + Row(250), + Row(0), + Row(0), + Row(null))) } // Test with local relation, the Project will be evaluated without codegen @@ -2403,6 +2567,20 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSQLContext { Row(""), Row("c"), Row(null))) + checkAnswer(df.select(aggregate(df("ss"), df("s"), (acc, x) => concat(acc, x))), + Seq( + Row("acab"), + Row(null), + Row("c"), + Row(null))) + checkAnswer( + df.select( + aggregate(df("ss"), df("s"), (acc, x) => concat(acc, x), acc => coalesce(acc, lit("")))), + Seq( + Row("acab"), + Row(""), + Row("c"), + Row(null))) } // Test with local relation, the Project will be evaluated without codegen @@ -2435,11 +2613,21 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSQLContext { } assert(ex3.getMessage.contains("data type mismatch: argument 1 requires array type")) + val ex3a = intercept[MatchError] { + df.select(aggregate(df("i"), lit(0), (acc, x) => x)) + } + assert(ex3a.getMessage.contains("IntegerType")) + val ex4 = intercept[AnalysisException] { df.selectExpr("aggregate(s, 0, (acc, x) -> x)") } assert(ex4.getMessage.contains("data type mismatch: argument 3 requires int type")) + val ex4a = intercept[AnalysisException] { + df.select(aggregate(df("s"), lit(0), (acc, x) => x)) + } + assert(ex4a.getMessage.contains("data type mismatch: argument 3 requires int type")) + val ex5 = intercept[AnalysisException] { df.selectExpr("aggregate(a, 0, (acc, x) -> x)") } @@ -2460,6 +2648,13 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSQLContext { Row(Map(10 -> null, 8 -> false, 4 -> null)), Row(Map(5 -> null)), Row(null))) + + checkAnswer(df.select(map_zip_with(df("m1"), df("m2"), (k, v1, v2) => k === v1 + v2)), + Seq( + Row(Map(8 -> true, 3 -> false, 6 -> true)), + Row(Map(10 -> null, 8 -> false, 4 -> null)), + Row(Map(5 -> null)), + Row(null))) } test("map_zip_with function - map of non-primitive types") { @@ -2476,6 +2671,13 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSQLContext { Row(Map("b" -> Row("a", null), "c" -> Row("d", "a"), "d" -> Row(null, "k"))), Row(Map("a" -> Row("d", null))), Row(null))) + + checkAnswer(df.select(map_zip_with(df("m1"), df("m2"), (k, v1, v2) => struct(v1, v2))), + Seq( + Row(Map("z" -> Row("a", "c"), "y" -> Row("b", null), "x" -> Row("c", "a"))), + Row(Map("b" -> Row("a", null), "c" -> Row("d", "a"), "d" -> Row(null, "k"))), + Row(Map("a" -> Row("d", null))), + Row(null))) } test("map_zip_with function - invalid") { @@ -2494,16 +2696,31 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSQLContext { assert(ex2.getMessage.contains("The input to function map_zip_with should have " + "been two maps with compatible key types")) + val ex2a = intercept[NoSuchElementException] { + df.select(map_zip_with(df("mis"), df("mmi"), (x, y, z) => concat(x, y, z))) + } + assert(ex2a.getMessage.contains("None.get")) + val ex3 = intercept[AnalysisException] { df.selectExpr("map_zip_with(i, mis, (x, y, z) -> concat(x, y, z))") } assert(ex3.getMessage.contains("type mismatch: argument 1 requires map type")) + val ex3a = intercept[MatchError] { + df.select(map_zip_with(df("i"), df("mis"), (x, y, z) => concat(x, y, z))) + } + assert(ex3a.getMessage.contains("IntegerType")) + val ex4 = intercept[AnalysisException] { df.selectExpr("map_zip_with(mis, i, (x, y, z) -> concat(x, y, z))") } assert(ex4.getMessage.contains("type mismatch: argument 2 requires map type")) + val ex4a = intercept[MatchError] { + df.select(map_zip_with(df("mis"), df("i"), (x, y, z) => concat(x, y, z))) + } + assert(ex4a.getMessage.contains("IntegerType")) + val ex5 = intercept[AnalysisException] { df.selectExpr("map_zip_with(mmi, mmi, (x, y, z) -> x)") } @@ -2532,27 +2749,59 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSQLContext { checkAnswer(dfExample1.selectExpr("transform_keys(i, (k, v) -> k + v)"), Seq(Row(Map(2 -> 1, 18 -> 9, 16 -> 8, 14 -> 7)))) + checkAnswer(dfExample1.select(transform_keys(dfExample1("i"), (k, v) => k + v)), + Seq(Row(Map(2 -> 1, 18 -> 9, 16 -> 8, 14 -> 7)))) + checkAnswer(dfExample2.selectExpr("transform_keys(j, " + "(k, v) -> map_from_arrays(ARRAY(1, 2, 3), ARRAY('one', 'two', 'three'))[k])"), Seq(Row(Map("one" -> 1.0, "two" -> 1.4, "three" -> 1.7)))) + checkAnswer(dfExample2.select( + transform_keys( + dfExample2("j"), + (k, v) => element_at( + map_from_arrays( + array(lit(1), lit(2), lit(3)), + array(lit("one"), lit("two"), lit("three")) + ), + k + ) + ) + ), + Seq(Row(Map("one" -> 1.0, "two" -> 1.4, "three" -> 1.7)))) + checkAnswer(dfExample2.selectExpr("transform_keys(j, (k, v) -> CAST(v * 2 AS BIGINT) + k)"), Seq(Row(Map(3 -> 1.0, 4 -> 1.4, 6 -> 1.7)))) + checkAnswer(dfExample2.select(transform_keys(dfExample2("j"), + (k, v) => (v * 2).cast("bigint") + k)), + Seq(Row(Map(3 -> 1.0, 4 -> 1.4, 6 -> 1.7)))) + checkAnswer(dfExample2.selectExpr("transform_keys(j, (k, v) -> k + v)"), Seq(Row(Map(2.0 -> 1.0, 3.4 -> 1.4, 4.7 -> 1.7)))) + checkAnswer(dfExample2.select(transform_keys(dfExample2("j"), (k, v) => k + v)), + Seq(Row(Map(2.0 -> 1.0, 3.4 -> 1.4, 4.7 -> 1.7)))) + checkAnswer(dfExample3.selectExpr("transform_keys(x, (k, v) -> k % 2 = 0 OR v)"), Seq(Row(Map(true -> true, true -> false)))) + checkAnswer(dfExample3.select(transform_keys(dfExample3("x"), (k, v) => k % 2 === 0 || v)), + Seq(Row(Map(true -> true, true -> false)))) + checkAnswer(dfExample3.selectExpr("transform_keys(x, (k, v) -> if(v, 2 * k, 3 * k))"), Seq(Row(Map(50 -> true, 78 -> false)))) - checkAnswer(dfExample3.selectExpr("transform_keys(x, (k, v) -> if(v, 2 * k, 3 * k))"), + checkAnswer(dfExample3.select(transform_keys(dfExample3("x"), + (k, v) => when(v, k * 2).otherwise(k * 3))), Seq(Row(Map(50 -> true, 78 -> false)))) checkAnswer(dfExample4.selectExpr("transform_keys(y, (k, v) -> array_contains(k, 3) AND v)"), Seq(Row(Map(false -> false)))) + + checkAnswer(dfExample4.select(transform_keys(dfExample4("y"), + (k, v) => array_contains(k, lit(3)) && v)), + Seq(Row(Map(false -> false)))) } // Test with local relation, the Project will be evaluated without codegen @@ -2590,6 +2839,11 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSQLContext { } assert(ex3.getMessage.contains("Cannot use null as map key")) + val ex3a = intercept[Exception] { + dfExample1.select(transform_keys(dfExample1("i"), (k, v) => v)).show() + } + assert(ex3a.getMessage.contains("Cannot use null as map key")) + val ex4 = intercept[AnalysisException] { dfExample2.selectExpr("transform_keys(j, (k, v) -> k + 1)") } @@ -2654,6 +2908,46 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSQLContext { checkAnswer( dfExample5.selectExpr("transform_values(c, (k, v) -> k + cardinality(v))"), Seq(Row(Map(1 -> 3)))) + + checkAnswer(dfExample1.select(transform_values(dfExample1("i"), (k, v) => k + v)), + Seq(Row(Map(1 -> 2, 9 -> 18, 8 -> 16, 7 -> 14)))) + + checkAnswer(dfExample2.select( + transform_values(dfExample2("x"), (k, v) => when(k, v).otherwise(k.cast("string")))), + Seq(Row(Map(false -> "false", true -> "def")))) + + checkAnswer(dfExample2.select(transform_values(dfExample2("x"), + (k, v) => (!k) && v === "abc")), + Seq(Row(Map(false -> true, true -> false)))) + + checkAnswer(dfExample3.select(transform_values(dfExample3("y"), (k, v) => v * v)), + Seq(Row(Map("a" -> 1, "b" -> 4, "c" -> 9)))) + + checkAnswer(dfExample3.select( + transform_values(dfExample3("y"), (k, v) => concat(k, lit(":"), v.cast("string")))), + Seq(Row(Map("a" -> "a:1", "b" -> "b:2", "c" -> "c:3")))) + + checkAnswer( + dfExample3.select(transform_values(dfExample3("y"), (k, v) => concat(k, v.cast("string")))), + Seq(Row(Map("a" -> "a1", "b" -> "b2", "c" -> "c3")))) + + val testMap = map_from_arrays( + array(lit(1), lit(2), lit(3)), + array(lit("one"), lit("two"), lit("three")) + ) + + checkAnswer( + dfExample4.select(transform_values(dfExample4("z"), + (k, v) => concat(element_at(testMap, k), lit("_"), v.cast("string")))), + Seq(Row(Map(1 -> "one_1.0", 2 -> "two_1.4", 3 ->"three_1.7")))) + + checkAnswer( + dfExample4.select(transform_values(dfExample4("z"), (k, v) => k - v)), + Seq(Row(Map(1 -> 0.0, 2 -> 0.6000000000000001, 3 -> 1.3)))) + + checkAnswer( + dfExample5.select(transform_values(dfExample5("c"), (k, v) => k + size(v))), + Seq(Row(Map(1 -> 3)))) } // Test with local relation, the Project will be evaluated without codegen @@ -2697,6 +2991,28 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSQLContext { checkAnswer(dfExample2.selectExpr("transform_values(j, (k, v) -> k + cast(v as BIGINT))"), Seq(Row(Map.empty[BigInt, BigInt]))) + + checkAnswer(dfExample1.select(transform_values(dfExample1("i"), + (k, v) => lit(null).cast("int"))), + Seq(Row(Map.empty[Integer, Integer]))) + + checkAnswer(dfExample1.select(transform_values(dfExample1("i"), (k, v) => k)), + Seq(Row(Map.empty[Integer, Integer]))) + + checkAnswer(dfExample1.select(transform_values(dfExample1("i"), (k, v) => v)), + Seq(Row(Map.empty[Integer, Integer]))) + + checkAnswer(dfExample1.select(transform_values(dfExample1("i"), (k, v) => lit(0))), + Seq(Row(Map.empty[Integer, Integer]))) + + checkAnswer(dfExample1.select(transform_values(dfExample1("i"), (k, v) => lit("value"))), + Seq(Row(Map.empty[Integer, String]))) + + checkAnswer(dfExample1.select(transform_values(dfExample1("i"), (k, v) => lit(true))), + Seq(Row(Map.empty[Integer, Boolean]))) + + checkAnswer(dfExample1.select(transform_values(dfExample1("i"), (k, v) => v.cast("bigint"))), + Seq(Row(Map.empty[BigInt, BigInt]))) } testEmpty() @@ -2721,6 +3037,15 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSQLContext { checkAnswer(dfExample2.selectExpr( "transform_values(b, (k, v) -> IF(v IS NULL, k + 1, k + 2))"), Seq(Row(Map(1 -> 3, 2 -> 4, 3 -> 4)))) + + checkAnswer(dfExample1.select(transform_values(dfExample1("a"), + (k, v) => lit(null).cast("int"))), + Seq(Row(Map[Int, Integer](1 -> null, 2 -> null, 3 -> null, 4 -> null)))) + + checkAnswer(dfExample2.select( + transform_values(dfExample2("b"), (k, v) => when(v.isNull, k + 1).otherwise(k + 2)) + ), + Seq(Row(Map(1 -> 3, 2 -> 4, 3 -> 4)))) } testNullValue() @@ -2759,6 +3084,11 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSQLContext { } assert(ex3.getMessage.contains( "data type mismatch: argument 1 requires map type")) + + val ex3a = intercept[MatchError] { + dfExample3.select(transform_values(dfExample3("x"), (k, v) => k + 1)) + } + assert(ex3a.getMessage.contains("IntegerType")) } testInvalidLambdaFunctions() @@ -2785,10 +3115,15 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSQLContext { Row(Seq.empty), Row(null)) checkAnswer(df1.selectExpr("zip_with(val1, val2, (x, y) -> x + y)"), expectedValue1) + checkAnswer(df1.select(zip_with(df1("val1"), df1("val2"), (x, y) => x + y)), expectedValue1) val expectedValue2 = Seq( Row(Seq(Row(1L, 1), Row(2L, null), Row(null, 3))), Row(Seq(Row(4L, 1), Row(11L, 2), Row(null, 3)))) checkAnswer(df2.selectExpr("zip_with(val1, val2, (x, y) -> (y, x))"), expectedValue2) + checkAnswer( + df2.select(zip_with(df2("val1"), df2("val2"), (x, y) => struct(y, x))), + expectedValue2 + ) } test("arrays zip_with function - for non-primitive types") { @@ -2803,7 +3138,14 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSQLContext { Row(Seq(Row("x", "a"), Row("y", null))), Row(Seq.empty), Row(null)) - checkAnswer(df.selectExpr("zip_with(val1, val2, (x, y) -> (y, x))"), expectedValue1) + checkAnswer( + df.selectExpr("zip_with(val1, val2, (x, y) -> (y, x))"), + expectedValue1 + ) + checkAnswer( + df.select(zip_with(df("val1"), df("val2"), (x, y) => struct(y, x))), + expectedValue1 + ) } test("arrays zip_with function - invalid") { @@ -2825,6 +3167,10 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSQLContext { df.selectExpr("zip_with(i, a2, (acc, x) -> x)") } assert(ex3.getMessage.contains("data type mismatch: argument 1 requires array type")) + val ex3a = intercept[MatchError] { + df.select(zip_with(df("i"), df("a2"), (acc, x) => x)) + } + assert(ex3a.getMessage.contains("IntegerType")) val ex4 = intercept[AnalysisException] { df.selectExpr("zip_with(a1, a, (acc, x) -> x)") } From efc6ba42ecac7afa8986449cbaf884bd73107240 Mon Sep 17 00:00:00 2001 From: Nik Vanderhoof Date: Thu, 28 Mar 2019 08:59:01 -0400 Subject: [PATCH 02/38] Add (Scala-specifc) note to higher order functions These signatures won't work in java as they rely on Scala lambdas --- .../org/apache/spark/sql/functions.scala | 22 +++++++++---------- 1 file changed, 11 insertions(+), 11 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala index e4ff7e460b38d..fd1fe1abe418a 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala @@ -3329,7 +3329,7 @@ object functions { (x, y, z) => f(Column(x), Column(y), Column(z)).expr /** - * Returns an array of elements after applying a tranformation to each element + * (Scala-specific) Returns an array of elements after applying a tranformation to each element * in the input array. * * @group collection_funcs @@ -3339,7 +3339,7 @@ object functions { } /** - * Returns an array of elements after applying a tranformation to each element + * (Scala-specific) Returns an array of elements after applying a tranformation to each element * in the input array. * * @group collection_funcs @@ -3349,7 +3349,7 @@ object functions { } /** - * Returns whether a predicate holds for one or more elements in the array. + * (Scala-specific) Returns whether a predicate holds for one or more elements in the array. * * @group collection_funcs */ @@ -3358,7 +3358,7 @@ object functions { } /** - * Returns an array of elements for which a predicate holds in a given array. + * (Scala-specific) Returns an array of elements for which a predicate holds in a given array. * * @group collection_funcs */ @@ -3367,7 +3367,7 @@ object functions { } /** - * Applies a binary operator to an initial state and all elements in the array, + * (Scala-specific) Applies a binary operator to an initial state and all elements in the array, * and reduces this to a single state. The final state is converted into the final result * by applying a finish function. * @@ -3387,7 +3387,7 @@ object functions { } /** - * Applies a binary operator to an initial state and all elements in the array, + * (Scala-specific) Applies a binary operator to an initial state and all elements in the array, * and reduces this to a single state. * * @group collection_funcs @@ -3399,7 +3399,7 @@ object functions { aggregate(expr, zero, merge, identity) /** - * Merge two given arrays, element-wise, into a signle array using a function. + * (Scala-specific) Merge two given arrays, element-wise, into a signle array using a function. * If one array is shorter, nulls are appended at the end to match the length of the longer * array, before applying the function. * @@ -3413,7 +3413,7 @@ object functions { } /** - * Applies a function to every key-value pair in a map and returns + * (Scala-specific) Applies a function to every key-value pair in a map and returns * a map with the results of those applications as the new keys for the pairs. * * @group collection_funcs @@ -3423,7 +3423,7 @@ object functions { } /** - * Applies a function to every key-value pair in a map and returns + * (Scala-specific) Applies a function to every key-value pair in a map and returns * a map with the results of those applications as the new values for the pairs. * * @group collection_funcs @@ -3433,7 +3433,7 @@ object functions { } /** - * Returns a map whose key-value pairs satisfy a predicate. + * (Scala-specific) Returns a map whose key-value pairs satisfy a predicate. * * @group collection_funcs */ @@ -3442,7 +3442,7 @@ object functions { } /** - * Merge two given maps, key-wise into a single map using a function. + * (Scala-specific) Merge two given maps, key-wise into a single map using a function. * * @group collection_funcs */ From b9dceec0587baa0d7939dd98972a5c54bf7c7ee4 Mon Sep 17 00:00:00 2001 From: Nik Vanderhoof Date: Thu, 28 Mar 2019 09:07:06 -0400 Subject: [PATCH 03/38] Follow style guide more closely --- .../org/apache/spark/sql/functions.scala | 44 +++++++------------ 1 file changed, 16 insertions(+), 28 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala index fd1fe1abe418a..84380b9ca745a 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala @@ -3316,16 +3316,14 @@ object functions { ArrayExcept(col1.expr, col2.expr) } - private def expressionFunction(f: Column => Column) - : Expression => Expression = + private def expressionFunction(f: Column => Column): Expression => Expression = x => f(Column(x)).expr - private def expressionFunction(f: (Column, Column) => Column) - : (Expression, Expression) => Expression = + private def expressionFunction(f: (Column, Column) => Column): (Expression, Expression) => Expression = (x, y) => f(Column(x), Column(y)).expr private def expressionFunction(f: (Column, Column, Column) => Column) - : (Expression, Expression, Expression) => Expression = + : (Expression, Expression, Expression) => Expression = (x, y, z) => f(Column(x), Column(y), Column(z)).expr /** @@ -3373,18 +3371,15 @@ object functions { * * @group collection_funcs */ - def aggregate( - expr: Column, - zero: Column, - merge: (Column, Column) => Column, - finish: Column => Column): Column = withExpr { - HigherOrderUtils.aggregate( - expr.expr, - zero.expr, - expressionFunction(merge), - expressionFunction(finish) - ) - } + def aggregate(expr: Column, zero: Column, merge: (Column, Column) => Column, finish: Column => Column): Column = + withExpr { + HigherOrderUtils.aggregate( + expr.expr, + zero.expr, + expressionFunction(merge), + expressionFunction(finish) + ) + } /** * (Scala-specific) Applies a binary operator to an initial state and all elements in the array, @@ -3392,11 +3387,8 @@ object functions { * * @group collection_funcs */ - def aggregate( - expr: Column, - zero: Column, - merge: (Column, Column) => Column): Column = - aggregate(expr, zero, merge, identity) + def aggregate(expr: Column, zero: Column, merge: (Column, Column) => Column): Column = + aggregate(expr, zero, merge, identity) /** * (Scala-specific) Merge two given arrays, element-wise, into a signle array using a function. @@ -3405,10 +3397,7 @@ object functions { * * @group collection_funcs */ - def zip_with( - left: Column, - right: Column, - f: (Column, Column) => Column): Column = withExpr { + def zip_with(left: Column, right: Column, f: (Column, Column) => Column): Column = withExpr { HigherOrderUtils.zip_with(left.expr, right.expr, expressionFunction(f)) } @@ -3446,8 +3435,7 @@ object functions { * * @group collection_funcs */ - def map_zip_with(left: Column, right: Column, f: (Column, Column, Column) => Column): Column = - withExpr { + def map_zip_with(left: Column, right: Column, f: (Column, Column, Column) => Column): Column = withExpr { HigherOrderUtils.map_zip_with(left.expr, right.expr, expressionFunction(f)) } From 1fb46a3e7f78f761dbc7c699242599275d903de2 Mon Sep 17 00:00:00 2001 From: Nik Vanderhoof Date: Thu, 28 Mar 2019 09:24:43 -0400 Subject: [PATCH 04/38] Fix scalastyle issues --- .../org/apache/spark/sql/functions.scala | 28 ++++++++++--------- 1 file changed, 15 insertions(+), 13 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala index 84380b9ca745a..19299cdd85203 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala @@ -3319,7 +3319,8 @@ object functions { private def expressionFunction(f: Column => Column): Expression => Expression = x => f(Column(x)).expr - private def expressionFunction(f: (Column, Column) => Column): (Expression, Expression) => Expression = + private def expressionFunction(f: (Column, Column) => Column) + : (Expression, Expression) => Expression = (x, y) => f(Column(x), Column(y)).expr private def expressionFunction(f: (Column, Column, Column) => Column) @@ -3371,15 +3372,15 @@ object functions { * * @group collection_funcs */ - def aggregate(expr: Column, zero: Column, merge: (Column, Column) => Column, finish: Column => Column): Column = - withExpr { - HigherOrderUtils.aggregate( - expr.expr, - zero.expr, - expressionFunction(merge), - expressionFunction(finish) - ) - } + def aggregate(expr: Column, zero: Column, merge: (Column, Column) => Column, + finish: Column => Column): Column = withExpr { + HigherOrderUtils.aggregate( + expr.expr, + zero.expr, + expressionFunction(merge), + expressionFunction(finish) + ) + } /** * (Scala-specific) Applies a binary operator to an initial state and all elements in the array, @@ -3435,9 +3436,10 @@ object functions { * * @group collection_funcs */ - def map_zip_with(left: Column, right: Column, f: (Column, Column, Column) => Column): Column = withExpr { - HigherOrderUtils.map_zip_with(left.expr, right.expr, expressionFunction(f)) - } + def map_zip_with(left: Column, right: Column, + f: (Column, Column, Column) => Column): Column = withExpr { + HigherOrderUtils.map_zip_with(left.expr, right.expr, expressionFunction(f)) + } /** * Creates a new row for each element in the given array or map column. From 03d602ff09d2abd29762737fa9c76c832ac3a3ca Mon Sep 17 00:00:00 2001 From: Nik Vanderhoof Date: Thu, 28 Mar 2019 19:01:20 -0400 Subject: [PATCH 05/38] Add java-specific version of higher order function api --- .../org/apache/spark/sql/functions.scala | 129 +++++++++++++++++- 1 file changed, 128 insertions(+), 1 deletion(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala index 19299cdd85203..d9b5b9e143c69 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala @@ -24,6 +24,7 @@ import scala.util.Try import scala.util.control.NonFatal import org.apache.spark.annotation.Stable +import org.apache.spark.api.java.function._ import org.apache.spark.sql.api.java._ import org.apache.spark.sql.catalyst.ScalaReflection import org.apache.spark.sql.catalyst.analysis.{Star, UnresolvedFunction} @@ -3327,6 +3328,17 @@ object functions { : (Expression, Expression, Expression) => Expression = (x, y, z) => f(Column(x), Column(y), Column(z)).expr + private def expressionFunction(f: Function[Column, Column]): Expression => Expression = + x => f.call(Column(x)).expr + + private def expressionFunction(f: Function2[Column, Column, Column]) + : (Expression, Expression) => Expression = + (x, y) => f.call(Column(x), Column(y)).expr + + private def expressionFunction(f: Function3[Column, Column, Column, Column]) + : (Expression, Expression, Expression) => Expression = + (x, y, z) => f.call(Column(x), Column(y), Column(z)).expr + /** * (Scala-specific) Returns an array of elements after applying a tranformation to each element * in the input array. @@ -3389,7 +3401,7 @@ object functions { * @group collection_funcs */ def aggregate(expr: Column, zero: Column, merge: (Column, Column) => Column): Column = - aggregate(expr, zero, merge, identity) + aggregate(expr, zero, merge, c => c) /** * (Scala-specific) Merge two given arrays, element-wise, into a signle array using a function. @@ -3441,6 +3453,121 @@ object functions { HigherOrderUtils.map_zip_with(left.expr, right.expr, expressionFunction(f)) } + /** + * (Java-specific) Returns an array of elements after applying a tranformation to each element + * in the input array. + * + * @group collection_funcs + */ + def transform(column: Column, f: Function[Column, Column]): Column = withExpr { + HigherOrderUtils.transform(column.expr, expressionFunction(f)) + } + + /** + * (Java-specific) Returns an array of elements after applying a tranformation to each element + * in the input array. + * + * @group collection_funcs + */ + def transform(column: Column, f: Function2[Column, Column, Column]): Column = withExpr { + HigherOrderUtils.transform(column.expr, expressionFunction(f)) + } + + /** + * (Java-specific) Returns whether a predicate holds for one or more elements in the array. + * + * @group collection_funcs + */ + def exists(column: Column, f: Function[Column, Column]): Column = withExpr { + HigherOrderUtils.exists(column.expr, expressionFunction(f)) + } + + /** + * (Java-specific) Returns an array of elements for which a predicate holds in a given array. + * + * @group collection_funcs + */ + def filter(column: Column, f: Function[Column, Column]): Column = withExpr { + HigherOrderUtils.filter(column.expr, expressionFunction(f)) + } + + /** + * (Java-specific) Applies a binary operator to an initial state and all elements in the array, + * and reduces this to a single state. The final state is converted into the final result + * by applying a finish function. + * + * @group collection_funcs + */ + def aggregate(expr: Column, zero: Column, merge: Function2[Column, Column, Column], + finish: Function[Column, Column]): Column = withExpr { + HigherOrderUtils.aggregate( + expr.expr, + zero.expr, + expressionFunction(merge), + expressionFunction(finish) + ) + } + + /** + * (Java-specific) Applies a binary operator to an initial state and all elements in the array, + * and reduces this to a single state. + * + * @group collection_funcs + */ + def aggregate(expr: Column, zero: Column, merge: Function2[Column, Column, Column]): Column = + aggregate(expr, zero, merge, new Function[Column, Column] { def call(c: Column): Column = c }) + + /** + * (Java-specific) Merge two given arrays, element-wise, into a signle array using a function. + * If one array is shorter, nulls are appended at the end to match the length of the longer + * array, before applying the function. + * + * @group collection_funcs + */ + def zip_with(left: Column, right: Column, f: Function2[Column, Column, Column]): Column = + withExpr { + HigherOrderUtils.zip_with(left.expr, right.expr, expressionFunction(f)) + } + + /** + * (Java-specific) Applies a function to every key-value pair in a map and returns + * a map with the results of those applications as the new keys for the pairs. + * + * @group collection_funcs + */ + def transform_keys(expr: Column, f: Function2[Column, Column, Column]): Column = withExpr { + HigherOrderUtils.transformKeys(expr.expr, expressionFunction(f)) + } + + /** + * (Java-specific) Applies a function to every key-value pair in a map and returns + * a map with the results of those applications as the new values for the pairs. + * + * @group collection_funcs + */ + def transform_values(expr: Column, f: Function2[Column, Column, Column]): Column = withExpr { + HigherOrderUtils.transformValues(expr.expr, expressionFunction(f)) + } + + /** + * (Java-specific) Returns a map whose key-value pairs satisfy a predicate. + * + * @group collection_funcs + */ + def map_filter(expr: Column, f: Function2[Column, Column, Column]): Column = withExpr { + HigherOrderUtils.mapFilter(expr.expr, expressionFunction(f)) + } + + /** + * (Java-specific) Merge two given maps, key-wise into a single map using a function. + * + * @group collection_funcs + */ + def map_zip_with(left: Column, right: Column, + f: Function3[Column, Column, Column, Column]): Column = withExpr { + HigherOrderUtils.map_zip_with(left.expr, right.expr, expressionFunction(f)) + } + /** * Creates a new row for each element in the given array or map column. * From 6bf07d89ebbebf0073fe3feaa3c0b3d7aaaa45d9 Mon Sep 17 00:00:00 2001 From: Nik Vanderhoof Date: Fri, 14 Jun 2019 15:34:08 -0400 Subject: [PATCH 06/38] Do not prematurely bind lambda variables --- .../expressions/higherOrderFunctions.scala | 133 -------------- .../HigherOrderFunctionsSuite.scala | 125 ++++++++++++- .../org/apache/spark/sql/functions.scala | 100 ++++++----- .../spark/sql/DataFrameFunctionsSuite.scala | 165 +++++++++--------- 4 files changed, 268 insertions(+), 255 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/higherOrderFunctions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/higherOrderFunctions.scala index 39e5b63a1c218..e6cc11d1ad280 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/higherOrderFunctions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/higherOrderFunctions.scala @@ -28,139 +28,6 @@ import org.apache.spark.sql.catalyst.util._ import org.apache.spark.sql.types._ import org.apache.spark.unsafe.array.ByteArrayMethods -/** - * Helper methods for constructing higher order functions. - */ -object HigherOrderUtils { - def createLambda( - dt: DataType, - nullable: Boolean, - f: Expression => Expression): Expression = { - val lv = NamedLambdaVariable("arg", dt, nullable) - val function = f(lv) - LambdaFunction(function, Seq(lv)) - } - - def createLambda( - dt1: DataType, - nullable1: Boolean, - dt2: DataType, - nullable2: Boolean, - f: (Expression, Expression) => Expression): Expression = { - val lv1 = NamedLambdaVariable("arg1", dt1, nullable1) - val lv2 = NamedLambdaVariable("arg2", dt2, nullable2) - val function = f(lv1, lv2) - LambdaFunction(function, Seq(lv1, lv2)) - } - - def createLambda( - dt1: DataType, - nullable1: Boolean, - dt2: DataType, - nullable2: Boolean, - dt3: DataType, - nullable3: Boolean, - f: (Expression, Expression, Expression) => Expression): Expression = { - val lv1 = NamedLambdaVariable("arg1", dt1, nullable1) - val lv2 = NamedLambdaVariable("arg2", dt2, nullable2) - val lv3 = NamedLambdaVariable("arg3", dt3, nullable3) - val function = f(lv1, lv2, lv3) - LambdaFunction(function, Seq(lv1, lv2, lv3)) - } - - def validateBinding( - e: Expression, - argInfo: Seq[(DataType, Boolean)]): LambdaFunction = e match { - case f: LambdaFunction => - assert(f.arguments.size == argInfo.size) - f.arguments.zip(argInfo).foreach { - case (arg, (dataType, nullable)) => - assert(arg.dataType == dataType) - assert(arg.nullable == nullable) - } - f - } - - // Array-based helpers - def filter(expr: Expression, f: Expression => Expression): Expression = { - val ArrayType(et, cn) = expr.dataType - ArrayFilter(expr, createLambda(et, cn, f)).bind(validateBinding) - } - - def exists(expr: Expression, f: Expression => Expression): Expression = { - val ArrayType(et, cn) = expr.dataType - ArrayExists(expr, createLambda(et, cn, f)).bind(validateBinding) - } - - def transform(expr: Expression, f: Expression => Expression): Expression = { - val ArrayType(et, cn) = expr.dataType - ArrayTransform(expr, createLambda(et, cn, f)).bind(validateBinding) - } - - def transform(expr: Expression, f: (Expression, Expression) => Expression): Expression = { - val ArrayType(et, cn) = expr.dataType - ArrayTransform(expr, createLambda(et, cn, IntegerType, false, f)).bind(validateBinding) - } - - def aggregate( - expr: Expression, - zero: Expression, - merge: (Expression, Expression) => Expression, - finish: Expression => Expression): Expression = { - val ArrayType(et, cn) = expr.dataType - val zeroType = zero.dataType - ArrayAggregate( - expr, - zero, - createLambda(zeroType, true, et, cn, merge), - createLambda(zeroType, true, finish)) - .bind(validateBinding) - } - - def aggregate( - expr: Expression, - zero: Expression, - merge: (Expression, Expression) => Expression): Expression = { - aggregate(expr, zero, merge, identity) - } - - def zip_with( - left: Expression, - right: Expression, - f: (Expression, Expression) => Expression): Expression = { - val ArrayType(leftT, _) = left.dataType - val ArrayType(rightT, _) = right.dataType - ZipWith(left, right, createLambda(leftT, true, rightT, true, f)).bind(validateBinding) - } - - // Map-based helpers - - def transformKeys(expr: Expression, f: (Expression, Expression) => Expression): Expression = { - val MapType(kt, vt, vcn) = expr.dataType - TransformKeys(expr, createLambda(kt, false, vt, vcn, f)).bind(validateBinding) - } - - def transformValues(expr: Expression, f: (Expression, Expression) => Expression): Expression = { - val MapType(kt, vt, vcn) = expr.dataType - TransformValues(expr, createLambda(kt, false, vt, vcn, f)).bind(validateBinding) - } - - def mapFilter(expr: Expression, f: (Expression, Expression) => Expression): Expression = { - val MapType(kt, vt, vcn) = expr.dataType - MapFilter(expr, createLambda(kt, false, vt, vcn, f)).bind(validateBinding) - } - - def map_zip_with( - left: Expression, - right: Expression, - f: (Expression, Expression, Expression) => Expression): Expression = { - val MapType(kt, vt1, _) = left.dataType - val MapType(_, vt2, _) = right.dataType - MapZipWith(left, right, createLambda(kt, false, vt1, true, vt2, true, f)) - .bind(validateBinding) - } -} - /** * A placeholder of lambda variables to prevent unexpected resolution of [[LambdaFunction]]. */ diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/HigherOrderFunctionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/HigherOrderFunctionsSuite.scala index 1362572b0d2a7..03fb75e330c66 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/HigherOrderFunctionsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/HigherOrderFunctionsSuite.scala @@ -24,7 +24,102 @@ import org.apache.spark.sql.types._ class HigherOrderFunctionsSuite extends SparkFunSuite with ExpressionEvalHelper { import org.apache.spark.sql.catalyst.dsl.expressions._ - import org.apache.spark.sql.catalyst.expressions.HigherOrderUtils._ + + private def createLambda( + dt: DataType, + nullable: Boolean, + f: Expression => Expression): Expression = { + val lv = NamedLambdaVariable("arg", dt, nullable) + val function = f(lv) + LambdaFunction(function, Seq(lv)) + } + + private def createLambda( + dt1: DataType, + nullable1: Boolean, + dt2: DataType, + nullable2: Boolean, + f: (Expression, Expression) => Expression): Expression = { + val lv1 = NamedLambdaVariable("arg1", dt1, nullable1) + val lv2 = NamedLambdaVariable("arg2", dt2, nullable2) + val function = f(lv1, lv2) + LambdaFunction(function, Seq(lv1, lv2)) + } + + private def createLambda( + dt1: DataType, + nullable1: Boolean, + dt2: DataType, + nullable2: Boolean, + dt3: DataType, + nullable3: Boolean, + f: (Expression, Expression, Expression) => Expression): Expression = { + val lv1 = NamedLambdaVariable("arg1", dt1, nullable1) + val lv2 = NamedLambdaVariable("arg2", dt2, nullable2) + val lv3 = NamedLambdaVariable("arg3", dt3, nullable3) + val function = f(lv1, lv2, lv3) + LambdaFunction(function, Seq(lv1, lv2, lv3)) + } + + private def validateBinding( + e: Expression, + argInfo: Seq[(DataType, Boolean)]): LambdaFunction = e match { + case f: LambdaFunction => + assert(f.arguments.size === argInfo.size) + f.arguments.zip(argInfo).foreach { + case (arg, (dataType, nullable)) => + assert(arg.dataType === dataType) + assert(arg.nullable === nullable) + } + f + } + + def transform(expr: Expression, f: Expression => Expression): Expression = { + val ArrayType(et, cn) = expr.dataType + ArrayTransform(expr, createLambda(et, cn, f)).bind(validateBinding) + } + + def transform(expr: Expression, f: (Expression, Expression) => Expression): Expression = { + val ArrayType(et, cn) = expr.dataType + ArrayTransform(expr, createLambda(et, cn, IntegerType, false, f)).bind(validateBinding) + } + + def filter(expr: Expression, f: Expression => Expression): Expression = { + val ArrayType(et, cn) = expr.dataType + ArrayFilter(expr, createLambda(et, cn, f)).bind(validateBinding) + } + + def transformKeys(expr: Expression, f: (Expression, Expression) => Expression): Expression = { + val MapType(kt, vt, vcn) = expr.dataType + TransformKeys(expr, createLambda(kt, false, vt, vcn, f)).bind(validateBinding) + } + + def aggregate( + expr: Expression, + zero: Expression, + merge: (Expression, Expression) => Expression, + finish: Expression => Expression): Expression = { + val ArrayType(et, cn) = expr.dataType + val zeroType = zero.dataType + ArrayAggregate( + expr, + zero, + createLambda(zeroType, true, et, cn, merge), + createLambda(zeroType, true, finish)) + .bind(validateBinding) + } + + def aggregate( + expr: Expression, + zero: Expression, + merge: (Expression, Expression) => Expression): Expression = { + aggregate(expr, zero, merge, identity) + } + + def transformValues(expr: Expression, f: (Expression, Expression) => Expression): Expression = { + val MapType(kt, vt, vcn) = expr.dataType + TransformValues(expr, createLambda(kt, false, vt, vcn, f)).bind(validateBinding) + } test("ArrayTransform") { val ai0 = Literal.create(Seq(1, 2, 3), ArrayType(IntegerType, containsNull = false)) @@ -68,6 +163,10 @@ class HigherOrderFunctionsSuite extends SparkFunSuite with ExpressionEvalHelper } test("MapFilter") { + def mapFilter(expr: Expression, f: (Expression, Expression) => Expression): Expression = { + val MapType(kt, vt, vcn) = expr.dataType + MapFilter(expr, createLambda(kt, false, vt, vcn, f)).bind(validateBinding) + } val mii0 = Literal.create(Map(1 -> 0, 2 -> 10, 3 -> -1), MapType(IntegerType, IntegerType, valueContainsNull = false)) val mii1 = Literal.create(Map(1 -> null, 2 -> 10, 3 -> null), @@ -145,6 +244,11 @@ class HigherOrderFunctionsSuite extends SparkFunSuite with ExpressionEvalHelper } test("ArrayExists") { + def exists(expr: Expression, f: Expression => Expression): Expression = { + val ArrayType(et, cn) = expr.dataType + ArrayExists(expr, createLambda(et, cn, f)).bind(validateBinding) + } + val ai0 = Literal.create(Seq(1, 2, 3), ArrayType(IntegerType, containsNull = false)) val ai1 = Literal.create(Seq[Integer](1, null, 3), ArrayType(IntegerType, containsNull = true)) val ain = Literal.create(null, ArrayType(IntegerType, containsNull = false)) @@ -353,6 +457,16 @@ class HigherOrderFunctionsSuite extends SparkFunSuite with ExpressionEvalHelper } test("MapZipWith") { + def map_zip_with( + left: Expression, + right: Expression, + f: (Expression, Expression, Expression) => Expression): Expression = { + val MapType(kt, vt1, _) = left.dataType + val MapType(_, vt2, _) = right.dataType + MapZipWith(left, right, createLambda(kt, false, vt1, true, vt2, true, f)) + .bind(validateBinding) + } + val mii0 = Literal.create(create_map(1 -> 10, 2 -> 20, 3 -> 30), MapType(IntegerType, IntegerType, valueContainsNull = false)) val mii1 = Literal.create(create_map(1 -> -1, 2 -> -2, 4 -> -4), @@ -435,6 +549,15 @@ class HigherOrderFunctionsSuite extends SparkFunSuite with ExpressionEvalHelper } test("ZipWith") { + def zip_with( + left: Expression, + right: Expression, + f: (Expression, Expression) => Expression): Expression = { + val ArrayType(leftT, _) = left.dataType + val ArrayType(rightT, _) = right.dataType + ZipWith(left, right, createLambda(leftT, true, rightT, true, f)).bind(validateBinding) + } + val ai0 = Literal.create(Seq(1, 2, 3), ArrayType(IntegerType, containsNull = false)) val ai1 = Literal.create(Seq(1, 2, 3, 4), ArrayType(IntegerType, containsNull = false)) val ai2 = Literal.create(Seq[Integer](1, null, 3), ArrayType(IntegerType, containsNull = true)) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala index d9b5b9e143c69..ae0d8e75b1569 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala @@ -3317,27 +3317,47 @@ object functions { ArrayExcept(col1.expr, col2.expr) } - private def expressionFunction(f: Column => Column): Expression => Expression = - x => f(Column(x)).expr + private def createLambda(f: Column => Column) = { + val x = UnresolvedNamedLambdaVariable(Seq("x")) + val function = f(Column(x)).expr + LambdaFunction(function, Seq(x)) + } - private def expressionFunction(f: (Column, Column) => Column) - : (Expression, Expression) => Expression = - (x, y) => f(Column(x), Column(y)).expr + private def createLambda(f: (Column, Column) => Column) = { + val x = UnresolvedNamedLambdaVariable(Seq("x")) + val y = UnresolvedNamedLambdaVariable(Seq("y")) + val function = f(Column(x), Column(y)).expr + LambdaFunction(function, Seq(x, y)) + } - private def expressionFunction(f: (Column, Column, Column) => Column) - : (Expression, Expression, Expression) => Expression = - (x, y, z) => f(Column(x), Column(y), Column(z)).expr + private def createLambda(f: (Column, Column, Column) => Column) = { + val x = UnresolvedNamedLambdaVariable(Seq("x")) + val y = UnresolvedNamedLambdaVariable(Seq("y")) + val z = UnresolvedNamedLambdaVariable(Seq("z")) + val function = f(Column(x), Column(y), Column(z)).expr + LambdaFunction(function, Seq(x, y, z)) + } - private def expressionFunction(f: Function[Column, Column]): Expression => Expression = - x => f.call(Column(x)).expr + private def createLambda(f: Function[Column, Column]) = { + val x = UnresolvedNamedLambdaVariable(Seq("x")) + val function = f.call(Column(x)).expr + LambdaFunction(function, Seq(x)) + } - private def expressionFunction(f: Function2[Column, Column, Column]) - : (Expression, Expression) => Expression = - (x, y) => f.call(Column(x), Column(y)).expr + private def createLambda(f: Function2[Column, Column, Column]) = { + val x = UnresolvedNamedLambdaVariable(Seq("x")) + val y = UnresolvedNamedLambdaVariable(Seq("y")) + val function = f.call(Column(x), Column(y)).expr + LambdaFunction(function, Seq(x, y)) + } - private def expressionFunction(f: Function3[Column, Column, Column, Column]) - : (Expression, Expression, Expression) => Expression = - (x, y, z) => f.call(Column(x), Column(y), Column(z)).expr + private def createLambda(f: Function3[Column, Column, Column, Column]) = { + val x = UnresolvedNamedLambdaVariable(Seq("x")) + val y = UnresolvedNamedLambdaVariable(Seq("y")) + val z = UnresolvedNamedLambdaVariable(Seq("z")) + val function = f.call(Column(x), Column(y), Column(z)).expr + LambdaFunction(function, Seq(x, y, z)) + } /** * (Scala-specific) Returns an array of elements after applying a tranformation to each element @@ -3346,7 +3366,7 @@ object functions { * @group collection_funcs */ def transform(column: Column, f: Column => Column): Column = withExpr { - HigherOrderUtils.transform(column.expr, expressionFunction(f)) + ArrayTransform(column.expr, createLambda(f)) } /** @@ -3356,7 +3376,7 @@ object functions { * @group collection_funcs */ def transform(column: Column, f: (Column, Column) => Column): Column = withExpr { - HigherOrderUtils.transform(column.expr, expressionFunction(f)) + ArrayTransform(column.expr, createLambda(f)) } /** @@ -3365,7 +3385,7 @@ object functions { * @group collection_funcs */ def exists(column: Column, f: Column => Column): Column = withExpr { - HigherOrderUtils.exists(column.expr, expressionFunction(f)) + ArrayExists(column.expr, createLambda(f)) } /** @@ -3374,7 +3394,7 @@ object functions { * @group collection_funcs */ def filter(column: Column, f: Column => Column): Column = withExpr { - HigherOrderUtils.filter(column.expr, expressionFunction(f)) + ArrayFilter(column.expr, createLambda(f)) } /** @@ -3386,11 +3406,11 @@ object functions { */ def aggregate(expr: Column, zero: Column, merge: (Column, Column) => Column, finish: Column => Column): Column = withExpr { - HigherOrderUtils.aggregate( + ArrayAggregate( expr.expr, zero.expr, - expressionFunction(merge), - expressionFunction(finish) + createLambda(merge), + createLambda(finish) ) } @@ -3411,7 +3431,7 @@ object functions { * @group collection_funcs */ def zip_with(left: Column, right: Column, f: (Column, Column) => Column): Column = withExpr { - HigherOrderUtils.zip_with(left.expr, right.expr, expressionFunction(f)) + ZipWith(left.expr, right.expr, createLambda(f)) } /** @@ -3421,7 +3441,7 @@ object functions { * @group collection_funcs */ def transform_keys(expr: Column, f: (Column, Column) => Column): Column = withExpr { - HigherOrderUtils.transformKeys(expr.expr, expressionFunction(f)) + TransformKeys(expr.expr, createLambda(f)) } /** @@ -3431,7 +3451,7 @@ object functions { * @group collection_funcs */ def transform_values(expr: Column, f: (Column, Column) => Column): Column = withExpr { - HigherOrderUtils.transformValues(expr.expr, expressionFunction(f)) + TransformValues(expr.expr, createLambda(f)) } /** @@ -3440,7 +3460,7 @@ object functions { * @group collection_funcs */ def map_filter(expr: Column, f: (Column, Column) => Column): Column = withExpr { - HigherOrderUtils.mapFilter(expr.expr, expressionFunction(f)) + MapFilter(expr.expr, createLambda(f)) } /** @@ -3450,7 +3470,7 @@ object functions { */ def map_zip_with(left: Column, right: Column, f: (Column, Column, Column) => Column): Column = withExpr { - HigherOrderUtils.map_zip_with(left.expr, right.expr, expressionFunction(f)) + MapZipWith(left.expr, right.expr, createLambda(f)) } /** @@ -3460,7 +3480,7 @@ object functions { * @group collection_funcs */ def transform(column: Column, f: Function[Column, Column]): Column = withExpr { - HigherOrderUtils.transform(column.expr, expressionFunction(f)) + ArrayTransform(column.expr, createLambda(f)) } /** @@ -3470,7 +3490,7 @@ object functions { * @group collection_funcs */ def transform(column: Column, f: Function2[Column, Column, Column]): Column = withExpr { - HigherOrderUtils.transform(column.expr, expressionFunction(f)) + ArrayTransform(column.expr, createLambda(f)) } /** @@ -3479,7 +3499,7 @@ object functions { * @group collection_funcs */ def exists(column: Column, f: Function[Column, Column]): Column = withExpr { - HigherOrderUtils.exists(column.expr, expressionFunction(f)) + ArrayExists(column.expr, createLambda(f)) } /** @@ -3488,7 +3508,7 @@ object functions { * @group collection_funcs */ def filter(column: Column, f: Function[Column, Column]): Column = withExpr { - HigherOrderUtils.filter(column.expr, expressionFunction(f)) + ArrayFilter(column.expr, createLambda(f)) } /** @@ -3500,11 +3520,11 @@ object functions { */ def aggregate(expr: Column, zero: Column, merge: Function2[Column, Column, Column], finish: Function[Column, Column]): Column = withExpr { - HigherOrderUtils.aggregate( + ArrayAggregate( expr.expr, zero.expr, - expressionFunction(merge), - expressionFunction(finish) + createLambda(merge), + createLambda(finish) ) } @@ -3526,7 +3546,7 @@ object functions { */ def zip_with(left: Column, right: Column, f: Function2[Column, Column, Column]): Column = withExpr { - HigherOrderUtils.zip_with(left.expr, right.expr, expressionFunction(f)) + ZipWith(left.expr, right.expr, createLambda(f)) } /** @@ -3536,7 +3556,7 @@ object functions { * @group collection_funcs */ def transform_keys(expr: Column, f: Function2[Column, Column, Column]): Column = withExpr { - HigherOrderUtils.transformKeys(expr.expr, expressionFunction(f)) + TransformKeys(expr.expr, createLambda(f)) } /** @@ -3546,7 +3566,7 @@ object functions { * @group collection_funcs */ def transform_values(expr: Column, f: Function2[Column, Column, Column]): Column = withExpr { - HigherOrderUtils.transformValues(expr.expr, expressionFunction(f)) + TransformValues(expr.expr, createLambda(f)) } /** @@ -3555,7 +3575,7 @@ object functions { * @group collection_funcs */ def map_filter(expr: Column, f: Function2[Column, Column, Column]): Column = withExpr { - HigherOrderUtils.mapFilter(expr.expr, expressionFunction(f)) + MapFilter(expr.expr, createLambda(f)) } /** @@ -3565,7 +3585,7 @@ object functions { */ def map_zip_with(left: Column, right: Column, f: Function3[Column, Column, Column, Column]): Column = withExpr { - HigherOrderUtils.map_zip_with(left.expr, right.expr, expressionFunction(f)) + MapZipWith(left.expr, right.expr, createLambda(f)) } /** diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala index 6a9ab0413c455..8efa69e4d4322 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala @@ -1930,13 +1930,13 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSQLContext { Row(Seq(5, 9, 11, 10, 6)), Row(Seq.empty), Row(null))) - checkAnswer(df.select(transform(df("i"), x => x + 1)), + checkAnswer(df.select(transform(col("i"), x => x + 1)), Seq( Row(Seq(2, 10, 9, 8)), Row(Seq(6, 9, 10, 8, 3)), Row(Seq.empty), Row(null))) - checkAnswer(df.select(transform(df("i"), (x, i) => x + i)), + checkAnswer(df.select(transform(col("i"), (x, i) => x + i)), Seq( Row(Seq(1, 10, 10, 10)), Row(Seq(5, 9, 11, 10, 6)), @@ -1972,13 +1972,13 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSQLContext { Row(Seq(5, null, 10, 12, 11, 7)), Row(Seq.empty), Row(null))) - checkAnswer(df.select(transform(df("i"), x => x + 1)), + checkAnswer(df.select(transform(col("i"), x => x + 1)), Seq( Row(Seq(2, 10, 9, null, 8)), Row(Seq(6, null, 9, 10, 8, 3)), Row(Seq.empty), Row(null))) - checkAnswer(df.select(transform(df("i"), (x, i) => x + i)), + checkAnswer(df.select(transform(col("i"), (x, i) => x + i)), Seq( Row(Seq(1, 10, 10, null, 11)), Row(Seq(5, null, 10, 12, 11, 7)), @@ -2014,13 +2014,13 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSQLContext { Row(Seq("b0", null, "c2", null)), Row(Seq.empty), Row(null))) - checkAnswer(df.select(transform(df("s"), x => concat(x, x))), + checkAnswer(df.select(transform(col("s"), x => concat(x, x))), Seq( Row(Seq("cc", "aa", "bb")), Row(Seq("bb", null, "cc", null)), Row(Seq.empty), Row(null))) - checkAnswer(df.select(transform(df("s"), (x, i) => concat(x, i))), + checkAnswer(df.select(transform(col("s"), (x, i) => concat(x, i))), Seq( Row(Seq("c0", "a1", "b2")), Row(Seq("b0", null, "c2", null)), @@ -2070,13 +2070,13 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSQLContext { Seq("b", null, "c", null, null))), Row(Seq.empty), Row(null))) - checkAnswer(df.select(transform(df("arg"), arg => arg)), + checkAnswer(df.select(transform(col("arg"), arg => arg)), Seq( Row(Seq("c", "a", "b")), Row(Seq("b", null, "c", null)), Row(Seq.empty), Row(null))) - checkAnswer(df.select(transform(df("arg"), _ => df("arg"))), + checkAnswer(df.select(transform(col("arg"), _ => col("arg"))), Seq( Row(Seq(Seq("c", "a", "b"), Seq("c", "a", "b"), Seq("c", "a", "b"))), Row(Seq( @@ -2086,7 +2086,7 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSQLContext { Seq("b", null, "c", null))), Row(Seq.empty), Row(null))) - checkAnswer(df.select(transform(df("arg"), x => concat(df("arg"), array(x)))), + checkAnswer(df.select(transform(col("arg"), x => concat(col("arg"), array(x)))), Seq( Row(Seq(Seq("c", "a", "b", "c"), Seq("c", "a", "b", "a"), Seq("c", "a", "b", "b"))), Row(Seq( @@ -2143,8 +2143,8 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSQLContext { Row(Map(1 -> 10), Map(3 -> -3)))) checkAnswer(dfInts.select( - map_filter(dfInts("m"), (k, v) => k * 10 === v), - map_filter(dfInts("m"), (k, v) => k === (v * -1))), + map_filter(col("m"), (k, v) => k * 10 === v), + map_filter(col("m"), (k, v) => k === (v * -1))), Seq( Row(Map(1 -> 10, 2 -> 20, 3 -> 30), Map()), Row(Map(), Map(1 -> -1, 2 -> -2, 3 -> -3)), @@ -2161,8 +2161,8 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSQLContext { Row(Map(), Map(2 -> Seq(-2, -2))))) checkAnswer(dfComplex.select( - map_filter(dfComplex("m"), (k, v) => k === element_at(v, 1)), - map_filter(dfComplex("m"), (k, v) => k === size(v))), + map_filter(col("m"), (k, v) => k === element_at(v, 1)), + map_filter(col("m"), (k, v) => k === size(v))), Seq( Row(Map(1 -> Seq(1)), Map(1 -> Seq(1), 2 -> Seq(1, 2), 3 -> Seq(1, 2, 3))), Row(Map(), Map(2 -> Seq(-2, -2))))) @@ -2189,10 +2189,10 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSQLContext { } assert(ex3.getMessage.contains("data type mismatch: argument 1 requires map type")) - val ex3a = intercept[MatchError] { - df.select(map_filter(df("i"), (k, v) => k > v)) + val ex3a = intercept[AnalysisException] { + df.select(map_filter(col("i"), (k, v) => k > v)) } - assert(ex3a.getMessage.contains("IntegerType")) + assert(ex3a.getMessage.contains("data type mismatch: argument 1 requires map type")) val ex4 = intercept[AnalysisException] { df.selectExpr("map_filter(a, (k, v) -> k > v)") @@ -2215,7 +2215,7 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSQLContext { Row(Seq(8, 2)), Row(Seq.empty), Row(null))) - checkAnswer(df.select(filter(df("i"), _ % 2 === 0)), + checkAnswer(df.select(filter(col("i"), _ % 2 === 0)), Seq( Row(Seq(8)), Row(Seq(8, 2)), @@ -2245,7 +2245,7 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSQLContext { Row(Seq(8, 2)), Row(Seq.empty), Row(null))) - checkAnswer(df.select(filter(df("i"), _ % 2 === 0)), + checkAnswer(df.select(filter(col("i"), _ % 2 === 0)), Seq( Row(Seq(8)), Row(Seq(8, 2)), @@ -2275,7 +2275,7 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSQLContext { Row(Seq("b", "c")), Row(Seq.empty), Row(null))) - checkAnswer(df.select(filter(df("s"), x => x.isNotNull)), + checkAnswer(df.select(filter(col("s"), x => x.isNotNull)), Seq( Row(Seq("c", "a", "b")), Row(Seq("b", "c")), @@ -2308,10 +2308,10 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSQLContext { } assert(ex2.getMessage.contains("data type mismatch: argument 1 requires array type")) - val ex2a = intercept[MatchError] { - df.select(filter(df("i"), x => x)) + val ex2a = intercept[AnalysisException] { + df.select(filter(col("i"), x => x)) } - assert(ex2a.getMessage.contains("IntegerType")) + assert(ex2a.getMessage.contains("data type mismatch: argument 1 requires array type")) val ex3 = intercept[AnalysisException] { df.selectExpr("filter(s, x -> x)") @@ -2319,7 +2319,7 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSQLContext { assert(ex3.getMessage.contains("data type mismatch: argument 2 requires boolean type")) val ex3a = intercept[AnalysisException] { - df.select(filter(df("s"), x => x)) + df.select(filter(col("s"), x => x)) } assert(ex3a.getMessage.contains("data type mismatch: argument 2 requires boolean type")) @@ -2344,7 +2344,7 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSQLContext { Row(false), Row(false), Row(null))) - checkAnswer(df.select(exists(df("i"), _ % 2 === 0)), + checkAnswer(df.select(exists(col("i"), _ % 2 === 0)), Seq( Row(true), Row(false), @@ -2374,7 +2374,7 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSQLContext { Row(false), Row(false), Row(null))) - checkAnswer(df.select(exists(df("i"), _ % 2 === 0)), + checkAnswer(df.select(exists(col("i"), _ % 2 === 0)), Seq( Row(true), Row(false), @@ -2404,7 +2404,7 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSQLContext { Row(true), Row(false), Row(null))) - checkAnswer(df.select(exists(df("s"), x => x.isNull)), + checkAnswer(df.select(exists(col("s"), x => x.isNull)), Seq( Row(false), Row(true), @@ -2437,10 +2437,10 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSQLContext { } assert(ex2.getMessage.contains("data type mismatch: argument 1 requires array type")) - val ex2a = intercept[MatchError] { - df.select(exists(df("i"), x => x)) + val ex2a = intercept[AnalysisException] { + df.select(exists(col("i"), x => x)) } - assert(ex2a.getMessage.contains("IntegerType")) + assert(ex2.getMessage.contains("data type mismatch: argument 1 requires array type")) val ex3 = intercept[AnalysisException] { df.selectExpr("exists(s, x -> x)") @@ -2479,13 +2479,13 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSQLContext { Row(310), Row(0), Row(null))) - checkAnswer(df.select(aggregate(df("i"), lit(0), (acc, x) => acc + x)), + checkAnswer(df.select(aggregate(col("i"), lit(0), (acc, x) => acc + x)), Seq( Row(25), Row(31), Row(0), Row(null))) - checkAnswer(df.select(aggregate(df("i"), lit(0), (acc, x) => acc + x, _ * 10)), + checkAnswer(df.select(aggregate(col("i"), lit(0), (acc, x) => acc + x, _ * 10)), Seq( Row(250), Row(310), @@ -2522,7 +2522,7 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSQLContext { Row(0), Row(0), Row(null))) - checkAnswer(df.select(aggregate(df("i"), lit(0), (acc, x) => acc + x)), + checkAnswer(df.select(aggregate(col("i"), lit(0), (acc, x) => acc + x)), Seq( Row(25), Row(null), @@ -2530,7 +2530,7 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSQLContext { Row(null))) checkAnswer( df.select( - aggregate(df("i"), lit(0), (acc, x) => acc + x, acc => coalesce(acc, lit(0)) * 10)), + aggregate(col("i"), lit(0), (acc, x) => acc + x, acc => coalesce(acc, lit(0)) * 10)), Seq( Row(250), Row(0), @@ -2567,7 +2567,7 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSQLContext { Row(""), Row("c"), Row(null))) - checkAnswer(df.select(aggregate(df("ss"), df("s"), (acc, x) => concat(acc, x))), + checkAnswer(df.select(aggregate(col("ss"), col("s"), (acc, x) => concat(acc, x))), Seq( Row("acab"), Row(null), @@ -2575,7 +2575,8 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSQLContext { Row(null))) checkAnswer( df.select( - aggregate(df("ss"), df("s"), (acc, x) => concat(acc, x), acc => coalesce(acc, lit("")))), + aggregate(col("ss"), col("s"), (acc, x) => concat(acc, x), + acc => coalesce(acc, lit("")))), Seq( Row("acab"), Row(""), @@ -2613,10 +2614,10 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSQLContext { } assert(ex3.getMessage.contains("data type mismatch: argument 1 requires array type")) - val ex3a = intercept[MatchError] { - df.select(aggregate(df("i"), lit(0), (acc, x) => x)) + val ex3a = intercept[AnalysisException] { + df.select(aggregate(col("i"), lit(0), (acc, x) => x)) } - assert(ex3a.getMessage.contains("IntegerType")) + assert(ex3a.getMessage.contains("data type mismatch: argument 1 requires array type")) val ex4 = intercept[AnalysisException] { df.selectExpr("aggregate(s, 0, (acc, x) -> x)") @@ -2624,7 +2625,7 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSQLContext { assert(ex4.getMessage.contains("data type mismatch: argument 3 requires int type")) val ex4a = intercept[AnalysisException] { - df.select(aggregate(df("s"), lit(0), (acc, x) => x)) + df.select(aggregate(col("s"), lit(0), (acc, x) => x)) } assert(ex4a.getMessage.contains("data type mismatch: argument 3 requires int type")) @@ -2672,7 +2673,7 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSQLContext { Row(Map("a" -> Row("d", null))), Row(null))) - checkAnswer(df.select(map_zip_with(df("m1"), df("m2"), (k, v1, v2) => struct(v1, v2))), + checkAnswer(df.select(map_zip_with(col("m1"), col("m2"), (k, v1, v2) => struct(v1, v2))), Seq( Row(Map("z" -> Row("a", "c"), "y" -> Row("b", null), "x" -> Row("c", "a"))), Row(Map("b" -> Row("a", null), "c" -> Row("d", "a"), "d" -> Row(null, "k"))), @@ -2696,30 +2697,31 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSQLContext { assert(ex2.getMessage.contains("The input to function map_zip_with should have " + "been two maps with compatible key types")) - val ex2a = intercept[NoSuchElementException] { - df.select(map_zip_with(df("mis"), df("mmi"), (x, y, z) => concat(x, y, z))) + val ex2a = intercept[AnalysisException] { + df.select(map_zip_with(df("mis"), col("mmi"), (x, y, z) => concat(x, y, z))) } - assert(ex2a.getMessage.contains("None.get")) + assert(ex2a.getMessage.contains("The input to function map_zip_with should have " + + "been two maps with compatible key types")) val ex3 = intercept[AnalysisException] { df.selectExpr("map_zip_with(i, mis, (x, y, z) -> concat(x, y, z))") } assert(ex3.getMessage.contains("type mismatch: argument 1 requires map type")) - val ex3a = intercept[MatchError] { - df.select(map_zip_with(df("i"), df("mis"), (x, y, z) => concat(x, y, z))) + val ex3a = intercept[AnalysisException] { + df.select(map_zip_with(col("i"), col("mis"), (x, y, z) => concat(x, y, z))) } - assert(ex3a.getMessage.contains("IntegerType")) + assert(ex3a.getMessage.contains("type mismatch: argument 1 requires map type")) val ex4 = intercept[AnalysisException] { df.selectExpr("map_zip_with(mis, i, (x, y, z) -> concat(x, y, z))") } assert(ex4.getMessage.contains("type mismatch: argument 2 requires map type")) - val ex4a = intercept[MatchError] { - df.select(map_zip_with(df("mis"), df("i"), (x, y, z) => concat(x, y, z))) + val ex4a = intercept[AnalysisException] { + df.select(map_zip_with(col("mis"), col("i"), (x, y, z) => concat(x, y, z))) } - assert(ex4a.getMessage.contains("IntegerType")) + assert(ex4a.getMessage.contains("type mismatch: argument 2 requires map type")) val ex5 = intercept[AnalysisException] { df.selectExpr("map_zip_with(mmi, mmi, (x, y, z) -> x)") @@ -2749,7 +2751,7 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSQLContext { checkAnswer(dfExample1.selectExpr("transform_keys(i, (k, v) -> k + v)"), Seq(Row(Map(2 -> 1, 18 -> 9, 16 -> 8, 14 -> 7)))) - checkAnswer(dfExample1.select(transform_keys(dfExample1("i"), (k, v) => k + v)), + checkAnswer(dfExample1.select(transform_keys(col("i"), (k, v) => k + v)), Seq(Row(Map(2 -> 1, 18 -> 9, 16 -> 8, 14 -> 7)))) checkAnswer(dfExample2.selectExpr("transform_keys(j, " + @@ -2758,7 +2760,7 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSQLContext { checkAnswer(dfExample2.select( transform_keys( - dfExample2("j"), + col("j"), (k, v) => element_at( map_from_arrays( array(lit(1), lit(2), lit(3)), @@ -2773,33 +2775,33 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSQLContext { checkAnswer(dfExample2.selectExpr("transform_keys(j, (k, v) -> CAST(v * 2 AS BIGINT) + k)"), Seq(Row(Map(3 -> 1.0, 4 -> 1.4, 6 -> 1.7)))) - checkAnswer(dfExample2.select(transform_keys(dfExample2("j"), + checkAnswer(dfExample2.select(transform_keys(col("j"), (k, v) => (v * 2).cast("bigint") + k)), Seq(Row(Map(3 -> 1.0, 4 -> 1.4, 6 -> 1.7)))) checkAnswer(dfExample2.selectExpr("transform_keys(j, (k, v) -> k + v)"), Seq(Row(Map(2.0 -> 1.0, 3.4 -> 1.4, 4.7 -> 1.7)))) - checkAnswer(dfExample2.select(transform_keys(dfExample2("j"), (k, v) => k + v)), + checkAnswer(dfExample2.select(transform_keys(col("j"), (k, v) => k + v)), Seq(Row(Map(2.0 -> 1.0, 3.4 -> 1.4, 4.7 -> 1.7)))) checkAnswer(dfExample3.selectExpr("transform_keys(x, (k, v) -> k % 2 = 0 OR v)"), Seq(Row(Map(true -> true, true -> false)))) - checkAnswer(dfExample3.select(transform_keys(dfExample3("x"), (k, v) => k % 2 === 0 || v)), + checkAnswer(dfExample3.select(transform_keys(col("x"), (k, v) => k % 2 === 0 || v)), Seq(Row(Map(true -> true, true -> false)))) checkAnswer(dfExample3.selectExpr("transform_keys(x, (k, v) -> if(v, 2 * k, 3 * k))"), Seq(Row(Map(50 -> true, 78 -> false)))) - checkAnswer(dfExample3.select(transform_keys(dfExample3("x"), + checkAnswer(dfExample3.select(transform_keys(col("x"), (k, v) => when(v, k * 2).otherwise(k * 3))), Seq(Row(Map(50 -> true, 78 -> false)))) checkAnswer(dfExample4.selectExpr("transform_keys(y, (k, v) -> array_contains(k, 3) AND v)"), Seq(Row(Map(false -> false)))) - checkAnswer(dfExample4.select(transform_keys(dfExample4("y"), + checkAnswer(dfExample4.select(transform_keys(col("y"), (k, v) => array_contains(k, lit(3)) && v)), Seq(Row(Map(false -> false)))) } @@ -2840,7 +2842,7 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSQLContext { assert(ex3.getMessage.contains("Cannot use null as map key")) val ex3a = intercept[Exception] { - dfExample1.select(transform_keys(dfExample1("i"), (k, v) => v)).show() + dfExample1.select(transform_keys(col("i"), (k, v) => v)).show() } assert(ex3a.getMessage.contains("Cannot use null as map key")) @@ -2909,26 +2911,26 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSQLContext { dfExample5.selectExpr("transform_values(c, (k, v) -> k + cardinality(v))"), Seq(Row(Map(1 -> 3)))) - checkAnswer(dfExample1.select(transform_values(dfExample1("i"), (k, v) => k + v)), + checkAnswer(dfExample1.select(transform_values(col("i"), (k, v) => k + v)), Seq(Row(Map(1 -> 2, 9 -> 18, 8 -> 16, 7 -> 14)))) checkAnswer(dfExample2.select( - transform_values(dfExample2("x"), (k, v) => when(k, v).otherwise(k.cast("string")))), + transform_values(col("x"), (k, v) => when(k, v).otherwise(k.cast("string")))), Seq(Row(Map(false -> "false", true -> "def")))) - checkAnswer(dfExample2.select(transform_values(dfExample2("x"), + checkAnswer(dfExample2.select(transform_values(col("x"), (k, v) => (!k) && v === "abc")), Seq(Row(Map(false -> true, true -> false)))) - checkAnswer(dfExample3.select(transform_values(dfExample3("y"), (k, v) => v * v)), + checkAnswer(dfExample3.select(transform_values(col("y"), (k, v) => v * v)), Seq(Row(Map("a" -> 1, "b" -> 4, "c" -> 9)))) checkAnswer(dfExample3.select( - transform_values(dfExample3("y"), (k, v) => concat(k, lit(":"), v.cast("string")))), + transform_values(col("y"), (k, v) => concat(k, lit(":"), v.cast("string")))), Seq(Row(Map("a" -> "a:1", "b" -> "b:2", "c" -> "c:3")))) checkAnswer( - dfExample3.select(transform_values(dfExample3("y"), (k, v) => concat(k, v.cast("string")))), + dfExample3.select(transform_values(col("y"), (k, v) => concat(k, v.cast("string")))), Seq(Row(Map("a" -> "a1", "b" -> "b2", "c" -> "c3")))) val testMap = map_from_arrays( @@ -2937,16 +2939,16 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSQLContext { ) checkAnswer( - dfExample4.select(transform_values(dfExample4("z"), + dfExample4.select(transform_values(col("z"), (k, v) => concat(element_at(testMap, k), lit("_"), v.cast("string")))), Seq(Row(Map(1 -> "one_1.0", 2 -> "two_1.4", 3 ->"three_1.7")))) checkAnswer( - dfExample4.select(transform_values(dfExample4("z"), (k, v) => k - v)), + dfExample4.select(transform_values(col("z"), (k, v) => k - v)), Seq(Row(Map(1 -> 0.0, 2 -> 0.6000000000000001, 3 -> 1.3)))) checkAnswer( - dfExample5.select(transform_values(dfExample5("c"), (k, v) => k + size(v))), + dfExample5.select(transform_values(col("c"), (k, v) => k + size(v))), Seq(Row(Map(1 -> 3)))) } @@ -2992,26 +2994,26 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSQLContext { checkAnswer(dfExample2.selectExpr("transform_values(j, (k, v) -> k + cast(v as BIGINT))"), Seq(Row(Map.empty[BigInt, BigInt]))) - checkAnswer(dfExample1.select(transform_values(dfExample1("i"), + checkAnswer(dfExample1.select(transform_values(col("i"), (k, v) => lit(null).cast("int"))), Seq(Row(Map.empty[Integer, Integer]))) - checkAnswer(dfExample1.select(transform_values(dfExample1("i"), (k, v) => k)), + checkAnswer(dfExample1.select(transform_values(col("i"), (k, v) => k)), Seq(Row(Map.empty[Integer, Integer]))) - checkAnswer(dfExample1.select(transform_values(dfExample1("i"), (k, v) => v)), + checkAnswer(dfExample1.select(transform_values(col("i"), (k, v) => v)), Seq(Row(Map.empty[Integer, Integer]))) - checkAnswer(dfExample1.select(transform_values(dfExample1("i"), (k, v) => lit(0))), + checkAnswer(dfExample1.select(transform_values(col("i"), (k, v) => lit(0))), Seq(Row(Map.empty[Integer, Integer]))) - checkAnswer(dfExample1.select(transform_values(dfExample1("i"), (k, v) => lit("value"))), + checkAnswer(dfExample1.select(transform_values(col("i"), (k, v) => lit("value"))), Seq(Row(Map.empty[Integer, String]))) - checkAnswer(dfExample1.select(transform_values(dfExample1("i"), (k, v) => lit(true))), + checkAnswer(dfExample1.select(transform_values(col("i"), (k, v) => lit(true))), Seq(Row(Map.empty[Integer, Boolean]))) - checkAnswer(dfExample1.select(transform_values(dfExample1("i"), (k, v) => v.cast("bigint"))), + checkAnswer(dfExample1.select(transform_values(col("i"), (k, v) => v.cast("bigint"))), Seq(Row(Map.empty[BigInt, BigInt]))) } @@ -3038,12 +3040,12 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSQLContext { "transform_values(b, (k, v) -> IF(v IS NULL, k + 1, k + 2))"), Seq(Row(Map(1 -> 3, 2 -> 4, 3 -> 4)))) - checkAnswer(dfExample1.select(transform_values(dfExample1("a"), + checkAnswer(dfExample1.select(transform_values(col("a"), (k, v) => lit(null).cast("int"))), Seq(Row(Map[Int, Integer](1 -> null, 2 -> null, 3 -> null, 4 -> null)))) checkAnswer(dfExample2.select( - transform_values(dfExample2("b"), (k, v) => when(v.isNull, k + 1).otherwise(k + 2)) + transform_values(col("b"), (k, v) => when(v.isNull, k + 1).otherwise(k + 2)) ), Seq(Row(Map(1 -> 3, 2 -> 4, 3 -> 4)))) } @@ -3085,10 +3087,11 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSQLContext { assert(ex3.getMessage.contains( "data type mismatch: argument 1 requires map type")) - val ex3a = intercept[MatchError] { - dfExample3.select(transform_values(dfExample3("x"), (k, v) => k + 1)) + val ex3a = intercept[AnalysisException] { + dfExample3.select(transform_values(col("x"), (k, v) => k + 1)) } - assert(ex3a.getMessage.contains("IntegerType")) + assert(ex3a.getMessage.contains( + "data type mismatch: argument 1 requires map type")) } testInvalidLambdaFunctions() @@ -3143,7 +3146,7 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSQLContext { expectedValue1 ) checkAnswer( - df.select(zip_with(df("val1"), df("val2"), (x, y) => struct(y, x))), + df.select(zip_with(col("val1"), col("val2"), (x, y) => struct(y, x))), expectedValue1 ) } @@ -3167,10 +3170,10 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSQLContext { df.selectExpr("zip_with(i, a2, (acc, x) -> x)") } assert(ex3.getMessage.contains("data type mismatch: argument 1 requires array type")) - val ex3a = intercept[MatchError] { + val ex3a = intercept[AnalysisException] { df.select(zip_with(df("i"), df("a2"), (acc, x) => x)) } - assert(ex3a.getMessage.contains("IntegerType")) + assert(ex3a.getMessage.contains("data type mismatch: argument 1 requires array type")) val ex4 = intercept[AnalysisException] { df.selectExpr("zip_with(a1, a, (acc, x) -> x)") } From b03399a9d2db75b16a153ac055228798bb99591b Mon Sep 17 00:00:00 2001 From: HyukjinKwon Date: Thu, 25 Jul 2019 22:10:03 +0900 Subject: [PATCH 07/38] Resolve conflict between Java Function and Scala Function --- .../org/apache/spark/sql/functions.scala | 35 ++++++++++--------- 1 file changed, 18 insertions(+), 17 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala index ae0d8e75b1569..7f5217f78b84d 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala @@ -24,7 +24,7 @@ import scala.util.Try import scala.util.control.NonFatal import org.apache.spark.annotation.Stable -import org.apache.spark.api.java.function._ +import org.apache.spark.api.java.function.{Function => JavaFunction, Function2 => JavaFunction2, Function3 => JavaFunction3} import org.apache.spark.sql.api.java._ import org.apache.spark.sql.catalyst.ScalaReflection import org.apache.spark.sql.catalyst.analysis.{Star, UnresolvedFunction} @@ -3338,20 +3338,20 @@ object functions { LambdaFunction(function, Seq(x, y, z)) } - private def createLambda(f: Function[Column, Column]) = { + private def createLambda(f: JavaFunction[Column, Column]) = { val x = UnresolvedNamedLambdaVariable(Seq("x")) val function = f.call(Column(x)).expr LambdaFunction(function, Seq(x)) } - private def createLambda(f: Function2[Column, Column, Column]) = { + private def createLambda(f: JavaFunction2[Column, Column, Column]) = { val x = UnresolvedNamedLambdaVariable(Seq("x")) val y = UnresolvedNamedLambdaVariable(Seq("y")) val function = f.call(Column(x), Column(y)).expr LambdaFunction(function, Seq(x, y)) } - private def createLambda(f: Function3[Column, Column, Column, Column]) = { + private def createLambda(f: JavaFunction3[Column, Column, Column, Column]) = { val x = UnresolvedNamedLambdaVariable(Seq("x")) val y = UnresolvedNamedLambdaVariable(Seq("y")) val z = UnresolvedNamedLambdaVariable(Seq("z")) @@ -3479,7 +3479,7 @@ object functions { * * @group collection_funcs */ - def transform(column: Column, f: Function[Column, Column]): Column = withExpr { + def transform(column: Column, f: JavaFunction[Column, Column]): Column = withExpr { ArrayTransform(column.expr, createLambda(f)) } @@ -3489,7 +3489,7 @@ object functions { * * @group collection_funcs */ - def transform(column: Column, f: Function2[Column, Column, Column]): Column = withExpr { + def transform(column: Column, f: JavaFunction2[Column, Column, Column]): Column = withExpr { ArrayTransform(column.expr, createLambda(f)) } @@ -3498,7 +3498,7 @@ object functions { * * @group collection_funcs */ - def exists(column: Column, f: Function[Column, Column]): Column = withExpr { + def exists(column: Column, f: JavaFunction[Column, Column]): Column = withExpr { ArrayExists(column.expr, createLambda(f)) } @@ -3507,7 +3507,7 @@ object functions { * * @group collection_funcs */ - def filter(column: Column, f: Function[Column, Column]): Column = withExpr { + def filter(column: Column, f: JavaFunction[Column, Column]): Column = withExpr { ArrayFilter(column.expr, createLambda(f)) } @@ -3518,8 +3518,8 @@ object functions { * * @group collection_funcs */ - def aggregate(expr: Column, zero: Column, merge: Function2[Column, Column, Column], - finish: Function[Column, Column]): Column = withExpr { + def aggregate(expr: Column, zero: Column, merge: JavaFunction2[Column, Column, Column], + finish: JavaFunction[Column, Column]): Column = withExpr { ArrayAggregate( expr.expr, zero.expr, @@ -3534,8 +3534,9 @@ object functions { * * @group collection_funcs */ - def aggregate(expr: Column, zero: Column, merge: Function2[Column, Column, Column]): Column = - aggregate(expr, zero, merge, new Function[Column, Column] { def call(c: Column): Column = c }) + def aggregate(expr: Column, zero: Column, merge: JavaFunction2[Column, Column, Column]): Column = + aggregate( + expr, zero, merge, new JavaFunction[Column, Column] { def call(c: Column): Column = c }) /** * (Java-specific) Merge two given arrays, element-wise, into a signle array using a function. @@ -3544,7 +3545,7 @@ object functions { * * @group collection_funcs */ - def zip_with(left: Column, right: Column, f: Function2[Column, Column, Column]): Column = + def zip_with(left: Column, right: Column, f: JavaFunction2[Column, Column, Column]): Column = withExpr { ZipWith(left.expr, right.expr, createLambda(f)) } @@ -3555,7 +3556,7 @@ object functions { * * @group collection_funcs */ - def transform_keys(expr: Column, f: Function2[Column, Column, Column]): Column = withExpr { + def transform_keys(expr: Column, f: JavaFunction2[Column, Column, Column]): Column = withExpr { TransformKeys(expr.expr, createLambda(f)) } @@ -3565,7 +3566,7 @@ object functions { * * @group collection_funcs */ - def transform_values(expr: Column, f: Function2[Column, Column, Column]): Column = withExpr { + def transform_values(expr: Column, f: JavaFunction2[Column, Column, Column]): Column = withExpr { TransformValues(expr.expr, createLambda(f)) } @@ -3574,7 +3575,7 @@ object functions { * * @group collection_funcs */ - def map_filter(expr: Column, f: Function2[Column, Column, Column]): Column = withExpr { + def map_filter(expr: Column, f: JavaFunction2[Column, Column, Column]): Column = withExpr { MapFilter(expr.expr, createLambda(f)) } @@ -3584,7 +3585,7 @@ object functions { * @group collection_funcs */ def map_zip_with(left: Column, right: Column, - f: Function3[Column, Column, Column, Column]): Column = withExpr { + f: JavaFunction3[Column, Column, Column, Column]): Column = withExpr { MapZipWith(left.expr, right.expr, createLambda(f)) } From 79d6f841c1aa43a024055ba17c255cb4a42e9f02 Mon Sep 17 00:00:00 2001 From: Nik Vanderhoof Date: Wed, 27 Mar 2019 23:34:48 -0400 Subject: [PATCH 08/38] Adds higher order functions to scala API --- .../expressions/higherOrderFunctions.scala | 133 +++++++ .../HigherOrderFunctionsSuite.scala | 125 +------ .../org/apache/spark/sql/functions.scala | 135 +++++++ .../spark/sql/DataFrameFunctionsSuite.scala | 350 +++++++++++++++++- 4 files changed, 617 insertions(+), 126 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/higherOrderFunctions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/higherOrderFunctions.scala index b326e1c4c6af4..cc9e42c893d09 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/higherOrderFunctions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/higherOrderFunctions.scala @@ -29,6 +29,139 @@ import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types._ import org.apache.spark.unsafe.array.ByteArrayMethods +/** + * Helper methods for constructing higher order functions. + */ +object HigherOrderUtils { + def createLambda( + dt: DataType, + nullable: Boolean, + f: Expression => Expression): Expression = { + val lv = NamedLambdaVariable("arg", dt, nullable) + val function = f(lv) + LambdaFunction(function, Seq(lv)) + } + + def createLambda( + dt1: DataType, + nullable1: Boolean, + dt2: DataType, + nullable2: Boolean, + f: (Expression, Expression) => Expression): Expression = { + val lv1 = NamedLambdaVariable("arg1", dt1, nullable1) + val lv2 = NamedLambdaVariable("arg2", dt2, nullable2) + val function = f(lv1, lv2) + LambdaFunction(function, Seq(lv1, lv2)) + } + + def createLambda( + dt1: DataType, + nullable1: Boolean, + dt2: DataType, + nullable2: Boolean, + dt3: DataType, + nullable3: Boolean, + f: (Expression, Expression, Expression) => Expression): Expression = { + val lv1 = NamedLambdaVariable("arg1", dt1, nullable1) + val lv2 = NamedLambdaVariable("arg2", dt2, nullable2) + val lv3 = NamedLambdaVariable("arg3", dt3, nullable3) + val function = f(lv1, lv2, lv3) + LambdaFunction(function, Seq(lv1, lv2, lv3)) + } + + def validateBinding( + e: Expression, + argInfo: Seq[(DataType, Boolean)]): LambdaFunction = e match { + case f: LambdaFunction => + assert(f.arguments.size == argInfo.size) + f.arguments.zip(argInfo).foreach { + case (arg, (dataType, nullable)) => + assert(arg.dataType == dataType) + assert(arg.nullable == nullable) + } + f + } + + // Array-based helpers + def filter(expr: Expression, f: Expression => Expression): Expression = { + val ArrayType(et, cn) = expr.dataType + ArrayFilter(expr, createLambda(et, cn, f)).bind(validateBinding) + } + + def exists(expr: Expression, f: Expression => Expression): Expression = { + val ArrayType(et, cn) = expr.dataType + ArrayExists(expr, createLambda(et, cn, f)).bind(validateBinding) + } + + def transform(expr: Expression, f: Expression => Expression): Expression = { + val ArrayType(et, cn) = expr.dataType + ArrayTransform(expr, createLambda(et, cn, f)).bind(validateBinding) + } + + def transform(expr: Expression, f: (Expression, Expression) => Expression): Expression = { + val ArrayType(et, cn) = expr.dataType + ArrayTransform(expr, createLambda(et, cn, IntegerType, false, f)).bind(validateBinding) + } + + def aggregate( + expr: Expression, + zero: Expression, + merge: (Expression, Expression) => Expression, + finish: Expression => Expression): Expression = { + val ArrayType(et, cn) = expr.dataType + val zeroType = zero.dataType + ArrayAggregate( + expr, + zero, + createLambda(zeroType, true, et, cn, merge), + createLambda(zeroType, true, finish)) + .bind(validateBinding) + } + + def aggregate( + expr: Expression, + zero: Expression, + merge: (Expression, Expression) => Expression): Expression = { + aggregate(expr, zero, merge, identity) + } + + def zip_with( + left: Expression, + right: Expression, + f: (Expression, Expression) => Expression): Expression = { + val ArrayType(leftT, _) = left.dataType + val ArrayType(rightT, _) = right.dataType + ZipWith(left, right, createLambda(leftT, true, rightT, true, f)).bind(validateBinding) + } + + // Map-based helpers + + def transformKeys(expr: Expression, f: (Expression, Expression) => Expression): Expression = { + val MapType(kt, vt, vcn) = expr.dataType + TransformKeys(expr, createLambda(kt, false, vt, vcn, f)).bind(validateBinding) + } + + def transformValues(expr: Expression, f: (Expression, Expression) => Expression): Expression = { + val MapType(kt, vt, vcn) = expr.dataType + TransformValues(expr, createLambda(kt, false, vt, vcn, f)).bind(validateBinding) + } + + def mapFilter(expr: Expression, f: (Expression, Expression) => Expression): Expression = { + val MapType(kt, vt, vcn) = expr.dataType + MapFilter(expr, createLambda(kt, false, vt, vcn, f)).bind(validateBinding) + } + + def map_zip_with( + left: Expression, + right: Expression, + f: (Expression, Expression, Expression) => Expression): Expression = { + val MapType(kt, vt1, _) = left.dataType + val MapType(_, vt2, _) = right.dataType + MapZipWith(left, right, createLambda(kt, false, vt1, true, vt2, true, f)) + .bind(validateBinding) + } +} + /** * A placeholder of lambda variables to prevent unexpected resolution of [[LambdaFunction]]. */ diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/HigherOrderFunctionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/HigherOrderFunctionsSuite.scala index 1411be8007deb..046c67a8a665c 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/HigherOrderFunctionsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/HigherOrderFunctionsSuite.scala @@ -24,102 +24,7 @@ import org.apache.spark.sql.types._ class HigherOrderFunctionsSuite extends SparkFunSuite with ExpressionEvalHelper { import org.apache.spark.sql.catalyst.dsl.expressions._ - - private def createLambda( - dt: DataType, - nullable: Boolean, - f: Expression => Expression): Expression = { - val lv = NamedLambdaVariable("arg", dt, nullable) - val function = f(lv) - LambdaFunction(function, Seq(lv)) - } - - private def createLambda( - dt1: DataType, - nullable1: Boolean, - dt2: DataType, - nullable2: Boolean, - f: (Expression, Expression) => Expression): Expression = { - val lv1 = NamedLambdaVariable("arg1", dt1, nullable1) - val lv2 = NamedLambdaVariable("arg2", dt2, nullable2) - val function = f(lv1, lv2) - LambdaFunction(function, Seq(lv1, lv2)) - } - - private def createLambda( - dt1: DataType, - nullable1: Boolean, - dt2: DataType, - nullable2: Boolean, - dt3: DataType, - nullable3: Boolean, - f: (Expression, Expression, Expression) => Expression): Expression = { - val lv1 = NamedLambdaVariable("arg1", dt1, nullable1) - val lv2 = NamedLambdaVariable("arg2", dt2, nullable2) - val lv3 = NamedLambdaVariable("arg3", dt3, nullable3) - val function = f(lv1, lv2, lv3) - LambdaFunction(function, Seq(lv1, lv2, lv3)) - } - - private def validateBinding( - e: Expression, - argInfo: Seq[(DataType, Boolean)]): LambdaFunction = e match { - case f: LambdaFunction => - assert(f.arguments.size === argInfo.size) - f.arguments.zip(argInfo).foreach { - case (arg, (dataType, nullable)) => - assert(arg.dataType === dataType) - assert(arg.nullable === nullable) - } - f - } - - def transform(expr: Expression, f: Expression => Expression): Expression = { - val ArrayType(et, cn) = expr.dataType - ArrayTransform(expr, createLambda(et, cn, f)).bind(validateBinding) - } - - def transform(expr: Expression, f: (Expression, Expression) => Expression): Expression = { - val ArrayType(et, cn) = expr.dataType - ArrayTransform(expr, createLambda(et, cn, IntegerType, false, f)).bind(validateBinding) - } - - def filter(expr: Expression, f: Expression => Expression): Expression = { - val ArrayType(et, cn) = expr.dataType - ArrayFilter(expr, createLambda(et, cn, f)).bind(validateBinding) - } - - def transformKeys(expr: Expression, f: (Expression, Expression) => Expression): Expression = { - val MapType(kt, vt, vcn) = expr.dataType - TransformKeys(expr, createLambda(kt, false, vt, vcn, f)).bind(validateBinding) - } - - def aggregate( - expr: Expression, - zero: Expression, - merge: (Expression, Expression) => Expression, - finish: Expression => Expression): Expression = { - val ArrayType(et, cn) = expr.dataType - val zeroType = zero.dataType - ArrayAggregate( - expr, - zero, - createLambda(zeroType, true, et, cn, merge), - createLambda(zeroType, true, finish)) - .bind(validateBinding) - } - - def aggregate( - expr: Expression, - zero: Expression, - merge: (Expression, Expression) => Expression): Expression = { - aggregate(expr, zero, merge, identity) - } - - def transformValues(expr: Expression, f: (Expression, Expression) => Expression): Expression = { - val MapType(kt, vt, vcn) = expr.dataType - TransformValues(expr, createLambda(kt, false, vt, vcn, f)).bind(validateBinding) - } + import org.apache.spark.sql.catalyst.expressions.HigherOrderUtils._ test("ArrayTransform") { val ai0 = Literal.create(Seq(1, 2, 3), ArrayType(IntegerType, containsNull = false)) @@ -163,10 +68,6 @@ class HigherOrderFunctionsSuite extends SparkFunSuite with ExpressionEvalHelper } test("MapFilter") { - def mapFilter(expr: Expression, f: (Expression, Expression) => Expression): Expression = { - val MapType(kt, vt, vcn) = expr.dataType - MapFilter(expr, createLambda(kt, false, vt, vcn, f)).bind(validateBinding) - } val mii0 = Literal.create(Map(1 -> 0, 2 -> 10, 3 -> -1), MapType(IntegerType, IntegerType, valueContainsNull = false)) val mii1 = Literal.create(Map(1 -> null, 2 -> 10, 3 -> null), @@ -244,11 +145,6 @@ class HigherOrderFunctionsSuite extends SparkFunSuite with ExpressionEvalHelper } test("ArrayExists") { - def exists(expr: Expression, f: Expression => Expression): Expression = { - val ArrayType(et, cn) = expr.dataType - ArrayExists(expr, createLambda(et, cn, f)).bind(validateBinding) - } - val ai0 = Literal.create(Seq(1, 2, 3), ArrayType(IntegerType, containsNull = false)) val ai1 = Literal.create(Seq[Integer](1, null, 3), ArrayType(IntegerType, containsNull = true)) val ain = Literal.create(null, ArrayType(IntegerType, containsNull = false)) @@ -481,16 +377,6 @@ class HigherOrderFunctionsSuite extends SparkFunSuite with ExpressionEvalHelper } test("MapZipWith") { - def map_zip_with( - left: Expression, - right: Expression, - f: (Expression, Expression, Expression) => Expression): Expression = { - val MapType(kt, vt1, _) = left.dataType - val MapType(_, vt2, _) = right.dataType - MapZipWith(left, right, createLambda(kt, false, vt1, true, vt2, true, f)) - .bind(validateBinding) - } - val mii0 = Literal.create(create_map(1 -> 10, 2 -> 20, 3 -> 30), MapType(IntegerType, IntegerType, valueContainsNull = false)) val mii1 = Literal.create(create_map(1 -> -1, 2 -> -2, 4 -> -4), @@ -573,15 +459,6 @@ class HigherOrderFunctionsSuite extends SparkFunSuite with ExpressionEvalHelper } test("ZipWith") { - def zip_with( - left: Expression, - right: Expression, - f: (Expression, Expression) => Expression): Expression = { - val ArrayType(leftT, _) = left.dataType - val ArrayType(rightT, _) = right.dataType - ZipWith(left, right, createLambda(leftT, true, rightT, true, f)).bind(validateBinding) - } - val ai0 = Literal.create(Seq(1, 2, 3), ArrayType(IntegerType, containsNull = false)) val ai1 = Literal.create(Seq(1, 2, 3, 4), ArrayType(IntegerType, containsNull = false)) val ai2 = Literal.create(Seq[Integer](1, null, 3), ArrayType(IntegerType, containsNull = true)) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala index 1abda54518fd3..da5a011be881a 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala @@ -3321,6 +3321,141 @@ object functions { ArrayExcept(col1.expr, col2.expr) } + private def expressionFunction(f: Column => Column) + : Expression => Expression = + x => f(Column(x)).expr + + private def expressionFunction(f: (Column, Column) => Column) + : (Expression, Expression) => Expression = + (x, y) => f(Column(x), Column(y)).expr + + private def expressionFunction(f: (Column, Column, Column) => Column) + : (Expression, Expression, Expression) => Expression = + (x, y, z) => f(Column(x), Column(y), Column(z)).expr + + /** + * Returns an array of elements after applying a tranformation to each element + * in the input array. + * + * @group collection_funcs + */ + def transform(column: Column, f: Column => Column): Column = withExpr { + HigherOrderUtils.transform(column.expr, expressionFunction(f)) + } + + /** + * Returns an array of elements after applying a tranformation to each element + * in the input array. + * + * @group collection_funcs + */ + def transform(column: Column, f: (Column, Column) => Column): Column = withExpr { + HigherOrderUtils.transform(column.expr, expressionFunction(f)) + } + + /** + * Returns whether a predicate holds for one or more elements in the array. + * + * @group collection_funcs + */ + def exists(column: Column, f: Column => Column): Column = withExpr { + HigherOrderUtils.exists(column.expr, expressionFunction(f)) + } + + /** + * Returns an array of elements for which a predicate holds in a given array. + * + * @group collection_funcs + */ + def filter(column: Column, f: Column => Column): Column = withExpr { + HigherOrderUtils.filter(column.expr, expressionFunction(f)) + } + + /** + * Applies a binary operator to an initial state and all elements in the array, + * and reduces this to a single state. The final state is converted into the final result + * by applying a finish function. + * + * @group collection_funcs + */ + def aggregate( + expr: Column, + zero: Column, + merge: (Column, Column) => Column, + finish: Column => Column): Column = withExpr { + HigherOrderUtils.aggregate( + expr.expr, + zero.expr, + expressionFunction(merge), + expressionFunction(finish) + ) + } + + /** + * Applies a binary operator to an initial state and all elements in the array, + * and reduces this to a single state. + * + * @group collection_funcs + */ + def aggregate( + expr: Column, + zero: Column, + merge: (Column, Column) => Column): Column = + aggregate(expr, zero, merge, identity) + + /** + * Merge two given arrays, element-wise, into a signle array using a function. + * If one array is shorter, nulls are appended at the end to match the length of the longer + * array, before applying the function. + * + * @group collection_funcs + */ + def zip_with( + left: Column, + right: Column, + f: (Column, Column) => Column): Column = withExpr { + HigherOrderUtils.zip_with(left.expr, right.expr, expressionFunction(f)) + } + + /** + * Applies a function to every key-value pair in a map and returns + * a map with the results of those applications as the new keys for the pairs. + * + * @group collection_funcs + */ + def transform_keys(expr: Column, f: (Column, Column) => Column): Column = withExpr { + HigherOrderUtils.transformKeys(expr.expr, expressionFunction(f)) + } + + /** + * Applies a function to every key-value pair in a map and returns + * a map with the results of those applications as the new values for the pairs. + * + * @group collection_funcs + */ + def transform_values(expr: Column, f: (Column, Column) => Column): Column = withExpr { + HigherOrderUtils.transformValues(expr.expr, expressionFunction(f)) + } + + /** + * Returns a map whose key-value pairs satisfy a predicate. + * + * @group collection_funcs + */ + def map_filter(expr: Column, f: (Column, Column) => Column): Column = withExpr { + HigherOrderUtils.mapFilter(expr.expr, expressionFunction(f)) + } + + /** + * Merge two given maps, key-wise into a single map using a function. + * + * @group collection_funcs + */ + def map_zip_with(left: Column, right: Column, f: (Column, Column, Column) => Column): Column = + withExpr { + HigherOrderUtils.map_zip_with(left.expr, right.expr, expressionFunction(f)) + } + /** * Creates a new row for each element in the given array or map column. * Uses the default column name `col` for elements in the array and diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala index 3f16f64f4b900..69a1e2130d990 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala @@ -1930,6 +1930,18 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSQLContext { Row(Seq(5, 9, 11, 10, 6)), Row(Seq.empty), Row(null))) + checkAnswer(df.select(transform(df("i"), x => x + 1)), + Seq( + Row(Seq(2, 10, 9, 8)), + Row(Seq(6, 9, 10, 8, 3)), + Row(Seq.empty), + Row(null))) + checkAnswer(df.select(transform(df("i"), (x, i) => x + i)), + Seq( + Row(Seq(1, 10, 10, 10)), + Row(Seq(5, 9, 11, 10, 6)), + Row(Seq.empty), + Row(null))) } // Test with local relation, the Project will be evaluated without codegen @@ -1960,6 +1972,18 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSQLContext { Row(Seq(5, null, 10, 12, 11, 7)), Row(Seq.empty), Row(null))) + checkAnswer(df.select(transform(df("i"), x => x + 1)), + Seq( + Row(Seq(2, 10, 9, null, 8)), + Row(Seq(6, null, 9, 10, 8, 3)), + Row(Seq.empty), + Row(null))) + checkAnswer(df.select(transform(df("i"), (x, i) => x + i)), + Seq( + Row(Seq(1, 10, 10, null, 11)), + Row(Seq(5, null, 10, 12, 11, 7)), + Row(Seq.empty), + Row(null))) } // Test with local relation, the Project will be evaluated without codegen @@ -1990,6 +2014,18 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSQLContext { Row(Seq("b0", null, "c2", null)), Row(Seq.empty), Row(null))) + checkAnswer(df.select(transform(df("s"), x => concat(x, x))), + Seq( + Row(Seq("cc", "aa", "bb")), + Row(Seq("bb", null, "cc", null)), + Row(Seq.empty), + Row(null))) + checkAnswer(df.select(transform(df("s"), (x, i) => concat(x, i))), + Seq( + Row(Seq("c0", "a1", "b2")), + Row(Seq("b0", null, "c2", null)), + Row(Seq.empty), + Row(null))) } // Test with local relation, the Project will be evaluated without codegen @@ -2034,6 +2070,32 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSQLContext { Seq("b", null, "c", null, null))), Row(Seq.empty), Row(null))) + checkAnswer(df.select(transform(df("arg"), arg => arg)), + Seq( + Row(Seq("c", "a", "b")), + Row(Seq("b", null, "c", null)), + Row(Seq.empty), + Row(null))) + checkAnswer(df.select(transform(df("arg"), _ => df("arg"))), + Seq( + Row(Seq(Seq("c", "a", "b"), Seq("c", "a", "b"), Seq("c", "a", "b"))), + Row(Seq( + Seq("b", null, "c", null), + Seq("b", null, "c", null), + Seq("b", null, "c", null), + Seq("b", null, "c", null))), + Row(Seq.empty), + Row(null))) + checkAnswer(df.select(transform(df("arg"), x => concat(df("arg"), array(x)))), + Seq( + Row(Seq(Seq("c", "a", "b", "c"), Seq("c", "a", "b", "a"), Seq("c", "a", "b", "b"))), + Row(Seq( + Seq("b", null, "c", null, "b"), + Seq("b", null, "c", null, null), + Seq("b", null, "c", null, "c"), + Seq("b", null, "c", null, null))), + Row(Seq.empty), + Row(null))) } // Test with local relation, the Project will be evaluated without codegen @@ -2080,6 +2142,14 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSQLContext { Row(Map(), Map(1 -> -1, 2 -> -2, 3 -> -3)), Row(Map(1 -> 10), Map(3 -> -3)))) + checkAnswer(dfInts.select( + map_filter(dfInts("m"), (k, v) => k * 10 === v), + map_filter(dfInts("m"), (k, v) => k === (v * -1))), + Seq( + Row(Map(1 -> 10, 2 -> 20, 3 -> 30), Map()), + Row(Map(), Map(1 -> -1, 2 -> -2, 3 -> -3)), + Row(Map(1 -> 10), Map(3 -> -3)))) + val dfComplex = Seq( Map(1 -> Seq(Some(1)), 2 -> Seq(Some(1), Some(2)), 3 -> Seq(Some(1), Some(2), Some(3))), Map(1 -> null, 2 -> Seq(Some(-2), Some(-2)), 3 -> Seq[Option[Int]](None))).toDF("m") @@ -2090,6 +2160,13 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSQLContext { Row(Map(1 -> Seq(1)), Map(1 -> Seq(1), 2 -> Seq(1, 2), 3 -> Seq(1, 2, 3))), Row(Map(), Map(2 -> Seq(-2, -2))))) + checkAnswer(dfComplex.select( + map_filter(dfComplex("m"), (k, v) => k === element_at(v, 1)), + map_filter(dfComplex("m"), (k, v) => k === size(v))), + Seq( + Row(Map(1 -> Seq(1)), Map(1 -> Seq(1), 2 -> Seq(1, 2), 3 -> Seq(1, 2, 3))), + Row(Map(), Map(2 -> Seq(-2, -2))))) + // Invalid use cases val df = Seq( (Map(1 -> "a"), 1), @@ -2112,6 +2189,11 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSQLContext { } assert(ex3.getMessage.contains("data type mismatch: argument 1 requires map type")) + val ex3a = intercept[MatchError] { + df.select(map_filter(df("i"), (k, v) => k > v)) + } + assert(ex3a.getMessage.contains("IntegerType")) + val ex4 = intercept[AnalysisException] { df.selectExpr("map_filter(a, (k, v) -> k > v)") } @@ -2133,6 +2215,12 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSQLContext { Row(Seq(8, 2)), Row(Seq.empty), Row(null))) + checkAnswer(df.select(filter(df("i"), _ % 2 === 0)), + Seq( + Row(Seq(8)), + Row(Seq(8, 2)), + Row(Seq.empty), + Row(null))) } // Test with local relation, the Project will be evaluated without codegen @@ -2157,6 +2245,12 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSQLContext { Row(Seq(8, 2)), Row(Seq.empty), Row(null))) + checkAnswer(df.select(filter(df("i"), _ % 2 === 0)), + Seq( + Row(Seq(8)), + Row(Seq(8, 2)), + Row(Seq.empty), + Row(null))) } // Test with local relation, the Project will be evaluated without codegen @@ -2181,6 +2275,12 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSQLContext { Row(Seq("b", "c")), Row(Seq.empty), Row(null))) + checkAnswer(df.select(filter(df("s"), x => x.isNotNull)), + Seq( + Row(Seq("c", "a", "b")), + Row(Seq("b", "c")), + Row(Seq.empty), + Row(null))) } // Test with local relation, the Project will be evaluated without codegen @@ -2208,11 +2308,21 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSQLContext { } assert(ex2.getMessage.contains("data type mismatch: argument 1 requires array type")) + val ex2a = intercept[MatchError] { + df.select(filter(df("i"), x => x)) + } + assert(ex2a.getMessage.contains("IntegerType")) + val ex3 = intercept[AnalysisException] { df.selectExpr("filter(s, x -> x)") } assert(ex3.getMessage.contains("data type mismatch: argument 2 requires boolean type")) + val ex3a = intercept[AnalysisException] { + df.select(filter(df("s"), x => x)) + } + assert(ex3a.getMessage.contains("data type mismatch: argument 2 requires boolean type")) + val ex4 = intercept[AnalysisException] { df.selectExpr("filter(a, x -> x)") } @@ -2234,6 +2344,12 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSQLContext { Row(false), Row(false), Row(null))) + checkAnswer(df.select(exists(df("i"), _ % 2 === 0)), + Seq( + Row(true), + Row(false), + Row(false), + Row(null))) } // Test with local relation, the Project will be evaluated without codegen @@ -2260,6 +2376,12 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSQLContext { Row(null), Row(false), Row(null))) + checkAnswer(df.select(exists(df("i"), _ % 2 === 0)), + Seq( + Row(true), + Row(false), + Row(false), + Row(null))) } // Test with local relation, the Project will be evaluated without codegen @@ -2284,6 +2406,12 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSQLContext { Row(true), Row(false), Row(null))) + checkAnswer(df.select(exists(df("s"), x => x.isNull)), + Seq( + Row(false), + Row(true), + Row(false), + Row(null))) } // Test with local relation, the Project will be evaluated without codegen @@ -2311,11 +2439,21 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSQLContext { } assert(ex2.getMessage.contains("data type mismatch: argument 1 requires array type")) + val ex2a = intercept[MatchError] { + df.select(exists(df("i"), x => x)) + } + assert(ex2a.getMessage.contains("IntegerType")) + val ex3 = intercept[AnalysisException] { df.selectExpr("exists(s, x -> x)") } assert(ex3.getMessage.contains("data type mismatch: argument 2 requires boolean type")) + val ex3a = intercept[AnalysisException] { + df.select(exists(df("s"), x => x)) + } + assert(ex3a.getMessage.contains("data type mismatch: argument 2 requires boolean type")) + val ex4 = intercept[AnalysisException] { df.selectExpr("exists(a, x -> x)") } @@ -2343,6 +2481,18 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSQLContext { Row(310), Row(0), Row(null))) + checkAnswer(df.select(aggregate(df("i"), lit(0), (acc, x) => acc + x)), + Seq( + Row(25), + Row(31), + Row(0), + Row(null))) + checkAnswer(df.select(aggregate(df("i"), lit(0), (acc, x) => acc + x, _ * 10)), + Seq( + Row(250), + Row(310), + Row(0), + Row(null))) } // Test with local relation, the Project will be evaluated without codegen @@ -2374,6 +2524,20 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSQLContext { Row(0), Row(0), Row(null))) + checkAnswer(df.select(aggregate(df("i"), lit(0), (acc, x) => acc + x)), + Seq( + Row(25), + Row(null), + Row(0), + Row(null))) + checkAnswer( + df.select( + aggregate(df("i"), lit(0), (acc, x) => acc + x, acc => coalesce(acc, lit(0)) * 10)), + Seq( + Row(250), + Row(0), + Row(0), + Row(null))) } // Test with local relation, the Project will be evaluated without codegen @@ -2405,6 +2569,20 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSQLContext { Row(""), Row("c"), Row(null))) + checkAnswer(df.select(aggregate(df("ss"), df("s"), (acc, x) => concat(acc, x))), + Seq( + Row("acab"), + Row(null), + Row("c"), + Row(null))) + checkAnswer( + df.select( + aggregate(df("ss"), df("s"), (acc, x) => concat(acc, x), acc => coalesce(acc, lit("")))), + Seq( + Row("acab"), + Row(""), + Row("c"), + Row(null))) } // Test with local relation, the Project will be evaluated without codegen @@ -2437,11 +2615,21 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSQLContext { } assert(ex3.getMessage.contains("data type mismatch: argument 1 requires array type")) + val ex3a = intercept[MatchError] { + df.select(aggregate(df("i"), lit(0), (acc, x) => x)) + } + assert(ex3a.getMessage.contains("IntegerType")) + val ex4 = intercept[AnalysisException] { df.selectExpr("aggregate(s, 0, (acc, x) -> x)") } assert(ex4.getMessage.contains("data type mismatch: argument 3 requires int type")) + val ex4a = intercept[AnalysisException] { + df.select(aggregate(df("s"), lit(0), (acc, x) => x)) + } + assert(ex4a.getMessage.contains("data type mismatch: argument 3 requires int type")) + val ex5 = intercept[AnalysisException] { df.selectExpr("aggregate(a, 0, (acc, x) -> x)") } @@ -2462,6 +2650,13 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSQLContext { Row(Map(10 -> null, 8 -> false, 4 -> null)), Row(Map(5 -> null)), Row(null))) + + checkAnswer(df.select(map_zip_with(df("m1"), df("m2"), (k, v1, v2) => k === v1 + v2)), + Seq( + Row(Map(8 -> true, 3 -> false, 6 -> true)), + Row(Map(10 -> null, 8 -> false, 4 -> null)), + Row(Map(5 -> null)), + Row(null))) } test("map_zip_with function - map of non-primitive types") { @@ -2478,6 +2673,13 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSQLContext { Row(Map("b" -> Row("a", null), "c" -> Row("d", "a"), "d" -> Row(null, "k"))), Row(Map("a" -> Row("d", null))), Row(null))) + + checkAnswer(df.select(map_zip_with(df("m1"), df("m2"), (k, v1, v2) => struct(v1, v2))), + Seq( + Row(Map("z" -> Row("a", "c"), "y" -> Row("b", null), "x" -> Row("c", "a"))), + Row(Map("b" -> Row("a", null), "c" -> Row("d", "a"), "d" -> Row(null, "k"))), + Row(Map("a" -> Row("d", null))), + Row(null))) } test("map_zip_with function - invalid") { @@ -2496,16 +2698,31 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSQLContext { assert(ex2.getMessage.contains("The input to function map_zip_with should have " + "been two maps with compatible key types")) + val ex2a = intercept[NoSuchElementException] { + df.select(map_zip_with(df("mis"), df("mmi"), (x, y, z) => concat(x, y, z))) + } + assert(ex2a.getMessage.contains("None.get")) + val ex3 = intercept[AnalysisException] { df.selectExpr("map_zip_with(i, mis, (x, y, z) -> concat(x, y, z))") } assert(ex3.getMessage.contains("type mismatch: argument 1 requires map type")) + val ex3a = intercept[MatchError] { + df.select(map_zip_with(df("i"), df("mis"), (x, y, z) => concat(x, y, z))) + } + assert(ex3a.getMessage.contains("IntegerType")) + val ex4 = intercept[AnalysisException] { df.selectExpr("map_zip_with(mis, i, (x, y, z) -> concat(x, y, z))") } assert(ex4.getMessage.contains("type mismatch: argument 2 requires map type")) + val ex4a = intercept[MatchError] { + df.select(map_zip_with(df("mis"), df("i"), (x, y, z) => concat(x, y, z))) + } + assert(ex4a.getMessage.contains("IntegerType")) + val ex5 = intercept[AnalysisException] { df.selectExpr("map_zip_with(mmi, mmi, (x, y, z) -> x)") } @@ -2534,27 +2751,59 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSQLContext { checkAnswer(dfExample1.selectExpr("transform_keys(i, (k, v) -> k + v)"), Seq(Row(Map(2 -> 1, 18 -> 9, 16 -> 8, 14 -> 7)))) + checkAnswer(dfExample1.select(transform_keys(dfExample1("i"), (k, v) => k + v)), + Seq(Row(Map(2 -> 1, 18 -> 9, 16 -> 8, 14 -> 7)))) + checkAnswer(dfExample2.selectExpr("transform_keys(j, " + "(k, v) -> map_from_arrays(ARRAY(1, 2, 3), ARRAY('one', 'two', 'three'))[k])"), Seq(Row(Map("one" -> 1.0, "two" -> 1.4, "three" -> 1.7)))) + checkAnswer(dfExample2.select( + transform_keys( + dfExample2("j"), + (k, v) => element_at( + map_from_arrays( + array(lit(1), lit(2), lit(3)), + array(lit("one"), lit("two"), lit("three")) + ), + k + ) + ) + ), + Seq(Row(Map("one" -> 1.0, "two" -> 1.4, "three" -> 1.7)))) + checkAnswer(dfExample2.selectExpr("transform_keys(j, (k, v) -> CAST(v * 2 AS BIGINT) + k)"), Seq(Row(Map(3 -> 1.0, 4 -> 1.4, 6 -> 1.7)))) + checkAnswer(dfExample2.select(transform_keys(dfExample2("j"), + (k, v) => (v * 2).cast("bigint") + k)), + Seq(Row(Map(3 -> 1.0, 4 -> 1.4, 6 -> 1.7)))) + checkAnswer(dfExample2.selectExpr("transform_keys(j, (k, v) -> k + v)"), Seq(Row(Map(2.0 -> 1.0, 3.4 -> 1.4, 4.7 -> 1.7)))) + checkAnswer(dfExample2.select(transform_keys(dfExample2("j"), (k, v) => k + v)), + Seq(Row(Map(2.0 -> 1.0, 3.4 -> 1.4, 4.7 -> 1.7)))) + checkAnswer(dfExample3.selectExpr("transform_keys(x, (k, v) -> k % 2 = 0 OR v)"), Seq(Row(Map(true -> true, true -> false)))) + checkAnswer(dfExample3.select(transform_keys(dfExample3("x"), (k, v) => k % 2 === 0 || v)), + Seq(Row(Map(true -> true, true -> false)))) + checkAnswer(dfExample3.selectExpr("transform_keys(x, (k, v) -> if(v, 2 * k, 3 * k))"), Seq(Row(Map(50 -> true, 78 -> false)))) - checkAnswer(dfExample3.selectExpr("transform_keys(x, (k, v) -> if(v, 2 * k, 3 * k))"), + checkAnswer(dfExample3.select(transform_keys(dfExample3("x"), + (k, v) => when(v, k * 2).otherwise(k * 3))), Seq(Row(Map(50 -> true, 78 -> false)))) checkAnswer(dfExample4.selectExpr("transform_keys(y, (k, v) -> array_contains(k, 3) AND v)"), Seq(Row(Map(false -> false)))) + + checkAnswer(dfExample4.select(transform_keys(dfExample4("y"), + (k, v) => array_contains(k, lit(3)) && v)), + Seq(Row(Map(false -> false)))) } // Test with local relation, the Project will be evaluated without codegen @@ -2592,6 +2841,11 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSQLContext { } assert(ex3.getMessage.contains("Cannot use null as map key")) + val ex3a = intercept[Exception] { + dfExample1.select(transform_keys(dfExample1("i"), (k, v) => v)).show() + } + assert(ex3a.getMessage.contains("Cannot use null as map key")) + val ex4 = intercept[AnalysisException] { dfExample2.selectExpr("transform_keys(j, (k, v) -> k + 1)") } @@ -2656,6 +2910,46 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSQLContext { checkAnswer( dfExample5.selectExpr("transform_values(c, (k, v) -> k + cardinality(v))"), Seq(Row(Map(1 -> 3)))) + + checkAnswer(dfExample1.select(transform_values(dfExample1("i"), (k, v) => k + v)), + Seq(Row(Map(1 -> 2, 9 -> 18, 8 -> 16, 7 -> 14)))) + + checkAnswer(dfExample2.select( + transform_values(dfExample2("x"), (k, v) => when(k, v).otherwise(k.cast("string")))), + Seq(Row(Map(false -> "false", true -> "def")))) + + checkAnswer(dfExample2.select(transform_values(dfExample2("x"), + (k, v) => (!k) && v === "abc")), + Seq(Row(Map(false -> true, true -> false)))) + + checkAnswer(dfExample3.select(transform_values(dfExample3("y"), (k, v) => v * v)), + Seq(Row(Map("a" -> 1, "b" -> 4, "c" -> 9)))) + + checkAnswer(dfExample3.select( + transform_values(dfExample3("y"), (k, v) => concat(k, lit(":"), v.cast("string")))), + Seq(Row(Map("a" -> "a:1", "b" -> "b:2", "c" -> "c:3")))) + + checkAnswer( + dfExample3.select(transform_values(dfExample3("y"), (k, v) => concat(k, v.cast("string")))), + Seq(Row(Map("a" -> "a1", "b" -> "b2", "c" -> "c3")))) + + val testMap = map_from_arrays( + array(lit(1), lit(2), lit(3)), + array(lit("one"), lit("two"), lit("three")) + ) + + checkAnswer( + dfExample4.select(transform_values(dfExample4("z"), + (k, v) => concat(element_at(testMap, k), lit("_"), v.cast("string")))), + Seq(Row(Map(1 -> "one_1.0", 2 -> "two_1.4", 3 ->"three_1.7")))) + + checkAnswer( + dfExample4.select(transform_values(dfExample4("z"), (k, v) => k - v)), + Seq(Row(Map(1 -> 0.0, 2 -> 0.6000000000000001, 3 -> 1.3)))) + + checkAnswer( + dfExample5.select(transform_values(dfExample5("c"), (k, v) => k + size(v))), + Seq(Row(Map(1 -> 3)))) } // Test with local relation, the Project will be evaluated without codegen @@ -2699,6 +2993,28 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSQLContext { checkAnswer(dfExample2.selectExpr("transform_values(j, (k, v) -> k + cast(v as BIGINT))"), Seq(Row(Map.empty[BigInt, BigInt]))) + + checkAnswer(dfExample1.select(transform_values(dfExample1("i"), + (k, v) => lit(null).cast("int"))), + Seq(Row(Map.empty[Integer, Integer]))) + + checkAnswer(dfExample1.select(transform_values(dfExample1("i"), (k, v) => k)), + Seq(Row(Map.empty[Integer, Integer]))) + + checkAnswer(dfExample1.select(transform_values(dfExample1("i"), (k, v) => v)), + Seq(Row(Map.empty[Integer, Integer]))) + + checkAnswer(dfExample1.select(transform_values(dfExample1("i"), (k, v) => lit(0))), + Seq(Row(Map.empty[Integer, Integer]))) + + checkAnswer(dfExample1.select(transform_values(dfExample1("i"), (k, v) => lit("value"))), + Seq(Row(Map.empty[Integer, String]))) + + checkAnswer(dfExample1.select(transform_values(dfExample1("i"), (k, v) => lit(true))), + Seq(Row(Map.empty[Integer, Boolean]))) + + checkAnswer(dfExample1.select(transform_values(dfExample1("i"), (k, v) => v.cast("bigint"))), + Seq(Row(Map.empty[BigInt, BigInt]))) } testEmpty() @@ -2723,6 +3039,15 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSQLContext { checkAnswer(dfExample2.selectExpr( "transform_values(b, (k, v) -> IF(v IS NULL, k + 1, k + 2))"), Seq(Row(Map(1 -> 3, 2 -> 4, 3 -> 4)))) + + checkAnswer(dfExample1.select(transform_values(dfExample1("a"), + (k, v) => lit(null).cast("int"))), + Seq(Row(Map[Int, Integer](1 -> null, 2 -> null, 3 -> null, 4 -> null)))) + + checkAnswer(dfExample2.select( + transform_values(dfExample2("b"), (k, v) => when(v.isNull, k + 1).otherwise(k + 2)) + ), + Seq(Row(Map(1 -> 3, 2 -> 4, 3 -> 4)))) } testNullValue() @@ -2761,6 +3086,11 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSQLContext { } assert(ex3.getMessage.contains( "data type mismatch: argument 1 requires map type")) + + val ex3a = intercept[MatchError] { + dfExample3.select(transform_values(dfExample3("x"), (k, v) => k + 1)) + } + assert(ex3a.getMessage.contains("IntegerType")) } testInvalidLambdaFunctions() @@ -2787,10 +3117,15 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSQLContext { Row(Seq.empty), Row(null)) checkAnswer(df1.selectExpr("zip_with(val1, val2, (x, y) -> x + y)"), expectedValue1) + checkAnswer(df1.select(zip_with(df1("val1"), df1("val2"), (x, y) => x + y)), expectedValue1) val expectedValue2 = Seq( Row(Seq(Row(1L, 1), Row(2L, null), Row(null, 3))), Row(Seq(Row(4L, 1), Row(11L, 2), Row(null, 3)))) checkAnswer(df2.selectExpr("zip_with(val1, val2, (x, y) -> (y, x))"), expectedValue2) + checkAnswer( + df2.select(zip_with(df2("val1"), df2("val2"), (x, y) => struct(y, x))), + expectedValue2 + ) } test("arrays zip_with function - for non-primitive types") { @@ -2805,7 +3140,14 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSQLContext { Row(Seq(Row("x", "a"), Row("y", null))), Row(Seq.empty), Row(null)) - checkAnswer(df.selectExpr("zip_with(val1, val2, (x, y) -> (y, x))"), expectedValue1) + checkAnswer( + df.selectExpr("zip_with(val1, val2, (x, y) -> (y, x))"), + expectedValue1 + ) + checkAnswer( + df.select(zip_with(df("val1"), df("val2"), (x, y) => struct(y, x))), + expectedValue1 + ) } test("arrays zip_with function - invalid") { @@ -2827,6 +3169,10 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSQLContext { df.selectExpr("zip_with(i, a2, (acc, x) -> x)") } assert(ex3.getMessage.contains("data type mismatch: argument 1 requires array type")) + val ex3a = intercept[MatchError] { + df.select(zip_with(df("i"), df("a2"), (acc, x) => x)) + } + assert(ex3a.getMessage.contains("IntegerType")) val ex4 = intercept[AnalysisException] { df.selectExpr("zip_with(a1, a, (acc, x) -> x)") } From 7adaf9c9627b7057b2dbadba0efcd3928eea43f3 Mon Sep 17 00:00:00 2001 From: Nik Vanderhoof Date: Thu, 28 Mar 2019 08:59:01 -0400 Subject: [PATCH 09/38] Add (Scala-specifc) note to higher order functions These signatures won't work in java as they rely on Scala lambdas --- .../org/apache/spark/sql/functions.scala | 22 +++++++++---------- 1 file changed, 11 insertions(+), 11 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala index da5a011be881a..72b9290ffe2a4 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala @@ -3334,7 +3334,7 @@ object functions { (x, y, z) => f(Column(x), Column(y), Column(z)).expr /** - * Returns an array of elements after applying a tranformation to each element + * (Scala-specific) Returns an array of elements after applying a tranformation to each element * in the input array. * * @group collection_funcs @@ -3344,7 +3344,7 @@ object functions { } /** - * Returns an array of elements after applying a tranformation to each element + * (Scala-specific) Returns an array of elements after applying a tranformation to each element * in the input array. * * @group collection_funcs @@ -3354,7 +3354,7 @@ object functions { } /** - * Returns whether a predicate holds for one or more elements in the array. + * (Scala-specific) Returns whether a predicate holds for one or more elements in the array. * * @group collection_funcs */ @@ -3363,7 +3363,7 @@ object functions { } /** - * Returns an array of elements for which a predicate holds in a given array. + * (Scala-specific) Returns an array of elements for which a predicate holds in a given array. * * @group collection_funcs */ @@ -3372,7 +3372,7 @@ object functions { } /** - * Applies a binary operator to an initial state and all elements in the array, + * (Scala-specific) Applies a binary operator to an initial state and all elements in the array, * and reduces this to a single state. The final state is converted into the final result * by applying a finish function. * @@ -3392,7 +3392,7 @@ object functions { } /** - * Applies a binary operator to an initial state and all elements in the array, + * (Scala-specific) Applies a binary operator to an initial state and all elements in the array, * and reduces this to a single state. * * @group collection_funcs @@ -3404,7 +3404,7 @@ object functions { aggregate(expr, zero, merge, identity) /** - * Merge two given arrays, element-wise, into a signle array using a function. + * (Scala-specific) Merge two given arrays, element-wise, into a signle array using a function. * If one array is shorter, nulls are appended at the end to match the length of the longer * array, before applying the function. * @@ -3418,7 +3418,7 @@ object functions { } /** - * Applies a function to every key-value pair in a map and returns + * (Scala-specific) Applies a function to every key-value pair in a map and returns * a map with the results of those applications as the new keys for the pairs. * * @group collection_funcs @@ -3428,7 +3428,7 @@ object functions { } /** - * Applies a function to every key-value pair in a map and returns + * (Scala-specific) Applies a function to every key-value pair in a map and returns * a map with the results of those applications as the new values for the pairs. * * @group collection_funcs @@ -3438,7 +3438,7 @@ object functions { } /** - * Returns a map whose key-value pairs satisfy a predicate. + * (Scala-specific) Returns a map whose key-value pairs satisfy a predicate. * * @group collection_funcs */ @@ -3447,7 +3447,7 @@ object functions { } /** - * Merge two given maps, key-wise into a single map using a function. + * (Scala-specific) Merge two given maps, key-wise into a single map using a function. * * @group collection_funcs */ From ac5c1c2dadc6b9fa22e0ca21e91d65628f70ac03 Mon Sep 17 00:00:00 2001 From: Nik Vanderhoof Date: Thu, 28 Mar 2019 09:07:06 -0400 Subject: [PATCH 10/38] Follow style guide more closely --- .../org/apache/spark/sql/functions.scala | 44 +++++++------------ 1 file changed, 16 insertions(+), 28 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala index 72b9290ffe2a4..99ad4ff4d6a31 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala @@ -3321,16 +3321,14 @@ object functions { ArrayExcept(col1.expr, col2.expr) } - private def expressionFunction(f: Column => Column) - : Expression => Expression = + private def expressionFunction(f: Column => Column): Expression => Expression = x => f(Column(x)).expr - private def expressionFunction(f: (Column, Column) => Column) - : (Expression, Expression) => Expression = + private def expressionFunction(f: (Column, Column) => Column): (Expression, Expression) => Expression = (x, y) => f(Column(x), Column(y)).expr private def expressionFunction(f: (Column, Column, Column) => Column) - : (Expression, Expression, Expression) => Expression = + : (Expression, Expression, Expression) => Expression = (x, y, z) => f(Column(x), Column(y), Column(z)).expr /** @@ -3378,18 +3376,15 @@ object functions { * * @group collection_funcs */ - def aggregate( - expr: Column, - zero: Column, - merge: (Column, Column) => Column, - finish: Column => Column): Column = withExpr { - HigherOrderUtils.aggregate( - expr.expr, - zero.expr, - expressionFunction(merge), - expressionFunction(finish) - ) - } + def aggregate(expr: Column, zero: Column, merge: (Column, Column) => Column, finish: Column => Column): Column = + withExpr { + HigherOrderUtils.aggregate( + expr.expr, + zero.expr, + expressionFunction(merge), + expressionFunction(finish) + ) + } /** * (Scala-specific) Applies a binary operator to an initial state and all elements in the array, @@ -3397,11 +3392,8 @@ object functions { * * @group collection_funcs */ - def aggregate( - expr: Column, - zero: Column, - merge: (Column, Column) => Column): Column = - aggregate(expr, zero, merge, identity) + def aggregate(expr: Column, zero: Column, merge: (Column, Column) => Column): Column = + aggregate(expr, zero, merge, identity) /** * (Scala-specific) Merge two given arrays, element-wise, into a signle array using a function. @@ -3410,10 +3402,7 @@ object functions { * * @group collection_funcs */ - def zip_with( - left: Column, - right: Column, - f: (Column, Column) => Column): Column = withExpr { + def zip_with(left: Column, right: Column, f: (Column, Column) => Column): Column = withExpr { HigherOrderUtils.zip_with(left.expr, right.expr, expressionFunction(f)) } @@ -3451,8 +3440,7 @@ object functions { * * @group collection_funcs */ - def map_zip_with(left: Column, right: Column, f: (Column, Column, Column) => Column): Column = - withExpr { + def map_zip_with(left: Column, right: Column, f: (Column, Column, Column) => Column): Column = withExpr { HigherOrderUtils.map_zip_with(left.expr, right.expr, expressionFunction(f)) } From 40ac418a338963af830a24e1cd559de75f44afd6 Mon Sep 17 00:00:00 2001 From: Nik Vanderhoof Date: Thu, 28 Mar 2019 09:24:43 -0400 Subject: [PATCH 11/38] Fix scalastyle issues --- .../org/apache/spark/sql/functions.scala | 28 ++++++++++--------- 1 file changed, 15 insertions(+), 13 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala index 99ad4ff4d6a31..78a2be5f46d3d 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala @@ -3324,7 +3324,8 @@ object functions { private def expressionFunction(f: Column => Column): Expression => Expression = x => f(Column(x)).expr - private def expressionFunction(f: (Column, Column) => Column): (Expression, Expression) => Expression = + private def expressionFunction(f: (Column, Column) => Column) + : (Expression, Expression) => Expression = (x, y) => f(Column(x), Column(y)).expr private def expressionFunction(f: (Column, Column, Column) => Column) @@ -3376,15 +3377,15 @@ object functions { * * @group collection_funcs */ - def aggregate(expr: Column, zero: Column, merge: (Column, Column) => Column, finish: Column => Column): Column = - withExpr { - HigherOrderUtils.aggregate( - expr.expr, - zero.expr, - expressionFunction(merge), - expressionFunction(finish) - ) - } + def aggregate(expr: Column, zero: Column, merge: (Column, Column) => Column, + finish: Column => Column): Column = withExpr { + HigherOrderUtils.aggregate( + expr.expr, + zero.expr, + expressionFunction(merge), + expressionFunction(finish) + ) + } /** * (Scala-specific) Applies a binary operator to an initial state and all elements in the array, @@ -3440,9 +3441,10 @@ object functions { * * @group collection_funcs */ - def map_zip_with(left: Column, right: Column, f: (Column, Column, Column) => Column): Column = withExpr { - HigherOrderUtils.map_zip_with(left.expr, right.expr, expressionFunction(f)) - } + def map_zip_with(left: Column, right: Column, + f: (Column, Column, Column) => Column): Column = withExpr { + HigherOrderUtils.map_zip_with(left.expr, right.expr, expressionFunction(f)) + } /** * Creates a new row for each element in the given array or map column. From fb5f8ef8fa0c00a37a18ddd2abb1877079b3d2c3 Mon Sep 17 00:00:00 2001 From: Nik Vanderhoof Date: Thu, 28 Mar 2019 19:01:20 -0400 Subject: [PATCH 12/38] Add java-specific version of higher order function api --- .../org/apache/spark/sql/functions.scala | 129 +++++++++++++++++- 1 file changed, 128 insertions(+), 1 deletion(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala index 78a2be5f46d3d..cc3d23d943c65 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala @@ -24,6 +24,7 @@ import scala.util.Try import scala.util.control.NonFatal import org.apache.spark.annotation.Stable +import org.apache.spark.api.java.function._ import org.apache.spark.sql.api.java._ import org.apache.spark.sql.catalyst.ScalaReflection import org.apache.spark.sql.catalyst.analysis.{Star, UnresolvedFunction} @@ -3332,6 +3333,17 @@ object functions { : (Expression, Expression, Expression) => Expression = (x, y, z) => f(Column(x), Column(y), Column(z)).expr + private def expressionFunction(f: Function[Column, Column]): Expression => Expression = + x => f.call(Column(x)).expr + + private def expressionFunction(f: Function2[Column, Column, Column]) + : (Expression, Expression) => Expression = + (x, y) => f.call(Column(x), Column(y)).expr + + private def expressionFunction(f: Function3[Column, Column, Column, Column]) + : (Expression, Expression, Expression) => Expression = + (x, y, z) => f.call(Column(x), Column(y), Column(z)).expr + /** * (Scala-specific) Returns an array of elements after applying a tranformation to each element * in the input array. @@ -3394,7 +3406,7 @@ object functions { * @group collection_funcs */ def aggregate(expr: Column, zero: Column, merge: (Column, Column) => Column): Column = - aggregate(expr, zero, merge, identity) + aggregate(expr, zero, merge, c => c) /** * (Scala-specific) Merge two given arrays, element-wise, into a signle array using a function. @@ -3446,6 +3458,121 @@ object functions { HigherOrderUtils.map_zip_with(left.expr, right.expr, expressionFunction(f)) } + /** + * (Java-specific) Returns an array of elements after applying a tranformation to each element + * in the input array. + * + * @group collection_funcs + */ + def transform(column: Column, f: Function[Column, Column]): Column = withExpr { + HigherOrderUtils.transform(column.expr, expressionFunction(f)) + } + + /** + * (Java-specific) Returns an array of elements after applying a tranformation to each element + * in the input array. + * + * @group collection_funcs + */ + def transform(column: Column, f: Function2[Column, Column, Column]): Column = withExpr { + HigherOrderUtils.transform(column.expr, expressionFunction(f)) + } + + /** + * (Java-specific) Returns whether a predicate holds for one or more elements in the array. + * + * @group collection_funcs + */ + def exists(column: Column, f: Function[Column, Column]): Column = withExpr { + HigherOrderUtils.exists(column.expr, expressionFunction(f)) + } + + /** + * (Java-specific) Returns an array of elements for which a predicate holds in a given array. + * + * @group collection_funcs + */ + def filter(column: Column, f: Function[Column, Column]): Column = withExpr { + HigherOrderUtils.filter(column.expr, expressionFunction(f)) + } + + /** + * (Java-specific) Applies a binary operator to an initial state and all elements in the array, + * and reduces this to a single state. The final state is converted into the final result + * by applying a finish function. + * + * @group collection_funcs + */ + def aggregate(expr: Column, zero: Column, merge: Function2[Column, Column, Column], + finish: Function[Column, Column]): Column = withExpr { + HigherOrderUtils.aggregate( + expr.expr, + zero.expr, + expressionFunction(merge), + expressionFunction(finish) + ) + } + + /** + * (Java-specific) Applies a binary operator to an initial state and all elements in the array, + * and reduces this to a single state. + * + * @group collection_funcs + */ + def aggregate(expr: Column, zero: Column, merge: Function2[Column, Column, Column]): Column = + aggregate(expr, zero, merge, new Function[Column, Column] { def call(c: Column): Column = c }) + + /** + * (Java-specific) Merge two given arrays, element-wise, into a signle array using a function. + * If one array is shorter, nulls are appended at the end to match the length of the longer + * array, before applying the function. + * + * @group collection_funcs + */ + def zip_with(left: Column, right: Column, f: Function2[Column, Column, Column]): Column = + withExpr { + HigherOrderUtils.zip_with(left.expr, right.expr, expressionFunction(f)) + } + + /** + * (Java-specific) Applies a function to every key-value pair in a map and returns + * a map with the results of those applications as the new keys for the pairs. + * + * @group collection_funcs + */ + def transform_keys(expr: Column, f: Function2[Column, Column, Column]): Column = withExpr { + HigherOrderUtils.transformKeys(expr.expr, expressionFunction(f)) + } + + /** + * (Java-specific) Applies a function to every key-value pair in a map and returns + * a map with the results of those applications as the new values for the pairs. + * + * @group collection_funcs + */ + def transform_values(expr: Column, f: Function2[Column, Column, Column]): Column = withExpr { + HigherOrderUtils.transformValues(expr.expr, expressionFunction(f)) + } + + /** + * (Java-specific) Returns a map whose key-value pairs satisfy a predicate. + * + * @group collection_funcs + */ + def map_filter(expr: Column, f: Function2[Column, Column, Column]): Column = withExpr { + HigherOrderUtils.mapFilter(expr.expr, expressionFunction(f)) + } + + /** + * (Java-specific) Merge two given maps, key-wise into a single map using a function. + * + * @group collection_funcs + */ + def map_zip_with(left: Column, right: Column, + f: Function3[Column, Column, Column, Column]): Column = withExpr { + HigherOrderUtils.map_zip_with(left.expr, right.expr, expressionFunction(f)) + } + /** * Creates a new row for each element in the given array or map column. * Uses the default column name `col` for elements in the array and From 85979d44e5d44112723d6a9b25ea278be529e15d Mon Sep 17 00:00:00 2001 From: Nik Vanderhoof Date: Fri, 14 Jun 2019 15:34:08 -0400 Subject: [PATCH 13/38] Do not prematurely bind lambda variables --- .../expressions/higherOrderFunctions.scala | 133 -------------- .../HigherOrderFunctionsSuite.scala | 125 ++++++++++++- .../org/apache/spark/sql/functions.scala | 100 ++++++----- .../spark/sql/DataFrameFunctionsSuite.scala | 165 +++++++++--------- 4 files changed, 268 insertions(+), 255 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/higherOrderFunctions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/higherOrderFunctions.scala index cc9e42c893d09..b326e1c4c6af4 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/higherOrderFunctions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/higherOrderFunctions.scala @@ -29,139 +29,6 @@ import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types._ import org.apache.spark.unsafe.array.ByteArrayMethods -/** - * Helper methods for constructing higher order functions. - */ -object HigherOrderUtils { - def createLambda( - dt: DataType, - nullable: Boolean, - f: Expression => Expression): Expression = { - val lv = NamedLambdaVariable("arg", dt, nullable) - val function = f(lv) - LambdaFunction(function, Seq(lv)) - } - - def createLambda( - dt1: DataType, - nullable1: Boolean, - dt2: DataType, - nullable2: Boolean, - f: (Expression, Expression) => Expression): Expression = { - val lv1 = NamedLambdaVariable("arg1", dt1, nullable1) - val lv2 = NamedLambdaVariable("arg2", dt2, nullable2) - val function = f(lv1, lv2) - LambdaFunction(function, Seq(lv1, lv2)) - } - - def createLambda( - dt1: DataType, - nullable1: Boolean, - dt2: DataType, - nullable2: Boolean, - dt3: DataType, - nullable3: Boolean, - f: (Expression, Expression, Expression) => Expression): Expression = { - val lv1 = NamedLambdaVariable("arg1", dt1, nullable1) - val lv2 = NamedLambdaVariable("arg2", dt2, nullable2) - val lv3 = NamedLambdaVariable("arg3", dt3, nullable3) - val function = f(lv1, lv2, lv3) - LambdaFunction(function, Seq(lv1, lv2, lv3)) - } - - def validateBinding( - e: Expression, - argInfo: Seq[(DataType, Boolean)]): LambdaFunction = e match { - case f: LambdaFunction => - assert(f.arguments.size == argInfo.size) - f.arguments.zip(argInfo).foreach { - case (arg, (dataType, nullable)) => - assert(arg.dataType == dataType) - assert(arg.nullable == nullable) - } - f - } - - // Array-based helpers - def filter(expr: Expression, f: Expression => Expression): Expression = { - val ArrayType(et, cn) = expr.dataType - ArrayFilter(expr, createLambda(et, cn, f)).bind(validateBinding) - } - - def exists(expr: Expression, f: Expression => Expression): Expression = { - val ArrayType(et, cn) = expr.dataType - ArrayExists(expr, createLambda(et, cn, f)).bind(validateBinding) - } - - def transform(expr: Expression, f: Expression => Expression): Expression = { - val ArrayType(et, cn) = expr.dataType - ArrayTransform(expr, createLambda(et, cn, f)).bind(validateBinding) - } - - def transform(expr: Expression, f: (Expression, Expression) => Expression): Expression = { - val ArrayType(et, cn) = expr.dataType - ArrayTransform(expr, createLambda(et, cn, IntegerType, false, f)).bind(validateBinding) - } - - def aggregate( - expr: Expression, - zero: Expression, - merge: (Expression, Expression) => Expression, - finish: Expression => Expression): Expression = { - val ArrayType(et, cn) = expr.dataType - val zeroType = zero.dataType - ArrayAggregate( - expr, - zero, - createLambda(zeroType, true, et, cn, merge), - createLambda(zeroType, true, finish)) - .bind(validateBinding) - } - - def aggregate( - expr: Expression, - zero: Expression, - merge: (Expression, Expression) => Expression): Expression = { - aggregate(expr, zero, merge, identity) - } - - def zip_with( - left: Expression, - right: Expression, - f: (Expression, Expression) => Expression): Expression = { - val ArrayType(leftT, _) = left.dataType - val ArrayType(rightT, _) = right.dataType - ZipWith(left, right, createLambda(leftT, true, rightT, true, f)).bind(validateBinding) - } - - // Map-based helpers - - def transformKeys(expr: Expression, f: (Expression, Expression) => Expression): Expression = { - val MapType(kt, vt, vcn) = expr.dataType - TransformKeys(expr, createLambda(kt, false, vt, vcn, f)).bind(validateBinding) - } - - def transformValues(expr: Expression, f: (Expression, Expression) => Expression): Expression = { - val MapType(kt, vt, vcn) = expr.dataType - TransformValues(expr, createLambda(kt, false, vt, vcn, f)).bind(validateBinding) - } - - def mapFilter(expr: Expression, f: (Expression, Expression) => Expression): Expression = { - val MapType(kt, vt, vcn) = expr.dataType - MapFilter(expr, createLambda(kt, false, vt, vcn, f)).bind(validateBinding) - } - - def map_zip_with( - left: Expression, - right: Expression, - f: (Expression, Expression, Expression) => Expression): Expression = { - val MapType(kt, vt1, _) = left.dataType - val MapType(_, vt2, _) = right.dataType - MapZipWith(left, right, createLambda(kt, false, vt1, true, vt2, true, f)) - .bind(validateBinding) - } -} - /** * A placeholder of lambda variables to prevent unexpected resolution of [[LambdaFunction]]. */ diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/HigherOrderFunctionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/HigherOrderFunctionsSuite.scala index 046c67a8a665c..1411be8007deb 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/HigherOrderFunctionsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/HigherOrderFunctionsSuite.scala @@ -24,7 +24,102 @@ import org.apache.spark.sql.types._ class HigherOrderFunctionsSuite extends SparkFunSuite with ExpressionEvalHelper { import org.apache.spark.sql.catalyst.dsl.expressions._ - import org.apache.spark.sql.catalyst.expressions.HigherOrderUtils._ + + private def createLambda( + dt: DataType, + nullable: Boolean, + f: Expression => Expression): Expression = { + val lv = NamedLambdaVariable("arg", dt, nullable) + val function = f(lv) + LambdaFunction(function, Seq(lv)) + } + + private def createLambda( + dt1: DataType, + nullable1: Boolean, + dt2: DataType, + nullable2: Boolean, + f: (Expression, Expression) => Expression): Expression = { + val lv1 = NamedLambdaVariable("arg1", dt1, nullable1) + val lv2 = NamedLambdaVariable("arg2", dt2, nullable2) + val function = f(lv1, lv2) + LambdaFunction(function, Seq(lv1, lv2)) + } + + private def createLambda( + dt1: DataType, + nullable1: Boolean, + dt2: DataType, + nullable2: Boolean, + dt3: DataType, + nullable3: Boolean, + f: (Expression, Expression, Expression) => Expression): Expression = { + val lv1 = NamedLambdaVariable("arg1", dt1, nullable1) + val lv2 = NamedLambdaVariable("arg2", dt2, nullable2) + val lv3 = NamedLambdaVariable("arg3", dt3, nullable3) + val function = f(lv1, lv2, lv3) + LambdaFunction(function, Seq(lv1, lv2, lv3)) + } + + private def validateBinding( + e: Expression, + argInfo: Seq[(DataType, Boolean)]): LambdaFunction = e match { + case f: LambdaFunction => + assert(f.arguments.size === argInfo.size) + f.arguments.zip(argInfo).foreach { + case (arg, (dataType, nullable)) => + assert(arg.dataType === dataType) + assert(arg.nullable === nullable) + } + f + } + + def transform(expr: Expression, f: Expression => Expression): Expression = { + val ArrayType(et, cn) = expr.dataType + ArrayTransform(expr, createLambda(et, cn, f)).bind(validateBinding) + } + + def transform(expr: Expression, f: (Expression, Expression) => Expression): Expression = { + val ArrayType(et, cn) = expr.dataType + ArrayTransform(expr, createLambda(et, cn, IntegerType, false, f)).bind(validateBinding) + } + + def filter(expr: Expression, f: Expression => Expression): Expression = { + val ArrayType(et, cn) = expr.dataType + ArrayFilter(expr, createLambda(et, cn, f)).bind(validateBinding) + } + + def transformKeys(expr: Expression, f: (Expression, Expression) => Expression): Expression = { + val MapType(kt, vt, vcn) = expr.dataType + TransformKeys(expr, createLambda(kt, false, vt, vcn, f)).bind(validateBinding) + } + + def aggregate( + expr: Expression, + zero: Expression, + merge: (Expression, Expression) => Expression, + finish: Expression => Expression): Expression = { + val ArrayType(et, cn) = expr.dataType + val zeroType = zero.dataType + ArrayAggregate( + expr, + zero, + createLambda(zeroType, true, et, cn, merge), + createLambda(zeroType, true, finish)) + .bind(validateBinding) + } + + def aggregate( + expr: Expression, + zero: Expression, + merge: (Expression, Expression) => Expression): Expression = { + aggregate(expr, zero, merge, identity) + } + + def transformValues(expr: Expression, f: (Expression, Expression) => Expression): Expression = { + val MapType(kt, vt, vcn) = expr.dataType + TransformValues(expr, createLambda(kt, false, vt, vcn, f)).bind(validateBinding) + } test("ArrayTransform") { val ai0 = Literal.create(Seq(1, 2, 3), ArrayType(IntegerType, containsNull = false)) @@ -68,6 +163,10 @@ class HigherOrderFunctionsSuite extends SparkFunSuite with ExpressionEvalHelper } test("MapFilter") { + def mapFilter(expr: Expression, f: (Expression, Expression) => Expression): Expression = { + val MapType(kt, vt, vcn) = expr.dataType + MapFilter(expr, createLambda(kt, false, vt, vcn, f)).bind(validateBinding) + } val mii0 = Literal.create(Map(1 -> 0, 2 -> 10, 3 -> -1), MapType(IntegerType, IntegerType, valueContainsNull = false)) val mii1 = Literal.create(Map(1 -> null, 2 -> 10, 3 -> null), @@ -145,6 +244,11 @@ class HigherOrderFunctionsSuite extends SparkFunSuite with ExpressionEvalHelper } test("ArrayExists") { + def exists(expr: Expression, f: Expression => Expression): Expression = { + val ArrayType(et, cn) = expr.dataType + ArrayExists(expr, createLambda(et, cn, f)).bind(validateBinding) + } + val ai0 = Literal.create(Seq(1, 2, 3), ArrayType(IntegerType, containsNull = false)) val ai1 = Literal.create(Seq[Integer](1, null, 3), ArrayType(IntegerType, containsNull = true)) val ain = Literal.create(null, ArrayType(IntegerType, containsNull = false)) @@ -377,6 +481,16 @@ class HigherOrderFunctionsSuite extends SparkFunSuite with ExpressionEvalHelper } test("MapZipWith") { + def map_zip_with( + left: Expression, + right: Expression, + f: (Expression, Expression, Expression) => Expression): Expression = { + val MapType(kt, vt1, _) = left.dataType + val MapType(_, vt2, _) = right.dataType + MapZipWith(left, right, createLambda(kt, false, vt1, true, vt2, true, f)) + .bind(validateBinding) + } + val mii0 = Literal.create(create_map(1 -> 10, 2 -> 20, 3 -> 30), MapType(IntegerType, IntegerType, valueContainsNull = false)) val mii1 = Literal.create(create_map(1 -> -1, 2 -> -2, 4 -> -4), @@ -459,6 +573,15 @@ class HigherOrderFunctionsSuite extends SparkFunSuite with ExpressionEvalHelper } test("ZipWith") { + def zip_with( + left: Expression, + right: Expression, + f: (Expression, Expression) => Expression): Expression = { + val ArrayType(leftT, _) = left.dataType + val ArrayType(rightT, _) = right.dataType + ZipWith(left, right, createLambda(leftT, true, rightT, true, f)).bind(validateBinding) + } + val ai0 = Literal.create(Seq(1, 2, 3), ArrayType(IntegerType, containsNull = false)) val ai1 = Literal.create(Seq(1, 2, 3, 4), ArrayType(IntegerType, containsNull = false)) val ai2 = Literal.create(Seq[Integer](1, null, 3), ArrayType(IntegerType, containsNull = true)) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala index cc3d23d943c65..5e3d43e25dc9a 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala @@ -3322,27 +3322,47 @@ object functions { ArrayExcept(col1.expr, col2.expr) } - private def expressionFunction(f: Column => Column): Expression => Expression = - x => f(Column(x)).expr + private def createLambda(f: Column => Column) = { + val x = UnresolvedNamedLambdaVariable(Seq("x")) + val function = f(Column(x)).expr + LambdaFunction(function, Seq(x)) + } - private def expressionFunction(f: (Column, Column) => Column) - : (Expression, Expression) => Expression = - (x, y) => f(Column(x), Column(y)).expr + private def createLambda(f: (Column, Column) => Column) = { + val x = UnresolvedNamedLambdaVariable(Seq("x")) + val y = UnresolvedNamedLambdaVariable(Seq("y")) + val function = f(Column(x), Column(y)).expr + LambdaFunction(function, Seq(x, y)) + } - private def expressionFunction(f: (Column, Column, Column) => Column) - : (Expression, Expression, Expression) => Expression = - (x, y, z) => f(Column(x), Column(y), Column(z)).expr + private def createLambda(f: (Column, Column, Column) => Column) = { + val x = UnresolvedNamedLambdaVariable(Seq("x")) + val y = UnresolvedNamedLambdaVariable(Seq("y")) + val z = UnresolvedNamedLambdaVariable(Seq("z")) + val function = f(Column(x), Column(y), Column(z)).expr + LambdaFunction(function, Seq(x, y, z)) + } - private def expressionFunction(f: Function[Column, Column]): Expression => Expression = - x => f.call(Column(x)).expr + private def createLambda(f: Function[Column, Column]) = { + val x = UnresolvedNamedLambdaVariable(Seq("x")) + val function = f.call(Column(x)).expr + LambdaFunction(function, Seq(x)) + } - private def expressionFunction(f: Function2[Column, Column, Column]) - : (Expression, Expression) => Expression = - (x, y) => f.call(Column(x), Column(y)).expr + private def createLambda(f: Function2[Column, Column, Column]) = { + val x = UnresolvedNamedLambdaVariable(Seq("x")) + val y = UnresolvedNamedLambdaVariable(Seq("y")) + val function = f.call(Column(x), Column(y)).expr + LambdaFunction(function, Seq(x, y)) + } - private def expressionFunction(f: Function3[Column, Column, Column, Column]) - : (Expression, Expression, Expression) => Expression = - (x, y, z) => f.call(Column(x), Column(y), Column(z)).expr + private def createLambda(f: Function3[Column, Column, Column, Column]) = { + val x = UnresolvedNamedLambdaVariable(Seq("x")) + val y = UnresolvedNamedLambdaVariable(Seq("y")) + val z = UnresolvedNamedLambdaVariable(Seq("z")) + val function = f.call(Column(x), Column(y), Column(z)).expr + LambdaFunction(function, Seq(x, y, z)) + } /** * (Scala-specific) Returns an array of elements after applying a tranformation to each element @@ -3351,7 +3371,7 @@ object functions { * @group collection_funcs */ def transform(column: Column, f: Column => Column): Column = withExpr { - HigherOrderUtils.transform(column.expr, expressionFunction(f)) + ArrayTransform(column.expr, createLambda(f)) } /** @@ -3361,7 +3381,7 @@ object functions { * @group collection_funcs */ def transform(column: Column, f: (Column, Column) => Column): Column = withExpr { - HigherOrderUtils.transform(column.expr, expressionFunction(f)) + ArrayTransform(column.expr, createLambda(f)) } /** @@ -3370,7 +3390,7 @@ object functions { * @group collection_funcs */ def exists(column: Column, f: Column => Column): Column = withExpr { - HigherOrderUtils.exists(column.expr, expressionFunction(f)) + ArrayExists(column.expr, createLambda(f)) } /** @@ -3379,7 +3399,7 @@ object functions { * @group collection_funcs */ def filter(column: Column, f: Column => Column): Column = withExpr { - HigherOrderUtils.filter(column.expr, expressionFunction(f)) + ArrayFilter(column.expr, createLambda(f)) } /** @@ -3391,11 +3411,11 @@ object functions { */ def aggregate(expr: Column, zero: Column, merge: (Column, Column) => Column, finish: Column => Column): Column = withExpr { - HigherOrderUtils.aggregate( + ArrayAggregate( expr.expr, zero.expr, - expressionFunction(merge), - expressionFunction(finish) + createLambda(merge), + createLambda(finish) ) } @@ -3416,7 +3436,7 @@ object functions { * @group collection_funcs */ def zip_with(left: Column, right: Column, f: (Column, Column) => Column): Column = withExpr { - HigherOrderUtils.zip_with(left.expr, right.expr, expressionFunction(f)) + ZipWith(left.expr, right.expr, createLambda(f)) } /** @@ -3426,7 +3446,7 @@ object functions { * @group collection_funcs */ def transform_keys(expr: Column, f: (Column, Column) => Column): Column = withExpr { - HigherOrderUtils.transformKeys(expr.expr, expressionFunction(f)) + TransformKeys(expr.expr, createLambda(f)) } /** @@ -3436,7 +3456,7 @@ object functions { * @group collection_funcs */ def transform_values(expr: Column, f: (Column, Column) => Column): Column = withExpr { - HigherOrderUtils.transformValues(expr.expr, expressionFunction(f)) + TransformValues(expr.expr, createLambda(f)) } /** @@ -3445,7 +3465,7 @@ object functions { * @group collection_funcs */ def map_filter(expr: Column, f: (Column, Column) => Column): Column = withExpr { - HigherOrderUtils.mapFilter(expr.expr, expressionFunction(f)) + MapFilter(expr.expr, createLambda(f)) } /** @@ -3455,7 +3475,7 @@ object functions { */ def map_zip_with(left: Column, right: Column, f: (Column, Column, Column) => Column): Column = withExpr { - HigherOrderUtils.map_zip_with(left.expr, right.expr, expressionFunction(f)) + MapZipWith(left.expr, right.expr, createLambda(f)) } /** @@ -3465,7 +3485,7 @@ object functions { * @group collection_funcs */ def transform(column: Column, f: Function[Column, Column]): Column = withExpr { - HigherOrderUtils.transform(column.expr, expressionFunction(f)) + ArrayTransform(column.expr, createLambda(f)) } /** @@ -3475,7 +3495,7 @@ object functions { * @group collection_funcs */ def transform(column: Column, f: Function2[Column, Column, Column]): Column = withExpr { - HigherOrderUtils.transform(column.expr, expressionFunction(f)) + ArrayTransform(column.expr, createLambda(f)) } /** @@ -3484,7 +3504,7 @@ object functions { * @group collection_funcs */ def exists(column: Column, f: Function[Column, Column]): Column = withExpr { - HigherOrderUtils.exists(column.expr, expressionFunction(f)) + ArrayExists(column.expr, createLambda(f)) } /** @@ -3493,7 +3513,7 @@ object functions { * @group collection_funcs */ def filter(column: Column, f: Function[Column, Column]): Column = withExpr { - HigherOrderUtils.filter(column.expr, expressionFunction(f)) + ArrayFilter(column.expr, createLambda(f)) } /** @@ -3505,11 +3525,11 @@ object functions { */ def aggregate(expr: Column, zero: Column, merge: Function2[Column, Column, Column], finish: Function[Column, Column]): Column = withExpr { - HigherOrderUtils.aggregate( + ArrayAggregate( expr.expr, zero.expr, - expressionFunction(merge), - expressionFunction(finish) + createLambda(merge), + createLambda(finish) ) } @@ -3531,7 +3551,7 @@ object functions { */ def zip_with(left: Column, right: Column, f: Function2[Column, Column, Column]): Column = withExpr { - HigherOrderUtils.zip_with(left.expr, right.expr, expressionFunction(f)) + ZipWith(left.expr, right.expr, createLambda(f)) } /** @@ -3541,7 +3561,7 @@ object functions { * @group collection_funcs */ def transform_keys(expr: Column, f: Function2[Column, Column, Column]): Column = withExpr { - HigherOrderUtils.transformKeys(expr.expr, expressionFunction(f)) + TransformKeys(expr.expr, createLambda(f)) } /** @@ -3551,7 +3571,7 @@ object functions { * @group collection_funcs */ def transform_values(expr: Column, f: Function2[Column, Column, Column]): Column = withExpr { - HigherOrderUtils.transformValues(expr.expr, expressionFunction(f)) + TransformValues(expr.expr, createLambda(f)) } /** @@ -3560,7 +3580,7 @@ object functions { * @group collection_funcs */ def map_filter(expr: Column, f: Function2[Column, Column, Column]): Column = withExpr { - HigherOrderUtils.mapFilter(expr.expr, expressionFunction(f)) + MapFilter(expr.expr, createLambda(f)) } /** @@ -3570,7 +3590,7 @@ object functions { */ def map_zip_with(left: Column, right: Column, f: Function3[Column, Column, Column, Column]): Column = withExpr { - HigherOrderUtils.map_zip_with(left.expr, right.expr, expressionFunction(f)) + MapZipWith(left.expr, right.expr, createLambda(f)) } /** diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala index 69a1e2130d990..068382d0f0dcb 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala @@ -1930,13 +1930,13 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSQLContext { Row(Seq(5, 9, 11, 10, 6)), Row(Seq.empty), Row(null))) - checkAnswer(df.select(transform(df("i"), x => x + 1)), + checkAnswer(df.select(transform(col("i"), x => x + 1)), Seq( Row(Seq(2, 10, 9, 8)), Row(Seq(6, 9, 10, 8, 3)), Row(Seq.empty), Row(null))) - checkAnswer(df.select(transform(df("i"), (x, i) => x + i)), + checkAnswer(df.select(transform(col("i"), (x, i) => x + i)), Seq( Row(Seq(1, 10, 10, 10)), Row(Seq(5, 9, 11, 10, 6)), @@ -1972,13 +1972,13 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSQLContext { Row(Seq(5, null, 10, 12, 11, 7)), Row(Seq.empty), Row(null))) - checkAnswer(df.select(transform(df("i"), x => x + 1)), + checkAnswer(df.select(transform(col("i"), x => x + 1)), Seq( Row(Seq(2, 10, 9, null, 8)), Row(Seq(6, null, 9, 10, 8, 3)), Row(Seq.empty), Row(null))) - checkAnswer(df.select(transform(df("i"), (x, i) => x + i)), + checkAnswer(df.select(transform(col("i"), (x, i) => x + i)), Seq( Row(Seq(1, 10, 10, null, 11)), Row(Seq(5, null, 10, 12, 11, 7)), @@ -2014,13 +2014,13 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSQLContext { Row(Seq("b0", null, "c2", null)), Row(Seq.empty), Row(null))) - checkAnswer(df.select(transform(df("s"), x => concat(x, x))), + checkAnswer(df.select(transform(col("s"), x => concat(x, x))), Seq( Row(Seq("cc", "aa", "bb")), Row(Seq("bb", null, "cc", null)), Row(Seq.empty), Row(null))) - checkAnswer(df.select(transform(df("s"), (x, i) => concat(x, i))), + checkAnswer(df.select(transform(col("s"), (x, i) => concat(x, i))), Seq( Row(Seq("c0", "a1", "b2")), Row(Seq("b0", null, "c2", null)), @@ -2070,13 +2070,13 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSQLContext { Seq("b", null, "c", null, null))), Row(Seq.empty), Row(null))) - checkAnswer(df.select(transform(df("arg"), arg => arg)), + checkAnswer(df.select(transform(col("arg"), arg => arg)), Seq( Row(Seq("c", "a", "b")), Row(Seq("b", null, "c", null)), Row(Seq.empty), Row(null))) - checkAnswer(df.select(transform(df("arg"), _ => df("arg"))), + checkAnswer(df.select(transform(col("arg"), _ => col("arg"))), Seq( Row(Seq(Seq("c", "a", "b"), Seq("c", "a", "b"), Seq("c", "a", "b"))), Row(Seq( @@ -2086,7 +2086,7 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSQLContext { Seq("b", null, "c", null))), Row(Seq.empty), Row(null))) - checkAnswer(df.select(transform(df("arg"), x => concat(df("arg"), array(x)))), + checkAnswer(df.select(transform(col("arg"), x => concat(col("arg"), array(x)))), Seq( Row(Seq(Seq("c", "a", "b", "c"), Seq("c", "a", "b", "a"), Seq("c", "a", "b", "b"))), Row(Seq( @@ -2143,8 +2143,8 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSQLContext { Row(Map(1 -> 10), Map(3 -> -3)))) checkAnswer(dfInts.select( - map_filter(dfInts("m"), (k, v) => k * 10 === v), - map_filter(dfInts("m"), (k, v) => k === (v * -1))), + map_filter(col("m"), (k, v) => k * 10 === v), + map_filter(col("m"), (k, v) => k === (v * -1))), Seq( Row(Map(1 -> 10, 2 -> 20, 3 -> 30), Map()), Row(Map(), Map(1 -> -1, 2 -> -2, 3 -> -3)), @@ -2161,8 +2161,8 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSQLContext { Row(Map(), Map(2 -> Seq(-2, -2))))) checkAnswer(dfComplex.select( - map_filter(dfComplex("m"), (k, v) => k === element_at(v, 1)), - map_filter(dfComplex("m"), (k, v) => k === size(v))), + map_filter(col("m"), (k, v) => k === element_at(v, 1)), + map_filter(col("m"), (k, v) => k === size(v))), Seq( Row(Map(1 -> Seq(1)), Map(1 -> Seq(1), 2 -> Seq(1, 2), 3 -> Seq(1, 2, 3))), Row(Map(), Map(2 -> Seq(-2, -2))))) @@ -2189,10 +2189,10 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSQLContext { } assert(ex3.getMessage.contains("data type mismatch: argument 1 requires map type")) - val ex3a = intercept[MatchError] { - df.select(map_filter(df("i"), (k, v) => k > v)) + val ex3a = intercept[AnalysisException] { + df.select(map_filter(col("i"), (k, v) => k > v)) } - assert(ex3a.getMessage.contains("IntegerType")) + assert(ex3a.getMessage.contains("data type mismatch: argument 1 requires map type")) val ex4 = intercept[AnalysisException] { df.selectExpr("map_filter(a, (k, v) -> k > v)") @@ -2215,7 +2215,7 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSQLContext { Row(Seq(8, 2)), Row(Seq.empty), Row(null))) - checkAnswer(df.select(filter(df("i"), _ % 2 === 0)), + checkAnswer(df.select(filter(col("i"), _ % 2 === 0)), Seq( Row(Seq(8)), Row(Seq(8, 2)), @@ -2245,7 +2245,7 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSQLContext { Row(Seq(8, 2)), Row(Seq.empty), Row(null))) - checkAnswer(df.select(filter(df("i"), _ % 2 === 0)), + checkAnswer(df.select(filter(col("i"), _ % 2 === 0)), Seq( Row(Seq(8)), Row(Seq(8, 2)), @@ -2275,7 +2275,7 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSQLContext { Row(Seq("b", "c")), Row(Seq.empty), Row(null))) - checkAnswer(df.select(filter(df("s"), x => x.isNotNull)), + checkAnswer(df.select(filter(col("s"), x => x.isNotNull)), Seq( Row(Seq("c", "a", "b")), Row(Seq("b", "c")), @@ -2308,10 +2308,10 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSQLContext { } assert(ex2.getMessage.contains("data type mismatch: argument 1 requires array type")) - val ex2a = intercept[MatchError] { - df.select(filter(df("i"), x => x)) + val ex2a = intercept[AnalysisException] { + df.select(filter(col("i"), x => x)) } - assert(ex2a.getMessage.contains("IntegerType")) + assert(ex2a.getMessage.contains("data type mismatch: argument 1 requires array type")) val ex3 = intercept[AnalysisException] { df.selectExpr("filter(s, x -> x)") @@ -2319,7 +2319,7 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSQLContext { assert(ex3.getMessage.contains("data type mismatch: argument 2 requires boolean type")) val ex3a = intercept[AnalysisException] { - df.select(filter(df("s"), x => x)) + df.select(filter(col("s"), x => x)) } assert(ex3a.getMessage.contains("data type mismatch: argument 2 requires boolean type")) @@ -2344,7 +2344,7 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSQLContext { Row(false), Row(false), Row(null))) - checkAnswer(df.select(exists(df("i"), _ % 2 === 0)), + checkAnswer(df.select(exists(col("i"), _ % 2 === 0)), Seq( Row(true), Row(false), @@ -2376,7 +2376,7 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSQLContext { Row(null), Row(false), Row(null))) - checkAnswer(df.select(exists(df("i"), _ % 2 === 0)), + checkAnswer(df.select(exists(col("i"), _ % 2 === 0)), Seq( Row(true), Row(false), @@ -2406,7 +2406,7 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSQLContext { Row(true), Row(false), Row(null))) - checkAnswer(df.select(exists(df("s"), x => x.isNull)), + checkAnswer(df.select(exists(col("s"), x => x.isNull)), Seq( Row(false), Row(true), @@ -2439,10 +2439,10 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSQLContext { } assert(ex2.getMessage.contains("data type mismatch: argument 1 requires array type")) - val ex2a = intercept[MatchError] { - df.select(exists(df("i"), x => x)) + val ex2a = intercept[AnalysisException] { + df.select(exists(col("i"), x => x)) } - assert(ex2a.getMessage.contains("IntegerType")) + assert(ex2.getMessage.contains("data type mismatch: argument 1 requires array type")) val ex3 = intercept[AnalysisException] { df.selectExpr("exists(s, x -> x)") @@ -2481,13 +2481,13 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSQLContext { Row(310), Row(0), Row(null))) - checkAnswer(df.select(aggregate(df("i"), lit(0), (acc, x) => acc + x)), + checkAnswer(df.select(aggregate(col("i"), lit(0), (acc, x) => acc + x)), Seq( Row(25), Row(31), Row(0), Row(null))) - checkAnswer(df.select(aggregate(df("i"), lit(0), (acc, x) => acc + x, _ * 10)), + checkAnswer(df.select(aggregate(col("i"), lit(0), (acc, x) => acc + x, _ * 10)), Seq( Row(250), Row(310), @@ -2524,7 +2524,7 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSQLContext { Row(0), Row(0), Row(null))) - checkAnswer(df.select(aggregate(df("i"), lit(0), (acc, x) => acc + x)), + checkAnswer(df.select(aggregate(col("i"), lit(0), (acc, x) => acc + x)), Seq( Row(25), Row(null), @@ -2532,7 +2532,7 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSQLContext { Row(null))) checkAnswer( df.select( - aggregate(df("i"), lit(0), (acc, x) => acc + x, acc => coalesce(acc, lit(0)) * 10)), + aggregate(col("i"), lit(0), (acc, x) => acc + x, acc => coalesce(acc, lit(0)) * 10)), Seq( Row(250), Row(0), @@ -2569,7 +2569,7 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSQLContext { Row(""), Row("c"), Row(null))) - checkAnswer(df.select(aggregate(df("ss"), df("s"), (acc, x) => concat(acc, x))), + checkAnswer(df.select(aggregate(col("ss"), col("s"), (acc, x) => concat(acc, x))), Seq( Row("acab"), Row(null), @@ -2577,7 +2577,8 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSQLContext { Row(null))) checkAnswer( df.select( - aggregate(df("ss"), df("s"), (acc, x) => concat(acc, x), acc => coalesce(acc, lit("")))), + aggregate(col("ss"), col("s"), (acc, x) => concat(acc, x), + acc => coalesce(acc, lit("")))), Seq( Row("acab"), Row(""), @@ -2615,10 +2616,10 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSQLContext { } assert(ex3.getMessage.contains("data type mismatch: argument 1 requires array type")) - val ex3a = intercept[MatchError] { - df.select(aggregate(df("i"), lit(0), (acc, x) => x)) + val ex3a = intercept[AnalysisException] { + df.select(aggregate(col("i"), lit(0), (acc, x) => x)) } - assert(ex3a.getMessage.contains("IntegerType")) + assert(ex3a.getMessage.contains("data type mismatch: argument 1 requires array type")) val ex4 = intercept[AnalysisException] { df.selectExpr("aggregate(s, 0, (acc, x) -> x)") @@ -2626,7 +2627,7 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSQLContext { assert(ex4.getMessage.contains("data type mismatch: argument 3 requires int type")) val ex4a = intercept[AnalysisException] { - df.select(aggregate(df("s"), lit(0), (acc, x) => x)) + df.select(aggregate(col("s"), lit(0), (acc, x) => x)) } assert(ex4a.getMessage.contains("data type mismatch: argument 3 requires int type")) @@ -2674,7 +2675,7 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSQLContext { Row(Map("a" -> Row("d", null))), Row(null))) - checkAnswer(df.select(map_zip_with(df("m1"), df("m2"), (k, v1, v2) => struct(v1, v2))), + checkAnswer(df.select(map_zip_with(col("m1"), col("m2"), (k, v1, v2) => struct(v1, v2))), Seq( Row(Map("z" -> Row("a", "c"), "y" -> Row("b", null), "x" -> Row("c", "a"))), Row(Map("b" -> Row("a", null), "c" -> Row("d", "a"), "d" -> Row(null, "k"))), @@ -2698,30 +2699,31 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSQLContext { assert(ex2.getMessage.contains("The input to function map_zip_with should have " + "been two maps with compatible key types")) - val ex2a = intercept[NoSuchElementException] { - df.select(map_zip_with(df("mis"), df("mmi"), (x, y, z) => concat(x, y, z))) + val ex2a = intercept[AnalysisException] { + df.select(map_zip_with(df("mis"), col("mmi"), (x, y, z) => concat(x, y, z))) } - assert(ex2a.getMessage.contains("None.get")) + assert(ex2a.getMessage.contains("The input to function map_zip_with should have " + + "been two maps with compatible key types")) val ex3 = intercept[AnalysisException] { df.selectExpr("map_zip_with(i, mis, (x, y, z) -> concat(x, y, z))") } assert(ex3.getMessage.contains("type mismatch: argument 1 requires map type")) - val ex3a = intercept[MatchError] { - df.select(map_zip_with(df("i"), df("mis"), (x, y, z) => concat(x, y, z))) + val ex3a = intercept[AnalysisException] { + df.select(map_zip_with(col("i"), col("mis"), (x, y, z) => concat(x, y, z))) } - assert(ex3a.getMessage.contains("IntegerType")) + assert(ex3a.getMessage.contains("type mismatch: argument 1 requires map type")) val ex4 = intercept[AnalysisException] { df.selectExpr("map_zip_with(mis, i, (x, y, z) -> concat(x, y, z))") } assert(ex4.getMessage.contains("type mismatch: argument 2 requires map type")) - val ex4a = intercept[MatchError] { - df.select(map_zip_with(df("mis"), df("i"), (x, y, z) => concat(x, y, z))) + val ex4a = intercept[AnalysisException] { + df.select(map_zip_with(col("mis"), col("i"), (x, y, z) => concat(x, y, z))) } - assert(ex4a.getMessage.contains("IntegerType")) + assert(ex4a.getMessage.contains("type mismatch: argument 2 requires map type")) val ex5 = intercept[AnalysisException] { df.selectExpr("map_zip_with(mmi, mmi, (x, y, z) -> x)") @@ -2751,7 +2753,7 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSQLContext { checkAnswer(dfExample1.selectExpr("transform_keys(i, (k, v) -> k + v)"), Seq(Row(Map(2 -> 1, 18 -> 9, 16 -> 8, 14 -> 7)))) - checkAnswer(dfExample1.select(transform_keys(dfExample1("i"), (k, v) => k + v)), + checkAnswer(dfExample1.select(transform_keys(col("i"), (k, v) => k + v)), Seq(Row(Map(2 -> 1, 18 -> 9, 16 -> 8, 14 -> 7)))) checkAnswer(dfExample2.selectExpr("transform_keys(j, " + @@ -2760,7 +2762,7 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSQLContext { checkAnswer(dfExample2.select( transform_keys( - dfExample2("j"), + col("j"), (k, v) => element_at( map_from_arrays( array(lit(1), lit(2), lit(3)), @@ -2775,33 +2777,33 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSQLContext { checkAnswer(dfExample2.selectExpr("transform_keys(j, (k, v) -> CAST(v * 2 AS BIGINT) + k)"), Seq(Row(Map(3 -> 1.0, 4 -> 1.4, 6 -> 1.7)))) - checkAnswer(dfExample2.select(transform_keys(dfExample2("j"), + checkAnswer(dfExample2.select(transform_keys(col("j"), (k, v) => (v * 2).cast("bigint") + k)), Seq(Row(Map(3 -> 1.0, 4 -> 1.4, 6 -> 1.7)))) checkAnswer(dfExample2.selectExpr("transform_keys(j, (k, v) -> k + v)"), Seq(Row(Map(2.0 -> 1.0, 3.4 -> 1.4, 4.7 -> 1.7)))) - checkAnswer(dfExample2.select(transform_keys(dfExample2("j"), (k, v) => k + v)), + checkAnswer(dfExample2.select(transform_keys(col("j"), (k, v) => k + v)), Seq(Row(Map(2.0 -> 1.0, 3.4 -> 1.4, 4.7 -> 1.7)))) checkAnswer(dfExample3.selectExpr("transform_keys(x, (k, v) -> k % 2 = 0 OR v)"), Seq(Row(Map(true -> true, true -> false)))) - checkAnswer(dfExample3.select(transform_keys(dfExample3("x"), (k, v) => k % 2 === 0 || v)), + checkAnswer(dfExample3.select(transform_keys(col("x"), (k, v) => k % 2 === 0 || v)), Seq(Row(Map(true -> true, true -> false)))) checkAnswer(dfExample3.selectExpr("transform_keys(x, (k, v) -> if(v, 2 * k, 3 * k))"), Seq(Row(Map(50 -> true, 78 -> false)))) - checkAnswer(dfExample3.select(transform_keys(dfExample3("x"), + checkAnswer(dfExample3.select(transform_keys(col("x"), (k, v) => when(v, k * 2).otherwise(k * 3))), Seq(Row(Map(50 -> true, 78 -> false)))) checkAnswer(dfExample4.selectExpr("transform_keys(y, (k, v) -> array_contains(k, 3) AND v)"), Seq(Row(Map(false -> false)))) - checkAnswer(dfExample4.select(transform_keys(dfExample4("y"), + checkAnswer(dfExample4.select(transform_keys(col("y"), (k, v) => array_contains(k, lit(3)) && v)), Seq(Row(Map(false -> false)))) } @@ -2842,7 +2844,7 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSQLContext { assert(ex3.getMessage.contains("Cannot use null as map key")) val ex3a = intercept[Exception] { - dfExample1.select(transform_keys(dfExample1("i"), (k, v) => v)).show() + dfExample1.select(transform_keys(col("i"), (k, v) => v)).show() } assert(ex3a.getMessage.contains("Cannot use null as map key")) @@ -2911,26 +2913,26 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSQLContext { dfExample5.selectExpr("transform_values(c, (k, v) -> k + cardinality(v))"), Seq(Row(Map(1 -> 3)))) - checkAnswer(dfExample1.select(transform_values(dfExample1("i"), (k, v) => k + v)), + checkAnswer(dfExample1.select(transform_values(col("i"), (k, v) => k + v)), Seq(Row(Map(1 -> 2, 9 -> 18, 8 -> 16, 7 -> 14)))) checkAnswer(dfExample2.select( - transform_values(dfExample2("x"), (k, v) => when(k, v).otherwise(k.cast("string")))), + transform_values(col("x"), (k, v) => when(k, v).otherwise(k.cast("string")))), Seq(Row(Map(false -> "false", true -> "def")))) - checkAnswer(dfExample2.select(transform_values(dfExample2("x"), + checkAnswer(dfExample2.select(transform_values(col("x"), (k, v) => (!k) && v === "abc")), Seq(Row(Map(false -> true, true -> false)))) - checkAnswer(dfExample3.select(transform_values(dfExample3("y"), (k, v) => v * v)), + checkAnswer(dfExample3.select(transform_values(col("y"), (k, v) => v * v)), Seq(Row(Map("a" -> 1, "b" -> 4, "c" -> 9)))) checkAnswer(dfExample3.select( - transform_values(dfExample3("y"), (k, v) => concat(k, lit(":"), v.cast("string")))), + transform_values(col("y"), (k, v) => concat(k, lit(":"), v.cast("string")))), Seq(Row(Map("a" -> "a:1", "b" -> "b:2", "c" -> "c:3")))) checkAnswer( - dfExample3.select(transform_values(dfExample3("y"), (k, v) => concat(k, v.cast("string")))), + dfExample3.select(transform_values(col("y"), (k, v) => concat(k, v.cast("string")))), Seq(Row(Map("a" -> "a1", "b" -> "b2", "c" -> "c3")))) val testMap = map_from_arrays( @@ -2939,16 +2941,16 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSQLContext { ) checkAnswer( - dfExample4.select(transform_values(dfExample4("z"), + dfExample4.select(transform_values(col("z"), (k, v) => concat(element_at(testMap, k), lit("_"), v.cast("string")))), Seq(Row(Map(1 -> "one_1.0", 2 -> "two_1.4", 3 ->"three_1.7")))) checkAnswer( - dfExample4.select(transform_values(dfExample4("z"), (k, v) => k - v)), + dfExample4.select(transform_values(col("z"), (k, v) => k - v)), Seq(Row(Map(1 -> 0.0, 2 -> 0.6000000000000001, 3 -> 1.3)))) checkAnswer( - dfExample5.select(transform_values(dfExample5("c"), (k, v) => k + size(v))), + dfExample5.select(transform_values(col("c"), (k, v) => k + size(v))), Seq(Row(Map(1 -> 3)))) } @@ -2994,26 +2996,26 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSQLContext { checkAnswer(dfExample2.selectExpr("transform_values(j, (k, v) -> k + cast(v as BIGINT))"), Seq(Row(Map.empty[BigInt, BigInt]))) - checkAnswer(dfExample1.select(transform_values(dfExample1("i"), + checkAnswer(dfExample1.select(transform_values(col("i"), (k, v) => lit(null).cast("int"))), Seq(Row(Map.empty[Integer, Integer]))) - checkAnswer(dfExample1.select(transform_values(dfExample1("i"), (k, v) => k)), + checkAnswer(dfExample1.select(transform_values(col("i"), (k, v) => k)), Seq(Row(Map.empty[Integer, Integer]))) - checkAnswer(dfExample1.select(transform_values(dfExample1("i"), (k, v) => v)), + checkAnswer(dfExample1.select(transform_values(col("i"), (k, v) => v)), Seq(Row(Map.empty[Integer, Integer]))) - checkAnswer(dfExample1.select(transform_values(dfExample1("i"), (k, v) => lit(0))), + checkAnswer(dfExample1.select(transform_values(col("i"), (k, v) => lit(0))), Seq(Row(Map.empty[Integer, Integer]))) - checkAnswer(dfExample1.select(transform_values(dfExample1("i"), (k, v) => lit("value"))), + checkAnswer(dfExample1.select(transform_values(col("i"), (k, v) => lit("value"))), Seq(Row(Map.empty[Integer, String]))) - checkAnswer(dfExample1.select(transform_values(dfExample1("i"), (k, v) => lit(true))), + checkAnswer(dfExample1.select(transform_values(col("i"), (k, v) => lit(true))), Seq(Row(Map.empty[Integer, Boolean]))) - checkAnswer(dfExample1.select(transform_values(dfExample1("i"), (k, v) => v.cast("bigint"))), + checkAnswer(dfExample1.select(transform_values(col("i"), (k, v) => v.cast("bigint"))), Seq(Row(Map.empty[BigInt, BigInt]))) } @@ -3040,12 +3042,12 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSQLContext { "transform_values(b, (k, v) -> IF(v IS NULL, k + 1, k + 2))"), Seq(Row(Map(1 -> 3, 2 -> 4, 3 -> 4)))) - checkAnswer(dfExample1.select(transform_values(dfExample1("a"), + checkAnswer(dfExample1.select(transform_values(col("a"), (k, v) => lit(null).cast("int"))), Seq(Row(Map[Int, Integer](1 -> null, 2 -> null, 3 -> null, 4 -> null)))) checkAnswer(dfExample2.select( - transform_values(dfExample2("b"), (k, v) => when(v.isNull, k + 1).otherwise(k + 2)) + transform_values(col("b"), (k, v) => when(v.isNull, k + 1).otherwise(k + 2)) ), Seq(Row(Map(1 -> 3, 2 -> 4, 3 -> 4)))) } @@ -3087,10 +3089,11 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSQLContext { assert(ex3.getMessage.contains( "data type mismatch: argument 1 requires map type")) - val ex3a = intercept[MatchError] { - dfExample3.select(transform_values(dfExample3("x"), (k, v) => k + 1)) + val ex3a = intercept[AnalysisException] { + dfExample3.select(transform_values(col("x"), (k, v) => k + 1)) } - assert(ex3a.getMessage.contains("IntegerType")) + assert(ex3a.getMessage.contains( + "data type mismatch: argument 1 requires map type")) } testInvalidLambdaFunctions() @@ -3145,7 +3148,7 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSQLContext { expectedValue1 ) checkAnswer( - df.select(zip_with(df("val1"), df("val2"), (x, y) => struct(y, x))), + df.select(zip_with(col("val1"), col("val2"), (x, y) => struct(y, x))), expectedValue1 ) } @@ -3169,10 +3172,10 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSQLContext { df.selectExpr("zip_with(i, a2, (acc, x) -> x)") } assert(ex3.getMessage.contains("data type mismatch: argument 1 requires array type")) - val ex3a = intercept[MatchError] { + val ex3a = intercept[AnalysisException] { df.select(zip_with(df("i"), df("a2"), (acc, x) => x)) } - assert(ex3a.getMessage.contains("IntegerType")) + assert(ex3a.getMessage.contains("data type mismatch: argument 1 requires array type")) val ex4 = intercept[AnalysisException] { df.selectExpr("zip_with(a1, a, (acc, x) -> x)") } From a8c7ecd27b8d0fcabfd86571eeba801bb5c7e62a Mon Sep 17 00:00:00 2001 From: Nik Date: Tue, 6 Aug 2019 19:47:54 -0400 Subject: [PATCH 14/38] Add forall to org.apache.spark.sql.functions --- .../org/apache/spark/sql/functions.scala | 18 ++++++++ .../spark/sql/DataFrameFunctionsSuite.scala | 41 +++++++++++++++++++ 2 files changed, 59 insertions(+) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala index 9a9a934eaeb4f..3a74ce66d999c 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala @@ -3415,6 +3415,15 @@ object functions { ArrayExists(column.expr, createLambda(f)) } + /** + * (Scala-specific) Returns whether a predicate holds for every element in the array. + * + * @group collection_funcs + */ + def forall(column: Column, f: Column => Column): Column = withExpr { + ArrayForAll(column.expr, createLambda(f)) + } + /** * (Scala-specific) Returns an array of elements for which a predicate holds in a given array. * @@ -3529,6 +3538,15 @@ object functions { ArrayExists(column.expr, createLambda(f)) } + /** + * (Java-specific) Returns whether a predicate holds for every element in the array. + * + * @group collection_funcs + */ + def forall(column: Column, f: JavaFunction[Column, Column]): Column = withExpr { + ArrayForAll(column.expr, createLambda(f)) + } + /** * (Java-specific) Returns an array of elements for which a predicate holds in a given array. * diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala index 502ed6f1ea98b..b2da6d03a6764 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala @@ -2476,6 +2476,12 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSQLContext { Row(true), Row(true), Row(null))) + checkAnswer(df.select(forall(col("i"), x => x % 2 === 0)), + Seq( + Row(false), + Row(true), + Row(true), + Row(null))) } // Test with local relation, the Project will be evaluated without codegen @@ -2502,6 +2508,13 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSQLContext { Row(true), Row(true), Row(null))) + checkAnswer(df.select(forall(col("i"), x => (x % 2 === 0) || x.isNull)), + Seq( + Row(false), + Row(true), + Row(true), + Row(true), + Row(null))) checkAnswer(df.selectExpr("forall(i, x -> x % 2 == 0)"), Seq( Row(false), @@ -2509,6 +2522,13 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSQLContext { Row(true), Row(true), Row(null))) + checkAnswer(df.select(forall(col("i"), x => x % 2 === 0)), + Seq( + Row(false), + Row(null), + Row(true), + Row(true), + Row(null))) } // Test with local relation, the Project will be evaluated without codegen @@ -2533,6 +2553,12 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSQLContext { Row(true), Row(true), Row(null))) + checkAnswer(df.select(forall(col("s"), _.isNull)), + Seq( + Row(false), + Row(true), + Row(true), + Row(null))) } // Test with local relation, the Project will be evaluated without codegen @@ -2560,15 +2586,30 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSQLContext { } assert(ex2.getMessage.contains("data type mismatch: argument 1 requires array type")) + val ex2a = intercept[AnalysisException] { + df.select(forall(col("i"), x => x)) + } + assert(ex2a.getMessage.contains("data type mismatch: argument 1 requires array type")) + val ex3 = intercept[AnalysisException] { df.selectExpr("forall(s, x -> x)") } assert(ex3.getMessage.contains("data type mismatch: argument 2 requires boolean type")) + val ex3a = intercept[AnalysisException] { + df.select(forall(col("s"), x => x)) + } + assert(ex3a.getMessage.contains("data type mismatch: argument 2 requires boolean type")) + val ex4 = intercept[AnalysisException] { df.selectExpr("forall(a, x -> x)") } assert(ex4.getMessage.contains("cannot resolve '`a`'")) + + val ex4a = intercept[AnalysisException] { + df.select(forall(col("a"), x => x)) + } + assert(ex4a.getMessage.contains("cannot resolve '`a`'")) } test("aggregate function - array for primitive type not containing null") { From 96fb0ad110fc0e0cc3e8f8097fd370143999a7c2 Mon Sep 17 00:00:00 2001 From: Nik Date: Fri, 9 Aug 2019 21:48:47 -0400 Subject: [PATCH 15/38] Add "@since 3.0.0" to new functions --- .../org/apache/spark/sql/functions.scala | 24 +++++++++++++++++++ 1 file changed, 24 insertions(+) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala index 3a74ce66d999c..9c2a57364b738 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala @@ -3391,6 +3391,7 @@ object functions { * in the input array. * * @group collection_funcs + * @since 3.0.0 */ def transform(column: Column, f: Column => Column): Column = withExpr { ArrayTransform(column.expr, createLambda(f)) @@ -3401,6 +3402,7 @@ object functions { * in the input array. * * @group collection_funcs + * @since 3.0.0 */ def transform(column: Column, f: (Column, Column) => Column): Column = withExpr { ArrayTransform(column.expr, createLambda(f)) @@ -3410,6 +3412,7 @@ object functions { * (Scala-specific) Returns whether a predicate holds for one or more elements in the array. * * @group collection_funcs + * @since 3.0.0 */ def exists(column: Column, f: Column => Column): Column = withExpr { ArrayExists(column.expr, createLambda(f)) @@ -3419,6 +3422,7 @@ object functions { * (Scala-specific) Returns whether a predicate holds for every element in the array. * * @group collection_funcs + * @since 3.0.0 */ def forall(column: Column, f: Column => Column): Column = withExpr { ArrayForAll(column.expr, createLambda(f)) @@ -3428,6 +3432,7 @@ object functions { * (Scala-specific) Returns an array of elements for which a predicate holds in a given array. * * @group collection_funcs + * @since 3.0.0 */ def filter(column: Column, f: Column => Column): Column = withExpr { ArrayFilter(column.expr, createLambda(f)) @@ -3439,6 +3444,7 @@ object functions { * by applying a finish function. * * @group collection_funcs + * @since 3.0.0 */ def aggregate(expr: Column, zero: Column, merge: (Column, Column) => Column, finish: Column => Column): Column = withExpr { @@ -3455,6 +3461,7 @@ object functions { * and reduces this to a single state. * * @group collection_funcs + * @since 3.0.0 */ def aggregate(expr: Column, zero: Column, merge: (Column, Column) => Column): Column = aggregate(expr, zero, merge, c => c) @@ -3465,6 +3472,7 @@ object functions { * array, before applying the function. * * @group collection_funcs + * @since 3.0.0 */ def zip_with(left: Column, right: Column, f: (Column, Column) => Column): Column = withExpr { ZipWith(left.expr, right.expr, createLambda(f)) @@ -3475,6 +3483,7 @@ object functions { * a map with the results of those applications as the new keys for the pairs. * * @group collection_funcs + * @since 3.0.0 */ def transform_keys(expr: Column, f: (Column, Column) => Column): Column = withExpr { TransformKeys(expr.expr, createLambda(f)) @@ -3485,6 +3494,7 @@ object functions { * a map with the results of those applications as the new values for the pairs. * * @group collection_funcs + * @since 3.0.0 */ def transform_values(expr: Column, f: (Column, Column) => Column): Column = withExpr { TransformValues(expr.expr, createLambda(f)) @@ -3494,6 +3504,7 @@ object functions { * (Scala-specific) Returns a map whose key-value pairs satisfy a predicate. * * @group collection_funcs + * @since 3.0.0 */ def map_filter(expr: Column, f: (Column, Column) => Column): Column = withExpr { MapFilter(expr.expr, createLambda(f)) @@ -3503,6 +3514,7 @@ object functions { * (Scala-specific) Merge two given maps, key-wise into a single map using a function. * * @group collection_funcs + * @since 3.0.0 */ def map_zip_with(left: Column, right: Column, f: (Column, Column, Column) => Column): Column = withExpr { @@ -3514,6 +3526,7 @@ object functions { * in the input array. * * @group collection_funcs + * @since 3.0.0 */ def transform(column: Column, f: JavaFunction[Column, Column]): Column = withExpr { ArrayTransform(column.expr, createLambda(f)) @@ -3524,6 +3537,7 @@ object functions { * in the input array. * * @group collection_funcs + * @since 3.0.0 */ def transform(column: Column, f: JavaFunction2[Column, Column, Column]): Column = withExpr { ArrayTransform(column.expr, createLambda(f)) @@ -3533,6 +3547,7 @@ object functions { * (Java-specific) Returns whether a predicate holds for one or more elements in the array. * * @group collection_funcs + * @since 3.0.0 */ def exists(column: Column, f: JavaFunction[Column, Column]): Column = withExpr { ArrayExists(column.expr, createLambda(f)) @@ -3542,6 +3557,7 @@ object functions { * (Java-specific) Returns whether a predicate holds for every element in the array. * * @group collection_funcs + * @since 3.0.0 */ def forall(column: Column, f: JavaFunction[Column, Column]): Column = withExpr { ArrayForAll(column.expr, createLambda(f)) @@ -3551,6 +3567,7 @@ object functions { * (Java-specific) Returns an array of elements for which a predicate holds in a given array. * * @group collection_funcs + * @since 3.0.0 */ def filter(column: Column, f: JavaFunction[Column, Column]): Column = withExpr { ArrayFilter(column.expr, createLambda(f)) @@ -3562,6 +3579,7 @@ object functions { * by applying a finish function. * * @group collection_funcs + * @since 3.0.0 */ def aggregate(expr: Column, zero: Column, merge: JavaFunction2[Column, Column, Column], finish: JavaFunction[Column, Column]): Column = withExpr { @@ -3578,6 +3596,7 @@ object functions { * and reduces this to a single state. * * @group collection_funcs + * @since 3.0.0 */ def aggregate(expr: Column, zero: Column, merge: JavaFunction2[Column, Column, Column]): Column = aggregate( @@ -3589,6 +3608,7 @@ object functions { * array, before applying the function. * * @group collection_funcs + * @since 3.0.0 */ def zip_with(left: Column, right: Column, f: JavaFunction2[Column, Column, Column]): Column = withExpr { @@ -3600,6 +3620,7 @@ object functions { * a map with the results of those applications as the new keys for the pairs. * * @group collection_funcs + * @since 3.0.0 */ def transform_keys(expr: Column, f: JavaFunction2[Column, Column, Column]): Column = withExpr { TransformKeys(expr.expr, createLambda(f)) @@ -3610,6 +3631,7 @@ object functions { * a map with the results of those applications as the new values for the pairs. * * @group collection_funcs + * @since 3.0.0 */ def transform_values(expr: Column, f: JavaFunction2[Column, Column, Column]): Column = withExpr { TransformValues(expr.expr, createLambda(f)) @@ -3619,6 +3641,7 @@ object functions { * (Java-specific) Returns a map whose key-value pairs satisfy a predicate. * * @group collection_funcs + * @since 3.0.0 */ def map_filter(expr: Column, f: JavaFunction2[Column, Column, Column]): Column = withExpr { MapFilter(expr.expr, createLambda(f)) @@ -3628,6 +3651,7 @@ object functions { * (Java-specific) Merge two given maps, key-wise into a single map using a function. * * @group collection_funcs + * @since 3.0.0 */ def map_zip_with(left: Column, right: Column, f: JavaFunction3[Column, Column, Column, Column]): Column = withExpr { From 5fa3e71dffb44def349334713989c10825268a39 Mon Sep 17 00:00:00 2001 From: Nik Date: Fri, 9 Aug 2019 22:25:42 -0400 Subject: [PATCH 16/38] Add tests for Java transform function --- .../spark/sql/DataFrameFunctionsSuite.scala | 271 ++++++++++-------- 1 file changed, 150 insertions(+), 121 deletions(-) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala index b2da6d03a6764..a71a8e4f52436 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala @@ -23,6 +23,7 @@ import java.util.TimeZone import scala.util.Random +import org.apache.spark.api.java.function.{Function => JavaFunction, Function2 => JavaFunction2, Function3 => JavaFunction3} import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.Expression import org.apache.spark.sql.catalyst.expressions.codegen.CodegenFallback @@ -37,6 +38,9 @@ import org.apache.spark.sql.types._ * Test suite for functions in [[org.apache.spark.sql.functions]]. */ class DataFrameFunctionsSuite extends QueryTest with SharedSQLContext { + type JFunc = JavaFunction[Column, Column] + type JFunc2 = JavaFunction2[Column, Column, Column] + type JFunc3 = JavaFunction3[Column, Column, Column, Column] import testImplicits._ test("array with column name") { @@ -1917,31 +1921,33 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSQLContext { null ).toDF("i") + // transform(i, x -> x + 1) + val resA = Seq( + Row(Seq(2, 10, 9, 8)), + Row(Seq(6, 9, 10, 8, 3)), + Row(Seq.empty), + Row(null)) + + // transform(i, (x, i) -> x + i) + val resB = Seq( + Row(Seq(1, 10, 10, 10)), + Row(Seq(5, 9, 11, 10, 6)), + Row(Seq.empty), + Row(null)) + def testArrayOfPrimitiveTypeNotContainsNull(): Unit = { - checkAnswer(df.selectExpr("transform(i, x -> x + 1)"), - Seq( - Row(Seq(2, 10, 9, 8)), - Row(Seq(6, 9, 10, 8, 3)), - Row(Seq.empty), - Row(null))) - checkAnswer(df.selectExpr("transform(i, (x, i) -> x + i)"), - Seq( - Row(Seq(1, 10, 10, 10)), - Row(Seq(5, 9, 11, 10, 6)), - Row(Seq.empty), - Row(null))) - checkAnswer(df.select(transform(col("i"), x => x + 1)), - Seq( - Row(Seq(2, 10, 9, 8)), - Row(Seq(6, 9, 10, 8, 3)), - Row(Seq.empty), - Row(null))) - checkAnswer(df.select(transform(col("i"), (x, i) => x + i)), - Seq( - Row(Seq(1, 10, 10, 10)), - Row(Seq(5, 9, 11, 10, 6)), - Row(Seq.empty), - Row(null))) + checkAnswer(df.selectExpr("transform(i, x -> x + 1)"), resA) + checkAnswer(df.selectExpr("transform(i, (x, i) -> x + i)"), resB) + + checkAnswer(df.select(transform(col("i"), x => x + 1)), resA) + checkAnswer(df.select(transform(col("i"), (x, i) => x + i)), resB) + + checkAnswer(df.select(transform(col("i"), new JFunc { + def call(x: Column) = x + 1 + })), resA) + checkAnswer(df.select(transform(col("i"), new JFunc2 { + def call(x: Column, i: Column) = x + i + })), resB) } // Test with local relation, the Project will be evaluated without codegen @@ -1959,31 +1965,33 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSQLContext { null ).toDF("i") + // transform(i, x -> x + 1) + val resA = Seq( + Row(Seq(2, 10, 9, null, 8)), + Row(Seq(6, null, 9, 10, 8, 3)), + Row(Seq.empty), + Row(null)) + + // transform(i, (x, i) -> x + i) + val resB = Seq( + Row(Seq(1, 10, 10, null, 11)), + Row(Seq(5, null, 10, 12, 11, 7)), + Row(Seq.empty), + Row(null)) + def testArrayOfPrimitiveTypeContainsNull(): Unit = { - checkAnswer(df.selectExpr("transform(i, x -> x + 1)"), - Seq( - Row(Seq(2, 10, 9, null, 8)), - Row(Seq(6, null, 9, 10, 8, 3)), - Row(Seq.empty), - Row(null))) - checkAnswer(df.selectExpr("transform(i, (x, i) -> x + i)"), - Seq( - Row(Seq(1, 10, 10, null, 11)), - Row(Seq(5, null, 10, 12, 11, 7)), - Row(Seq.empty), - Row(null))) - checkAnswer(df.select(transform(col("i"), x => x + 1)), - Seq( - Row(Seq(2, 10, 9, null, 8)), - Row(Seq(6, null, 9, 10, 8, 3)), - Row(Seq.empty), - Row(null))) - checkAnswer(df.select(transform(col("i"), (x, i) => x + i)), - Seq( - Row(Seq(1, 10, 10, null, 11)), - Row(Seq(5, null, 10, 12, 11, 7)), - Row(Seq.empty), - Row(null))) + checkAnswer(df.selectExpr("transform(i, x -> x + 1)"), resA) + checkAnswer(df.selectExpr("transform(i, (x, i) -> x + i)"), resB) + + checkAnswer(df.select(transform(col("i"), x => x + 1)), resA) + checkAnswer(df.select(transform(col("i"), (x, i) => x + i)), resB) + + checkAnswer(df.select(transform(col("i"), new JFunc { + def call(x: Column) = x + 1 + })), resA) + checkAnswer(df.select(transform(col("i"), new JFunc2 { + def call(x: Column, i: Column) = x + i + })), resB) } // Test with local relation, the Project will be evaluated without codegen @@ -2001,31 +2009,33 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSQLContext { null ).toDF("s") + // transform(s, x -> concat(x, x)) + val resA = Seq( + Row(Seq("cc", "aa", "bb")), + Row(Seq("bb", null, "cc", null)), + Row(Seq.empty), + Row(null)) + + // transform(s, (x, i) -> concat(x, i)) + val resB = Seq( + Row(Seq("c0", "a1", "b2")), + Row(Seq("b0", null, "c2", null)), + Row(Seq.empty), + Row(null)) + def testNonPrimitiveType(): Unit = { - checkAnswer(df.selectExpr("transform(s, x -> concat(x, x))"), - Seq( - Row(Seq("cc", "aa", "bb")), - Row(Seq("bb", null, "cc", null)), - Row(Seq.empty), - Row(null))) - checkAnswer(df.selectExpr("transform(s, (x, i) -> concat(x, i))"), - Seq( - Row(Seq("c0", "a1", "b2")), - Row(Seq("b0", null, "c2", null)), - Row(Seq.empty), - Row(null))) - checkAnswer(df.select(transform(col("s"), x => concat(x, x))), - Seq( - Row(Seq("cc", "aa", "bb")), - Row(Seq("bb", null, "cc", null)), - Row(Seq.empty), - Row(null))) - checkAnswer(df.select(transform(col("s"), (x, i) => concat(x, i))), - Seq( - Row(Seq("c0", "a1", "b2")), - Row(Seq("b0", null, "c2", null)), - Row(Seq.empty), - Row(null))) + checkAnswer(df.selectExpr("transform(s, x -> concat(x, x))"), resA) + checkAnswer(df.selectExpr("transform(s, (x, i) -> concat(x, i))"), resB) + + checkAnswer(df.select(transform(col("s"), x => concat(x, x))), resA) + checkAnswer(df.select(transform(col("s"), (x, i) => concat(x, i))), resB) + + checkAnswer(df.select(transform(col("s"), new JFunc { + def call(x: Column) = concat(x, x) + })), resA) + checkAnswer(df.select(transform(col("s"), new JFunc2 { + def call(x: Column, i: Column) = concat(x, i) + })), resB) } // Test with local relation, the Project will be evaluated without codegen @@ -2043,59 +2053,54 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSQLContext { null ).toDF("arg") - def testSpecialCases(): Unit = { - checkAnswer(df.selectExpr("transform(arg, arg -> arg)"), + // transform(arg, arg -> arg) + val resA = Seq( Row(Seq("c", "a", "b")), Row(Seq("b", null, "c", null)), Row(Seq.empty), - Row(null))) - checkAnswer(df.selectExpr("transform(arg, arg)"), - Seq( - Row(Seq(Seq("c", "a", "b"), Seq("c", "a", "b"), Seq("c", "a", "b"))), - Row(Seq( - Seq("b", null, "c", null), - Seq("b", null, "c", null), - Seq("b", null, "c", null), - Seq("b", null, "c", null))), - Row(Seq.empty), - Row(null))) - checkAnswer(df.selectExpr("transform(arg, x -> concat(arg, array(x)))"), - Seq( - Row(Seq(Seq("c", "a", "b", "c"), Seq("c", "a", "b", "a"), Seq("c", "a", "b", "b"))), - Row(Seq( - Seq("b", null, "c", null, "b"), - Seq("b", null, "c", null, null), - Seq("b", null, "c", null, "c"), - Seq("b", null, "c", null, null))), - Row(Seq.empty), - Row(null))) - checkAnswer(df.select(transform(col("arg"), arg => arg)), - Seq( - Row(Seq("c", "a", "b")), - Row(Seq("b", null, "c", null)), - Row(Seq.empty), - Row(null))) - checkAnswer(df.select(transform(col("arg"), _ => col("arg"))), - Seq( - Row(Seq(Seq("c", "a", "b"), Seq("c", "a", "b"), Seq("c", "a", "b"))), - Row(Seq( - Seq("b", null, "c", null), - Seq("b", null, "c", null), - Seq("b", null, "c", null), - Seq("b", null, "c", null))), - Row(Seq.empty), - Row(null))) - checkAnswer(df.select(transform(col("arg"), x => concat(col("arg"), array(x)))), - Seq( - Row(Seq(Seq("c", "a", "b", "c"), Seq("c", "a", "b", "a"), Seq("c", "a", "b", "b"))), - Row(Seq( - Seq("b", null, "c", null, "b"), - Seq("b", null, "c", null, null), - Seq("b", null, "c", null, "c"), - Seq("b", null, "c", null, null))), - Row(Seq.empty), - Row(null))) + Row(null)) + + // transform(arg, arg) + val resB = Seq( + Row(Seq(Seq("c", "a", "b"), Seq("c", "a", "b"), Seq("c", "a", "b"))), + Row(Seq( + Seq("b", null, "c", null), + Seq("b", null, "c", null), + Seq("b", null, "c", null), + Seq("b", null, "c", null))), + Row(Seq.empty), + Row(null)) + + // transform(arg, x -> concat(arg, array(x))) + val resC = Seq( + Row(Seq(Seq("c", "a", "b", "c"), Seq("c", "a", "b", "a"), Seq("c", "a", "b", "b"))), + Row(Seq( + Seq("b", null, "c", null, "b"), + Seq("b", null, "c", null, null), + Seq("b", null, "c", null, "c"), + Seq("b", null, "c", null, null))), + Row(Seq.empty), + Row(null)) + + def testSpecialCases(): Unit = { + checkAnswer(df.selectExpr("transform(arg, arg -> arg)"), resA) + checkAnswer(df.selectExpr("transform(arg, arg)"), resB) + checkAnswer(df.selectExpr("transform(arg, x -> concat(arg, array(x)))"), resC) + + checkAnswer(df.select(transform(col("arg"), arg => arg)), resA) + checkAnswer(df.select(transform(col("arg"), _ => col("arg"))), resB) + checkAnswer(df.select(transform(col("arg"), x => concat(col("arg"), array(x)))), resC) + + checkAnswer(df.select(transform(col("arg"), new JFunc { + def call(arg: Column) = arg + })), resA) + checkAnswer(df.select(transform(col("arg"), new JFunc { + def call(arg: Column) = col("arg") + })), resB) + checkAnswer(df.select(transform(col("arg"), new JFunc { + def call(x: Column) = concat(col("arg"), array(x)) + })), resC) } // Test with local relation, the Project will be evaluated without codegen @@ -2123,10 +2128,34 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSQLContext { } assert(ex2.getMessage.contains("data type mismatch: argument 1 requires array type")) + val ex2a = intercept[AnalysisException] { + df.select(transform(col("i"), x => x)) + } + assert(ex2a.getMessage.contains("data type mismatch: argument 1 requires array type")) + + val ex2b = intercept[AnalysisException] { + df.select(transform(col("i"), new JFunc { + def call(x: Column) = x + })) + } + assert(ex2b.getMessage.contains("data type mismatch: argument 1 requires array type")) + val ex3 = intercept[AnalysisException] { df.selectExpr("transform(a, x -> x)") } assert(ex3.getMessage.contains("cannot resolve '`a`'")) + + val ex3a = intercept[AnalysisException] { + df.select(transform(col("a"), x => x)) + } + assert(ex3a.getMessage.contains("cannot resolve '`a`'")) + + val ex3b = intercept[AnalysisException] { + df.select(transform(col("a"), new JFunc { + def call(x: Column) = x + })) + } + assert(ex3b.getMessage.contains("cannot resolve '`a`'")) } test("map_filter") { From 0bfa483fbc04edc7b8adad8d0619e08f1287539d Mon Sep 17 00:00:00 2001 From: Nik Date: Fri, 9 Aug 2019 22:41:51 -0400 Subject: [PATCH 17/38] Add tests for Java map_filter function --- .../spark/sql/DataFrameFunctionsSuite.scala | 64 +++++++++++++++---- 1 file changed, 50 insertions(+), 14 deletions(-) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala index a71a8e4f52436..c6b5eb9fbc006 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala @@ -2164,37 +2164,54 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSQLContext { Map(1 -> -1, 2 -> -2, 3 -> -3), Map(1 -> 10, 2 -> 5, 3 -> -3)).toDF("m") + // map_filter(m, (k, v) -> k * 10 = v), map_filter(m, (k, v) -> k = -v + val resA = Seq( + Row(Map(1 -> 10, 2 -> 20, 3 -> 30), Map()), + Row(Map(), Map(1 -> -1, 2 -> -2, 3 -> -3)), + Row(Map(1 -> 10), Map(3 -> -3))) + checkAnswer(dfInts.selectExpr( "map_filter(m, (k, v) -> k * 10 = v)", "map_filter(m, (k, v) -> k = -v)"), - Seq( - Row(Map(1 -> 10, 2 -> 20, 3 -> 30), Map()), - Row(Map(), Map(1 -> -1, 2 -> -2, 3 -> -3)), - Row(Map(1 -> 10), Map(3 -> -3)))) + resA) checkAnswer(dfInts.select( map_filter(col("m"), (k, v) => k * 10 === v), map_filter(col("m"), (k, v) => k === (v * -1))), - Seq( - Row(Map(1 -> 10, 2 -> 20, 3 -> 30), Map()), - Row(Map(), Map(1 -> -1, 2 -> -2, 3 -> -3)), - Row(Map(1 -> 10), Map(3 -> -3)))) + resA) + + checkAnswer(dfInts.select( + map_filter(col("m"), new JFunc2 { + def call(k: Column, v: Column): Column = k * 10 === v + }), + map_filter(col("m"), new JFunc2 { + def call(k: Column, v: Column): Column = k === (v * -1) + })), resA) val dfComplex = Seq( Map(1 -> Seq(Some(1)), 2 -> Seq(Some(1), Some(2)), 3 -> Seq(Some(1), Some(2), Some(3))), Map(1 -> null, 2 -> Seq(Some(-2), Some(-2)), 3 -> Seq[Option[Int]](None))).toDF("m") + // map_filter(m, (k, v) -> k = v[0]), map_filter(m, (k, v) -> k = size(v)) + val resB = Seq( + Row(Map(1 -> Seq(1)), Map(1 -> Seq(1), 2 -> Seq(1, 2), 3 -> Seq(1, 2, 3))), + Row(Map(), Map(2 -> Seq(-2, -2)))) + checkAnswer(dfComplex.selectExpr( "map_filter(m, (k, v) -> k = v[0])", "map_filter(m, (k, v) -> k = size(v))"), - Seq( - Row(Map(1 -> Seq(1)), Map(1 -> Seq(1), 2 -> Seq(1, 2), 3 -> Seq(1, 2, 3))), - Row(Map(), Map(2 -> Seq(-2, -2))))) + resB) checkAnswer(dfComplex.select( map_filter(col("m"), (k, v) => k === element_at(v, 1)), map_filter(col("m"), (k, v) => k === size(v))), - Seq( - Row(Map(1 -> Seq(1)), Map(1 -> Seq(1), 2 -> Seq(1, 2), 3 -> Seq(1, 2, 3))), - Row(Map(), Map(2 -> Seq(-2, -2))))) + resB) + + checkAnswer(dfComplex.select( + map_filter(col("m"), new JFunc2 { + def call(k: Column, v: Column): Column = k === element_at(v, 1) + }), + map_filter(col("m"), new JFunc2 { + def call(k: Column, v: Column): Column = k === size(v) + })), resB) // Invalid use cases val df = Seq( @@ -2223,10 +2240,29 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSQLContext { } assert(ex3a.getMessage.contains("data type mismatch: argument 1 requires map type")) + val ex3b = intercept[AnalysisException] { + df.select(map_filter(col("i"), new JFunc2 { + def call(k: Column, v: Column): Column = k > v + })) + } + assert(ex3b.getMessage.contains("data type mismatch: argument 1 requires map type")) + val ex4 = intercept[AnalysisException] { df.selectExpr("map_filter(a, (k, v) -> k > v)") } assert(ex4.getMessage.contains("cannot resolve '`a`'")) + + val ex4a = intercept[AnalysisException] { + df.select(map_filter(col("a"), (k, v) => k > v)) + } + assert(ex4a.getMessage.contains("cannot resolve '`a`'")) + + val ex4b = intercept[AnalysisException] { + df.select(map_filter(col("a"), new JFunc2 { + def call(k: Column, v: Column): Column = k > v + })) + } + assert(ex4b.getMessage.contains("cannot resolve '`a`'")) } test("filter function - array for primitive type not containing null") { From 815e9f66bb1496874459aabc2e690a136b030c9b Mon Sep 17 00:00:00 2001 From: Nik Date: Fri, 9 Aug 2019 22:51:05 -0400 Subject: [PATCH 18/38] Add tests for Java filter function --- .../spark/sql/DataFrameFunctionsSuite.scala | 86 +++++++++++-------- 1 file changed, 50 insertions(+), 36 deletions(-) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala index c6b5eb9fbc006..40a3184988d3f 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala @@ -2273,19 +2273,19 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSQLContext { null ).toDF("i") + // filter(i, x -> x % 2 == 0) + val res = Seq( + Row(Seq(8)), + Row(Seq(8, 2)), + Row(Seq.empty), + Row(null)) + def testArrayOfPrimitiveTypeNotContainsNull(): Unit = { - checkAnswer(df.selectExpr("filter(i, x -> x % 2 == 0)"), - Seq( - Row(Seq(8)), - Row(Seq(8, 2)), - Row(Seq.empty), - Row(null))) - checkAnswer(df.select(filter(col("i"), _ % 2 === 0)), - Seq( - Row(Seq(8)), - Row(Seq(8, 2)), - Row(Seq.empty), - Row(null))) + checkAnswer(df.selectExpr("filter(i, x -> x % 2 == 0)"), res) + checkAnswer(df.select(filter(col("i"), _ % 2 === 0)), res) + checkAnswer(df.select(filter(col("i"), new JFunc { + def call(x: Column): Column = x % 2 === 0 + })), res) } // Test with local relation, the Project will be evaluated without codegen @@ -2303,19 +2303,19 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSQLContext { null ).toDF("i") + // filter(i, x -> x % 2 == 0) + val res = Seq( + Row(Seq(8)), + Row(Seq(8, 2)), + Row(Seq.empty), + Row(null)) + def testArrayOfPrimitiveTypeContainsNull(): Unit = { - checkAnswer(df.selectExpr("filter(i, x -> x % 2 == 0)"), - Seq( - Row(Seq(8)), - Row(Seq(8, 2)), - Row(Seq.empty), - Row(null))) - checkAnswer(df.select(filter(col("i"), _ % 2 === 0)), - Seq( - Row(Seq(8)), - Row(Seq(8, 2)), - Row(Seq.empty), - Row(null))) + checkAnswer(df.selectExpr("filter(i, x -> x % 2 == 0)"), res) + checkAnswer(df.select(filter(col("i"), _ % 2 === 0)), res) + checkAnswer(df.select(filter(col("i"), new JFunc { + def call(x: Column): Column = x % 2 === 0 + })), res) } // Test with local relation, the Project will be evaluated without codegen @@ -2333,19 +2333,19 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSQLContext { null ).toDF("s") + // filter(s, x -> x is not null) + val res = Seq( + Row(Seq("c", "a", "b")), + Row(Seq("b", "c")), + Row(Seq.empty), + Row(null)) + def testNonPrimitiveType(): Unit = { - checkAnswer(df.selectExpr("filter(s, x -> x is not null)"), - Seq( - Row(Seq("c", "a", "b")), - Row(Seq("b", "c")), - Row(Seq.empty), - Row(null))) - checkAnswer(df.select(filter(col("s"), x => x.isNotNull)), - Seq( - Row(Seq("c", "a", "b")), - Row(Seq("b", "c")), - Row(Seq.empty), - Row(null))) + checkAnswer(df.selectExpr("filter(s, x -> x is not null)"), res) + checkAnswer(df.select(filter(col("s"), x => x.isNotNull)), res) + checkAnswer(df.select(filter(col("s"), new JFunc { + def call(x: Column) = x.isNotNull + })), res) } // Test with local relation, the Project will be evaluated without codegen @@ -2378,6 +2378,13 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSQLContext { } assert(ex2a.getMessage.contains("data type mismatch: argument 1 requires array type")) + val ex2b = intercept[AnalysisException] { + df.select(filter(col("i"), new JFunc { + def call(x: Column): Column = x + })) + } + assert(ex2b.getMessage.contains("data type mismatch: argument 1 requires array type")) + val ex3 = intercept[AnalysisException] { df.selectExpr("filter(s, x -> x)") } @@ -2388,6 +2395,13 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSQLContext { } assert(ex3a.getMessage.contains("data type mismatch: argument 2 requires boolean type")) + val ex3b = intercept[AnalysisException] { + df.select(filter(col("s"), new JFunc { + def call(x: Column): Column = x + })) + } + assert(ex3b.getMessage.contains("data type mismatch: argument 2 requires boolean type")) + val ex4 = intercept[AnalysisException] { df.selectExpr("filter(a, x -> x)") } From 47b100b76ebcad1fa35862cb3b841f4d0c2a4431 Mon Sep 17 00:00:00 2001 From: Nik Date: Fri, 9 Aug 2019 22:57:48 -0400 Subject: [PATCH 19/38] Add tests for Java exists function --- .../spark/sql/DataFrameFunctionsSuite.scala | 89 +++++++++++-------- 1 file changed, 51 insertions(+), 38 deletions(-) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala index 40a3184988d3f..f7d5472866320 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala @@ -2416,19 +2416,19 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSQLContext { null ).toDF("i") + // exists(i, x -> x % 2 == 0) + val res = Seq( + Row(true), + Row(false), + Row(false), + Row(null)) + def testArrayOfPrimitiveTypeNotContainsNull(): Unit = { - checkAnswer(df.selectExpr("exists(i, x -> x % 2 == 0)"), - Seq( - Row(true), - Row(false), - Row(false), - Row(null))) - checkAnswer(df.select(exists(col("i"), _ % 2 === 0)), - Seq( - Row(true), - Row(false), - Row(false), - Row(null))) + checkAnswer(df.selectExpr("exists(i, x -> x % 2 == 0)"), res) + checkAnswer(df.select(exists(col("i"), _ % 2 === 0)), res) + checkAnswer(df.select(exists(col("i"), new JFunc { + def call(x: Column): Column = x % 2 === 0 + })), res) } // Test with local relation, the Project will be evaluated without codegen @@ -2447,21 +2447,20 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSQLContext { null ).toDF("i") + // exists(i, x -> x % 2 == 0) + val res = Seq( + Row(true), + Row(false), + Row(null), + Row(false), + Row(null)) + def testArrayOfPrimitiveTypeContainsNull(): Unit = { - checkAnswer(df.selectExpr("exists(i, x -> x % 2 == 0)"), - Seq( - Row(true), - Row(false), - Row(null), - Row(false), - Row(null))) - checkAnswer(df.select(exists(col("i"), _ % 2 === 0)), - Seq( - Row(true), - Row(false), - Row(null), - Row(false), - Row(null))) + checkAnswer(df.selectExpr("exists(i, x -> x % 2 == 0)"), res) + checkAnswer(df.select(exists(col("i"), _ % 2 === 0)), res) + checkAnswer(df.select(exists(col("i"), new JFunc { + def call(x: Column): Column = x % 2 === 0 + })), res) } // Test with local relation, the Project will be evaluated without codegen @@ -2479,19 +2478,19 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSQLContext { null ).toDF("s") + // exists(s, x -> x is null) + val res = Seq( + Row(false), + Row(true), + Row(false), + Row(null)) + def testNonPrimitiveType(): Unit = { - checkAnswer(df.selectExpr("exists(s, x -> x is null)"), - Seq( - Row(false), - Row(true), - Row(false), - Row(null))) - checkAnswer(df.select(exists(col("s"), x => x.isNull)), - Seq( - Row(false), - Row(true), - Row(false), - Row(null))) + checkAnswer(df.selectExpr("exists(s, x -> x is null)"), res) + checkAnswer(df.select(exists(col("s"), x => x.isNull)), res) + checkAnswer(df.select(exists(col("s"), new JFunc { + def call(x: Column): Column = x.isNull + })), res) } // Test with local relation, the Project will be evaluated without codegen @@ -2524,6 +2523,13 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSQLContext { } assert(ex2.getMessage.contains("data type mismatch: argument 1 requires array type")) + val ex2b = intercept[AnalysisException] { + df.select(exists(col("i"), new JFunc { + def call(x: Column): Column = x + })) + } + assert(ex2b.getMessage.contains("data type mismatch: argument 1 requires array type")) + val ex3 = intercept[AnalysisException] { df.selectExpr("exists(s, x -> x)") } @@ -2534,6 +2540,13 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSQLContext { } assert(ex3a.getMessage.contains("data type mismatch: argument 2 requires boolean type")) + val ex3b = intercept[AnalysisException] { + df.select(exists(df("s"), new JFunc { + def call(x: Column): Column = x + })) + } + assert(ex3b.getMessage.contains("data type mismatch: argument 2 requires boolean type")) + val ex4 = intercept[AnalysisException] { df.selectExpr("exists(a, x -> x)") } From 4baf0840e32ee7e89e9ca0a475eb05ef9661328c Mon Sep 17 00:00:00 2001 From: Nik Date: Mon, 19 Aug 2019 18:14:56 -0400 Subject: [PATCH 20/38] Add test for Java API forall --- .../spark/sql/DataFrameFunctionsSuite.scala | 122 ++++++++++-------- 1 file changed, 70 insertions(+), 52 deletions(-) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala index f7d5472866320..527798663a2a0 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala @@ -2561,19 +2561,18 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSQLContext { null ).toDF("i") + val resA = Seq( + Row(false), + Row(true), + Row(true), + Row(null)) + def testArrayOfPrimitiveTypeNotContainsNull(): Unit = { - checkAnswer(df.selectExpr("forall(i, x -> x % 2 == 0)"), - Seq( - Row(false), - Row(true), - Row(true), - Row(null))) - checkAnswer(df.select(forall(col("i"), x => x % 2 === 0)), - Seq( - Row(false), - Row(true), - Row(true), - Row(null))) + checkAnswer(df.selectExpr("forall(i, x -> x % 2 == 0)"), resA) + checkAnswer(df.select(forall(col("i"), x => x % 2 === 0)), resA) + checkAnswer(df.select(forall(col("i"), new JFunc { + def call(x: Column): Column = x % 2 === 0 + })), resA) } // Test with local relation, the Project will be evaluated without codegen @@ -2592,35 +2591,33 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSQLContext { null ).toDF("i") + // forall(i, x -> x % 2 == 0 or x is null) + val resA = Seq( + Row(false), + Row(true), + Row(true), + Row(true), + Row(null)) + + // forall(i, x -> x % 2 == 0) + val resB = Seq( + Row(false), + Row(null), + Row(true), + Row(true), + Row(null)) + def testArrayOfPrimitiveTypeContainsNull(): Unit = { - checkAnswer(df.selectExpr("forall(i, x -> x % 2 == 0 or x is null)"), - Seq( - Row(false), - Row(true), - Row(true), - Row(true), - Row(null))) - checkAnswer(df.select(forall(col("i"), x => (x % 2 === 0) || x.isNull)), - Seq( - Row(false), - Row(true), - Row(true), - Row(true), - Row(null))) - checkAnswer(df.selectExpr("forall(i, x -> x % 2 == 0)"), - Seq( - Row(false), - Row(null), - Row(true), - Row(true), - Row(null))) - checkAnswer(df.select(forall(col("i"), x => x % 2 === 0)), - Seq( - Row(false), - Row(null), - Row(true), - Row(true), - Row(null))) + checkAnswer(df.selectExpr("forall(i, x -> x % 2 == 0 or x is null)"), resA) + checkAnswer(df.select(forall(col("i"), x => (x % 2 === 0) || x.isNull)), resA) + checkAnswer(df.select(forall(col("i"), new JFunc { + def call(x: Column): Column = (x % 2 === 0 ) || x.isNull + })), resA) + checkAnswer(df.selectExpr("forall(i, x -> x % 2 == 0)"), resB) + checkAnswer(df.select(forall(col("i"), x => x % 2 === 0)), resB) + checkAnswer(df.select(forall(col("i"), new JFunc { + def call(x: Column): Column = (x % 2 === 0 ) + })), resB) } // Test with local relation, the Project will be evaluated without codegen @@ -2638,19 +2635,19 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSQLContext { null ).toDF("s") + // forall(s, x -> x is null) + val resA = Seq( + Row(false), + Row(true), + Row(true), + Row(null)) + def testNonPrimitiveType(): Unit = { - checkAnswer(df.selectExpr("forall(s, x -> x is null)"), - Seq( - Row(false), - Row(true), - Row(true), - Row(null))) - checkAnswer(df.select(forall(col("s"), _.isNull)), - Seq( - Row(false), - Row(true), - Row(true), - Row(null))) + checkAnswer(df.selectExpr("forall(s, x -> x is null)"), resA) + checkAnswer(df.select(forall(col("s"), _.isNull)), resA) + checkAnswer(df.select(forall(col("s"), new JFunc { + def call(x: Column): Column = x.isNull + })), resA) } // Test with local relation, the Project will be evaluated without codegen @@ -2683,6 +2680,13 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSQLContext { } assert(ex2a.getMessage.contains("data type mismatch: argument 1 requires array type")) + val ex2b = intercept[AnalysisException] { + df.select(forall(col("i"), new JFunc { + def call(x: Column): Column = x + })) + } + assert(ex2b.getMessage.contains("data type mismatch: argument 1 requires array type")) + val ex3 = intercept[AnalysisException] { df.selectExpr("forall(s, x -> x)") } @@ -2693,6 +2697,13 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSQLContext { } assert(ex3a.getMessage.contains("data type mismatch: argument 2 requires boolean type")) + val ex3b = intercept[AnalysisException] { + df.select(forall(col("s"), new JFunc { + def call(x: Column): Column = x + })) + } + assert(ex3b.getMessage.contains("data type mismatch: argument 2 requires boolean type")) + val ex4 = intercept[AnalysisException] { df.selectExpr("forall(a, x -> x)") } @@ -2702,6 +2713,13 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSQLContext { df.select(forall(col("a"), x => x)) } assert(ex4a.getMessage.contains("cannot resolve '`a`'")) + + val ex4b = intercept[AnalysisException] { + df.select(forall(col("a"), new JFunc { + def call(x: Column): Column = x + })) + } + assert(ex4b.getMessage.contains("cannot resolve '`a`'")) } test("aggregate function - array for primitive type not containing null") { From 06b4c82b687f823176ba1b2db39dd654f2b6df36 Mon Sep 17 00:00:00 2001 From: Nik Date: Mon, 19 Aug 2019 18:39:36 -0400 Subject: [PATCH 21/38] Add test for Java API: aggregate --- .../spark/sql/DataFrameFunctionsSuite.scala | 163 ++++++++++-------- 1 file changed, 94 insertions(+), 69 deletions(-) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala index 4b9ae0a7aad45..4bc8dee71e25c 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala @@ -2730,31 +2730,36 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSparkSession { null ).toDF("i") + // aggregate(i, 0, (acc, x) -> acc + x) + val resA = Seq( + Row(25), + Row(31), + Row(0), + Row(null)) + + // aggregate(i, 0, (acc, x) -> acc + x, acc -> acc * 10) + val resB = Seq( + Row(250), + Row(310), + Row(0), + Row(null)) + def testArrayOfPrimitiveTypeNotContainsNull(): Unit = { - checkAnswer(df.selectExpr("aggregate(i, 0, (acc, x) -> acc + x)"), - Seq( - Row(25), - Row(31), - Row(0), - Row(null))) - checkAnswer(df.selectExpr("aggregate(i, 0, (acc, x) -> acc + x, acc -> acc * 10)"), - Seq( - Row(250), - Row(310), - Row(0), - Row(null))) - checkAnswer(df.select(aggregate(col("i"), lit(0), (acc, x) => acc + x)), - Seq( - Row(25), - Row(31), - Row(0), - Row(null))) - checkAnswer(df.select(aggregate(col("i"), lit(0), (acc, x) => acc + x, _ * 10)), - Seq( - Row(250), - Row(310), - Row(0), - Row(null))) + checkAnswer(df.selectExpr("aggregate(i, 0, (acc, x) -> acc + x)"), resA) + checkAnswer(df.selectExpr("aggregate(i, 0, (acc, x) -> acc + x, acc -> acc * 10)"), resB) + checkAnswer(df.select(aggregate(col("i"), lit(0), (acc, x) => acc + x)), resA) + checkAnswer(df.select(aggregate(col("i"), lit(0), (acc, x) => acc + x, _ * 10)), resB) + checkAnswer(df.select(aggregate(col("i"), lit(0), + new JFunc2 { + def call(acc: Column, x: Column): Column = acc + x + })), resA) + checkAnswer(df.select(aggregate(col("i"), lit(0), + new JFunc2 { + def call(acc: Column, x: Column): Column = acc + x + }, + new JFunc { + def call(x: Column): Column = x * 10 + })), resB) } // Test with local relation, the Project will be evaluated without codegen @@ -2772,34 +2777,40 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSparkSession { null ).toDF("i") + // aggregate(i, 0, (acc, x) -> acc + x) + val resA = Seq( + Row(25), + Row(null), + Row(0), + Row(null)) + + // aggregate(i, 0, (acc, x) -> acc + x, acc -> coalesce(acc, 0) * 10) + val resB = Seq( + Row(250), + Row(0), + Row(0), + Row(null)) + def testArrayOfPrimitiveTypeContainsNull(): Unit = { - checkAnswer(df.selectExpr("aggregate(i, 0, (acc, x) -> acc + x)"), - Seq( - Row(25), - Row(null), - Row(0), - Row(null))) + checkAnswer(df.selectExpr("aggregate(i, 0, (acc, x) -> acc + x)"), resA) checkAnswer( df.selectExpr("aggregate(i, 0, (acc, x) -> acc + x, acc -> coalesce(acc, 0) * 10)"), - Seq( - Row(250), - Row(0), - Row(0), - Row(null))) - checkAnswer(df.select(aggregate(col("i"), lit(0), (acc, x) => acc + x)), - Seq( - Row(25), - Row(null), - Row(0), - Row(null))) + resB) + checkAnswer(df.select(aggregate(col("i"), lit(0), (acc, x) => acc + x)), resA) checkAnswer( df.select( aggregate(col("i"), lit(0), (acc, x) => acc + x, acc => coalesce(acc, lit(0)) * 10)), - Seq( - Row(250), - Row(0), - Row(0), - Row(null))) + resB) + checkAnswer(df.select(aggregate(col("i"), lit(0), + new JFunc2 { + def call(acc: Column, x: Column): Column = acc + x + })), resA) + checkAnswer(df.select(aggregate(col("i"), lit(0), + new JFunc2 { + def call(acc: Column, x: Column): Column = acc + x + }, new JFunc { + def call(acc: Column): Column = coalesce(acc, lit(0)) * 10 + })), resB) } // Test with local relation, the Project will be evaluated without codegen @@ -2817,35 +2828,36 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSparkSession { (null, "d") ).toDF("ss", "s") + val resA = Seq( + Row("acab"), + Row(null), + Row("c"), + Row(null)) + + val resB = Seq( + Row("acab"), + Row(""), + Row("c"), + Row(null)) + def testNonPrimitiveType(): Unit = { - checkAnswer(df.selectExpr("aggregate(ss, s, (acc, x) -> concat(acc, x))"), - Seq( - Row("acab"), - Row(null), - Row("c"), - Row(null))) + checkAnswer(df.selectExpr("aggregate(ss, s, (acc, x) -> concat(acc, x))"), resA) checkAnswer( df.selectExpr("aggregate(ss, s, (acc, x) -> concat(acc, x), acc -> coalesce(acc , ''))"), - Seq( - Row("acab"), - Row(""), - Row("c"), - Row(null))) - checkAnswer(df.select(aggregate(col("ss"), col("s"), (acc, x) => concat(acc, x))), - Seq( - Row("acab"), - Row(null), - Row("c"), - Row(null))) + resB) + checkAnswer(df.select(aggregate(col("ss"), col("s"), (acc, x) => concat(acc, x))), resA) checkAnswer( df.select( aggregate(col("ss"), col("s"), (acc, x) => concat(acc, x), - acc => coalesce(acc, lit("")))), - Seq( - Row("acab"), - Row(""), - Row("c"), - Row(null))) + acc => coalesce(acc, lit("")))), resB) + checkAnswer(df.select(aggregate(col("ss"), col("s"), new JFunc2 { + def call(acc: Column, x: Column): Column = concat(acc, x) + })), resA) + checkAnswer(df.select(aggregate(col("ss"), col("s"), new JFunc2 { + def call(acc: Column, x: Column): Column = concat(acc, x) + }, new JFunc { + def call(acc: Column): Column = coalesce(acc, lit("")) + })), resB) } // Test with local relation, the Project will be evaluated without codegen @@ -2883,6 +2895,13 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSparkSession { } assert(ex3a.getMessage.contains("data type mismatch: argument 1 requires array type")) + val ex3b = intercept[AnalysisException] { + df.select(aggregate(col("i"), lit(0), new JFunc2 { + def call(acc: Column, x: Column): Column = x + })) + } + assert(ex3b.getMessage.contains("data type mismatch: argument 1 requires array type")) + val ex4 = intercept[AnalysisException] { df.selectExpr("aggregate(s, 0, (acc, x) -> x)") } @@ -2892,6 +2911,12 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSparkSession { df.select(aggregate(col("s"), lit(0), (acc, x) => x)) } assert(ex4a.getMessage.contains("data type mismatch: argument 3 requires int type")) + val ex4b = intercept[AnalysisException] { + df.select(aggregate(col("s"), lit(0), new JFunc2 { + def call(acc: Column, x: Column): Column = x + })) + } + assert(ex4b.getMessage.contains("data type mismatch: argument 3 requires int type")) val ex5 = intercept[AnalysisException] { df.selectExpr("aggregate(a, 0, (acc, x) -> x)") From 412ece54a8d9f2e80d3e7c6cd1ea050f4f7f3934 Mon Sep 17 00:00:00 2001 From: Nik Date: Mon, 19 Aug 2019 18:48:48 -0400 Subject: [PATCH 22/38] Add test for Java API: map_zip_with --- .../spark/sql/DataFrameFunctionsSuite.scala | 65 ++++++++++++------- 1 file changed, 42 insertions(+), 23 deletions(-) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala index 4bc8dee71e25c..b672437991768 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala @@ -2932,19 +2932,17 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSparkSession { (Map(5 -> 1L), null) ).toDF("m1", "m2") - checkAnswer(df.selectExpr("map_zip_with(m1, m2, (k, v1, v2) -> k == v1 + v2)"), - Seq( - Row(Map(8 -> true, 3 -> false, 6 -> true)), - Row(Map(10 -> null, 8 -> false, 4 -> null)), - Row(Map(5 -> null)), - Row(null))) + val resA = Seq( + Row(Map(8 -> true, 3 -> false, 6 -> true)), + Row(Map(10 -> null, 8 -> false, 4 -> null)), + Row(Map(5 -> null)), + Row(null)) - checkAnswer(df.select(map_zip_with(df("m1"), df("m2"), (k, v1, v2) => k === v1 + v2)), - Seq( - Row(Map(8 -> true, 3 -> false, 6 -> true)), - Row(Map(10 -> null, 8 -> false, 4 -> null)), - Row(Map(5 -> null)), - Row(null))) + checkAnswer(df.selectExpr("map_zip_with(m1, m2, (k, v1, v2) -> k == v1 + v2)"), resA) + checkAnswer(df.select(map_zip_with(df("m1"), df("m2"), (k, v1, v2) => k === v1 + v2)), resA) + checkAnswer(df.select(map_zip_with(df("m1"), df("m2"), new JFunc3 { + def call(k: Column, v1: Column, v2: Column): Column = k === v1 + v2 + })), resA) } test("map_zip_with function - map of non-primitive types") { @@ -2955,19 +2953,18 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSparkSession { (Map("a" -> "d"), null) ).toDF("m1", "m2") - checkAnswer(df.selectExpr("map_zip_with(m1, m2, (k, v1, v2) -> (v1, v2))"), - Seq( - Row(Map("z" -> Row("a", "c"), "y" -> Row("b", null), "x" -> Row("c", "a"))), - Row(Map("b" -> Row("a", null), "c" -> Row("d", "a"), "d" -> Row(null, "k"))), - Row(Map("a" -> Row("d", null))), - Row(null))) + val resA = Seq( + Row(Map("z" -> Row("a", "c"), "y" -> Row("b", null), "x" -> Row("c", "a"))), + Row(Map("b" -> Row("a", null), "c" -> Row("d", "a"), "d" -> Row(null, "k"))), + Row(Map("a" -> Row("d", null))), + Row(null)) + checkAnswer(df.selectExpr("map_zip_with(m1, m2, (k, v1, v2) -> (v1, v2))"), resA) checkAnswer(df.select(map_zip_with(col("m1"), col("m2"), (k, v1, v2) => struct(v1, v2))), - Seq( - Row(Map("z" -> Row("a", "c"), "y" -> Row("b", null), "x" -> Row("c", "a"))), - Row(Map("b" -> Row("a", null), "c" -> Row("d", "a"), "d" -> Row(null, "k"))), - Row(Map("a" -> Row("d", null))), - Row(null))) + resA) + checkAnswer(df.select(map_zip_with(col("m1"), col("m2"), new JFunc3 { + def call(k: Column, v1: Column, v2: Column): Column = struct(v1, v2) + })), resA) } test("map_zip_with function - invalid") { @@ -2992,6 +2989,14 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSparkSession { assert(ex2a.getMessage.contains("The input to function map_zip_with should have " + "been two maps with compatible key types")) + val ex2b = intercept[AnalysisException] { + df.select(map_zip_with(df("mis"), col("mmi"), new JFunc3 { + def call(x: Column, y: Column, z: Column): Column = concat(x, y, z) + })) + } + assert(ex2b.getMessage.contains("The input to function map_zip_with should have " + + "been two maps with compatible key types")) + val ex3 = intercept[AnalysisException] { df.selectExpr("map_zip_with(i, mis, (x, y, z) -> concat(x, y, z))") } @@ -3002,6 +3007,13 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSparkSession { } assert(ex3a.getMessage.contains("type mismatch: argument 1 requires map type")) + val ex3b = intercept[AnalysisException] { + df.select(map_zip_with(df("i"), col("mmi"), new JFunc3 { + def call(x: Column, y: Column, z: Column): Column = concat(x, y, z) + })) + } + assert(ex3b.getMessage.contains("type mismatch: argument 1 requires map type")) + val ex4 = intercept[AnalysisException] { df.selectExpr("map_zip_with(mis, i, (x, y, z) -> concat(x, y, z))") } @@ -3012,6 +3024,13 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSparkSession { } assert(ex4a.getMessage.contains("type mismatch: argument 2 requires map type")) + val ex4b = intercept[AnalysisException] { + df.select(map_zip_with(df("mis"), col("i"), new JFunc3 { + def call(x: Column, y: Column, z: Column): Column = concat(x, y, z) + })) + } + assert(ex4b.getMessage.contains("type mismatch: argument 2 requires map type")) + val ex5 = intercept[AnalysisException] { df.selectExpr("map_zip_with(mmi, mmi, (x, y, z) -> x)") } From c49e7d3bfbd89bf9fe1ba2abef9a357587869871 Mon Sep 17 00:00:00 2001 From: Nik Date: Tue, 20 Aug 2019 23:46:23 -0400 Subject: [PATCH 23/38] Add java tests for transform_keys, transform_values --- .../spark/sql/DataFrameFunctionsSuite.scala | 286 +++++++++++++----- 1 file changed, 208 insertions(+), 78 deletions(-) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala index b672437991768..4439902a715ef 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala @@ -3054,18 +3054,27 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSparkSession { Map[Array[Int], Boolean](Array(1, 2) -> false) ).toDF("y") + val res1 = Seq(Row(Map(2 -> 1, 18 -> 9, 16 -> 8, 14 -> 7))) - def testMapOfPrimitiveTypesCombination(): Unit = { - checkAnswer(dfExample1.selectExpr("transform_keys(i, (k, v) -> k + v)"), - Seq(Row(Map(2 -> 1, 18 -> 9, 16 -> 8, 14 -> 7)))) + val res2 = Seq(Row(Map("one" -> 1.0, "two" -> 1.4, "three" -> 1.7))) + val res2a = Seq(Row(Map(3 -> 1.0, 4 -> 1.4, 6 -> 1.7))) + val res2b = Seq(Row(Map(2.0 -> 1.0, 3.4 -> 1.4, 4.7 -> 1.7))) - checkAnswer(dfExample1.select(transform_keys(col("i"), (k, v) => k + v)), - Seq(Row(Map(2 -> 1, 18 -> 9, 16 -> 8, 14 -> 7)))) + val res3 = Seq(Row(Map(true -> true, true -> false))) + val res3a = Seq(Row(Map(50 -> true, 78 -> false))) - checkAnswer(dfExample2.selectExpr("transform_keys(j, " + - "(k, v) -> map_from_arrays(ARRAY(1, 2, 3), ARRAY('one', 'two', 'three'))[k])"), - Seq(Row(Map("one" -> 1.0, "two" -> 1.4, "three" -> 1.7)))) + val res4 = Seq(Row(Map(false -> false))) + + def testMapOfPrimitiveTypesCombination(): Unit = { + checkAnswer(dfExample1.selectExpr("transform_keys(i, (k, v) -> k + v)"), res1) + checkAnswer(dfExample1.select(transform_keys(col("i"), (k, v) => k + v)), res1) + checkAnswer(dfExample1.select(transform_keys(col("i"), new JFunc2 { + def call(k: Column, v: Column): Column = k + v + })), res1) + + checkAnswer(dfExample2.selectExpr("transform_keys(j, " + + "(k, v) -> map_from_arrays(ARRAY(1, 2, 3), ARRAY('one', 'two', 'three'))[k])"), res2) checkAnswer(dfExample2.select( transform_keys( col("j"), @@ -3078,40 +3087,63 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSparkSession { ) ) ), - Seq(Row(Map("one" -> 1.0, "two" -> 1.4, "three" -> 1.7)))) + res2) + checkAnswer(dfExample2.select( + transform_keys( + col("j"), + new JFunc2 { + def call(k: Column, v: Column): Column = element_at( + map_from_arrays( + array(lit(1), lit(2), lit(3)), + array(lit("one"), lit("two"), lit("three")) + ), + k + ) + } + ) + ), + res2) checkAnswer(dfExample2.selectExpr("transform_keys(j, (k, v) -> CAST(v * 2 AS BIGINT) + k)"), - Seq(Row(Map(3 -> 1.0, 4 -> 1.4, 6 -> 1.7)))) - + res2a) checkAnswer(dfExample2.select(transform_keys(col("j"), - (k, v) => (v * 2).cast("bigint") + k)), - Seq(Row(Map(3 -> 1.0, 4 -> 1.4, 6 -> 1.7)))) - - checkAnswer(dfExample2.selectExpr("transform_keys(j, (k, v) -> k + v)"), - Seq(Row(Map(2.0 -> 1.0, 3.4 -> 1.4, 4.7 -> 1.7)))) - - checkAnswer(dfExample2.select(transform_keys(col("j"), (k, v) => k + v)), - Seq(Row(Map(2.0 -> 1.0, 3.4 -> 1.4, 4.7 -> 1.7)))) - - checkAnswer(dfExample3.selectExpr("transform_keys(x, (k, v) -> k % 2 = 0 OR v)"), - Seq(Row(Map(true -> true, true -> false)))) - - checkAnswer(dfExample3.select(transform_keys(col("x"), (k, v) => k % 2 === 0 || v)), - Seq(Row(Map(true -> true, true -> false)))) - - checkAnswer(dfExample3.selectExpr("transform_keys(x, (k, v) -> if(v, 2 * k, 3 * k))"), - Seq(Row(Map(50 -> true, 78 -> false)))) - + (k, v) => (v * 2).cast("bigint") + k)), res2a) + checkAnswer(dfExample2.select(transform_keys(col("j"), + new JFunc2 { + def call(k: Column, v: Column): Column = + (v * 2).cast("bigint") + k + })), res2a) + + checkAnswer(dfExample2.selectExpr("transform_keys(j, (k, v) -> k + v)"), res2b) + checkAnswer(dfExample2.select(transform_keys(col("j"), (k, v) => k + v)), res2b) + checkAnswer(dfExample2.select(transform_keys(col("j"), new JFunc2 { + def call(k: Column, v: Column): Column = k + v + })), res2b) + + checkAnswer(dfExample3.selectExpr("transform_keys(x, (k, v) -> k % 2 = 0 OR v)"), res3) + checkAnswer(dfExample3.select(transform_keys(col("x"), (k, v) => k % 2 === 0 || v)), res3) + checkAnswer(dfExample3.select(transform_keys(col("x"), new JFunc2 { + def call(k: Column, v: Column): Column = k % 2 === 0 || v + })), res3) + + checkAnswer(dfExample3.selectExpr("transform_keys(x, (k, v) -> if(v, 2 * k, 3 * k))"), res3a) + checkAnswer(dfExample3.select(transform_keys(col("x"), + (k, v) => when(v, k * 2).otherwise(k * 3))), res3a) checkAnswer(dfExample3.select(transform_keys(col("x"), - (k, v) => when(v, k * 2).otherwise(k * 3))), - Seq(Row(Map(50 -> true, 78 -> false)))) + new JFunc2 { + def call(k: Column, v: Column): Column = + when(v, k * 2).otherwise(k * 3) + })), res3a) checkAnswer(dfExample4.selectExpr("transform_keys(y, (k, v) -> array_contains(k, 3) AND v)"), - Seq(Row(Map(false -> false)))) - + res4) + checkAnswer(dfExample4.select(transform_keys(col("y"), + (k, v) => array_contains(k, lit(3)) && v)), res4) checkAnswer(dfExample4.select(transform_keys(col("y"), - (k, v) => array_contains(k, lit(3)) && v)), - Seq(Row(Map(false -> false)))) + new JFunc2 { + def call(k: Column, v: Column): Column = + array_contains(k, lit(3)) && v + })), res4) } // Test with local relation, the Project will be evaluated without codegen @@ -3154,6 +3186,13 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSparkSession { } assert(ex3a.getMessage.contains("Cannot use null as map key")) + val ex3b = intercept[Exception] { + dfExample1.select(transform_keys(col("i"), new JFunc2 { + def call(k: Column, v: Column): Column = v + })).show() + } + assert(ex3b.getMessage.contains("Cannot use null as map key")) + val ex4 = intercept[AnalysisException] { dfExample2.selectExpr("transform_keys(j, (k, v) -> k + 1)") } @@ -3182,82 +3221,124 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSparkSession { Map[Int, Array[Int]](1 -> Array(1, 2)) ).toDF("c") - def testMapOfPrimitiveTypesCombination(): Unit = { - checkAnswer(dfExample1.selectExpr("transform_values(i, (k, v) -> k + v)"), - Seq(Row(Map(1 -> 2, 9 -> 18, 8 -> 16, 7 -> 14)))) + val res_1_1 = Seq(Row(Map(1 -> 2, 9 -> 18, 8 -> 16, 7 -> 14))) - checkAnswer(dfExample2.selectExpr( - "transform_values(x, (k, v) -> if(k, v, CAST(k AS String)))"), - Seq(Row(Map(false -> "false", true -> "def")))) + val res_2_1 = Seq(Row(Map(false -> "false", true -> "def"))) + val res_2_2 = Seq(Row(Map(false -> true, true -> false))) - checkAnswer(dfExample2.selectExpr("transform_values(x, (k, v) -> NOT k AND v = 'abc')"), - Seq(Row(Map(false -> true, true -> false)))) + val res_3_1 = Seq(Row(Map("a" -> 1, "b" -> 4, "c" -> 9))) + val res_3_2 = Seq(Row(Map("a" -> "a:1", "b" -> "b:2", "c" -> "c:3"))) + val res_3_3 = Seq(Row(Map("a" -> "a1", "b" -> "b2", "c" -> "c3"))) - checkAnswer(dfExample3.selectExpr("transform_values(y, (k, v) -> v * v)"), - Seq(Row(Map("a" -> 1, "b" -> 4, "c" -> 9)))) + val res_4_1 = Seq(Row(Map(1 -> "one_1.0", 2 -> "two_1.4", 3 ->"three_1.7"))) + val res_4_2 = Seq(Row(Map(1 -> 0.0, 2 -> 0.6000000000000001, 3 -> 1.3))) - checkAnswer(dfExample3.selectExpr( - "transform_values(y, (k, v) -> k || ':' || CAST(v as String))"), - Seq(Row(Map("a" -> "a:1", "b" -> "b:2", "c" -> "c:3")))) - - checkAnswer( - dfExample3.selectExpr("transform_values(y, (k, v) -> concat(k, cast(v as String)))"), - Seq(Row(Map("a" -> "a1", "b" -> "b2", "c" -> "c3")))) - - checkAnswer( - dfExample4.selectExpr( - "transform_values(" + - "z,(k, v) -> map_from_arrays(ARRAY(1, 2, 3), " + - "ARRAY('one', 'two', 'three'))[k] || '_' || CAST(v AS String))"), - Seq(Row(Map(1 -> "one_1.0", 2 -> "two_1.4", 3 ->"three_1.7")))) - - checkAnswer( - dfExample4.selectExpr("transform_values(z, (k, v) -> k-v)"), - Seq(Row(Map(1 -> 0.0, 2 -> 0.6000000000000001, 3 -> 1.3)))) - - checkAnswer( - dfExample5.selectExpr("transform_values(c, (k, v) -> k + cardinality(v))"), - Seq(Row(Map(1 -> 3)))) + val res_5_1 = Seq(Row(Map(1 -> 3))) + def testMapOfPrimitiveTypesCombination(): Unit = { + checkAnswer(dfExample1.selectExpr("transform_values(i, (k, v) -> k + v)"), + res_1_1) checkAnswer(dfExample1.select(transform_values(col("i"), (k, v) => k + v)), - Seq(Row(Map(1 -> 2, 9 -> 18, 8 -> 16, 7 -> 14)))) + res_1_1) + checkAnswer(dfExample1.select(transform_values(col("i"), new JFunc2 { + def call(k: Column, v: Column): Column = k + v + })), res_1_1) + + checkAnswer(dfExample2.selectExpr( + "transform_values(x, (k, v) -> if(k, v, CAST(k AS String)))"), res_2_1) checkAnswer(dfExample2.select( transform_values(col("x"), (k, v) => when(k, v).otherwise(k.cast("string")))), - Seq(Row(Map(false -> "false", true -> "def")))) + res_2_1) + checkAnswer(dfExample2.select( + transform_values(col("x"), new JFunc2 { + def call(k: Column, v: Column): Column = + when(k, v).otherwise(k.cast("string")) + })), res_2_1) + + checkAnswer(dfExample2.selectExpr("transform_values(x, (k, v) -> NOT k AND v = 'abc')"), + res_2_2) checkAnswer(dfExample2.select(transform_values(col("x"), - (k, v) => (!k) && v === "abc")), - Seq(Row(Map(false -> true, true -> false)))) + (k, v) => (!k) && v === "abc")), res_2_2) + checkAnswer(dfExample2.select(transform_values(col("x"), + new JFunc2 { + def call(k: Column, v: Column): Column = (!k) && v === "abc" + })), res_2_2) + + checkAnswer(dfExample3.selectExpr("transform_values(y, (k, v) -> v * v)"), res_3_1) checkAnswer(dfExample3.select(transform_values(col("y"), (k, v) => v * v)), - Seq(Row(Map("a" -> 1, "b" -> 4, "c" -> 9)))) + res_3_1) + checkAnswer(dfExample3.select(transform_values(col("y"), new JFunc2{ + def call(k: Column, v: Column): Column = v * v + })), res_3_1) + + checkAnswer(dfExample3.selectExpr( + "transform_values(y, (k, v) -> k || ':' || CAST(v as String))"), res_3_2) checkAnswer(dfExample3.select( transform_values(col("y"), (k, v) => concat(k, lit(":"), v.cast("string")))), - Seq(Row(Map("a" -> "a:1", "b" -> "b:2", "c" -> "c:3")))) + res_3_2) + checkAnswer(dfExample3.select( + transform_values(col("y"), new JFunc2 { + def call(k: Column, v: Column): Column = + concat(k, lit(":"), v.cast("string")) + })), res_3_2) + + checkAnswer( + dfExample3.selectExpr("transform_values(y, (k, v) -> concat(k, cast(v as String)))"), + res_3_3) checkAnswer( dfExample3.select(transform_values(col("y"), (k, v) => concat(k, v.cast("string")))), - Seq(Row(Map("a" -> "a1", "b" -> "b2", "c" -> "c3")))) + res_3_3) + checkAnswer( + dfExample3.select(transform_values(col("y"), new JFunc2 { + def call(k: Column, v: Column): Column = + concat(k, v.cast("string")) + })), res_3_3) + + checkAnswer( + dfExample4.selectExpr( + "transform_values(" + + "z,(k, v) -> map_from_arrays(ARRAY(1, 2, 3), " + + "ARRAY('one', 'two', 'three'))[k] || '_' || CAST(v AS String))"), + res_4_1) val testMap = map_from_arrays( array(lit(1), lit(2), lit(3)), array(lit("one"), lit("two"), lit("three")) ) - checkAnswer( dfExample4.select(transform_values(col("z"), (k, v) => concat(element_at(testMap, k), lit("_"), v.cast("string")))), - Seq(Row(Map(1 -> "one_1.0", 2 -> "two_1.4", 3 ->"three_1.7")))) + res_4_1) + checkAnswer( + dfExample4.select(transform_values(col("z"), new JFunc2 { + def call(k: Column, v: Column): Column = + concat(element_at(testMap, k), lit("_"), v.cast("string")) + })), res_4_1) + checkAnswer( + dfExample4.selectExpr("transform_values(z, (k, v) -> k-v)"), res_4_2) checkAnswer( dfExample4.select(transform_values(col("z"), (k, v) => k - v)), - Seq(Row(Map(1 -> 0.0, 2 -> 0.6000000000000001, 3 -> 1.3)))) + res_4_2) + checkAnswer( + dfExample4.select(transform_values(col("z"), new JFunc2 { + def call(k: Column, v: Column): Column = k - v + })), res_4_2) + checkAnswer( + dfExample5.selectExpr("transform_values(c, (k, v) -> k + cardinality(v))"), res_5_1) checkAnswer( dfExample5.select(transform_values(col("c"), (k, v) => k + size(v))), - Seq(Row(Map(1 -> 3)))) + res_5_1) + checkAnswer( + dfExample5.select(transform_values(col("c"), new JFunc2 { + def call(k: Column, v: Column): Column = k + size(v) + })), res_5_1) } // Test with local relation, the Project will be evaluated without codegen @@ -3323,6 +3404,35 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSparkSession { checkAnswer(dfExample1.select(transform_values(col("i"), (k, v) => v.cast("bigint"))), Seq(Row(Map.empty[BigInt, BigInt]))) + + checkAnswer(dfExample1.select(transform_values(col("i"), + new JFunc2 { + def call(k: Column, v: Column): Column = lit(null).cast("int")})), + Seq(Row(Map.empty[Integer, Integer]))) + + checkAnswer(dfExample1.select(transform_values(col("i"), new JFunc2 { + def call(k: Column, v: Column): Column = k})), + Seq(Row(Map.empty[Integer, Integer]))) + + checkAnswer(dfExample1.select(transform_values(col("i"), new JFunc2 { + def call(k: Column, v: Column): Column = v})), + Seq(Row(Map.empty[Integer, Integer]))) + + checkAnswer(dfExample1.select(transform_values(col("i"), new JFunc2 { + def call(k: Column, v: Column): Column = lit(0)})), + Seq(Row(Map.empty[Integer, Integer]))) + + checkAnswer(dfExample1.select(transform_values(col("i"), new JFunc2 { + def call(k: Column, v: Column): Column = lit("value")})), + Seq(Row(Map.empty[Integer, String]))) + + checkAnswer(dfExample1.select(transform_values(col("i"), new JFunc2 { + def call(k: Column, v: Column): Column = lit(true)})), + Seq(Row(Map.empty[Integer, Boolean]))) + + checkAnswer(dfExample1.select(transform_values(col("i"), new JFunc2 { + def call(k: Column, v: Column): Column = v.cast("bigint")})), + Seq(Row(Map.empty[BigInt, BigInt]))) } testEmpty() @@ -3356,6 +3466,19 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSparkSession { transform_values(col("b"), (k, v) => when(v.isNull, k + 1).otherwise(k + 2)) ), Seq(Row(Map(1 -> 3, 2 -> 4, 3 -> 4)))) + + checkAnswer(dfExample1.select(transform_values(col("a"), + new JFunc2 { + def call(k: Column, v: Column): Column = lit(null).cast("int") + })), + Seq(Row(Map[Int, Integer](1 -> null, 2 -> null, 3 -> null, 4 -> null)))) + + checkAnswer(dfExample2.select( + transform_values(col("b"), new JFunc2 { + def call(k: Column, v: Column): Column = + when(v.isNull, k + 1).otherwise(k + 2) + })), + Seq(Row(Map(1 -> 3, 2 -> 4, 3 -> 4)))) } testNullValue() @@ -3400,6 +3523,13 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSparkSession { } assert(ex3a.getMessage.contains( "data type mismatch: argument 1 requires map type")) + + val ex3b = intercept[AnalysisException] { + dfExample3.select(transform_values(col("x"), new JFunc2 { + def call(k: Column, v: Column): Column = k + 1})) + } + assert(ex3b.getMessage.contains( + "data type mismatch: argument 1 requires map type")) } testInvalidLambdaFunctions() From 182a08ba872c99bc062d8bcd87f0b930f9bae96c Mon Sep 17 00:00:00 2001 From: Nik Date: Tue, 20 Aug 2019 23:49:40 -0400 Subject: [PATCH 24/38] Add tests for java zip_with function --- .../spark/sql/DataFrameFunctionsSuite.scala | 17 +++++++++++++++++ 1 file changed, 17 insertions(+) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala index 4439902a715ef..3d41da8066a91 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala @@ -3557,6 +3557,8 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSparkSession { Row(null)) checkAnswer(df1.selectExpr("zip_with(val1, val2, (x, y) -> x + y)"), expectedValue1) checkAnswer(df1.select(zip_with(df1("val1"), df1("val2"), (x, y) => x + y)), expectedValue1) + checkAnswer(df1.select(zip_with(df1("val1"), df1("val2"), new JFunc2 { + def call(x: Column, y: Column): Column = x + y})), expectedValue1) val expectedValue2 = Seq( Row(Seq(Row(1L, 1), Row(2L, null), Row(null, 3))), Row(Seq(Row(4L, 1), Row(11L, 2), Row(null, 3)))) @@ -3565,6 +3567,11 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSparkSession { df2.select(zip_with(df2("val1"), df2("val2"), (x, y) => struct(y, x))), expectedValue2 ) + checkAnswer( + df2.select(zip_with(df2("val1"), df2("val2"), new JFunc2 { + def call(x: Column, y: Column): Column = struct(y, x)})), + expectedValue2 + ) } test("arrays zip_with function - for non-primitive types") { @@ -3587,6 +3594,11 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSparkSession { df.select(zip_with(col("val1"), col("val2"), (x, y) => struct(y, x))), expectedValue1 ) + checkAnswer( + df.select(zip_with(col("val1"), col("val2"), new JFunc2 { + def call(x: Column, y: Column): Column = struct(y, x)})), + expectedValue1 + ) } test("arrays zip_with function - invalid") { @@ -3612,6 +3624,11 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSparkSession { df.select(zip_with(df("i"), df("a2"), (acc, x) => x)) } assert(ex3a.getMessage.contains("data type mismatch: argument 1 requires array type")) + val ex3b = intercept[AnalysisException] { + df.select(zip_with(df("i"), df("a2"), new JFunc2 { + def call(acc: Column, x: Column): Column = x})) + } + assert(ex3b.getMessage.contains("data type mismatch: argument 1 requires array type")) val ex4 = intercept[AnalysisException] { df.selectExpr("zip_with(a1, a, (acc, x) -> x)") } From ef6b6bbab0aaed7f4f19918a9a04530c79996b16 Mon Sep 17 00:00:00 2001 From: Nik Date: Wed, 21 Aug 2019 18:56:46 -0400 Subject: [PATCH 25/38] Remove JavaFunction overloads and add Java transform test --- .../org/apache/spark/sql/functions.scala | 159 ------------- .../sql/JavaHigherOrderFunctionsSuite.java | 215 ++++++++++++++++++ .../org/apache/spark/sql/JavaTestUtils.java | 47 ++++ 3 files changed, 262 insertions(+), 159 deletions(-) create mode 100644 sql/core/src/test/java/test/org/apache/spark/sql/JavaHigherOrderFunctionsSuite.java create mode 100644 sql/core/src/test/java/test/org/apache/spark/sql/JavaTestUtils.java diff --git a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala index 9c2a57364b738..abbf94d61a96d 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala @@ -24,7 +24,6 @@ import scala.util.Try import scala.util.control.NonFatal import org.apache.spark.annotation.Stable -import org.apache.spark.api.java.function.{Function => JavaFunction, Function2 => JavaFunction2, Function3 => JavaFunction3} import org.apache.spark.sql.api.java._ import org.apache.spark.sql.catalyst.ScalaReflection import org.apache.spark.sql.catalyst.analysis.{Star, UnresolvedFunction} @@ -3365,27 +3364,6 @@ object functions { LambdaFunction(function, Seq(x, y, z)) } - private def createLambda(f: JavaFunction[Column, Column]) = { - val x = UnresolvedNamedLambdaVariable(Seq("x")) - val function = f.call(Column(x)).expr - LambdaFunction(function, Seq(x)) - } - - private def createLambda(f: JavaFunction2[Column, Column, Column]) = { - val x = UnresolvedNamedLambdaVariable(Seq("x")) - val y = UnresolvedNamedLambdaVariable(Seq("y")) - val function = f.call(Column(x), Column(y)).expr - LambdaFunction(function, Seq(x, y)) - } - - private def createLambda(f: JavaFunction3[Column, Column, Column, Column]) = { - val x = UnresolvedNamedLambdaVariable(Seq("x")) - val y = UnresolvedNamedLambdaVariable(Seq("y")) - val z = UnresolvedNamedLambdaVariable(Seq("z")) - val function = f.call(Column(x), Column(y), Column(z)).expr - LambdaFunction(function, Seq(x, y, z)) - } - /** * (Scala-specific) Returns an array of elements after applying a tranformation to each element * in the input array. @@ -3521,143 +3499,6 @@ object functions { MapZipWith(left.expr, right.expr, createLambda(f)) } - /** - * (Java-specific) Returns an array of elements after applying a tranformation to each element - * in the input array. - * - * @group collection_funcs - * @since 3.0.0 - */ - def transform(column: Column, f: JavaFunction[Column, Column]): Column = withExpr { - ArrayTransform(column.expr, createLambda(f)) - } - - /** - * (Java-specific) Returns an array of elements after applying a tranformation to each element - * in the input array. - * - * @group collection_funcs - * @since 3.0.0 - */ - def transform(column: Column, f: JavaFunction2[Column, Column, Column]): Column = withExpr { - ArrayTransform(column.expr, createLambda(f)) - } - - /** - * (Java-specific) Returns whether a predicate holds for one or more elements in the array. - * - * @group collection_funcs - * @since 3.0.0 - */ - def exists(column: Column, f: JavaFunction[Column, Column]): Column = withExpr { - ArrayExists(column.expr, createLambda(f)) - } - - /** - * (Java-specific) Returns whether a predicate holds for every element in the array. - * - * @group collection_funcs - * @since 3.0.0 - */ - def forall(column: Column, f: JavaFunction[Column, Column]): Column = withExpr { - ArrayForAll(column.expr, createLambda(f)) - } - - /** - * (Java-specific) Returns an array of elements for which a predicate holds in a given array. - * - * @group collection_funcs - * @since 3.0.0 - */ - def filter(column: Column, f: JavaFunction[Column, Column]): Column = withExpr { - ArrayFilter(column.expr, createLambda(f)) - } - - /** - * (Java-specific) Applies a binary operator to an initial state and all elements in the array, - * and reduces this to a single state. The final state is converted into the final result - * by applying a finish function. - * - * @group collection_funcs - * @since 3.0.0 - */ - def aggregate(expr: Column, zero: Column, merge: JavaFunction2[Column, Column, Column], - finish: JavaFunction[Column, Column]): Column = withExpr { - ArrayAggregate( - expr.expr, - zero.expr, - createLambda(merge), - createLambda(finish) - ) - } - - /** - * (Java-specific) Applies a binary operator to an initial state and all elements in the array, - * and reduces this to a single state. - * - * @group collection_funcs - * @since 3.0.0 - */ - def aggregate(expr: Column, zero: Column, merge: JavaFunction2[Column, Column, Column]): Column = - aggregate( - expr, zero, merge, new JavaFunction[Column, Column] { def call(c: Column): Column = c }) - - /** - * (Java-specific) Merge two given arrays, element-wise, into a signle array using a function. - * If one array is shorter, nulls are appended at the end to match the length of the longer - * array, before applying the function. - * - * @group collection_funcs - * @since 3.0.0 - */ - def zip_with(left: Column, right: Column, f: JavaFunction2[Column, Column, Column]): Column = - withExpr { - ZipWith(left.expr, right.expr, createLambda(f)) - } - - /** - * (Java-specific) Applies a function to every key-value pair in a map and returns - * a map with the results of those applications as the new keys for the pairs. - * - * @group collection_funcs - * @since 3.0.0 - */ - def transform_keys(expr: Column, f: JavaFunction2[Column, Column, Column]): Column = withExpr { - TransformKeys(expr.expr, createLambda(f)) - } - - /** - * (Java-specific) Applies a function to every key-value pair in a map and returns - * a map with the results of those applications as the new values for the pairs. - * - * @group collection_funcs - * @since 3.0.0 - */ - def transform_values(expr: Column, f: JavaFunction2[Column, Column, Column]): Column = withExpr { - TransformValues(expr.expr, createLambda(f)) - } - - /** - * (Java-specific) Returns a map whose key-value pairs satisfy a predicate. - * - * @group collection_funcs - * @since 3.0.0 - */ - def map_filter(expr: Column, f: JavaFunction2[Column, Column, Column]): Column = withExpr { - MapFilter(expr.expr, createLambda(f)) - } - - /** - * (Java-specific) Merge two given maps, key-wise into a single map using a function. - * - * @group collection_funcs - * @since 3.0.0 - */ - def map_zip_with(left: Column, right: Column, - f: JavaFunction3[Column, Column, Column, Column]): Column = withExpr { - MapZipWith(left.expr, right.expr, createLambda(f)) - } - /** * Creates a new row for each element in the given array or map column. * Uses the default column name `col` for elements in the array and diff --git a/sql/core/src/test/java/test/org/apache/spark/sql/JavaHigherOrderFunctionsSuite.java b/sql/core/src/test/java/test/org/apache/spark/sql/JavaHigherOrderFunctionsSuite.java new file mode 100644 index 0000000000000..b1cc9682d0813 --- /dev/null +++ b/sql/core/src/test/java/test/org/apache/spark/sql/JavaHigherOrderFunctionsSuite.java @@ -0,0 +1,215 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package test.org.apache.spark.sql; + +import java.util.List; + +import scala.collection.Seq; + +import org.junit.After; +import org.junit.Assert; +import org.junit.Before; +import org.junit.Test; + +import org.apache.spark.sql.Dataset; +import org.apache.spark.sql.Row; +import org.apache.spark.sql.types.*; +import static org.apache.spark.sql.types.DataTypes.*; +import static org.apache.spark.sql.functions.*; +import org.apache.spark.sql.test.TestSparkSession; +import static test.org.apache.spark.sql.JavaTestUtils.*; +import test.org.apache.spark.sql.JavaTestUtils; + +public class JavaHigherOrderFunctionsSuite { + private transient TestSparkSession spark; + + @Before + public void setUp() { + spark = new TestSparkSession(); + } + + @After + public void tearDown() { + spark.stop(); + spark = null; + } + + @Test + public void testTransformArrayPrimitiveNotContainingNull() { + List data = toRows( + makeArray(1, 9, 8, 7), + makeArray(5, 8, 9, 7, 2), + JavaTestUtils.makeArray(), + null + ); + StructType schema = new StructType() + .add("i", new ArrayType(IntegerType, true), true); + Dataset df = spark.createDataFrame(data, schema); + + Runnable f = () -> { + checkAnswer( + df.select(transform(col("i"), x -> x.plus(1))), + toRows( + makeArray(2, 10, 9, 8), + makeArray(6, 9, 10, 8, 3), + JavaTestUtils.makeArray(), + null + )); + checkAnswer( + df.select(transform(col("i"), (x, i) -> x.plus(i))), + toRows( + makeArray(1, 10, 10, 10), + makeArray(5, 9, 11, 10, 6), + JavaTestUtils.makeArray(), + null + )); + }; + + // Test with local relation, the Project will be evaluated without codegen + f.run(); + // Test with cached relation, the Project will be evaluated with codegen + df.cache(); + f.run(); + } + + @Test + public void testTransformArrayPrimitiveContainingNull() { + List data = toRows( + makeArray(1, 9, 8, null, 7), + makeArray(5, null, 8, 9, 7, 2), + JavaTestUtils.makeArray(), + null + ); + StructType schema = new StructType() + .add("i", new ArrayType(IntegerType, true), true); + Dataset df = spark.createDataFrame(data, schema); + + Runnable f = () -> { + checkAnswer( + df.select(transform(col("i"), x -> x.plus(1))), + toRows( + makeArray(2, 10, 9, null, 8), + makeArray(6, null, 9, 10, 8, 3), + JavaTestUtils.makeArray(), + null + )); + checkAnswer( + df.select(transform(col("i"), (x, i) -> x.plus(i))), + toRows( + makeArray(1, 10, 10, null, 11), + makeArray(5, null, 10, 12, 11, 7), + JavaTestUtils.makeArray(), + null + )); + }; + + // Test with local relation, the Project will be evaluated without codegen + f.run(); + df.cache(); + // Test with cached relation, the Project will be evaluated with codegen + f.run(); + } + + @Test + public void testTransformArrayNonPrimitive() { + List data = toRows( + makeArray("c", "a", "b"), + makeArray("b", null, "c", null), + JavaTestUtils.makeArray(), + null + ); + StructType schema = new StructType() + .add("s", new ArrayType(StringType, true), true); + Dataset df = spark.createDataFrame(data, schema); + + Runnable f = () -> { + checkAnswer(df.select(transform(col("s"), x -> concat(x, x))), + toRows( + makeArray("cc", "aa", "bb"), + makeArray("bb", null, "cc", null), + JavaTestUtils.makeArray(), + null + )); + checkAnswer(df.select(transform(col("s"), (x, i) -> concat(x, i))), + toRows( + makeArray("c0", "a1", "b2"), + makeArray("b0", null, "c2", null), + JavaTestUtils.makeArray(), + null + )); + }; + + // Test with local relation, the Project will be evaluated without codegen + f.run(); + // Test with cached relation, the Project will be evaluated with codegen + df.cache(); + f.run(); + } + + @Test + public void testTransformSpecialCases() { + List data = toRows( + makeArray("c", "a", "b"), + makeArray("b", null, "c", null), + JavaTestUtils.makeArray(), + null + ); + StructType schema = new StructType() + .add("s", new ArrayType(StringType, true), true); + Dataset df = spark.createDataFrame(data, schema); + + Runnable f = () -> { + checkAnswer(df.select(transform(col("arg"), arg -> arg)), + toRows( + makeArray("c", "a", "b"), + makeArray("b", null, "c", null), + JavaTestUtils.makeArray(), + null)); + checkAnswer(df.select(transform(col("arg"), x -> col("arg"))), + toRows( + makeArray( + makeArray("c", "a", "b"), + makeArray("c", "a", "b"), + makeArray("c", "a", "b") + ), + makeArray( + makeArray("b", null, "c", null), + makeArray("b", null, "c", null), + makeArray("b", null, "c", null), + makeArray("b", null, "c", null) + ), + JavaTestUtils.makeArray(), + null)); + checkAnswer(df.select(transform(col("arg"), x -> concat(col("arg"), array(x)))), + toRows( + makeArray( + makeArray("c", "a", "b", "c"), + makeArray("c", "a", "b", "c"), + makeArray("c", "a", "b", "c") + ), + makeArray( + makeArray("b", null, "c", null, "b"), + makeArray("b", null, "c", null, null), + makeArray("b", null, "c", null, "b"), + makeArray("b", null, "c", null, null) + ), + JavaTestUtils.makeArray(), + null)); + }; + } +} diff --git a/sql/core/src/test/java/test/org/apache/spark/sql/JavaTestUtils.java b/sql/core/src/test/java/test/org/apache/spark/sql/JavaTestUtils.java new file mode 100644 index 0000000000000..7fc6460e7352c --- /dev/null +++ b/sql/core/src/test/java/test/org/apache/spark/sql/JavaTestUtils.java @@ -0,0 +1,47 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package test.org.apache.spark.sql; + +import java.util.Arrays; +import java.util.List; +import static java.util.stream.Collectors.toList; + +import scala.collection.mutable.WrappedArray; + +import static org.junit.Assert.assertEquals; + +import org.apache.spark.sql.Dataset; +import org.apache.spark.sql.Row; +import org.apache.spark.sql.RowFactory; + +public class JavaTestUtils { + public static void checkAnswer(Dataset actual, List expected) { + assertEquals(expected, actual.collectAsList()); + } + + public static List toRows(Object... objs) { + return Arrays.asList(objs) + .stream() + .map(RowFactory::create) + .collect(toList()); + } + + public static WrappedArray makeArray(T... ts) { + return WrappedArray.make(ts); + } +} From 527c0cbd9878b4f56bd38dacac3783b80cd6b659 Mon Sep 17 00:00:00 2001 From: Nik Date: Wed, 21 Aug 2019 19:02:45 -0400 Subject: [PATCH 26/38] Remove (Scala-specifc) from higher order functions --- .../org/apache/spark/sql/functions.scala | 24 +++++++++---------- 1 file changed, 12 insertions(+), 12 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala index d1b399b95a8f5..9413f51aeaa9f 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala @@ -3407,7 +3407,7 @@ object functions { } /** - * (Scala-specific) Returns an array of elements after applying a tranformation to each element + * Returns an array of elements after applying a tranformation to each element * in the input array. * * @group collection_funcs @@ -3418,7 +3418,7 @@ object functions { } /** - * (Scala-specific) Returns an array of elements after applying a tranformation to each element + * Returns an array of elements after applying a tranformation to each element * in the input array. * * @group collection_funcs @@ -3429,7 +3429,7 @@ object functions { } /** - * (Scala-specific) Returns whether a predicate holds for one or more elements in the array. + * Returns whether a predicate holds for one or more elements in the array. * * @group collection_funcs * @since 3.0.0 @@ -3439,7 +3439,7 @@ object functions { } /** - * (Scala-specific) Returns whether a predicate holds for every element in the array. + * Returns whether a predicate holds for every element in the array. * * @group collection_funcs * @since 3.0.0 @@ -3449,7 +3449,7 @@ object functions { } /** - * (Scala-specific) Returns an array of elements for which a predicate holds in a given array. + * Returns an array of elements for which a predicate holds in a given array. * * @group collection_funcs * @since 3.0.0 @@ -3459,7 +3459,7 @@ object functions { } /** - * (Scala-specific) Applies a binary operator to an initial state and all elements in the array, + * Applies a binary operator to an initial state and all elements in the array, * and reduces this to a single state. The final state is converted into the final result * by applying a finish function. * @@ -3477,7 +3477,7 @@ object functions { } /** - * (Scala-specific) Applies a binary operator to an initial state and all elements in the array, + * Applies a binary operator to an initial state and all elements in the array, * and reduces this to a single state. * * @group collection_funcs @@ -3487,7 +3487,7 @@ object functions { aggregate(expr, zero, merge, c => c) /** - * (Scala-specific) Merge two given arrays, element-wise, into a signle array using a function. + * Merge two given arrays, element-wise, into a signle array using a function. * If one array is shorter, nulls are appended at the end to match the length of the longer * array, before applying the function. * @@ -3499,7 +3499,7 @@ object functions { } /** - * (Scala-specific) Applies a function to every key-value pair in a map and returns + * Applies a function to every key-value pair in a map and returns * a map with the results of those applications as the new keys for the pairs. * * @group collection_funcs @@ -3510,7 +3510,7 @@ object functions { } /** - * (Scala-specific) Applies a function to every key-value pair in a map and returns + * Applies a function to every key-value pair in a map and returns * a map with the results of those applications as the new values for the pairs. * * @group collection_funcs @@ -3521,7 +3521,7 @@ object functions { } /** - * (Scala-specific) Returns a map whose key-value pairs satisfy a predicate. + * Returns a map whose key-value pairs satisfy a predicate. * * @group collection_funcs * @since 3.0.0 @@ -3531,7 +3531,7 @@ object functions { } /** - * (Scala-specific) Merge two given maps, key-wise into a single map using a function. + * Merge two given maps, key-wise into a single map using a function. * * @group collection_funcs * @since 3.0.0 From 013187fd60aeea4c9581fea126027e050b88fff2 Mon Sep 17 00:00:00 2001 From: Nik Vanderhoof Date: Tue, 17 Sep 2019 00:55:51 -0400 Subject: [PATCH 27/38] Remove java tests from DataFrameFunctionsSuite --- .../spark/sql/DataFrameFunctionsSuite.scala | 1163 ++++++----------- 1 file changed, 431 insertions(+), 732 deletions(-) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala index 3d41da8066a91..dbe72b64b4d17 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala @@ -23,7 +23,6 @@ import java.util.TimeZone import scala.util.Random -import org.apache.spark.api.java.function.{Function => JavaFunction, Function2 => JavaFunction2, Function3 => JavaFunction3} import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.Expression import org.apache.spark.sql.catalyst.expressions.codegen.CodegenFallback @@ -38,9 +37,6 @@ import org.apache.spark.sql.types._ * Test suite for functions in [[org.apache.spark.sql.functions]]. */ class DataFrameFunctionsSuite extends QueryTest with SharedSparkSession { - type JFunc = JavaFunction[Column, Column] - type JFunc2 = JavaFunction2[Column, Column, Column] - type JFunc3 = JavaFunction3[Column, Column, Column, Column] import testImplicits._ test("array with column name") { @@ -1921,33 +1917,31 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSparkSession { null ).toDF("i") - // transform(i, x -> x + 1) - val resA = Seq( - Row(Seq(2, 10, 9, 8)), - Row(Seq(6, 9, 10, 8, 3)), - Row(Seq.empty), - Row(null)) - - // transform(i, (x, i) -> x + i) - val resB = Seq( - Row(Seq(1, 10, 10, 10)), - Row(Seq(5, 9, 11, 10, 6)), - Row(Seq.empty), - Row(null)) - def testArrayOfPrimitiveTypeNotContainsNull(): Unit = { - checkAnswer(df.selectExpr("transform(i, x -> x + 1)"), resA) - checkAnswer(df.selectExpr("transform(i, (x, i) -> x + i)"), resB) - - checkAnswer(df.select(transform(col("i"), x => x + 1)), resA) - checkAnswer(df.select(transform(col("i"), (x, i) => x + i)), resB) - - checkAnswer(df.select(transform(col("i"), new JFunc { - def call(x: Column) = x + 1 - })), resA) - checkAnswer(df.select(transform(col("i"), new JFunc2 { - def call(x: Column, i: Column) = x + i - })), resB) + checkAnswer(df.selectExpr("transform(i, x -> x + 1)"), + Seq( + Row(Seq(2, 10, 9, 8)), + Row(Seq(6, 9, 10, 8, 3)), + Row(Seq.empty), + Row(null))) + checkAnswer(df.selectExpr("transform(i, (x, i) -> x + i)"), + Seq( + Row(Seq(1, 10, 10, 10)), + Row(Seq(5, 9, 11, 10, 6)), + Row(Seq.empty), + Row(null))) + checkAnswer(df.select(transform(col("i"), x => x + 1)), + Seq( + Row(Seq(2, 10, 9, 8)), + Row(Seq(6, 9, 10, 8, 3)), + Row(Seq.empty), + Row(null))) + checkAnswer(df.select(transform(col("i"), (x, i) => x + i)), + Seq( + Row(Seq(1, 10, 10, 10)), + Row(Seq(5, 9, 11, 10, 6)), + Row(Seq.empty), + Row(null))) } // Test with local relation, the Project will be evaluated without codegen @@ -1965,33 +1959,31 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSparkSession { null ).toDF("i") - // transform(i, x -> x + 1) - val resA = Seq( - Row(Seq(2, 10, 9, null, 8)), - Row(Seq(6, null, 9, 10, 8, 3)), - Row(Seq.empty), - Row(null)) - - // transform(i, (x, i) -> x + i) - val resB = Seq( - Row(Seq(1, 10, 10, null, 11)), - Row(Seq(5, null, 10, 12, 11, 7)), - Row(Seq.empty), - Row(null)) - def testArrayOfPrimitiveTypeContainsNull(): Unit = { - checkAnswer(df.selectExpr("transform(i, x -> x + 1)"), resA) - checkAnswer(df.selectExpr("transform(i, (x, i) -> x + i)"), resB) - - checkAnswer(df.select(transform(col("i"), x => x + 1)), resA) - checkAnswer(df.select(transform(col("i"), (x, i) => x + i)), resB) - - checkAnswer(df.select(transform(col("i"), new JFunc { - def call(x: Column) = x + 1 - })), resA) - checkAnswer(df.select(transform(col("i"), new JFunc2 { - def call(x: Column, i: Column) = x + i - })), resB) + checkAnswer(df.selectExpr("transform(i, x -> x + 1)"), + Seq( + Row(Seq(2, 10, 9, null, 8)), + Row(Seq(6, null, 9, 10, 8, 3)), + Row(Seq.empty), + Row(null))) + checkAnswer(df.selectExpr("transform(i, (x, i) -> x + i)"), + Seq( + Row(Seq(1, 10, 10, null, 11)), + Row(Seq(5, null, 10, 12, 11, 7)), + Row(Seq.empty), + Row(null))) + checkAnswer(df.select(transform(col("i"), x => x + 1)), + Seq( + Row(Seq(2, 10, 9, null, 8)), + Row(Seq(6, null, 9, 10, 8, 3)), + Row(Seq.empty), + Row(null))) + checkAnswer(df.select(transform(col("i"), (x, i) => x + i)), + Seq( + Row(Seq(1, 10, 10, null, 11)), + Row(Seq(5, null, 10, 12, 11, 7)), + Row(Seq.empty), + Row(null))) } // Test with local relation, the Project will be evaluated without codegen @@ -2009,33 +2001,31 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSparkSession { null ).toDF("s") - // transform(s, x -> concat(x, x)) - val resA = Seq( - Row(Seq("cc", "aa", "bb")), - Row(Seq("bb", null, "cc", null)), - Row(Seq.empty), - Row(null)) - - // transform(s, (x, i) -> concat(x, i)) - val resB = Seq( - Row(Seq("c0", "a1", "b2")), - Row(Seq("b0", null, "c2", null)), - Row(Seq.empty), - Row(null)) - def testNonPrimitiveType(): Unit = { - checkAnswer(df.selectExpr("transform(s, x -> concat(x, x))"), resA) - checkAnswer(df.selectExpr("transform(s, (x, i) -> concat(x, i))"), resB) - - checkAnswer(df.select(transform(col("s"), x => concat(x, x))), resA) - checkAnswer(df.select(transform(col("s"), (x, i) => concat(x, i))), resB) - - checkAnswer(df.select(transform(col("s"), new JFunc { - def call(x: Column) = concat(x, x) - })), resA) - checkAnswer(df.select(transform(col("s"), new JFunc2 { - def call(x: Column, i: Column) = concat(x, i) - })), resB) + checkAnswer(df.selectExpr("transform(s, x -> concat(x, x))"), + Seq( + Row(Seq("cc", "aa", "bb")), + Row(Seq("bb", null, "cc", null)), + Row(Seq.empty), + Row(null))) + checkAnswer(df.selectExpr("transform(s, (x, i) -> concat(x, i))"), + Seq( + Row(Seq("c0", "a1", "b2")), + Row(Seq("b0", null, "c2", null)), + Row(Seq.empty), + Row(null))) + checkAnswer(df.select(transform(col("s"), x => concat(x, x))), + Seq( + Row(Seq("cc", "aa", "bb")), + Row(Seq("bb", null, "cc", null)), + Row(Seq.empty), + Row(null))) + checkAnswer(df.select(transform(col("s"), (x, i) => concat(x, i))), + Seq( + Row(Seq("c0", "a1", "b2")), + Row(Seq("b0", null, "c2", null)), + Row(Seq.empty), + Row(null))) } // Test with local relation, the Project will be evaluated without codegen @@ -2053,54 +2043,59 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSparkSession { null ).toDF("arg") - // transform(arg, arg -> arg) - val resA = + def testSpecialCases(): Unit = { + checkAnswer(df.selectExpr("transform(arg, arg -> arg)"), Seq( Row(Seq("c", "a", "b")), Row(Seq("b", null, "c", null)), Row(Seq.empty), - Row(null)) - - // transform(arg, arg) - val resB = Seq( - Row(Seq(Seq("c", "a", "b"), Seq("c", "a", "b"), Seq("c", "a", "b"))), - Row(Seq( - Seq("b", null, "c", null), - Seq("b", null, "c", null), - Seq("b", null, "c", null), - Seq("b", null, "c", null))), - Row(Seq.empty), - Row(null)) - - // transform(arg, x -> concat(arg, array(x))) - val resC = Seq( - Row(Seq(Seq("c", "a", "b", "c"), Seq("c", "a", "b", "a"), Seq("c", "a", "b", "b"))), - Row(Seq( - Seq("b", null, "c", null, "b"), - Seq("b", null, "c", null, null), - Seq("b", null, "c", null, "c"), - Seq("b", null, "c", null, null))), - Row(Seq.empty), - Row(null)) - - def testSpecialCases(): Unit = { - checkAnswer(df.selectExpr("transform(arg, arg -> arg)"), resA) - checkAnswer(df.selectExpr("transform(arg, arg)"), resB) - checkAnswer(df.selectExpr("transform(arg, x -> concat(arg, array(x)))"), resC) - - checkAnswer(df.select(transform(col("arg"), arg => arg)), resA) - checkAnswer(df.select(transform(col("arg"), _ => col("arg"))), resB) - checkAnswer(df.select(transform(col("arg"), x => concat(col("arg"), array(x)))), resC) - - checkAnswer(df.select(transform(col("arg"), new JFunc { - def call(arg: Column) = arg - })), resA) - checkAnswer(df.select(transform(col("arg"), new JFunc { - def call(arg: Column) = col("arg") - })), resB) - checkAnswer(df.select(transform(col("arg"), new JFunc { - def call(x: Column) = concat(col("arg"), array(x)) - })), resC) + Row(null))) + checkAnswer(df.selectExpr("transform(arg, arg)"), + Seq( + Row(Seq(Seq("c", "a", "b"), Seq("c", "a", "b"), Seq("c", "a", "b"))), + Row(Seq( + Seq("b", null, "c", null), + Seq("b", null, "c", null), + Seq("b", null, "c", null), + Seq("b", null, "c", null))), + Row(Seq.empty), + Row(null))) + checkAnswer(df.selectExpr("transform(arg, x -> concat(arg, array(x)))"), + Seq( + Row(Seq(Seq("c", "a", "b", "c"), Seq("c", "a", "b", "a"), Seq("c", "a", "b", "b"))), + Row(Seq( + Seq("b", null, "c", null, "b"), + Seq("b", null, "c", null, null), + Seq("b", null, "c", null, "c"), + Seq("b", null, "c", null, null))), + Row(Seq.empty), + Row(null))) + checkAnswer(df.select(transform(col("arg"), arg => arg)), + Seq( + Row(Seq("c", "a", "b")), + Row(Seq("b", null, "c", null)), + Row(Seq.empty), + Row(null))) + checkAnswer(df.select(transform(col("arg"), _ => col("arg"))), + Seq( + Row(Seq(Seq("c", "a", "b"), Seq("c", "a", "b"), Seq("c", "a", "b"))), + Row(Seq( + Seq("b", null, "c", null), + Seq("b", null, "c", null), + Seq("b", null, "c", null), + Seq("b", null, "c", null))), + Row(Seq.empty), + Row(null))) + checkAnswer(df.select(transform(col("arg"), x => concat(col("arg"), array(x)))), + Seq( + Row(Seq(Seq("c", "a", "b", "c"), Seq("c", "a", "b", "a"), Seq("c", "a", "b", "b"))), + Row(Seq( + Seq("b", null, "c", null, "b"), + Seq("b", null, "c", null, null), + Seq("b", null, "c", null, "c"), + Seq("b", null, "c", null, null))), + Row(Seq.empty), + Row(null))) } // Test with local relation, the Project will be evaluated without codegen @@ -2128,34 +2123,10 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSparkSession { } assert(ex2.getMessage.contains("data type mismatch: argument 1 requires array type")) - val ex2a = intercept[AnalysisException] { - df.select(transform(col("i"), x => x)) - } - assert(ex2a.getMessage.contains("data type mismatch: argument 1 requires array type")) - - val ex2b = intercept[AnalysisException] { - df.select(transform(col("i"), new JFunc { - def call(x: Column) = x - })) - } - assert(ex2b.getMessage.contains("data type mismatch: argument 1 requires array type")) - val ex3 = intercept[AnalysisException] { df.selectExpr("transform(a, x -> x)") } assert(ex3.getMessage.contains("cannot resolve '`a`'")) - - val ex3a = intercept[AnalysisException] { - df.select(transform(col("a"), x => x)) - } - assert(ex3a.getMessage.contains("cannot resolve '`a`'")) - - val ex3b = intercept[AnalysisException] { - df.select(transform(col("a"), new JFunc { - def call(x: Column) = x - })) - } - assert(ex3b.getMessage.contains("cannot resolve '`a`'")) } test("map_filter") { @@ -2164,54 +2135,37 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSparkSession { Map(1 -> -1, 2 -> -2, 3 -> -3), Map(1 -> 10, 2 -> 5, 3 -> -3)).toDF("m") - // map_filter(m, (k, v) -> k * 10 = v), map_filter(m, (k, v) -> k = -v - val resA = Seq( - Row(Map(1 -> 10, 2 -> 20, 3 -> 30), Map()), - Row(Map(), Map(1 -> -1, 2 -> -2, 3 -> -3)), - Row(Map(1 -> 10), Map(3 -> -3))) - checkAnswer(dfInts.selectExpr( "map_filter(m, (k, v) -> k * 10 = v)", "map_filter(m, (k, v) -> k = -v)"), - resA) + Seq( + Row(Map(1 -> 10, 2 -> 20, 3 -> 30), Map()), + Row(Map(), Map(1 -> -1, 2 -> -2, 3 -> -3)), + Row(Map(1 -> 10), Map(3 -> -3)))) checkAnswer(dfInts.select( map_filter(col("m"), (k, v) => k * 10 === v), map_filter(col("m"), (k, v) => k === (v * -1))), - resA) - - checkAnswer(dfInts.select( - map_filter(col("m"), new JFunc2 { - def call(k: Column, v: Column): Column = k * 10 === v - }), - map_filter(col("m"), new JFunc2 { - def call(k: Column, v: Column): Column = k === (v * -1) - })), resA) + Seq( + Row(Map(1 -> 10, 2 -> 20, 3 -> 30), Map()), + Row(Map(), Map(1 -> -1, 2 -> -2, 3 -> -3)), + Row(Map(1 -> 10), Map(3 -> -3)))) val dfComplex = Seq( Map(1 -> Seq(Some(1)), 2 -> Seq(Some(1), Some(2)), 3 -> Seq(Some(1), Some(2), Some(3))), Map(1 -> null, 2 -> Seq(Some(-2), Some(-2)), 3 -> Seq[Option[Int]](None))).toDF("m") - // map_filter(m, (k, v) -> k = v[0]), map_filter(m, (k, v) -> k = size(v)) - val resB = Seq( - Row(Map(1 -> Seq(1)), Map(1 -> Seq(1), 2 -> Seq(1, 2), 3 -> Seq(1, 2, 3))), - Row(Map(), Map(2 -> Seq(-2, -2)))) - checkAnswer(dfComplex.selectExpr( "map_filter(m, (k, v) -> k = v[0])", "map_filter(m, (k, v) -> k = size(v))"), - resB) + Seq( + Row(Map(1 -> Seq(1)), Map(1 -> Seq(1), 2 -> Seq(1, 2), 3 -> Seq(1, 2, 3))), + Row(Map(), Map(2 -> Seq(-2, -2))))) checkAnswer(dfComplex.select( map_filter(col("m"), (k, v) => k === element_at(v, 1)), map_filter(col("m"), (k, v) => k === size(v))), - resB) - - checkAnswer(dfComplex.select( - map_filter(col("m"), new JFunc2 { - def call(k: Column, v: Column): Column = k === element_at(v, 1) - }), - map_filter(col("m"), new JFunc2 { - def call(k: Column, v: Column): Column = k === size(v) - })), resB) + Seq( + Row(Map(1 -> Seq(1)), Map(1 -> Seq(1), 2 -> Seq(1, 2), 3 -> Seq(1, 2, 3))), + Row(Map(), Map(2 -> Seq(-2, -2))))) // Invalid use cases val df = Seq( @@ -2240,29 +2194,10 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSparkSession { } assert(ex3a.getMessage.contains("data type mismatch: argument 1 requires map type")) - val ex3b = intercept[AnalysisException] { - df.select(map_filter(col("i"), new JFunc2 { - def call(k: Column, v: Column): Column = k > v - })) - } - assert(ex3b.getMessage.contains("data type mismatch: argument 1 requires map type")) - val ex4 = intercept[AnalysisException] { df.selectExpr("map_filter(a, (k, v) -> k > v)") } assert(ex4.getMessage.contains("cannot resolve '`a`'")) - - val ex4a = intercept[AnalysisException] { - df.select(map_filter(col("a"), (k, v) => k > v)) - } - assert(ex4a.getMessage.contains("cannot resolve '`a`'")) - - val ex4b = intercept[AnalysisException] { - df.select(map_filter(col("a"), new JFunc2 { - def call(k: Column, v: Column): Column = k > v - })) - } - assert(ex4b.getMessage.contains("cannot resolve '`a`'")) } test("filter function - array for primitive type not containing null") { @@ -2273,19 +2208,19 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSparkSession { null ).toDF("i") - // filter(i, x -> x % 2 == 0) - val res = Seq( - Row(Seq(8)), - Row(Seq(8, 2)), - Row(Seq.empty), - Row(null)) - def testArrayOfPrimitiveTypeNotContainsNull(): Unit = { - checkAnswer(df.selectExpr("filter(i, x -> x % 2 == 0)"), res) - checkAnswer(df.select(filter(col("i"), _ % 2 === 0)), res) - checkAnswer(df.select(filter(col("i"), new JFunc { - def call(x: Column): Column = x % 2 === 0 - })), res) + checkAnswer(df.selectExpr("filter(i, x -> x % 2 == 0)"), + Seq( + Row(Seq(8)), + Row(Seq(8, 2)), + Row(Seq.empty), + Row(null))) + checkAnswer(df.select(filter(col("i"), _ % 2 === 0)), + Seq( + Row(Seq(8)), + Row(Seq(8, 2)), + Row(Seq.empty), + Row(null))) } // Test with local relation, the Project will be evaluated without codegen @@ -2303,19 +2238,19 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSparkSession { null ).toDF("i") - // filter(i, x -> x % 2 == 0) - val res = Seq( - Row(Seq(8)), - Row(Seq(8, 2)), - Row(Seq.empty), - Row(null)) - def testArrayOfPrimitiveTypeContainsNull(): Unit = { - checkAnswer(df.selectExpr("filter(i, x -> x % 2 == 0)"), res) - checkAnswer(df.select(filter(col("i"), _ % 2 === 0)), res) - checkAnswer(df.select(filter(col("i"), new JFunc { - def call(x: Column): Column = x % 2 === 0 - })), res) + checkAnswer(df.selectExpr("filter(i, x -> x % 2 == 0)"), + Seq( + Row(Seq(8)), + Row(Seq(8, 2)), + Row(Seq.empty), + Row(null))) + checkAnswer(df.select(filter(col("i"), _ % 2 === 0)), + Seq( + Row(Seq(8)), + Row(Seq(8, 2)), + Row(Seq.empty), + Row(null))) } // Test with local relation, the Project will be evaluated without codegen @@ -2333,19 +2268,19 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSparkSession { null ).toDF("s") - // filter(s, x -> x is not null) - val res = Seq( - Row(Seq("c", "a", "b")), - Row(Seq("b", "c")), - Row(Seq.empty), - Row(null)) - def testNonPrimitiveType(): Unit = { - checkAnswer(df.selectExpr("filter(s, x -> x is not null)"), res) - checkAnswer(df.select(filter(col("s"), x => x.isNotNull)), res) - checkAnswer(df.select(filter(col("s"), new JFunc { - def call(x: Column) = x.isNotNull - })), res) + checkAnswer(df.selectExpr("filter(s, x -> x is not null)"), + Seq( + Row(Seq("c", "a", "b")), + Row(Seq("b", "c")), + Row(Seq.empty), + Row(null))) + checkAnswer(df.select(filter(col("s"), x => x.isNotNull)), + Seq( + Row(Seq("c", "a", "b")), + Row(Seq("b", "c")), + Row(Seq.empty), + Row(null))) } // Test with local relation, the Project will be evaluated without codegen @@ -2378,13 +2313,6 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSparkSession { } assert(ex2a.getMessage.contains("data type mismatch: argument 1 requires array type")) - val ex2b = intercept[AnalysisException] { - df.select(filter(col("i"), new JFunc { - def call(x: Column): Column = x - })) - } - assert(ex2b.getMessage.contains("data type mismatch: argument 1 requires array type")) - val ex3 = intercept[AnalysisException] { df.selectExpr("filter(s, x -> x)") } @@ -2395,13 +2323,6 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSparkSession { } assert(ex3a.getMessage.contains("data type mismatch: argument 2 requires boolean type")) - val ex3b = intercept[AnalysisException] { - df.select(filter(col("s"), new JFunc { - def call(x: Column): Column = x - })) - } - assert(ex3b.getMessage.contains("data type mismatch: argument 2 requires boolean type")) - val ex4 = intercept[AnalysisException] { df.selectExpr("filter(a, x -> x)") } @@ -2416,19 +2337,19 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSparkSession { null ).toDF("i") - // exists(i, x -> x % 2 == 0) - val res = Seq( - Row(true), - Row(false), - Row(false), - Row(null)) - def testArrayOfPrimitiveTypeNotContainsNull(): Unit = { - checkAnswer(df.selectExpr("exists(i, x -> x % 2 == 0)"), res) - checkAnswer(df.select(exists(col("i"), _ % 2 === 0)), res) - checkAnswer(df.select(exists(col("i"), new JFunc { - def call(x: Column): Column = x % 2 === 0 - })), res) + checkAnswer(df.selectExpr("exists(i, x -> x % 2 == 0)"), + Seq( + Row(true), + Row(false), + Row(false), + Row(null))) + checkAnswer(df.select(exists(col("i"), _ % 2 === 0)), + Seq( + Row(true), + Row(false), + Row(false), + Row(null))) } // Test with local relation, the Project will be evaluated without codegen @@ -2447,20 +2368,21 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSparkSession { null ).toDF("i") - // exists(i, x -> x % 2 == 0) - val res = Seq( - Row(true), - Row(false), - Row(null), - Row(false), - Row(null)) - def testArrayOfPrimitiveTypeContainsNull(): Unit = { - checkAnswer(df.selectExpr("exists(i, x -> x % 2 == 0)"), res) - checkAnswer(df.select(exists(col("i"), _ % 2 === 0)), res) - checkAnswer(df.select(exists(col("i"), new JFunc { - def call(x: Column): Column = x % 2 === 0 - })), res) + checkAnswer(df.selectExpr("exists(i, x -> x % 2 == 0)"), + Seq( + Row(true), + Row(false), + Row(null), + Row(false), + Row(null))) + checkAnswer(df.select(exists(col("i"), _ % 2 === 0)), + Seq( + Row(true), + Row(false), + Row(null), + Row(false), + Row(null))) } // Test with local relation, the Project will be evaluated without codegen @@ -2478,19 +2400,19 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSparkSession { null ).toDF("s") - // exists(s, x -> x is null) - val res = Seq( - Row(false), - Row(true), - Row(false), - Row(null)) - def testNonPrimitiveType(): Unit = { - checkAnswer(df.selectExpr("exists(s, x -> x is null)"), res) - checkAnswer(df.select(exists(col("s"), x => x.isNull)), res) - checkAnswer(df.select(exists(col("s"), new JFunc { - def call(x: Column): Column = x.isNull - })), res) + checkAnswer(df.selectExpr("exists(s, x -> x is null)"), + Seq( + Row(false), + Row(true), + Row(false), + Row(null))) + checkAnswer(df.select(exists(col("s"), x => x.isNull)), + Seq( + Row(false), + Row(true), + Row(false), + Row(null))) } // Test with local relation, the Project will be evaluated without codegen @@ -2523,13 +2445,6 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSparkSession { } assert(ex2.getMessage.contains("data type mismatch: argument 1 requires array type")) - val ex2b = intercept[AnalysisException] { - df.select(exists(col("i"), new JFunc { - def call(x: Column): Column = x - })) - } - assert(ex2b.getMessage.contains("data type mismatch: argument 1 requires array type")) - val ex3 = intercept[AnalysisException] { df.selectExpr("exists(s, x -> x)") } @@ -2540,13 +2455,6 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSparkSession { } assert(ex3a.getMessage.contains("data type mismatch: argument 2 requires boolean type")) - val ex3b = intercept[AnalysisException] { - df.select(exists(df("s"), new JFunc { - def call(x: Column): Column = x - })) - } - assert(ex3b.getMessage.contains("data type mismatch: argument 2 requires boolean type")) - val ex4 = intercept[AnalysisException] { df.selectExpr("exists(a, x -> x)") } @@ -2561,18 +2469,19 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSparkSession { null ).toDF("i") - val resA = Seq( - Row(false), - Row(true), - Row(true), - Row(null)) - def testArrayOfPrimitiveTypeNotContainsNull(): Unit = { - checkAnswer(df.selectExpr("forall(i, x -> x % 2 == 0)"), resA) - checkAnswer(df.select(forall(col("i"), x => x % 2 === 0)), resA) - checkAnswer(df.select(forall(col("i"), new JFunc { - def call(x: Column): Column = x % 2 === 0 - })), resA) + checkAnswer(df.selectExpr("forall(i, x -> x % 2 == 0)"), + Seq( + Row(false), + Row(true), + Row(true), + Row(null))) + checkAnswer(df.select(forall(col("i"), x => x % 2 === 0)), + Seq( + Row(false), + Row(true), + Row(true), + Row(null))) } // Test with local relation, the Project will be evaluated without codegen @@ -2591,33 +2500,35 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSparkSession { null ).toDF("i") - // forall(i, x -> x % 2 == 0 or x is null) - val resA = Seq( - Row(false), - Row(true), - Row(true), - Row(true), - Row(null)) - - // forall(i, x -> x % 2 == 0) - val resB = Seq( - Row(false), - Row(null), - Row(true), - Row(true), - Row(null)) - def testArrayOfPrimitiveTypeContainsNull(): Unit = { - checkAnswer(df.selectExpr("forall(i, x -> x % 2 == 0 or x is null)"), resA) - checkAnswer(df.select(forall(col("i"), x => (x % 2 === 0) || x.isNull)), resA) - checkAnswer(df.select(forall(col("i"), new JFunc { - def call(x: Column): Column = (x % 2 === 0 ) || x.isNull - })), resA) - checkAnswer(df.selectExpr("forall(i, x -> x % 2 == 0)"), resB) - checkAnswer(df.select(forall(col("i"), x => x % 2 === 0)), resB) - checkAnswer(df.select(forall(col("i"), new JFunc { - def call(x: Column): Column = (x % 2 === 0 ) - })), resB) + checkAnswer(df.selectExpr("forall(i, x -> x % 2 == 0 or x is null)"), + Seq( + Row(false), + Row(true), + Row(true), + Row(true), + Row(null))) + checkAnswer(df.select(forall(col("i"), x => (x % 2 === 0) || x.isNull)), + Seq( + Row(false), + Row(true), + Row(true), + Row(true), + Row(null))) + checkAnswer(df.selectExpr("forall(i, x -> x % 2 == 0)"), + Seq( + Row(false), + Row(null), + Row(true), + Row(true), + Row(null))) + checkAnswer(df.select(forall(col("i"), x => x % 2 === 0)), + Seq( + Row(false), + Row(null), + Row(true), + Row(true), + Row(null))) } // Test with local relation, the Project will be evaluated without codegen @@ -2635,19 +2546,19 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSparkSession { null ).toDF("s") - // forall(s, x -> x is null) - val resA = Seq( - Row(false), - Row(true), - Row(true), - Row(null)) - def testNonPrimitiveType(): Unit = { - checkAnswer(df.selectExpr("forall(s, x -> x is null)"), resA) - checkAnswer(df.select(forall(col("s"), _.isNull)), resA) - checkAnswer(df.select(forall(col("s"), new JFunc { - def call(x: Column): Column = x.isNull - })), resA) + checkAnswer(df.selectExpr("forall(s, x -> x is null)"), + Seq( + Row(false), + Row(true), + Row(true), + Row(null))) + checkAnswer(df.select(forall(col("s"), _.isNull)), + Seq( + Row(false), + Row(true), + Row(true), + Row(null))) } // Test with local relation, the Project will be evaluated without codegen @@ -2680,13 +2591,6 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSparkSession { } assert(ex2a.getMessage.contains("data type mismatch: argument 1 requires array type")) - val ex2b = intercept[AnalysisException] { - df.select(forall(col("i"), new JFunc { - def call(x: Column): Column = x - })) - } - assert(ex2b.getMessage.contains("data type mismatch: argument 1 requires array type")) - val ex3 = intercept[AnalysisException] { df.selectExpr("forall(s, x -> x)") } @@ -2697,13 +2601,6 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSparkSession { } assert(ex3a.getMessage.contains("data type mismatch: argument 2 requires boolean type")) - val ex3b = intercept[AnalysisException] { - df.select(forall(col("s"), new JFunc { - def call(x: Column): Column = x - })) - } - assert(ex3b.getMessage.contains("data type mismatch: argument 2 requires boolean type")) - val ex4 = intercept[AnalysisException] { df.selectExpr("forall(a, x -> x)") } @@ -2713,13 +2610,6 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSparkSession { df.select(forall(col("a"), x => x)) } assert(ex4a.getMessage.contains("cannot resolve '`a`'")) - - val ex4b = intercept[AnalysisException] { - df.select(forall(col("a"), new JFunc { - def call(x: Column): Column = x - })) - } - assert(ex4b.getMessage.contains("cannot resolve '`a`'")) } test("aggregate function - array for primitive type not containing null") { @@ -2730,36 +2620,31 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSparkSession { null ).toDF("i") - // aggregate(i, 0, (acc, x) -> acc + x) - val resA = Seq( - Row(25), - Row(31), - Row(0), - Row(null)) - - // aggregate(i, 0, (acc, x) -> acc + x, acc -> acc * 10) - val resB = Seq( - Row(250), - Row(310), - Row(0), - Row(null)) - def testArrayOfPrimitiveTypeNotContainsNull(): Unit = { - checkAnswer(df.selectExpr("aggregate(i, 0, (acc, x) -> acc + x)"), resA) - checkAnswer(df.selectExpr("aggregate(i, 0, (acc, x) -> acc + x, acc -> acc * 10)"), resB) - checkAnswer(df.select(aggregate(col("i"), lit(0), (acc, x) => acc + x)), resA) - checkAnswer(df.select(aggregate(col("i"), lit(0), (acc, x) => acc + x, _ * 10)), resB) - checkAnswer(df.select(aggregate(col("i"), lit(0), - new JFunc2 { - def call(acc: Column, x: Column): Column = acc + x - })), resA) - checkAnswer(df.select(aggregate(col("i"), lit(0), - new JFunc2 { - def call(acc: Column, x: Column): Column = acc + x - }, - new JFunc { - def call(x: Column): Column = x * 10 - })), resB) + checkAnswer(df.selectExpr("aggregate(i, 0, (acc, x) -> acc + x)"), + Seq( + Row(25), + Row(31), + Row(0), + Row(null))) + checkAnswer(df.selectExpr("aggregate(i, 0, (acc, x) -> acc + x, acc -> acc * 10)"), + Seq( + Row(250), + Row(310), + Row(0), + Row(null))) + checkAnswer(df.select(aggregate(col("i"), lit(0), (acc, x) => acc + x)), + Seq( + Row(25), + Row(31), + Row(0), + Row(null))) + checkAnswer(df.select(aggregate(col("i"), lit(0), (acc, x) => acc + x, _ * 10)), + Seq( + Row(250), + Row(310), + Row(0), + Row(null))) } // Test with local relation, the Project will be evaluated without codegen @@ -2777,40 +2662,34 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSparkSession { null ).toDF("i") - // aggregate(i, 0, (acc, x) -> acc + x) - val resA = Seq( - Row(25), - Row(null), - Row(0), - Row(null)) - - // aggregate(i, 0, (acc, x) -> acc + x, acc -> coalesce(acc, 0) * 10) - val resB = Seq( - Row(250), - Row(0), - Row(0), - Row(null)) - def testArrayOfPrimitiveTypeContainsNull(): Unit = { - checkAnswer(df.selectExpr("aggregate(i, 0, (acc, x) -> acc + x)"), resA) + checkAnswer(df.selectExpr("aggregate(i, 0, (acc, x) -> acc + x)"), + Seq( + Row(25), + Row(null), + Row(0), + Row(null))) checkAnswer( df.selectExpr("aggregate(i, 0, (acc, x) -> acc + x, acc -> coalesce(acc, 0) * 10)"), - resB) - checkAnswer(df.select(aggregate(col("i"), lit(0), (acc, x) => acc + x)), resA) + Seq( + Row(250), + Row(0), + Row(0), + Row(null))) + checkAnswer(df.select(aggregate(col("i"), lit(0), (acc, x) => acc + x)), + Seq( + Row(25), + Row(null), + Row(0), + Row(null))) checkAnswer( df.select( aggregate(col("i"), lit(0), (acc, x) => acc + x, acc => coalesce(acc, lit(0)) * 10)), - resB) - checkAnswer(df.select(aggregate(col("i"), lit(0), - new JFunc2 { - def call(acc: Column, x: Column): Column = acc + x - })), resA) - checkAnswer(df.select(aggregate(col("i"), lit(0), - new JFunc2 { - def call(acc: Column, x: Column): Column = acc + x - }, new JFunc { - def call(acc: Column): Column = coalesce(acc, lit(0)) * 10 - })), resB) + Seq( + Row(250), + Row(0), + Row(0), + Row(null))) } // Test with local relation, the Project will be evaluated without codegen @@ -2828,36 +2707,35 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSparkSession { (null, "d") ).toDF("ss", "s") - val resA = Seq( - Row("acab"), - Row(null), - Row("c"), - Row(null)) - - val resB = Seq( - Row("acab"), - Row(""), - Row("c"), - Row(null)) - def testNonPrimitiveType(): Unit = { - checkAnswer(df.selectExpr("aggregate(ss, s, (acc, x) -> concat(acc, x))"), resA) + checkAnswer(df.selectExpr("aggregate(ss, s, (acc, x) -> concat(acc, x))"), + Seq( + Row("acab"), + Row(null), + Row("c"), + Row(null))) checkAnswer( df.selectExpr("aggregate(ss, s, (acc, x) -> concat(acc, x), acc -> coalesce(acc , ''))"), - resB) - checkAnswer(df.select(aggregate(col("ss"), col("s"), (acc, x) => concat(acc, x))), resA) + Seq( + Row("acab"), + Row(""), + Row("c"), + Row(null))) + checkAnswer(df.select(aggregate(col("ss"), col("s"), (acc, x) => concat(acc, x))), + Seq( + Row("acab"), + Row(null), + Row("c"), + Row(null))) checkAnswer( df.select( aggregate(col("ss"), col("s"), (acc, x) => concat(acc, x), - acc => coalesce(acc, lit("")))), resB) - checkAnswer(df.select(aggregate(col("ss"), col("s"), new JFunc2 { - def call(acc: Column, x: Column): Column = concat(acc, x) - })), resA) - checkAnswer(df.select(aggregate(col("ss"), col("s"), new JFunc2 { - def call(acc: Column, x: Column): Column = concat(acc, x) - }, new JFunc { - def call(acc: Column): Column = coalesce(acc, lit("")) - })), resB) + acc => coalesce(acc, lit("")))), + Seq( + Row("acab"), + Row(""), + Row("c"), + Row(null))) } // Test with local relation, the Project will be evaluated without codegen @@ -2895,13 +2773,6 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSparkSession { } assert(ex3a.getMessage.contains("data type mismatch: argument 1 requires array type")) - val ex3b = intercept[AnalysisException] { - df.select(aggregate(col("i"), lit(0), new JFunc2 { - def call(acc: Column, x: Column): Column = x - })) - } - assert(ex3b.getMessage.contains("data type mismatch: argument 1 requires array type")) - val ex4 = intercept[AnalysisException] { df.selectExpr("aggregate(s, 0, (acc, x) -> x)") } @@ -2911,12 +2782,6 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSparkSession { df.select(aggregate(col("s"), lit(0), (acc, x) => x)) } assert(ex4a.getMessage.contains("data type mismatch: argument 3 requires int type")) - val ex4b = intercept[AnalysisException] { - df.select(aggregate(col("s"), lit(0), new JFunc2 { - def call(acc: Column, x: Column): Column = x - })) - } - assert(ex4b.getMessage.contains("data type mismatch: argument 3 requires int type")) val ex5 = intercept[AnalysisException] { df.selectExpr("aggregate(a, 0, (acc, x) -> x)") @@ -2932,17 +2797,19 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSparkSession { (Map(5 -> 1L), null) ).toDF("m1", "m2") - val resA = Seq( - Row(Map(8 -> true, 3 -> false, 6 -> true)), - Row(Map(10 -> null, 8 -> false, 4 -> null)), - Row(Map(5 -> null)), - Row(null)) + checkAnswer(df.selectExpr("map_zip_with(m1, m2, (k, v1, v2) -> k == v1 + v2)"), + Seq( + Row(Map(8 -> true, 3 -> false, 6 -> true)), + Row(Map(10 -> null, 8 -> false, 4 -> null)), + Row(Map(5 -> null)), + Row(null))) - checkAnswer(df.selectExpr("map_zip_with(m1, m2, (k, v1, v2) -> k == v1 + v2)"), resA) - checkAnswer(df.select(map_zip_with(df("m1"), df("m2"), (k, v1, v2) => k === v1 + v2)), resA) - checkAnswer(df.select(map_zip_with(df("m1"), df("m2"), new JFunc3 { - def call(k: Column, v1: Column, v2: Column): Column = k === v1 + v2 - })), resA) + checkAnswer(df.select(map_zip_with(df("m1"), df("m2"), (k, v1, v2) => k === v1 + v2)), + Seq( + Row(Map(8 -> true, 3 -> false, 6 -> true)), + Row(Map(10 -> null, 8 -> false, 4 -> null)), + Row(Map(5 -> null)), + Row(null))) } test("map_zip_with function - map of non-primitive types") { @@ -2953,18 +2820,19 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSparkSession { (Map("a" -> "d"), null) ).toDF("m1", "m2") - val resA = Seq( - Row(Map("z" -> Row("a", "c"), "y" -> Row("b", null), "x" -> Row("c", "a"))), - Row(Map("b" -> Row("a", null), "c" -> Row("d", "a"), "d" -> Row(null, "k"))), - Row(Map("a" -> Row("d", null))), - Row(null)) + checkAnswer(df.selectExpr("map_zip_with(m1, m2, (k, v1, v2) -> (v1, v2))"), + Seq( + Row(Map("z" -> Row("a", "c"), "y" -> Row("b", null), "x" -> Row("c", "a"))), + Row(Map("b" -> Row("a", null), "c" -> Row("d", "a"), "d" -> Row(null, "k"))), + Row(Map("a" -> Row("d", null))), + Row(null))) - checkAnswer(df.selectExpr("map_zip_with(m1, m2, (k, v1, v2) -> (v1, v2))"), resA) checkAnswer(df.select(map_zip_with(col("m1"), col("m2"), (k, v1, v2) => struct(v1, v2))), - resA) - checkAnswer(df.select(map_zip_with(col("m1"), col("m2"), new JFunc3 { - def call(k: Column, v1: Column, v2: Column): Column = struct(v1, v2) - })), resA) + Seq( + Row(Map("z" -> Row("a", "c"), "y" -> Row("b", null), "x" -> Row("c", "a"))), + Row(Map("b" -> Row("a", null), "c" -> Row("d", "a"), "d" -> Row(null, "k"))), + Row(Map("a" -> Row("d", null))), + Row(null))) } test("map_zip_with function - invalid") { @@ -2989,14 +2857,6 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSparkSession { assert(ex2a.getMessage.contains("The input to function map_zip_with should have " + "been two maps with compatible key types")) - val ex2b = intercept[AnalysisException] { - df.select(map_zip_with(df("mis"), col("mmi"), new JFunc3 { - def call(x: Column, y: Column, z: Column): Column = concat(x, y, z) - })) - } - assert(ex2b.getMessage.contains("The input to function map_zip_with should have " + - "been two maps with compatible key types")) - val ex3 = intercept[AnalysisException] { df.selectExpr("map_zip_with(i, mis, (x, y, z) -> concat(x, y, z))") } @@ -3007,13 +2867,6 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSparkSession { } assert(ex3a.getMessage.contains("type mismatch: argument 1 requires map type")) - val ex3b = intercept[AnalysisException] { - df.select(map_zip_with(df("i"), col("mmi"), new JFunc3 { - def call(x: Column, y: Column, z: Column): Column = concat(x, y, z) - })) - } - assert(ex3b.getMessage.contains("type mismatch: argument 1 requires map type")) - val ex4 = intercept[AnalysisException] { df.selectExpr("map_zip_with(mis, i, (x, y, z) -> concat(x, y, z))") } @@ -3024,13 +2877,6 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSparkSession { } assert(ex4a.getMessage.contains("type mismatch: argument 2 requires map type")) - val ex4b = intercept[AnalysisException] { - df.select(map_zip_with(df("mis"), col("i"), new JFunc3 { - def call(x: Column, y: Column, z: Column): Column = concat(x, y, z) - })) - } - assert(ex4b.getMessage.contains("type mismatch: argument 2 requires map type")) - val ex5 = intercept[AnalysisException] { df.selectExpr("map_zip_with(mmi, mmi, (x, y, z) -> x)") } @@ -3054,27 +2900,18 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSparkSession { Map[Array[Int], Boolean](Array(1, 2) -> false) ).toDF("y") - val res1 = Seq(Row(Map(2 -> 1, 18 -> 9, 16 -> 8, 14 -> 7))) - - val res2 = Seq(Row(Map("one" -> 1.0, "two" -> 1.4, "three" -> 1.7))) - val res2a = Seq(Row(Map(3 -> 1.0, 4 -> 1.4, 6 -> 1.7))) - val res2b = Seq(Row(Map(2.0 -> 1.0, 3.4 -> 1.4, 4.7 -> 1.7))) - - val res3 = Seq(Row(Map(true -> true, true -> false))) - val res3a = Seq(Row(Map(50 -> true, 78 -> false))) - - val res4 = Seq(Row(Map(false -> false))) - def testMapOfPrimitiveTypesCombination(): Unit = { - checkAnswer(dfExample1.selectExpr("transform_keys(i, (k, v) -> k + v)"), res1) - checkAnswer(dfExample1.select(transform_keys(col("i"), (k, v) => k + v)), res1) - checkAnswer(dfExample1.select(transform_keys(col("i"), new JFunc2 { - def call(k: Column, v: Column): Column = k + v - })), res1) + checkAnswer(dfExample1.selectExpr("transform_keys(i, (k, v) -> k + v)"), + Seq(Row(Map(2 -> 1, 18 -> 9, 16 -> 8, 14 -> 7)))) + + checkAnswer(dfExample1.select(transform_keys(col("i"), (k, v) => k + v)), + Seq(Row(Map(2 -> 1, 18 -> 9, 16 -> 8, 14 -> 7)))) checkAnswer(dfExample2.selectExpr("transform_keys(j, " + - "(k, v) -> map_from_arrays(ARRAY(1, 2, 3), ARRAY('one', 'two', 'three'))[k])"), res2) + "(k, v) -> map_from_arrays(ARRAY(1, 2, 3), ARRAY('one', 'two', 'three'))[k])"), + Seq(Row(Map("one" -> 1.0, "two" -> 1.4, "three" -> 1.7)))) + checkAnswer(dfExample2.select( transform_keys( col("j"), @@ -3087,63 +2924,40 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSparkSession { ) ) ), - res2) - checkAnswer(dfExample2.select( - transform_keys( - col("j"), - new JFunc2 { - def call(k: Column, v: Column): Column = element_at( - map_from_arrays( - array(lit(1), lit(2), lit(3)), - array(lit("one"), lit("two"), lit("three")) - ), - k - ) - } - ) - ), - res2) + Seq(Row(Map("one" -> 1.0, "two" -> 1.4, "three" -> 1.7)))) checkAnswer(dfExample2.selectExpr("transform_keys(j, (k, v) -> CAST(v * 2 AS BIGINT) + k)"), - res2a) - checkAnswer(dfExample2.select(transform_keys(col("j"), - (k, v) => (v * 2).cast("bigint") + k)), res2a) + Seq(Row(Map(3 -> 1.0, 4 -> 1.4, 6 -> 1.7)))) + checkAnswer(dfExample2.select(transform_keys(col("j"), - new JFunc2 { - def call(k: Column, v: Column): Column = - (v * 2).cast("bigint") + k - })), res2a) - - checkAnswer(dfExample2.selectExpr("transform_keys(j, (k, v) -> k + v)"), res2b) - checkAnswer(dfExample2.select(transform_keys(col("j"), (k, v) => k + v)), res2b) - checkAnswer(dfExample2.select(transform_keys(col("j"), new JFunc2 { - def call(k: Column, v: Column): Column = k + v - })), res2b) - - checkAnswer(dfExample3.selectExpr("transform_keys(x, (k, v) -> k % 2 = 0 OR v)"), res3) - checkAnswer(dfExample3.select(transform_keys(col("x"), (k, v) => k % 2 === 0 || v)), res3) - checkAnswer(dfExample3.select(transform_keys(col("x"), new JFunc2 { - def call(k: Column, v: Column): Column = k % 2 === 0 || v - })), res3) - - checkAnswer(dfExample3.selectExpr("transform_keys(x, (k, v) -> if(v, 2 * k, 3 * k))"), res3a) - checkAnswer(dfExample3.select(transform_keys(col("x"), - (k, v) => when(v, k * 2).otherwise(k * 3))), res3a) + (k, v) => (v * 2).cast("bigint") + k)), + Seq(Row(Map(3 -> 1.0, 4 -> 1.4, 6 -> 1.7)))) + + checkAnswer(dfExample2.selectExpr("transform_keys(j, (k, v) -> k + v)"), + Seq(Row(Map(2.0 -> 1.0, 3.4 -> 1.4, 4.7 -> 1.7)))) + + checkAnswer(dfExample2.select(transform_keys(col("j"), (k, v) => k + v)), + Seq(Row(Map(2.0 -> 1.0, 3.4 -> 1.4, 4.7 -> 1.7)))) + + checkAnswer(dfExample3.selectExpr("transform_keys(x, (k, v) -> k % 2 = 0 OR v)"), + Seq(Row(Map(true -> true, true -> false)))) + + checkAnswer(dfExample3.select(transform_keys(col("x"), (k, v) => k % 2 === 0 || v)), + Seq(Row(Map(true -> true, true -> false)))) + + checkAnswer(dfExample3.selectExpr("transform_keys(x, (k, v) -> if(v, 2 * k, 3 * k))"), + Seq(Row(Map(50 -> true, 78 -> false)))) + checkAnswer(dfExample3.select(transform_keys(col("x"), - new JFunc2 { - def call(k: Column, v: Column): Column = - when(v, k * 2).otherwise(k * 3) - })), res3a) + (k, v) => when(v, k * 2).otherwise(k * 3))), + Seq(Row(Map(50 -> true, 78 -> false)))) checkAnswer(dfExample4.selectExpr("transform_keys(y, (k, v) -> array_contains(k, 3) AND v)"), - res4) - checkAnswer(dfExample4.select(transform_keys(col("y"), - (k, v) => array_contains(k, lit(3)) && v)), res4) + Seq(Row(Map(false -> false)))) + checkAnswer(dfExample4.select(transform_keys(col("y"), - new JFunc2 { - def call(k: Column, v: Column): Column = - array_contains(k, lit(3)) && v - })), res4) + (k, v) => array_contains(k, lit(3)) && v)), + Seq(Row(Map(false -> false)))) } // Test with local relation, the Project will be evaluated without codegen @@ -3186,13 +3000,6 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSparkSession { } assert(ex3a.getMessage.contains("Cannot use null as map key")) - val ex3b = intercept[Exception] { - dfExample1.select(transform_keys(col("i"), new JFunc2 { - def call(k: Column, v: Column): Column = v - })).show() - } - assert(ex3b.getMessage.contains("Cannot use null as map key")) - val ex4 = intercept[AnalysisException] { dfExample2.selectExpr("transform_keys(j, (k, v) -> k + 1)") } @@ -3221,124 +3028,82 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSparkSession { Map[Int, Array[Int]](1 -> Array(1, 2)) ).toDF("c") - val res_1_1 = Seq(Row(Map(1 -> 2, 9 -> 18, 8 -> 16, 7 -> 14))) + def testMapOfPrimitiveTypesCombination(): Unit = { + checkAnswer(dfExample1.selectExpr("transform_values(i, (k, v) -> k + v)"), + Seq(Row(Map(1 -> 2, 9 -> 18, 8 -> 16, 7 -> 14)))) - val res_2_1 = Seq(Row(Map(false -> "false", true -> "def"))) - val res_2_2 = Seq(Row(Map(false -> true, true -> false))) + checkAnswer(dfExample2.selectExpr( + "transform_values(x, (k, v) -> if(k, v, CAST(k AS String)))"), + Seq(Row(Map(false -> "false", true -> "def")))) - val res_3_1 = Seq(Row(Map("a" -> 1, "b" -> 4, "c" -> 9))) - val res_3_2 = Seq(Row(Map("a" -> "a:1", "b" -> "b:2", "c" -> "c:3"))) - val res_3_3 = Seq(Row(Map("a" -> "a1", "b" -> "b2", "c" -> "c3"))) + checkAnswer(dfExample2.selectExpr("transform_values(x, (k, v) -> NOT k AND v = 'abc')"), + Seq(Row(Map(false -> true, true -> false)))) - val res_4_1 = Seq(Row(Map(1 -> "one_1.0", 2 -> "two_1.4", 3 ->"three_1.7"))) - val res_4_2 = Seq(Row(Map(1 -> 0.0, 2 -> 0.6000000000000001, 3 -> 1.3))) + checkAnswer(dfExample3.selectExpr("transform_values(y, (k, v) -> v * v)"), + Seq(Row(Map("a" -> 1, "b" -> 4, "c" -> 9)))) - val res_5_1 = Seq(Row(Map(1 -> 3))) + checkAnswer(dfExample3.selectExpr( + "transform_values(y, (k, v) -> k || ':' || CAST(v as String))"), + Seq(Row(Map("a" -> "a:1", "b" -> "b:2", "c" -> "c:3")))) - def testMapOfPrimitiveTypesCombination(): Unit = { - checkAnswer(dfExample1.selectExpr("transform_values(i, (k, v) -> k + v)"), - res_1_1) - checkAnswer(dfExample1.select(transform_values(col("i"), (k, v) => k + v)), - res_1_1) - checkAnswer(dfExample1.select(transform_values(col("i"), new JFunc2 { - def call(k: Column, v: Column): Column = k + v - })), res_1_1) + checkAnswer( + dfExample3.selectExpr("transform_values(y, (k, v) -> concat(k, cast(v as String)))"), + Seq(Row(Map("a" -> "a1", "b" -> "b2", "c" -> "c3")))) + checkAnswer( + dfExample4.selectExpr( + "transform_values(" + + "z,(k, v) -> map_from_arrays(ARRAY(1, 2, 3), " + + "ARRAY('one', 'two', 'three'))[k] || '_' || CAST(v AS String))"), + Seq(Row(Map(1 -> "one_1.0", 2 -> "two_1.4", 3 ->"three_1.7")))) + + checkAnswer( + dfExample4.selectExpr("transform_values(z, (k, v) -> k-v)"), + Seq(Row(Map(1 -> 0.0, 2 -> 0.6000000000000001, 3 -> 1.3)))) + + checkAnswer( + dfExample5.selectExpr("transform_values(c, (k, v) -> k + cardinality(v))"), + Seq(Row(Map(1 -> 3)))) + + checkAnswer(dfExample1.select(transform_values(col("i"), (k, v) => k + v)), + Seq(Row(Map(1 -> 2, 9 -> 18, 8 -> 16, 7 -> 14)))) - checkAnswer(dfExample2.selectExpr( - "transform_values(x, (k, v) -> if(k, v, CAST(k AS String)))"), res_2_1) checkAnswer(dfExample2.select( transform_values(col("x"), (k, v) => when(k, v).otherwise(k.cast("string")))), - res_2_1) - checkAnswer(dfExample2.select( - transform_values(col("x"), new JFunc2 { - def call(k: Column, v: Column): Column = - when(k, v).otherwise(k.cast("string")) - })), res_2_1) + Seq(Row(Map(false -> "false", true -> "def")))) - - checkAnswer(dfExample2.selectExpr("transform_values(x, (k, v) -> NOT k AND v = 'abc')"), - res_2_2) checkAnswer(dfExample2.select(transform_values(col("x"), - (k, v) => (!k) && v === "abc")), res_2_2) - checkAnswer(dfExample2.select(transform_values(col("x"), - new JFunc2 { - def call(k: Column, v: Column): Column = (!k) && v === "abc" - })), res_2_2) - + (k, v) => (!k) && v === "abc")), + Seq(Row(Map(false -> true, true -> false)))) - checkAnswer(dfExample3.selectExpr("transform_values(y, (k, v) -> v * v)"), res_3_1) checkAnswer(dfExample3.select(transform_values(col("y"), (k, v) => v * v)), - res_3_1) - checkAnswer(dfExample3.select(transform_values(col("y"), new JFunc2{ - def call(k: Column, v: Column): Column = v * v - })), res_3_1) - + Seq(Row(Map("a" -> 1, "b" -> 4, "c" -> 9)))) - checkAnswer(dfExample3.selectExpr( - "transform_values(y, (k, v) -> k || ':' || CAST(v as String))"), res_3_2) checkAnswer(dfExample3.select( transform_values(col("y"), (k, v) => concat(k, lit(":"), v.cast("string")))), - res_3_2) - checkAnswer(dfExample3.select( - transform_values(col("y"), new JFunc2 { - def call(k: Column, v: Column): Column = - concat(k, lit(":"), v.cast("string")) - })), res_3_2) + Seq(Row(Map("a" -> "a:1", "b" -> "b:2", "c" -> "c:3")))) - - checkAnswer( - dfExample3.selectExpr("transform_values(y, (k, v) -> concat(k, cast(v as String)))"), - res_3_3) checkAnswer( dfExample3.select(transform_values(col("y"), (k, v) => concat(k, v.cast("string")))), - res_3_3) - checkAnswer( - dfExample3.select(transform_values(col("y"), new JFunc2 { - def call(k: Column, v: Column): Column = - concat(k, v.cast("string")) - })), res_3_3) - + Seq(Row(Map("a" -> "a1", "b" -> "b2", "c" -> "c3")))) - checkAnswer( - dfExample4.selectExpr( - "transform_values(" + - "z,(k, v) -> map_from_arrays(ARRAY(1, 2, 3), " + - "ARRAY('one', 'two', 'three'))[k] || '_' || CAST(v AS String))"), - res_4_1) val testMap = map_from_arrays( array(lit(1), lit(2), lit(3)), array(lit("one"), lit("two"), lit("three")) ) + checkAnswer( dfExample4.select(transform_values(col("z"), (k, v) => concat(element_at(testMap, k), lit("_"), v.cast("string")))), - res_4_1) - checkAnswer( - dfExample4.select(transform_values(col("z"), new JFunc2 { - def call(k: Column, v: Column): Column = - concat(element_at(testMap, k), lit("_"), v.cast("string")) - })), res_4_1) + Seq(Row(Map(1 -> "one_1.0", 2 -> "two_1.4", 3 ->"three_1.7")))) - checkAnswer( - dfExample4.selectExpr("transform_values(z, (k, v) -> k-v)"), res_4_2) checkAnswer( dfExample4.select(transform_values(col("z"), (k, v) => k - v)), - res_4_2) - checkAnswer( - dfExample4.select(transform_values(col("z"), new JFunc2 { - def call(k: Column, v: Column): Column = k - v - })), res_4_2) + Seq(Row(Map(1 -> 0.0, 2 -> 0.6000000000000001, 3 -> 1.3)))) - checkAnswer( - dfExample5.selectExpr("transform_values(c, (k, v) -> k + cardinality(v))"), res_5_1) checkAnswer( dfExample5.select(transform_values(col("c"), (k, v) => k + size(v))), - res_5_1) - checkAnswer( - dfExample5.select(transform_values(col("c"), new JFunc2 { - def call(k: Column, v: Column): Column = k + size(v) - })), res_5_1) + Seq(Row(Map(1 -> 3)))) } // Test with local relation, the Project will be evaluated without codegen @@ -3404,35 +3169,6 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSparkSession { checkAnswer(dfExample1.select(transform_values(col("i"), (k, v) => v.cast("bigint"))), Seq(Row(Map.empty[BigInt, BigInt]))) - - checkAnswer(dfExample1.select(transform_values(col("i"), - new JFunc2 { - def call(k: Column, v: Column): Column = lit(null).cast("int")})), - Seq(Row(Map.empty[Integer, Integer]))) - - checkAnswer(dfExample1.select(transform_values(col("i"), new JFunc2 { - def call(k: Column, v: Column): Column = k})), - Seq(Row(Map.empty[Integer, Integer]))) - - checkAnswer(dfExample1.select(transform_values(col("i"), new JFunc2 { - def call(k: Column, v: Column): Column = v})), - Seq(Row(Map.empty[Integer, Integer]))) - - checkAnswer(dfExample1.select(transform_values(col("i"), new JFunc2 { - def call(k: Column, v: Column): Column = lit(0)})), - Seq(Row(Map.empty[Integer, Integer]))) - - checkAnswer(dfExample1.select(transform_values(col("i"), new JFunc2 { - def call(k: Column, v: Column): Column = lit("value")})), - Seq(Row(Map.empty[Integer, String]))) - - checkAnswer(dfExample1.select(transform_values(col("i"), new JFunc2 { - def call(k: Column, v: Column): Column = lit(true)})), - Seq(Row(Map.empty[Integer, Boolean]))) - - checkAnswer(dfExample1.select(transform_values(col("i"), new JFunc2 { - def call(k: Column, v: Column): Column = v.cast("bigint")})), - Seq(Row(Map.empty[BigInt, BigInt]))) } testEmpty() @@ -3466,19 +3202,6 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSparkSession { transform_values(col("b"), (k, v) => when(v.isNull, k + 1).otherwise(k + 2)) ), Seq(Row(Map(1 -> 3, 2 -> 4, 3 -> 4)))) - - checkAnswer(dfExample1.select(transform_values(col("a"), - new JFunc2 { - def call(k: Column, v: Column): Column = lit(null).cast("int") - })), - Seq(Row(Map[Int, Integer](1 -> null, 2 -> null, 3 -> null, 4 -> null)))) - - checkAnswer(dfExample2.select( - transform_values(col("b"), new JFunc2 { - def call(k: Column, v: Column): Column = - when(v.isNull, k + 1).otherwise(k + 2) - })), - Seq(Row(Map(1 -> 3, 2 -> 4, 3 -> 4)))) } testNullValue() @@ -3523,13 +3246,6 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSparkSession { } assert(ex3a.getMessage.contains( "data type mismatch: argument 1 requires map type")) - - val ex3b = intercept[AnalysisException] { - dfExample3.select(transform_values(col("x"), new JFunc2 { - def call(k: Column, v: Column): Column = k + 1})) - } - assert(ex3b.getMessage.contains( - "data type mismatch: argument 1 requires map type")) } testInvalidLambdaFunctions() @@ -3557,8 +3273,6 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSparkSession { Row(null)) checkAnswer(df1.selectExpr("zip_with(val1, val2, (x, y) -> x + y)"), expectedValue1) checkAnswer(df1.select(zip_with(df1("val1"), df1("val2"), (x, y) => x + y)), expectedValue1) - checkAnswer(df1.select(zip_with(df1("val1"), df1("val2"), new JFunc2 { - def call(x: Column, y: Column): Column = x + y})), expectedValue1) val expectedValue2 = Seq( Row(Seq(Row(1L, 1), Row(2L, null), Row(null, 3))), Row(Seq(Row(4L, 1), Row(11L, 2), Row(null, 3)))) @@ -3567,11 +3281,6 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSparkSession { df2.select(zip_with(df2("val1"), df2("val2"), (x, y) => struct(y, x))), expectedValue2 ) - checkAnswer( - df2.select(zip_with(df2("val1"), df2("val2"), new JFunc2 { - def call(x: Column, y: Column): Column = struct(y, x)})), - expectedValue2 - ) } test("arrays zip_with function - for non-primitive types") { @@ -3594,11 +3303,6 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSparkSession { df.select(zip_with(col("val1"), col("val2"), (x, y) => struct(y, x))), expectedValue1 ) - checkAnswer( - df.select(zip_with(col("val1"), col("val2"), new JFunc2 { - def call(x: Column, y: Column): Column = struct(y, x)})), - expectedValue1 - ) } test("arrays zip_with function - invalid") { @@ -3624,11 +3328,6 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSparkSession { df.select(zip_with(df("i"), df("a2"), (acc, x) => x)) } assert(ex3a.getMessage.contains("data type mismatch: argument 1 requires array type")) - val ex3b = intercept[AnalysisException] { - df.select(zip_with(df("i"), df("a2"), new JFunc2 { - def call(acc: Column, x: Column): Column = x})) - } - assert(ex3b.getMessage.contains("data type mismatch: argument 1 requires array type")) val ex4 = intercept[AnalysisException] { df.selectExpr("zip_with(a1, a, (acc, x) -> x)") } From 554a99296cd78f0ab25f9a02c412d8d0a0fb2f69 Mon Sep 17 00:00:00 2001 From: Nik Vanderhoof Date: Tue, 17 Sep 2019 22:42:30 -0400 Subject: [PATCH 28/38] Add simple java test for filter --- .../sql/JavaHigherOrderFunctionsSuite.java | 190 ++++-------------- 1 file changed, 34 insertions(+), 156 deletions(-) diff --git a/sql/core/src/test/java/test/org/apache/spark/sql/JavaHigherOrderFunctionsSuite.java b/sql/core/src/test/java/test/org/apache/spark/sql/JavaHigherOrderFunctionsSuite.java index b1cc9682d0813..11ee028e71397 100644 --- a/sql/core/src/test/java/test/org/apache/spark/sql/JavaHigherOrderFunctionsSuite.java +++ b/sql/core/src/test/java/test/org/apache/spark/sql/JavaHigherOrderFunctionsSuite.java @@ -37,20 +37,11 @@ public class JavaHigherOrderFunctionsSuite { private transient TestSparkSession spark; + private Dataset df; @Before public void setUp() { spark = new TestSparkSession(); - } - - @After - public void tearDown() { - spark.stop(); - spark = null; - } - - @Test - public void testTransformArrayPrimitiveNotContainingNull() { List data = toRows( makeArray(1, 9, 8, 7), makeArray(5, 8, 9, 7, 2), @@ -58,158 +49,45 @@ public void testTransformArrayPrimitiveNotContainingNull() { null ); StructType schema = new StructType() - .add("i", new ArrayType(IntegerType, true), true); - Dataset df = spark.createDataFrame(data, schema); - - Runnable f = () -> { - checkAnswer( - df.select(transform(col("i"), x -> x.plus(1))), - toRows( - makeArray(2, 10, 9, 8), - makeArray(6, 9, 10, 8, 3), - JavaTestUtils.makeArray(), - null - )); - checkAnswer( - df.select(transform(col("i"), (x, i) -> x.plus(i))), - toRows( - makeArray(1, 10, 10, 10), - makeArray(5, 9, 11, 10, 6), - JavaTestUtils.makeArray(), - null - )); - }; - - // Test with local relation, the Project will be evaluated without codegen - f.run(); - // Test with cached relation, the Project will be evaluated with codegen - df.cache(); - f.run(); - } - - @Test - public void testTransformArrayPrimitiveContainingNull() { - List data = toRows( - makeArray(1, 9, 8, null, 7), - makeArray(5, null, 8, 9, 7, 2), - JavaTestUtils.makeArray(), - null - ); - StructType schema = new StructType() - .add("i", new ArrayType(IntegerType, true), true); - Dataset df = spark.createDataFrame(data, schema); - - Runnable f = () -> { - checkAnswer( - df.select(transform(col("i"), x -> x.plus(1))), - toRows( - makeArray(2, 10, 9, null, 8), - makeArray(6, null, 9, 10, 8, 3), - JavaTestUtils.makeArray(), - null - )); - checkAnswer( - df.select(transform(col("i"), (x, i) -> x.plus(i))), - toRows( - makeArray(1, 10, 10, null, 11), - makeArray(5, null, 10, 12, 11, 7), - JavaTestUtils.makeArray(), - null - )); - }; + .add("x", new ArrayType(IntegerType, true), true); + df = spark.createDataFrame(data, schema); + } - // Test with local relation, the Project will be evaluated without codegen - f.run(); - df.cache(); - // Test with cached relation, the Project will be evaluated with codegen - f.run(); + @After + public void tearDown() { + spark.stop(); + spark = null; } @Test - public void testTransformArrayNonPrimitive() { - List data = toRows( - makeArray("c", "a", "b"), - makeArray("b", null, "c", null), - JavaTestUtils.makeArray(), - null - ); - StructType schema = new StructType() - .add("s", new ArrayType(StringType, true), true); - Dataset df = spark.createDataFrame(data, schema); - - Runnable f = () -> { - checkAnswer(df.select(transform(col("s"), x -> concat(x, x))), - toRows( - makeArray("cc", "aa", "bb"), - makeArray("bb", null, "cc", null), - JavaTestUtils.makeArray(), - null - )); - checkAnswer(df.select(transform(col("s"), (x, i) -> concat(x, i))), - toRows( - makeArray("c0", "a1", "b2"), - makeArray("b0", null, "c2", null), - JavaTestUtils.makeArray(), - null - )); - }; - - // Test with local relation, the Project will be evaluated without codegen - f.run(); - // Test with cached relation, the Project will be evaluated with codegen - df.cache(); - f.run(); + public void testTransform() { + checkAnswer( + df.select(transform(col("x"), x -> x.plus(1))), + toRows( + makeArray(2, 10, 9, 8), + makeArray(6, 9, 10, 8, 3), + JavaTestUtils.makeArray(), + null + )); + checkAnswer( + df.select(transform(col("x"), (x, i) -> x.plus(i))), + toRows( + makeArray(1, 10, 10, 10), + makeArray(5, 9, 11, 10, 6), + JavaTestUtils.makeArray(), + null + )); } @Test - public void testTransformSpecialCases() { - List data = toRows( - makeArray("c", "a", "b"), - makeArray("b", null, "c", null), - JavaTestUtils.makeArray(), - null - ); - StructType schema = new StructType() - .add("s", new ArrayType(StringType, true), true); - Dataset df = spark.createDataFrame(data, schema); - - Runnable f = () -> { - checkAnswer(df.select(transform(col("arg"), arg -> arg)), - toRows( - makeArray("c", "a", "b"), - makeArray("b", null, "c", null), - JavaTestUtils.makeArray(), - null)); - checkAnswer(df.select(transform(col("arg"), x -> col("arg"))), - toRows( - makeArray( - makeArray("c", "a", "b"), - makeArray("c", "a", "b"), - makeArray("c", "a", "b") - ), - makeArray( - makeArray("b", null, "c", null), - makeArray("b", null, "c", null), - makeArray("b", null, "c", null), - makeArray("b", null, "c", null) - ), - JavaTestUtils.makeArray(), - null)); - checkAnswer(df.select(transform(col("arg"), x -> concat(col("arg"), array(x)))), - toRows( - makeArray( - makeArray("c", "a", "b", "c"), - makeArray("c", "a", "b", "c"), - makeArray("c", "a", "b", "c") - ), - makeArray( - makeArray("b", null, "c", null, "b"), - makeArray("b", null, "c", null, null), - makeArray("b", null, "c", null, "b"), - makeArray("b", null, "c", null, null) - ), - JavaTestUtils.makeArray(), - null)); - }; + public void testFilter() { + checkAnswer( + df.select(filter(col("x"), x -> x.plus(1).equalTo(10))), + toRows( + makeArray(9), + makeArray(9), + JavaTestUtils.makeArray(), + null + )); } } From 0433756af77b536bceb3e588b1e27888834d6953 Mon Sep 17 00:00:00 2001 From: Nik Vanderhoof Date: Tue, 17 Sep 2019 23:06:37 -0400 Subject: [PATCH 29/38] Add simple java test for exists --- .../spark/sql/JavaHigherOrderFunctionsSuite.java | 12 ++++++++++++ 1 file changed, 12 insertions(+) diff --git a/sql/core/src/test/java/test/org/apache/spark/sql/JavaHigherOrderFunctionsSuite.java b/sql/core/src/test/java/test/org/apache/spark/sql/JavaHigherOrderFunctionsSuite.java index 11ee028e71397..0d3d372afa15c 100644 --- a/sql/core/src/test/java/test/org/apache/spark/sql/JavaHigherOrderFunctionsSuite.java +++ b/sql/core/src/test/java/test/org/apache/spark/sql/JavaHigherOrderFunctionsSuite.java @@ -90,4 +90,16 @@ public void testFilter() { null )); } + + @Test + public void testExists() { + checkAnswer( + df.select(exists(col("x"), x -> x.plus(1).equalTo(10))), + toRows( + true, + true, + false, + null + )); + } } From f371413de47c33beb78c396f0da36bea5cdc62a0 Mon Sep 17 00:00:00 2001 From: Nik Vanderhoof Date: Tue, 17 Sep 2019 23:10:54 -0400 Subject: [PATCH 30/38] Add simple java test for forall --- .../spark/sql/JavaHigherOrderFunctionsSuite.java | 12 ++++++++++++ 1 file changed, 12 insertions(+) diff --git a/sql/core/src/test/java/test/org/apache/spark/sql/JavaHigherOrderFunctionsSuite.java b/sql/core/src/test/java/test/org/apache/spark/sql/JavaHigherOrderFunctionsSuite.java index 0d3d372afa15c..264a796159519 100644 --- a/sql/core/src/test/java/test/org/apache/spark/sql/JavaHigherOrderFunctionsSuite.java +++ b/sql/core/src/test/java/test/org/apache/spark/sql/JavaHigherOrderFunctionsSuite.java @@ -102,4 +102,16 @@ public void testExists() { null )); } + + @Test + public void testForall() { + checkAnswer( + df.select(forall(col("x"), x -> x.plus(1).equalTo(10))), + toRows( + false, + false, + true, + null + )); + } } From c3e320c397a7c5779ff999d05f707ac02a56037c Mon Sep 17 00:00:00 2001 From: Nik Vanderhoof Date: Wed, 18 Sep 2019 22:47:12 -0400 Subject: [PATCH 31/38] Add java test for aggregate --- .../spark/sql/JavaHigherOrderFunctionsSuite.java | 12 ++++++++++++ 1 file changed, 12 insertions(+) diff --git a/sql/core/src/test/java/test/org/apache/spark/sql/JavaHigherOrderFunctionsSuite.java b/sql/core/src/test/java/test/org/apache/spark/sql/JavaHigherOrderFunctionsSuite.java index 264a796159519..381cb41d36db1 100644 --- a/sql/core/src/test/java/test/org/apache/spark/sql/JavaHigherOrderFunctionsSuite.java +++ b/sql/core/src/test/java/test/org/apache/spark/sql/JavaHigherOrderFunctionsSuite.java @@ -114,4 +114,16 @@ public void testForall() { null )); } + + @Test + public void testAggregate() { + checkAnswer( + df.select(aggregate(col("x"), lit(0), (acc, x) -> acc.plus(x))), + toRows( + 25, + 31, + 0, + null + )); + } } From 84ccf55354cc8d6861132f0c9c0b224e84703fa0 Mon Sep 17 00:00:00 2001 From: Nik Vanderhoof Date: Wed, 18 Sep 2019 22:48:30 -0400 Subject: [PATCH 32/38] Add java aggregate test with finish --- .../apache/spark/sql/JavaHigherOrderFunctionsSuite.java | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/sql/core/src/test/java/test/org/apache/spark/sql/JavaHigherOrderFunctionsSuite.java b/sql/core/src/test/java/test/org/apache/spark/sql/JavaHigherOrderFunctionsSuite.java index 381cb41d36db1..0dc40da5ba0e6 100644 --- a/sql/core/src/test/java/test/org/apache/spark/sql/JavaHigherOrderFunctionsSuite.java +++ b/sql/core/src/test/java/test/org/apache/spark/sql/JavaHigherOrderFunctionsSuite.java @@ -125,5 +125,13 @@ public void testAggregate() { 0, null )); + checkAnswer( + df.select(aggregate(col("x"), lit(0), (acc, x) -> acc.plus(x), x -> x)), + toRows( + 25, + 31, + 0, + null + )); } } From e43033b877e3af720c208e31c1c9e83780d425aa Mon Sep 17 00:00:00 2001 From: Nik Vanderhoof Date: Wed, 18 Sep 2019 22:51:58 -0400 Subject: [PATCH 33/38] Add java test for zip_with --- .../spark/sql/JavaHigherOrderFunctionsSuite.java | 12 ++++++++++++ 1 file changed, 12 insertions(+) diff --git a/sql/core/src/test/java/test/org/apache/spark/sql/JavaHigherOrderFunctionsSuite.java b/sql/core/src/test/java/test/org/apache/spark/sql/JavaHigherOrderFunctionsSuite.java index 0dc40da5ba0e6..e3c3ada79e156 100644 --- a/sql/core/src/test/java/test/org/apache/spark/sql/JavaHigherOrderFunctionsSuite.java +++ b/sql/core/src/test/java/test/org/apache/spark/sql/JavaHigherOrderFunctionsSuite.java @@ -134,4 +134,16 @@ public void testAggregate() { null )); } + + @Test + public void testZipWith() { + checkAnswer( + df.select(zip_with(col("x"), col("x"), (a, b) -> lit(42))), + toRows( + makeArray(42, 42, 42, 42), + makeArray(42, 42, 42, 42, 42), + JavaTestUtils.makeArray(), + null + )); + } } From c1c76a9fda852b0890bdf40197b0f7483c6e3bb0 Mon Sep 17 00:00:00 2001 From: Nik Vanderhoof Date: Thu, 19 Sep 2019 21:14:35 -0400 Subject: [PATCH 34/38] Add java test for transformKeys --- .../sql/JavaHigherOrderFunctionsSuite.java | 61 +++++++++++++++---- 1 file changed, 48 insertions(+), 13 deletions(-) diff --git a/sql/core/src/test/java/test/org/apache/spark/sql/JavaHigherOrderFunctionsSuite.java b/sql/core/src/test/java/test/org/apache/spark/sql/JavaHigherOrderFunctionsSuite.java index e3c3ada79e156..27c89c52db4c8 100644 --- a/sql/core/src/test/java/test/org/apache/spark/sql/JavaHigherOrderFunctionsSuite.java +++ b/sql/core/src/test/java/test/org/apache/spark/sql/JavaHigherOrderFunctionsSuite.java @@ -17,9 +17,11 @@ package test.org.apache.spark.sql; +import java.util.HashMap; import java.util.List; import scala.collection.Seq; +import static scala.collection.JavaConverters.mapAsScalaMap; import org.junit.After; import org.junit.Assert; @@ -37,11 +39,10 @@ public class JavaHigherOrderFunctionsSuite { private transient TestSparkSession spark; - private Dataset df; + private Dataset arrDf; + private Dataset mapDf; - @Before - public void setUp() { - spark = new TestSparkSession(); + private void setUpArrDf() { List data = toRows( makeArray(1, 9, 8, 7), makeArray(5, 8, 9, 7, 2), @@ -50,7 +51,27 @@ public void setUp() { ); StructType schema = new StructType() .add("x", new ArrayType(IntegerType, true), true); - df = spark.createDataFrame(data, schema); + arrDf = spark.createDataFrame(data, schema); + } + + private void setUpMapDf() { + List data = toRows( + new HashMap() {{ + put(1, 1); + put(2, 2); + }}, + null + ); + StructType schema = new StructType() + .add("x", new MapType(IntegerType, IntegerType, true)); + mapDf = spark.createDataFrame(data, schema); + } + + @Before + public void setUp() { + spark = new TestSparkSession(); + setUpArrDf(); + setUpMapDf(); } @After @@ -62,7 +83,7 @@ public void tearDown() { @Test public void testTransform() { checkAnswer( - df.select(transform(col("x"), x -> x.plus(1))), + arrDf.select(transform(col("x"), x -> x.plus(1))), toRows( makeArray(2, 10, 9, 8), makeArray(6, 9, 10, 8, 3), @@ -70,7 +91,7 @@ public void testTransform() { null )); checkAnswer( - df.select(transform(col("x"), (x, i) -> x.plus(i))), + arrDf.select(transform(col("x"), (x, i) -> x.plus(i))), toRows( makeArray(1, 10, 10, 10), makeArray(5, 9, 11, 10, 6), @@ -82,7 +103,7 @@ public void testTransform() { @Test public void testFilter() { checkAnswer( - df.select(filter(col("x"), x -> x.plus(1).equalTo(10))), + arrDf.select(filter(col("x"), x -> x.plus(1).equalTo(10))), toRows( makeArray(9), makeArray(9), @@ -94,7 +115,7 @@ public void testFilter() { @Test public void testExists() { checkAnswer( - df.select(exists(col("x"), x -> x.plus(1).equalTo(10))), + arrDf.select(exists(col("x"), x -> x.plus(1).equalTo(10))), toRows( true, true, @@ -106,7 +127,7 @@ public void testExists() { @Test public void testForall() { checkAnswer( - df.select(forall(col("x"), x -> x.plus(1).equalTo(10))), + arrDf.select(forall(col("x"), x -> x.plus(1).equalTo(10))), toRows( false, false, @@ -118,7 +139,7 @@ public void testForall() { @Test public void testAggregate() { checkAnswer( - df.select(aggregate(col("x"), lit(0), (acc, x) -> acc.plus(x))), + arrDf.select(aggregate(col("x"), lit(0), (acc, x) -> acc.plus(x))), toRows( 25, 31, @@ -126,7 +147,7 @@ public void testAggregate() { null )); checkAnswer( - df.select(aggregate(col("x"), lit(0), (acc, x) -> acc.plus(x), x -> x)), + arrDf.select(aggregate(col("x"), lit(0), (acc, x) -> acc.plus(x), x -> x)), toRows( 25, 31, @@ -138,7 +159,7 @@ public void testAggregate() { @Test public void testZipWith() { checkAnswer( - df.select(zip_with(col("x"), col("x"), (a, b) -> lit(42))), + arrDf.select(zip_with(col("x"), col("x"), (a, b) -> lit(42))), toRows( makeArray(42, 42, 42, 42), makeArray(42, 42, 42, 42, 42), @@ -146,4 +167,18 @@ public void testZipWith() { null )); } + + @Test + public void testTransformKeys() { + checkAnswer( + mapDf.select(transform_keys(col("x"), (k, v) -> k.plus(v))), + toRows( + mapAsScalaMap( + new HashMap() {{ + put(2, 1); + put(4, 2); + }}), + null + )); + } } From 10a5f2e1fdaf1091db28e0becd0a202861768615 Mon Sep 17 00:00:00 2001 From: Nik Vanderhoof Date: Thu, 19 Sep 2019 21:16:17 -0400 Subject: [PATCH 35/38] Add java test for transform_values --- .../spark/sql/JavaHigherOrderFunctionsSuite.java | 14 ++++++++++++++ 1 file changed, 14 insertions(+) diff --git a/sql/core/src/test/java/test/org/apache/spark/sql/JavaHigherOrderFunctionsSuite.java b/sql/core/src/test/java/test/org/apache/spark/sql/JavaHigherOrderFunctionsSuite.java index 27c89c52db4c8..aed4699521517 100644 --- a/sql/core/src/test/java/test/org/apache/spark/sql/JavaHigherOrderFunctionsSuite.java +++ b/sql/core/src/test/java/test/org/apache/spark/sql/JavaHigherOrderFunctionsSuite.java @@ -181,4 +181,18 @@ public void testTransformKeys() { null )); } + + @Test + public void testTransformValues() { + checkAnswer( + mapDf.select(transform_values(col("x"), (k, v) -> k.plus(v))), + toRows( + mapAsScalaMap( + new HashMap() {{ + put(1, 2); + put(2, 4); + }}), + null + )); + } } From 722f0e68a9a4a73118ea636761f8e888a2dc48fd Mon Sep 17 00:00:00 2001 From: Nik Vanderhoof Date: Thu, 19 Sep 2019 21:20:09 -0400 Subject: [PATCH 36/38] Add java test for map_filter and map_zip_with --- .../sql/JavaHigherOrderFunctionsSuite.java | 23 +++++++++++++++++++ 1 file changed, 23 insertions(+) diff --git a/sql/core/src/test/java/test/org/apache/spark/sql/JavaHigherOrderFunctionsSuite.java b/sql/core/src/test/java/test/org/apache/spark/sql/JavaHigherOrderFunctionsSuite.java index aed4699521517..94f6f747e197e 100644 --- a/sql/core/src/test/java/test/org/apache/spark/sql/JavaHigherOrderFunctionsSuite.java +++ b/sql/core/src/test/java/test/org/apache/spark/sql/JavaHigherOrderFunctionsSuite.java @@ -195,4 +195,27 @@ public void testTransformValues() { null )); } + + @Test + public void testMapFilter() { + checkAnswer( + mapDf.select(map_filter(col("x"), (k, v) -> lit(false))), + toRows( + mapAsScalaMap(new HashMap()), + null + )); + } + + @Test + public void testMapZipWith() { + checkAnswer( + mapDf.select(map_zip_with(col("x"), col("x"), (k, v1, v2) -> lit(false))), + toRows( + mapAsScalaMap(new HashMap() {{ + put(1, false); + put(2, false); + }}), + null + )); + } } From 1bf2654e759eef8f48aaeedcaaa80c01b1c97c94 Mon Sep 17 00:00:00 2001 From: Nik Vanderhoof Date: Tue, 1 Oct 2019 22:32:35 -0400 Subject: [PATCH 37/38] Fix style nits --- .../org/apache/spark/sql/functions.scala | 13 ++++-- .../sql/JavaHigherOrderFunctionsSuite.java | 42 ++++++++++++------- 2 files changed, 35 insertions(+), 20 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala index 9413f51aeaa9f..a8d6964b3b83e 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala @@ -3466,8 +3466,11 @@ object functions { * @group collection_funcs * @since 3.0.0 */ - def aggregate(expr: Column, zero: Column, merge: (Column, Column) => Column, - finish: Column => Column): Column = withExpr { + def aggregate( + expr: Column, + zero: Column, + merge: (Column, Column) => Column, + finish: Column => Column): Column = withExpr { ArrayAggregate( expr.expr, zero.expr, @@ -3536,8 +3539,10 @@ object functions { * @group collection_funcs * @since 3.0.0 */ - def map_zip_with(left: Column, right: Column, - f: (Column, Column, Column) => Column): Column = withExpr { + def map_zip_with( + left: Column, + right: Column, + f: (Column, Column, Column) => Column): Column = withExpr { MapZipWith(left.expr, right.expr, createLambda(f)) } diff --git a/sql/core/src/test/java/test/org/apache/spark/sql/JavaHigherOrderFunctionsSuite.java b/sql/core/src/test/java/test/org/apache/spark/sql/JavaHigherOrderFunctionsSuite.java index 94f6f747e197e..db59b7e98e125 100644 --- a/sql/core/src/test/java/test/org/apache/spark/sql/JavaHigherOrderFunctionsSuite.java +++ b/sql/core/src/test/java/test/org/apache/spark/sql/JavaHigherOrderFunctionsSuite.java @@ -89,7 +89,8 @@ public void testTransform() { makeArray(6, 9, 10, 8, 3), JavaTestUtils.makeArray(), null - )); + ) + ); checkAnswer( arrDf.select(transform(col("x"), (x, i) -> x.plus(i))), toRows( @@ -97,7 +98,8 @@ public void testTransform() { makeArray(5, 9, 11, 10, 6), JavaTestUtils.makeArray(), null - )); + ) + ); } @Test @@ -109,7 +111,8 @@ public void testFilter() { makeArray(9), JavaTestUtils.makeArray(), null - )); + ) + ); } @Test @@ -121,7 +124,8 @@ public void testExists() { true, false, null - )); + ) + ); } @Test @@ -133,7 +137,8 @@ public void testForall() { false, true, null - )); + ) + ); } @Test @@ -145,7 +150,8 @@ public void testAggregate() { 31, 0, null - )); + ) + ); checkAnswer( arrDf.select(aggregate(col("x"), lit(0), (acc, x) -> acc.plus(x), x -> x)), toRows( @@ -153,7 +159,8 @@ public void testAggregate() { 31, 0, null - )); + ) + ); } @Test @@ -165,7 +172,8 @@ public void testZipWith() { makeArray(42, 42, 42, 42, 42), JavaTestUtils.makeArray(), null - )); + ) + ); } @Test @@ -173,13 +181,13 @@ public void testTransformKeys() { checkAnswer( mapDf.select(transform_keys(col("x"), (k, v) -> k.plus(v))), toRows( - mapAsScalaMap( - new HashMap() {{ + mapAsScalaMap(new HashMap() {{ put(2, 1); put(4, 2); }}), null - )); + ) + ); } @Test @@ -187,13 +195,13 @@ public void testTransformValues() { checkAnswer( mapDf.select(transform_values(col("x"), (k, v) -> k.plus(v))), toRows( - mapAsScalaMap( - new HashMap() {{ + mapAsScalaMap(new HashMap() {{ put(1, 2); put(2, 4); }}), null - )); + ) + ); } @Test @@ -203,7 +211,8 @@ public void testMapFilter() { toRows( mapAsScalaMap(new HashMap()), null - )); + ) + ); } @Test @@ -216,6 +225,7 @@ public void testMapZipWith() { put(2, false); }}), null - )); + ) + ); } } From 64c0f87a8005a27458394a14648c0e75ee514678 Mon Sep 17 00:00:00 2001 From: Nik Vanderhoof Date: Tue, 1 Oct 2019 23:38:06 -0400 Subject: [PATCH 38/38] Fix linter errors in imports --- .../org/apache/spark/sql/JavaHigherOrderFunctionsSuite.java | 3 --- 1 file changed, 3 deletions(-) diff --git a/sql/core/src/test/java/test/org/apache/spark/sql/JavaHigherOrderFunctionsSuite.java b/sql/core/src/test/java/test/org/apache/spark/sql/JavaHigherOrderFunctionsSuite.java index db59b7e98e125..a5f11d57f3ce6 100644 --- a/sql/core/src/test/java/test/org/apache/spark/sql/JavaHigherOrderFunctionsSuite.java +++ b/sql/core/src/test/java/test/org/apache/spark/sql/JavaHigherOrderFunctionsSuite.java @@ -20,11 +20,9 @@ import java.util.HashMap; import java.util.List; -import scala.collection.Seq; import static scala.collection.JavaConverters.mapAsScalaMap; import org.junit.After; -import org.junit.Assert; import org.junit.Before; import org.junit.Test; @@ -35,7 +33,6 @@ import static org.apache.spark.sql.functions.*; import org.apache.spark.sql.test.TestSparkSession; import static test.org.apache.spark.sql.JavaTestUtils.*; -import test.org.apache.spark.sql.JavaTestUtils; public class JavaHigherOrderFunctionsSuite { private transient TestSparkSession spark;