From ae7092d84cff45f41e7d18ae34d8054f502a5205 Mon Sep 17 00:00:00 2001 From: gengjiaan Date: Thu, 21 Oct 2021 16:05:21 +0800 Subject: [PATCH 1/8] Spark SQL should support create function with Aggregator --- .../catalog/FunctionExpressionBuilder.scala | 2 + .../internal/BaseSessionStateBuilder.scala | 38 ++++++++++++++++++- .../test/resources/sql-tests/inputs/udaf.sql | 6 +++ .../resources/sql-tests/results/udaf.sql.out | 26 ++++++++++++- .../scala/org/apache/spark/sql/UDFSuite.scala | 35 +++++++++++++++++ 5 files changed, 104 insertions(+), 3 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/FunctionExpressionBuilder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/FunctionExpressionBuilder.scala index bf3d790b86c03..256486658efcd 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/FunctionExpressionBuilder.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/FunctionExpressionBuilder.scala @@ -21,6 +21,8 @@ import org.apache.spark.sql.catalyst.expressions.Expression // A builder to create `Expression` from function information. trait FunctionExpressionBuilder { + // `name` and `clazz` are the name and provided class of user-defined functions, respectively. + // `input` is the children of `ScalaUDAF` or `ScalaAggregator`. def makeExpression(name: String, clazz: Class[_], input: Seq[Expression]): Expression } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/internal/BaseSessionStateBuilder.scala b/sql/core/src/main/scala/org/apache/spark/sql/internal/BaseSessionStateBuilder.scala index 1e60cb8b1db2a..f6d028a13a0e8 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/internal/BaseSessionStateBuilder.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/internal/BaseSessionStateBuilder.scala @@ -16,10 +16,15 @@ */ package org.apache.spark.sql.internal +import scala.reflect.ClassTag +import scala.reflect.runtime.universe + import org.apache.spark.annotation.Unstable import org.apache.spark.sql.{ExperimentalMethods, SparkSession, UDFRegistration, _} +import org.apache.spark.sql.catalyst.ScalaReflection import org.apache.spark.sql.catalyst.analysis.{Analyzer, FunctionRegistry, ReplaceCharWithVarchar, ResolveSessionCatalog, TableFunctionRegistry} import org.apache.spark.sql.catalyst.catalog.{FunctionExpressionBuilder, SessionCatalog} +import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder import org.apache.spark.sql.catalyst.expressions.Expression import org.apache.spark.sql.catalyst.optimizer.Optimizer import org.apache.spark.sql.catalyst.parser.ParserInterface @@ -28,13 +33,13 @@ import org.apache.spark.sql.catalyst.rules.Rule import org.apache.spark.sql.connector.catalog.CatalogManager import org.apache.spark.sql.errors.QueryCompilationErrors import org.apache.spark.sql.execution.{ColumnarRule, CommandExecutionMode, QueryExecution, SparkOptimizer, SparkPlan, SparkPlanner, SparkSqlParser} -import org.apache.spark.sql.execution.aggregate.{ResolveEncodersInScalaAgg, ScalaUDAF} +import org.apache.spark.sql.execution.aggregate.{ResolveEncodersInScalaAgg, ScalaAggregator, ScalaUDAF} import org.apache.spark.sql.execution.analysis.DetectAmbiguousSelfJoin import org.apache.spark.sql.execution.command.CommandCheck import org.apache.spark.sql.execution.datasources._ import org.apache.spark.sql.execution.datasources.v2.{TableCapabilityCheck, V2SessionCatalog} import org.apache.spark.sql.execution.streaming.ResolveWriteToStream -import org.apache.spark.sql.expressions.UserDefinedAggregateFunction +import org.apache.spark.sql.expressions.{Aggregator, UserDefinedAggregateFunction} import org.apache.spark.sql.streaming.StreamingQueryManager import org.apache.spark.sql.util.ExecutionListenerManager @@ -411,6 +416,35 @@ class SparkUDFExpressionBuilder extends FunctionExpressionBuilder { name, expr.inputTypes.size.toString, input.size) } expr + } else if (classOf[Aggregator[_, _, _]].isAssignableFrom(clazz)) { + val aggregator = clazz.getConstructor().newInstance().asInstanceOf[Aggregator[Any, Any, Any]] + + // Construct the input encoder + val mirror = universe.runtimeMirror(clazz.getClassLoader) + val classType = mirror.classSymbol(clazz) + val baseClassType = universe.typeOf[Aggregator[_, _, _]].typeSymbol.asClass + val baseType = universe.internal.thisType(classType).baseType(baseClassType) + val tpe = baseType.typeArgs.head + val cls = mirror.runtimeClass(tpe) + val serializer = ScalaReflection.serializerForType(tpe) + val deserializer = ScalaReflection.deserializerForType(tpe) + val inputEncoder = new ExpressionEncoder[Any]( + serializer, + deserializer, + ClassTag(cls)) + + val expr = ScalaAggregator[Any, Any, Any]( + input, + aggregator, + inputEncoder, + aggregator.bufferEncoder.asInstanceOf[ExpressionEncoder[Any]], + aggregatorName = Some(name)) + // Check input argument size + if (expr.inputTypes.size != input.size) { + throw QueryCompilationErrors.invalidFunctionArgumentsError( + name, expr.inputTypes.size.toString, input.size) + } + expr } else { throw QueryCompilationErrors.noHandlerForUDAFError(clazz.getCanonicalName) } diff --git a/sql/core/src/test/resources/sql-tests/inputs/udaf.sql b/sql/core/src/test/resources/sql-tests/inputs/udaf.sql index 0374d98feb6e6..3fc65f46a17a2 100644 --- a/sql/core/src/test/resources/sql-tests/inputs/udaf.sql +++ b/sql/core/src/test/resources/sql-tests/inputs/udaf.sql @@ -19,3 +19,9 @@ SELECT default.udaf1(int_col1) as udaf1 from t1; DROP FUNCTION myDoubleAvg; DROP FUNCTION udaf1; + +CREATE FUNCTION myDoubleAverage AS 'test.org.apache.spark.sql.MyDoubleAverage'; + +SELECT default.myDoubleAverage(int_col1) as my_avg from t1; + +DROP FUNCTION myDoubleAverage; \ No newline at end of file diff --git a/sql/core/src/test/resources/sql-tests/results/udaf.sql.out b/sql/core/src/test/resources/sql-tests/results/udaf.sql.out index 9f4229a11b65d..992900b7dde4f 100644 --- a/sql/core/src/test/resources/sql-tests/results/udaf.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/udaf.sql.out @@ -1,5 +1,5 @@ -- Automatically generated by SQLQueryTestSuite --- Number of queries: 8 +-- Number of queries: 11 -- !query @@ -68,3 +68,27 @@ DROP FUNCTION udaf1 struct<> -- !query output + + +-- !query +CREATE FUNCTION myDoubleAverage AS 'test.org.apache.spark.sql.MyDoubleAverage' +-- !query schema +struct<> +-- !query output + + + +-- !query +SELECT default.myDoubleAverage(int_col1) as my_avg from t1 +-- !query schema +struct +-- !query output +2.5 + + +-- !query +DROP FUNCTION myDoubleAverage +-- !query schema +struct<> +-- !query output + diff --git a/sql/core/src/test/scala/org/apache/spark/sql/UDFSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/UDFSuite.scala index d100cad89fcc1..0006961cd6c79 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/UDFSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/UDFSuite.scala @@ -17,6 +17,7 @@ package org.apache.spark.sql +import java.lang.{Double => jlDouble} import java.math.BigDecimal import java.sql.Timestamp import java.time.{Instant, LocalDate} @@ -50,6 +51,20 @@ private case class FunctionResult(f1: String, f2: String) private case class LocalDateInstantType(date: LocalDate, instant: Instant) private case class TimestampInstantType(t: Timestamp, instant: Instant) +class MyDoubleAverage extends Aggregator[jlDouble, (Double, Long), jlDouble] { + def zero: (Double, Long) = (0.0, 0L) + def reduce(b: (Double, Long), a: jlDouble): (Double, Long) = { + if (a != null) (b._1 + a, b._2 + 1L) else b + } + def merge(b1: (Double, Long), b2: (Double, Long)): (Double, Long) = + (b1._1 + b2._1, b1._2 + b2._2) + def finish(r: (Double, Long)): jlDouble = + if (r._2 > 0L) 100.0 + (r._1 / r._2.toDouble) else null + def bufferEncoder: Encoder[(Double, Long)] = + Encoders.tuple(Encoders.scalaDouble, Encoders.scalaLong) + def outputEncoder: Encoder[jlDouble] = Encoders.DOUBLE +} + class UDFSuite extends QueryTest with SharedSparkSession { import testImplicits._ @@ -848,6 +863,26 @@ class UDFSuite extends QueryTest with SharedSparkSession { } } + test("SPARK-37018: Spark SQL should support create function with Aggregator") { + val avgFuncClass = "org.apache.spark.sql.MyDoubleAverage" + val functionName = "test_udf" + withTempDatabase { dbName => + withUserDefinedFunction( + s"default.$functionName" -> false, + s"$dbName.$functionName" -> false, + functionName -> true) { + // create a function in default database + sql("USE DEFAULT") + sql(s"CREATE FUNCTION $functionName AS '$avgFuncClass'") + // create a view using a function in 'default' database + withView("v1") { + sql(s"CREATE VIEW v1 AS SELECT $functionName(col1) AS func FROM VALUES (1), (2), (3)") + checkAnswer(sql("SELECT * FROM v1"), Seq(Row(102.0))) + } + } + } + } + test("SPARK-35674: using java.time.LocalDateTime in UDF") { // Regular case val input = Seq(java.time.LocalDateTime.parse("2021-01-01T00:00:00")).toDF("dateTime") From 0a3ccd8d54cf696374e74d3a86a83acfa1eac872 Mon Sep 17 00:00:00 2001 From: Jiaan Geng Date: Thu, 18 Nov 2021 09:03:05 +0800 Subject: [PATCH 2/8] Update sql/core/src/main/scala/org/apache/spark/sql/internal/BaseSessionStateBuilder.scala Co-authored-by: Wenchen Fan --- .../apache/spark/sql/internal/BaseSessionStateBuilder.scala | 5 +---- 1 file changed, 1 insertion(+), 4 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/internal/BaseSessionStateBuilder.scala b/sql/core/src/main/scala/org/apache/spark/sql/internal/BaseSessionStateBuilder.scala index f6d028a13a0e8..34a9ac3d62270 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/internal/BaseSessionStateBuilder.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/internal/BaseSessionStateBuilder.scala @@ -428,10 +428,7 @@ class SparkUDFExpressionBuilder extends FunctionExpressionBuilder { val cls = mirror.runtimeClass(tpe) val serializer = ScalaReflection.serializerForType(tpe) val deserializer = ScalaReflection.deserializerForType(tpe) - val inputEncoder = new ExpressionEncoder[Any]( - serializer, - deserializer, - ClassTag(cls)) + val inputEncoder = new ExpressionEncoder[Any](serializer, deserializer, ClassTag(cls)) val expr = ScalaAggregator[Any, Any, Any]( input, From 1fc88ffeec8d4d6acc2cf0efabddb3f61818adce Mon Sep 17 00:00:00 2001 From: Jiaan Geng Date: Thu, 18 Nov 2021 16:14:23 +0800 Subject: [PATCH 3/8] Update code --- sql/core/src/test/resources/sql-tests/inputs/udaf.sql | 2 +- sql/core/src/test/resources/sql-tests/results/udaf.sql.out | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/sql/core/src/test/resources/sql-tests/inputs/udaf.sql b/sql/core/src/test/resources/sql-tests/inputs/udaf.sql index 3fc65f46a17a2..bac35326df619 100644 --- a/sql/core/src/test/resources/sql-tests/inputs/udaf.sql +++ b/sql/core/src/test/resources/sql-tests/inputs/udaf.sql @@ -20,7 +20,7 @@ SELECT default.udaf1(int_col1) as udaf1 from t1; DROP FUNCTION myDoubleAvg; DROP FUNCTION udaf1; -CREATE FUNCTION myDoubleAverage AS 'test.org.apache.spark.sql.MyDoubleAverage'; +CREATE FUNCTION myDoubleAverage AS 'org.apache.spark.sql.MyDoubleAverage'; SELECT default.myDoubleAverage(int_col1) as my_avg from t1; diff --git a/sql/core/src/test/resources/sql-tests/results/udaf.sql.out b/sql/core/src/test/resources/sql-tests/results/udaf.sql.out index 992900b7dde4f..f12224c9b3b08 100644 --- a/sql/core/src/test/resources/sql-tests/results/udaf.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/udaf.sql.out @@ -71,7 +71,7 @@ struct<> -- !query -CREATE FUNCTION myDoubleAverage AS 'test.org.apache.spark.sql.MyDoubleAverage' +CREATE FUNCTION myDoubleAverage AS 'org.apache.spark.sql.MyDoubleAverage' -- !query schema struct<> -- !query output @@ -83,7 +83,7 @@ SELECT default.myDoubleAverage(int_col1) as my_avg from t1 -- !query schema struct -- !query output -2.5 +102.5 -- !query From aba8c90be094155400f5614eeaecc41402540555 Mon Sep 17 00:00:00 2001 From: Jiaan Geng Date: Thu, 18 Nov 2021 18:56:19 +0800 Subject: [PATCH 4/8] Update code --- .../sql/errors/QueryExecutionErrors.scala | 5 ++ .../internal/BaseSessionStateBuilder.scala | 11 ++- .../test/resources/sql-tests/inputs/udaf.sql | 19 +++-- .../resources/sql-tests/results/udaf.sql.out | 71 +++++++++++++++++-- .../scala/org/apache/spark/sql/UDFSuite.scala | 4 ++ 5 files changed, 97 insertions(+), 13 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/QueryExecutionErrors.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/QueryExecutionErrors.scala index 6f0ed23e10228..c173a33f1a951 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/QueryExecutionErrors.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/QueryExecutionErrors.scala @@ -1920,4 +1920,9 @@ object QueryExecutionErrors { s". To solve this try to set $maxDynamicPartitionsKey" + s" to at least $numWrittenParts.") } + + def registerFunctionWithoutParameterlessConstructorError(className: String): Throwable = { + new RuntimeException(s"Register aggregate function with '$className' which not provides " + + "parameterless constructor is not supported") + } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/internal/BaseSessionStateBuilder.scala b/sql/core/src/main/scala/org/apache/spark/sql/internal/BaseSessionStateBuilder.scala index 34a9ac3d62270..7a86e92fb154e 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/internal/BaseSessionStateBuilder.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/internal/BaseSessionStateBuilder.scala @@ -18,7 +18,6 @@ package org.apache.spark.sql.internal import scala.reflect.ClassTag import scala.reflect.runtime.universe - import org.apache.spark.annotation.Unstable import org.apache.spark.sql.{ExperimentalMethods, SparkSession, UDFRegistration, _} import org.apache.spark.sql.catalyst.ScalaReflection @@ -31,7 +30,7 @@ import org.apache.spark.sql.catalyst.parser.ParserInterface import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan import org.apache.spark.sql.catalyst.rules.Rule import org.apache.spark.sql.connector.catalog.CatalogManager -import org.apache.spark.sql.errors.QueryCompilationErrors +import org.apache.spark.sql.errors.{QueryCompilationErrors, QueryExecutionErrors} import org.apache.spark.sql.execution.{ColumnarRule, CommandExecutionMode, QueryExecution, SparkOptimizer, SparkPlan, SparkPlanner, SparkSqlParser} import org.apache.spark.sql.execution.aggregate.{ResolveEncodersInScalaAgg, ScalaAggregator, ScalaUDAF} import org.apache.spark.sql.execution.analysis.DetectAmbiguousSelfJoin @@ -417,7 +416,13 @@ class SparkUDFExpressionBuilder extends FunctionExpressionBuilder { } expr } else if (classOf[Aggregator[_, _, _]].isAssignableFrom(clazz)) { - val aggregator = clazz.getConstructor().newInstance().asInstanceOf[Aggregator[Any, Any, Any]] + val noParameterConstructor = clazz.getConstructors.find(_.getParameterCount == 0) + if (noParameterConstructor.isEmpty) { + throw QueryExecutionErrors.registerFunctionWithoutParameterlessConstructorError( + clazz.getCanonicalName) + } + val aggregator = + noParameterConstructor.get.newInstance().asInstanceOf[Aggregator[Any, Any, Any]] // Construct the input encoder val mirror = universe.runtimeMirror(clazz.getClassLoader) diff --git a/sql/core/src/test/resources/sql-tests/inputs/udaf.sql b/sql/core/src/test/resources/sql-tests/inputs/udaf.sql index bac35326df619..5d5e0dddaa17a 100644 --- a/sql/core/src/test/resources/sql-tests/inputs/udaf.sql +++ b/sql/core/src/test/resources/sql-tests/inputs/udaf.sql @@ -17,11 +17,22 @@ CREATE FUNCTION udaf1 AS 'test.non.existent.udaf'; SELECT default.udaf1(int_col1) as udaf1 from t1; -DROP FUNCTION myDoubleAvg; -DROP FUNCTION udaf1; - CREATE FUNCTION myDoubleAverage AS 'org.apache.spark.sql.MyDoubleAverage'; SELECT default.myDoubleAverage(int_col1) as my_avg from t1; -DROP FUNCTION myDoubleAverage; \ No newline at end of file +SELECT default.myDoubleAverage(int_col1, 3) as my_avg from t1; + +CREATE FUNCTION myDoubleAverage2 AS 'test.org.apache.spark.sql.MyDoubleAverage'; + +SELECT default.myDoubleAverage2(int_col1) as my_avg from t1; + +CREATE FUNCTION MyDoubleSum AS 'org.apache.spark.sql.MyDoubleSum'; + +SELECT default.MyDoubleSum(int_col1) as my_sum from t1; + +DROP FUNCTION myDoubleAvg; +DROP FUNCTION udaf1; +DROP FUNCTION myDoubleAverage; +DROP FUNCTION myDoubleAverage2; +DROP FUNCTION MyDoubleSum; \ No newline at end of file diff --git a/sql/core/src/test/resources/sql-tests/results/udaf.sql.out b/sql/core/src/test/resources/sql-tests/results/udaf.sql.out index f12224c9b3b08..b36f3949a61fe 100644 --- a/sql/core/src/test/resources/sql-tests/results/udaf.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/udaf.sql.out @@ -1,5 +1,5 @@ -- Automatically generated by SQLQueryTestSuite --- Number of queries: 11 +-- Number of queries: 18 -- !query @@ -54,6 +54,65 @@ org.apache.spark.sql.AnalysisException Can not load class 'test.non.existent.udaf' when registering the function 'default.udaf1', please make sure it is on the classpath; line 1 pos 7 +-- !query +CREATE FUNCTION myDoubleAverage AS 'org.apache.spark.sql.MyDoubleAverage' +-- !query schema +struct<> +-- !query output + + + +-- !query +SELECT default.myDoubleAverage(int_col1) as my_avg from t1 +-- !query schema +struct +-- !query output +102.5 + + +-- !query +SELECT default.myDoubleAverage(int_col1, 3) as my_avg from t1 +-- !query schema +struct<> +-- !query output +org.apache.spark.sql.AnalysisException +Invalid number of arguments for function default.myDoubleAverage. Expected: 1; Found: 2; line 1 pos 7 + + +-- !query +CREATE FUNCTION myDoubleAverage2 AS 'test.org.apache.spark.sql.MyDoubleAverage' +-- !query schema +struct<> +-- !query output + + + +-- !query +SELECT default.myDoubleAverage2(int_col1) as my_avg from t1 +-- !query schema +struct<> +-- !query output +org.apache.spark.sql.AnalysisException +Can not load class 'test.org.apache.spark.sql.MyDoubleAverage' when registering the function 'default.myDoubleAverage2', please make sure it is on the classpath; line 1 pos 7 + + +-- !query +CREATE FUNCTION MyDoubleSum AS 'org.apache.spark.sql.MyDoubleSum' +-- !query schema +struct<> +-- !query output + + + +-- !query +SELECT default.MyDoubleSum(int_col1) as my_sum from t1 +-- !query schema +struct<> +-- !query output +java.lang.RuntimeException +Register aggregate function with 'org.apache.spark.sql.MyDoubleSum' which not provides parameterless constructor is not supported + + -- !query DROP FUNCTION myDoubleAvg -- !query schema @@ -71,7 +130,7 @@ struct<> -- !query -CREATE FUNCTION myDoubleAverage AS 'org.apache.spark.sql.MyDoubleAverage' +DROP FUNCTION myDoubleAverage -- !query schema struct<> -- !query output @@ -79,15 +138,15 @@ struct<> -- !query -SELECT default.myDoubleAverage(int_col1) as my_avg from t1 +DROP FUNCTION myDoubleAverage2 -- !query schema -struct +struct<> -- !query output -102.5 + -- !query -DROP FUNCTION myDoubleAverage +DROP FUNCTION MyDoubleSum -- !query schema struct<> -- !query output diff --git a/sql/core/src/test/scala/org/apache/spark/sql/UDFSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/UDFSuite.scala index 0006961cd6c79..fc02848bb854b 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/UDFSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/UDFSuite.scala @@ -65,6 +65,10 @@ class MyDoubleAverage extends Aggregator[jlDouble, (Double, Long), jlDouble] { def outputEncoder: Encoder[jlDouble] = Encoders.DOUBLE } +class MyDoubleSum(test: Boolean) extends MyDoubleAverage { + override def finish(r: (Double, Long)): jlDouble = if (r._2 > 0L) r._1 else null +} + class UDFSuite extends QueryTest with SharedSparkSession { import testImplicits._ From 6ba32b95aea04d2504b2b47e52f68952bd756ec7 Mon Sep 17 00:00:00 2001 From: Jiaan Geng Date: Thu, 18 Nov 2021 19:16:27 +0800 Subject: [PATCH 5/8] Update code --- .../scala/org/apache/spark/sql/UDFSuite.scala | 23 ++++++++++++++----- 1 file changed, 17 insertions(+), 6 deletions(-) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/UDFSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/UDFSuite.scala index fc02848bb854b..e0e4dc51cfd4e 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/UDFSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/UDFSuite.scala @@ -869,19 +869,30 @@ class UDFSuite extends QueryTest with SharedSparkSession { test("SPARK-37018: Spark SQL should support create function with Aggregator") { val avgFuncClass = "org.apache.spark.sql.MyDoubleAverage" - val functionName = "test_udf" + val avgFunction = "my_avg" + val sumFuncClass = "org.apache.spark.sql.MyDoubleSum" + val sumFunction = "my_sum" withTempDatabase { dbName => withUserDefinedFunction( - s"default.$functionName" -> false, - s"$dbName.$functionName" -> false, - functionName -> true) { + s"default.$avgFunction" -> false, + s"default.$sumFunction" -> false, + s"$dbName.$avgFunction" -> false, + s"$dbName.$sumFunction" -> false, + avgFunction -> true, + sumFunction -> true) { // create a function in default database sql("USE DEFAULT") - sql(s"CREATE FUNCTION $functionName AS '$avgFuncClass'") + sql(s"CREATE FUNCTION $avgFunction AS '$avgFuncClass'") + sql(s"CREATE FUNCTION $sumFunction AS '$sumFuncClass'") // create a view using a function in 'default' database withView("v1") { - sql(s"CREATE VIEW v1 AS SELECT $functionName(col1) AS func FROM VALUES (1), (2), (3)") + sql(s"CREATE VIEW v1 AS SELECT $avgFunction(col1) AS func FROM VALUES (1), (2), (3)") checkAnswer(sql("SELECT * FROM v1"), Seq(Row(102.0))) + + val e = intercept[RuntimeException] { + sql(s"CREATE VIEW v2 AS SELECT $sumFunction(col1) AS func FROM VALUES (1), (2), (3)") + } + assert(e.getMessage.contains("not provides parameterless constructor is not supported")) } } } From da3ed2546c2491f8c2b90c1dd9d185f81a67d0f4 Mon Sep 17 00:00:00 2001 From: Jiaan Geng Date: Thu, 18 Nov 2021 19:20:33 +0800 Subject: [PATCH 6/8] Update code --- sql/core/src/test/resources/sql-tests/inputs/udaf.sql | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sql/core/src/test/resources/sql-tests/inputs/udaf.sql b/sql/core/src/test/resources/sql-tests/inputs/udaf.sql index 5d5e0dddaa17a..8e8eb2e8793fa 100644 --- a/sql/core/src/test/resources/sql-tests/inputs/udaf.sql +++ b/sql/core/src/test/resources/sql-tests/inputs/udaf.sql @@ -35,4 +35,4 @@ DROP FUNCTION myDoubleAvg; DROP FUNCTION udaf1; DROP FUNCTION myDoubleAverage; DROP FUNCTION myDoubleAverage2; -DROP FUNCTION MyDoubleSum; \ No newline at end of file +DROP FUNCTION MyDoubleSum; From e7425524f586a54585b7a986bef660e7303dfaae Mon Sep 17 00:00:00 2001 From: Jiaan Geng Date: Thu, 18 Nov 2021 21:11:57 +0800 Subject: [PATCH 7/8] Update code --- .../org/apache/spark/sql/internal/BaseSessionStateBuilder.scala | 1 + 1 file changed, 1 insertion(+) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/internal/BaseSessionStateBuilder.scala b/sql/core/src/main/scala/org/apache/spark/sql/internal/BaseSessionStateBuilder.scala index 7a86e92fb154e..dbffdaf18b27b 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/internal/BaseSessionStateBuilder.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/internal/BaseSessionStateBuilder.scala @@ -18,6 +18,7 @@ package org.apache.spark.sql.internal import scala.reflect.ClassTag import scala.reflect.runtime.universe + import org.apache.spark.annotation.Unstable import org.apache.spark.sql.{ExperimentalMethods, SparkSession, UDFRegistration, _} import org.apache.spark.sql.catalyst.ScalaReflection From f9a50e2a7dd17bbdfee5b3a0eea9680ca87977e3 Mon Sep 17 00:00:00 2001 From: Jiaan Geng Date: Thu, 23 Dec 2021 21:00:04 +0800 Subject: [PATCH 8/8] Update code --- .../internal/BaseSessionStateBuilder.scala | 27 ++++++++++--------- 1 file changed, 15 insertions(+), 12 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/internal/BaseSessionStateBuilder.scala b/sql/core/src/main/scala/org/apache/spark/sql/internal/BaseSessionStateBuilder.scala index dbffdaf18b27b..e28adc520db1b 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/internal/BaseSessionStateBuilder.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/internal/BaseSessionStateBuilder.scala @@ -16,6 +16,8 @@ */ package org.apache.spark.sql.internal +import java.io.Serializable + import scala.reflect.ClassTag import scala.reflect.runtime.universe @@ -33,13 +35,14 @@ import org.apache.spark.sql.catalyst.rules.Rule import org.apache.spark.sql.connector.catalog.CatalogManager import org.apache.spark.sql.errors.{QueryCompilationErrors, QueryExecutionErrors} import org.apache.spark.sql.execution.{ColumnarRule, CommandExecutionMode, QueryExecution, SparkOptimizer, SparkPlan, SparkPlanner, SparkSqlParser} -import org.apache.spark.sql.execution.aggregate.{ResolveEncodersInScalaAgg, ScalaAggregator, ScalaUDAF} +import org.apache.spark.sql.execution.aggregate.{ResolveEncodersInScalaAgg, ScalaUDAF} import org.apache.spark.sql.execution.analysis.DetectAmbiguousSelfJoin import org.apache.spark.sql.execution.command.CommandCheck import org.apache.spark.sql.execution.datasources._ import org.apache.spark.sql.execution.datasources.v2.{TableCapabilityCheck, V2SessionCatalog} import org.apache.spark.sql.execution.streaming.ResolveWriteToStream -import org.apache.spark.sql.expressions.{Aggregator, UserDefinedAggregateFunction} +import org.apache.spark.sql.expressions.{Aggregator, UserDefinedAggregateFunction, UserDefinedAggregator, UserDefinedFunction} +import org.apache.spark.sql.functions.udaf import org.apache.spark.sql.streaming.StreamingQueryManager import org.apache.spark.sql.util.ExecutionListenerManager @@ -423,25 +426,25 @@ class SparkUDFExpressionBuilder extends FunctionExpressionBuilder { clazz.getCanonicalName) } val aggregator = - noParameterConstructor.get.newInstance().asInstanceOf[Aggregator[Any, Any, Any]] + noParameterConstructor.get.newInstance().asInstanceOf[Aggregator[Serializable, Any, Any]] // Construct the input encoder val mirror = universe.runtimeMirror(clazz.getClassLoader) val classType = mirror.classSymbol(clazz) - val baseClassType = universe.typeOf[Aggregator[_, _, _]].typeSymbol.asClass + val baseClassType = universe.typeOf[Aggregator[Serializable, Any, Any]].typeSymbol.asClass val baseType = universe.internal.thisType(classType).baseType(baseClassType) val tpe = baseType.typeArgs.head - val cls = mirror.runtimeClass(tpe) val serializer = ScalaReflection.serializerForType(tpe) val deserializer = ScalaReflection.deserializerForType(tpe) - val inputEncoder = new ExpressionEncoder[Any](serializer, deserializer, ClassTag(cls)) + val cls = mirror.runtimeClass(tpe) + val inputEncoder = + new ExpressionEncoder[Serializable](serializer, deserializer, ClassTag(cls)) - val expr = ScalaAggregator[Any, Any, Any]( - input, - aggregator, - inputEncoder, - aggregator.bufferEncoder.asInstanceOf[ExpressionEncoder[Any]], - aggregatorName = Some(name)) + val udf: UserDefinedFunction = udaf[Serializable, Any, Any](aggregator, inputEncoder) + assert(udf.isInstanceOf[UserDefinedAggregator[_, _, _]]) + val udfAgg: UserDefinedAggregator[_, _, _] = udf.asInstanceOf[UserDefinedAggregator[_, _, _]] + + val expr = udfAgg.scalaAggregator(input) // Check input argument size if (expr.inputTypes.size != input.size) { throw QueryCompilationErrors.invalidFunctionArgumentsError(