diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/Count.java b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/Count.java index 0e28a939e3046..fecde71d56ab3 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/Count.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/Count.java @@ -38,7 +38,13 @@ public Count(FieldReference column, boolean isDistinct) { public boolean isDistinct() { return isDistinct; } @Override - public String toString() { return "Count(" + column.describe() + "," + isDistinct + ")"; } + public String toString() { + if (isDistinct) { + return "COUNT(DISTINCT " + column.describe() + ")"; + } else { + return "COUNT(" + column.describe() + ")"; + } + } @Override public String describe() { return this.toString(); } diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/CountStar.java b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/CountStar.java index 21a3564480a64..8e799cd23e57f 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/CountStar.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/CountStar.java @@ -31,7 +31,7 @@ public CountStar() { } @Override - public String toString() { return "CountStar()"; } + public String toString() { return "COUNT(*)"; } @Override public String describe() { return this.toString(); } diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/Max.java b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/Max.java index d2ff6b2f04d6c..3ce45cae919b0 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/Max.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/Max.java @@ -33,7 +33,7 @@ public final class Max implements AggregateFunc { public FieldReference column() { return column; } @Override - public String toString() { return "Max(" + column.describe() + ")"; } + public String toString() { return "MAX(" + column.describe() + ")"; } @Override public String describe() { return this.toString(); } diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/Min.java b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/Min.java index efa80361000b0..2449358f7cac8 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/Min.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/Min.java @@ -33,7 +33,7 @@ public final class Min implements AggregateFunc { public FieldReference column() { return column; } @Override - public String toString() { return "Min(" + column.describe() + ")"; } + public String toString() { return "MIN(" + column.describe() + ")"; } @Override public String describe() { return this.toString(); } diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/Sum.java b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/Sum.java index e4e860e3f3bd9..345194f27ac85 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/Sum.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/Sum.java @@ -18,7 +18,6 @@ package org.apache.spark.sql.connector.expressions; import org.apache.spark.annotation.Evolving; -import org.apache.spark.sql.types.DataType; /** * An aggregate function that returns the summation of all the values in a group. @@ -28,22 +27,23 @@ @Evolving public final class Sum implements AggregateFunc { private final FieldReference column; - private final DataType dataType; private final boolean isDistinct; - public Sum(FieldReference column, DataType dataType, boolean isDistinct) { + public Sum(FieldReference column, boolean isDistinct) { this.column = column; - this.dataType = dataType; this.isDistinct = isDistinct; } public FieldReference column() { return column; } - public DataType dataType() { return dataType; } public boolean isDistinct() { return isDistinct; } @Override public String toString() { - return "Sum(" + column.describe() + "," + dataType + "," + isDistinct + ")"; + if (isDistinct) { + return "SUM(DISTINCT " + column.describe() + ")"; + } else { + return "SUM(" + column.describe() + ")"; + } } @Override diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala index 2f334deebc8f2..81ecb2cb278e6 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala @@ -714,8 +714,7 @@ object DataSourceStrategy case _ => None } case sum @ aggregate.Sum(PushableColumnWithoutNestedColumn(name), _) => - Some(new Sum(FieldReference(name).asInstanceOf[FieldReference], - sum.dataType, aggregates.isDistinct)) + Some(new Sum(FieldReference(name).asInstanceOf[FieldReference], aggregates.isDistinct)) case _ => None } } else { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCRDD.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCRDD.scala index af6c407e4c904..c575e95485cea 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCRDD.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCRDD.scala @@ -25,7 +25,7 @@ import org.apache.spark.{InterruptibleIterator, Partition, SparkContext, TaskCon import org.apache.spark.internal.Logging import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.connector.expressions.{AggregateFunc, Count, CountStar, FieldReference, Max, Min, Sum} +import org.apache.spark.sql.connector.expressions.{AggregateFunc, Count, CountStar, Max, Min, Sum} import org.apache.spark.sql.jdbc.{JdbcDialect, JdbcDialects} import org.apache.spark.sql.sources._ import org.apache.spark.sql.types._ @@ -54,9 +54,14 @@ object JDBCRDD extends Logging { val url = options.url val table = options.tableOrQuery val dialect = JdbcDialects.get(url) + getQueryOutputSchema(dialect.getSchemaQuery(table), options, dialect) + } + + def getQueryOutputSchema( + query: String, options: JDBCOptions, dialect: JdbcDialect): StructType = { val conn: Connection = JdbcUtils.createConnectionFactory(options)() try { - val statement = conn.prepareStatement(dialect.getSchemaQuery(table)) + val statement = conn.prepareStatement(query) try { statement.setQueryTimeout(options.queryTimeout) val rs = statement.executeQuery() @@ -136,30 +141,30 @@ object JDBCRDD extends Logging { def compileAggregates( aggregates: Seq[AggregateFunc], - dialect: JdbcDialect): Seq[String] = { + dialect: JdbcDialect): Option[Seq[String]] = { def quote(colName: String): String = dialect.quoteIdentifier(colName) - aggregates.map { + Some(aggregates.map { case min: Min => - assert(min.column.fieldNames.length == 1) + if (min.column.fieldNames.length != 1) return None s"MIN(${quote(min.column.fieldNames.head)})" case max: Max => - assert(max.column.fieldNames.length == 1) + if (max.column.fieldNames.length != 1) return None s"MAX(${quote(max.column.fieldNames.head)})" case count: Count => - assert(count.column.fieldNames.length == 1) - val distinct = if (count.isDistinct) "DISTINCT" else "" + if (count.column.fieldNames.length != 1) return None + val distinct = if (count.isDistinct) "DISTINCT " else "" val column = quote(count.column.fieldNames.head) - s"COUNT($distinct $column)" + s"COUNT($distinct$column)" case sum: Sum => - assert(sum.column.fieldNames.length == 1) - val distinct = if (sum.isDistinct) "DISTINCT" else "" + if (sum.column.fieldNames.length != 1) return None + val distinct = if (sum.isDistinct) "DISTINCT " else "" val column = quote(sum.column.fieldNames.head) - s"SUM($distinct $column)" + s"SUM($distinct$column)" case _: CountStar => - s"COUNT(1)" - case _ => "" - } + s"COUNT(*)" + case _ => return None + }) } /** @@ -185,7 +190,7 @@ object JDBCRDD extends Logging { parts: Array[Partition], options: JDBCOptions, outputSchema: Option[StructType] = None, - groupByColumns: Option[Array[FieldReference]] = None): RDD[InternalRow] = { + groupByColumns: Option[Array[String]] = None): RDD[InternalRow] = { val url = options.url val dialect = JdbcDialects.get(url) val quotedColumns = if (groupByColumns.isEmpty) { @@ -221,7 +226,7 @@ private[jdbc] class JDBCRDD( partitions: Array[Partition], url: String, options: JDBCOptions, - groupByColumns: Option[Array[FieldReference]]) + groupByColumns: Option[Array[String]]) extends RDD[InternalRow](sc, Nil) { /** @@ -266,10 +271,8 @@ private[jdbc] class JDBCRDD( */ private def getGroupByClause: String = { if (groupByColumns.nonEmpty && groupByColumns.get.nonEmpty) { - assert(groupByColumns.get.forall(_.fieldNames.length == 1)) - val dialect = JdbcDialects.get(url) - val quotedColumns = groupByColumns.get.map(c => dialect.quoteIdentifier(c.fieldNames.head)) - s"GROUP BY ${quotedColumns.mkString(", ")}" + // The GROUP BY columns should already be quoted by the caller side. + s"GROUP BY ${groupByColumns.get.mkString(", ")}" } else { "" } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCRelation.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCRelation.scala index 5fb26d2f5e79b..60d88b6690587 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCRelation.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCRelation.scala @@ -27,7 +27,6 @@ import org.apache.spark.sql.{DataFrame, Row, SaveMode, SparkSession, SQLContext} import org.apache.spark.sql.catalyst.analysis._ import org.apache.spark.sql.catalyst.util.{DateFormatter, DateTimeUtils, TimestampFormatter} import org.apache.spark.sql.catalyst.util.DateTimeUtils.{getZoneId, stringToDate, stringToTimestamp} -import org.apache.spark.sql.connector.expressions.FieldReference import org.apache.spark.sql.errors.QueryCompilationErrors import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.jdbc.JdbcDialects @@ -291,9 +290,9 @@ private[sql] case class JDBCRelation( def buildScan( requiredColumns: Array[String], - requireSchema: Option[StructType], + finalSchema: StructType, filters: Array[Filter], - groupByColumns: Option[Array[FieldReference]]): RDD[Row] = { + groupByColumns: Option[Array[String]]): RDD[Row] = { // Rely on a type erasure hack to pass RDD[InternalRow] back as RDD[Row] JDBCRDD.scanTable( sparkSession.sparkContext, @@ -302,7 +301,7 @@ private[sql] case class JDBCRelation( filters, parts, jdbcOptions, - requireSchema, + Some(finalSchema), groupByColumns).asInstanceOf[RDD[Row]] } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/PushDownUtils.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/PushDownUtils.scala index 34b64313c6c45..6eedeba4d0b5a 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/PushDownUtils.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/PushDownUtils.scala @@ -91,7 +91,7 @@ object PushDownUtils extends PredicateHelper { } scanBuilder match { - case r: SupportsPushDownAggregates => + case r: SupportsPushDownAggregates if aggregates.nonEmpty => val translatedAggregates = aggregates.flatMap(DataSourceStrategy.translateAggregate) val translatedGroupBys = groupBy.flatMap(columnAsString) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/V2ScanRelationPushDown.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/V2ScanRelationPushDown.scala index a1fc981a69ff9..d05519b8801c1 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/V2ScanRelationPushDown.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/V2ScanRelationPushDown.scala @@ -17,6 +17,8 @@ package org.apache.spark.sql.execution.datasources.v2 +import scala.collection.mutable + import org.apache.spark.sql.catalyst.expressions.{And, Attribute, AttributeReference, Expression, NamedExpression, PredicateHelper, ProjectionOverSchema, SubqueryExpression} import org.apache.spark.sql.catalyst.expressions.aggregate import org.apache.spark.sql.catalyst.expressions.aggregate.AggregateExpression @@ -76,9 +78,18 @@ object V2ScanRelationPushDown extends Rule[LogicalPlan] with PredicateHelper { if filters.isEmpty && project.forall(_.isInstanceOf[AttributeReference]) => sHolder.builder match { case _: SupportsPushDownAggregates => + val aggExprToOutputOrdinal = mutable.HashMap.empty[Expression, Int] + var ordinal = 0 val aggregates = resultExpressions.flatMap { expr => expr.collect { - case agg: AggregateExpression => agg + // Do not push down duplicated aggregate expressions. For example, + // `SELECT max(a) + 1, max(a) + 2 FROM ...`, we should only push down one + // `max(a)` to the data source. + case agg: AggregateExpression + if !aggExprToOutputOrdinal.contains(agg.canonicalized) => + aggExprToOutputOrdinal(agg.canonicalized) = ordinal + ordinal += 1 + agg } } val pushedAggregates = PushDownUtils @@ -144,19 +155,18 @@ object V2ScanRelationPushDown extends Rule[LogicalPlan] with PredicateHelper { // Aggregate [c2#10], [min(min(c1)#21) AS min(c1)#17, max(max(c1)#22) AS max(c1)#18] // +- RelationV2[c2#10, min(c1)#21, max(c1)#22] ... // scalastyle:on - var i = 0 val aggOutput = output.drop(groupAttrs.length) plan.transformExpressions { case agg: AggregateExpression => + val ordinal = aggExprToOutputOrdinal(agg.canonicalized) val aggFunction: aggregate.AggregateFunction = agg.aggregateFunction match { - case max: aggregate.Max => max.copy(child = aggOutput(i)) - case min: aggregate.Min => min.copy(child = aggOutput(i)) - case sum: aggregate.Sum => sum.copy(child = aggOutput(i)) - case _: aggregate.Count => aggregate.Sum(aggOutput(i)) + case max: aggregate.Max => max.copy(child = aggOutput(ordinal)) + case min: aggregate.Min => min.copy(child = aggOutput(ordinal)) + case sum: aggregate.Sum => sum.copy(child = aggOutput(ordinal)) + case _: aggregate.Count => aggregate.Sum(aggOutput(ordinal)) case other => other } - i += 1 agg.copy(aggregateFunction = aggFunction) } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/jdbc/JDBCScan.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/jdbc/JDBCScan.scala index d6ae7c893aeef..ef42691e5ca94 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/jdbc/JDBCScan.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/jdbc/JDBCScan.scala @@ -18,7 +18,6 @@ package org.apache.spark.sql.execution.datasources.v2.jdbc import org.apache.spark.rdd.RDD import org.apache.spark.sql.{Row, SQLContext} -import org.apache.spark.sql.connector.expressions.FieldReference import org.apache.spark.sql.connector.read.V1Scan import org.apache.spark.sql.execution.datasources.jdbc.JDBCRelation import org.apache.spark.sql.sources.{BaseRelation, Filter, TableScan} @@ -29,7 +28,7 @@ case class JDBCScan( prunedSchema: StructType, pushedFilters: Array[Filter], pushedAggregateColumn: Array[String] = Array(), - groupByColumns: Option[Array[FieldReference]]) extends V1Scan { + groupByColumns: Option[Array[String]]) extends V1Scan { override def readSchema(): StructType = prunedSchema @@ -39,13 +38,12 @@ case class JDBCScan( override def schema: StructType = prunedSchema override def needConversion: Boolean = relation.needConversion override def buildScan(): RDD[Row] = { - if (groupByColumns.isEmpty) { - relation.buildScan( - prunedSchema.map(_.name).toArray, Some(prunedSchema), pushedFilters, groupByColumns) + val columnList = if (groupByColumns.isEmpty) { + prunedSchema.map(_.name).toArray } else { - relation.buildScan( - pushedAggregateColumn, Some(prunedSchema), pushedFilters, groupByColumns) + pushedAggregateColumn } + relation.buildScan(columnList, prunedSchema, pushedFilters, groupByColumns) } }.asInstanceOf[T] } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/jdbc/JDBCScanBuilder.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/jdbc/JDBCScanBuilder.scala index afdc822c665a8..89fa6213f5731 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/jdbc/JDBCScanBuilder.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/jdbc/JDBCScanBuilder.scala @@ -16,27 +16,33 @@ */ package org.apache.spark.sql.execution.datasources.v2.jdbc +import scala.util.control.NonFatal + +import org.apache.spark.internal.Logging import org.apache.spark.sql.SparkSession -import org.apache.spark.sql.connector.expressions.{Aggregation, Count, CountStar, FieldReference, Max, Min, Sum} +import org.apache.spark.sql.connector.expressions.Aggregation import org.apache.spark.sql.connector.read.{Scan, ScanBuilder, SupportsPushDownAggregates, SupportsPushDownFilters, SupportsPushDownRequiredColumns} import org.apache.spark.sql.execution.datasources.PartitioningUtils import org.apache.spark.sql.execution.datasources.jdbc.{JDBCOptions, JDBCRDD, JDBCRelation} import org.apache.spark.sql.jdbc.JdbcDialects import org.apache.spark.sql.sources.Filter -import org.apache.spark.sql.types.{LongType, StructField, StructType} +import org.apache.spark.sql.types.StructType case class JDBCScanBuilder( session: SparkSession, schema: StructType, jdbcOptions: JDBCOptions) - extends ScanBuilder with SupportsPushDownFilters with SupportsPushDownRequiredColumns - with SupportsPushDownAggregates{ + extends ScanBuilder + with SupportsPushDownFilters + with SupportsPushDownRequiredColumns + with SupportsPushDownAggregates + with Logging { private val isCaseSensitive = session.sessionState.conf.caseSensitiveAnalysis private var pushedFilter = Array.empty[Filter] - private var prunedSchema = schema + private var finalSchema = schema override def pushFilters(filters: Array[Filter]): Array[Filter] = { if (jdbcOptions.pushDownPredicate) { @@ -51,56 +57,45 @@ case class JDBCScanBuilder( override def pushedFilters(): Array[Filter] = pushedFilter - private var pushedAggregations = Option.empty[Aggregation] - - private var pushedAggregateColumn: Array[String] = Array() + private var pushedAggregateList: Array[String] = Array() - private def getStructFieldForCol(col: FieldReference): StructField = - schema.fields(schema.fieldNames.toList.indexOf(col.fieldNames.head)) + private var pushedGroupByCols: Option[Array[String]] = None override def pushAggregation(aggregation: Aggregation): Boolean = { if (!jdbcOptions.pushDownAggregate) return false val dialect = JdbcDialects.get(jdbcOptions.url) val compiledAgg = JDBCRDD.compileAggregates(aggregation.aggregateExpressions, dialect) + if (compiledAgg.isEmpty) return false - var outputSchema = new StructType() - aggregation.groupByColumns.foreach { col => - val structField = getStructFieldForCol(col) - outputSchema = outputSchema.add(structField) - pushedAggregateColumn = pushedAggregateColumn :+ dialect.quoteIdentifier(structField.name) + val groupByCols = aggregation.groupByColumns.map { col => + if (col.fieldNames.length != 1) return false + dialect.quoteIdentifier(col.fieldNames.head) } // The column names here are already quoted and can be used to build sql string directly. // e.g. "DEPT","NAME",MAX("SALARY"),MIN("BONUS") => // SELECT "DEPT","NAME",MAX("SALARY"),MIN("BONUS") FROM "test"."employee" // GROUP BY "DEPT", "NAME" - pushedAggregateColumn = pushedAggregateColumn ++ compiledAgg - - aggregation.aggregateExpressions.foreach { - case max: Max => - val structField = getStructFieldForCol(max.column) - outputSchema = outputSchema.add(structField.copy("max(" + structField.name + ")")) - case min: Min => - val structField = getStructFieldForCol(min.column) - outputSchema = outputSchema.add(structField.copy("min(" + structField.name + ")")) - case count: Count => - val distinct = if (count.isDistinct) "DISTINCT " else "" - val structField = getStructFieldForCol(count.column) - outputSchema = - outputSchema.add(StructField(s"count($distinct" + structField.name + ")", LongType)) - case _: CountStar => - outputSchema = outputSchema.add(StructField("count(*)", LongType)) - case sum: Sum => - val distinct = if (sum.isDistinct) "DISTINCT " else "" - val structField = getStructFieldForCol(sum.column) - outputSchema = - outputSchema.add(StructField(s"sum($distinct" + structField.name + ")", sum.dataType)) - case _ => return false + val selectList = groupByCols ++ compiledAgg.get + val groupByClause = if (groupByCols.isEmpty) { + "" + } else { + "GROUP BY " + groupByCols.mkString(",") + } + + val aggQuery = s"SELECT ${selectList.mkString(",")} FROM ${jdbcOptions.tableOrQuery} " + + s"WHERE 1=0 $groupByClause" + try { + finalSchema = JDBCRDD.getQueryOutputSchema(aggQuery, jdbcOptions, dialect) + pushedAggregateList = selectList + pushedGroupByCols = Some(groupByCols) + true + } catch { + case NonFatal(e) => + logError("Failed to push down aggregation to JDBC", e) + false } - this.pushedAggregations = Some(aggregation) - prunedSchema = outputSchema - true } override def pruneColumns(requiredSchema: StructType): Unit = { @@ -112,7 +107,7 @@ case class JDBCScanBuilder( val colName = PartitioningUtils.getColName(field, isCaseSensitive) requiredCols.contains(colName) } - prunedSchema = StructType(fields) + finalSchema = StructType(fields) } override def build(): Scan = { @@ -120,19 +115,14 @@ case class JDBCScanBuilder( val timeZoneId = session.sessionState.conf.sessionLocalTimeZone val parts = JDBCRelation.columnPartition(schema, resolver, timeZoneId, jdbcOptions) - // in prunedSchema, the schema is either pruned in pushAggregation (if aggregates are + // the `finalSchema` is either pruned in pushAggregation (if aggregates are // pushed down), or pruned in pruneColumns (in regular column pruning). These // two are mutual exclusive. // For aggregate push down case, we want to pass down the quoted column lists such as // "DEPT","NAME",MAX("SALARY"),MIN("BONUS"), instead of getting column names from // prunedSchema and quote them (will become "MAX(SALARY)", "MIN(BONUS)" and can't // be used in sql string. - val groupByColumns = if (pushedAggregations.nonEmpty) { - Some(pushedAggregations.get.groupByColumns) - } else { - Option.empty[Array[FieldReference]] - } - JDBCScan(JDBCRelation(schema, parts, jdbcOptions)(session), prunedSchema, pushedFilter, - pushedAggregateColumn, groupByColumns) + JDBCScan(JDBCRelation(schema, parts, jdbcOptions)(session), finalSchema, pushedFilter, + pushedAggregateList, pushedGroupByCols) } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCV2Suite.scala b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCV2Suite.scala index 8dfb6defa4b8e..37bc35210e6c5 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCV2Suite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCV2Suite.scala @@ -248,7 +248,7 @@ class JDBCV2Suite extends QueryTest with SharedSparkSession with ExplainSuiteHel df.queryExecution.optimizedPlan.collect { case _: DataSourceV2ScanRelation => val expected_plan_fragment = - "PushedAggregates: [Max(SALARY), Min(BONUS)], " + + "PushedAggregates: [MAX(SALARY), MIN(BONUS)], " + "PushedFilters: [IsNotNull(DEPT), GreaterThan(DEPT,0)], " + "PushedGroupby: [DEPT]" checkKeywordsExistsInExplain(df, expected_plan_fragment) @@ -265,7 +265,7 @@ class JDBCV2Suite extends QueryTest with SharedSparkSession with ExplainSuiteHel df.queryExecution.optimizedPlan.collect { case _: DataSourceV2ScanRelation => val expected_plan_fragment = - "PushedAggregates: [Max(ID), Min(ID)], " + + "PushedAggregates: [MAX(ID), MIN(ID)], " + "PushedFilters: [IsNotNull(ID), GreaterThan(ID,0)], " + "PushedGroupby: []" checkKeywordsExistsInExplain(df, expected_plan_fragment) @@ -278,7 +278,7 @@ class JDBCV2Suite extends QueryTest with SharedSparkSession with ExplainSuiteHel df.queryExecution.optimizedPlan.collect { case _: DataSourceV2ScanRelation => val expected_plan_fragment = - "PushedAggregates: [Max(SALARY)]" + "PushedAggregates: [MAX(SALARY)]" checkKeywordsExistsInExplain(df, expected_plan_fragment) } checkAnswer(df, Seq(Row(12001))) @@ -289,7 +289,7 @@ class JDBCV2Suite extends QueryTest with SharedSparkSession with ExplainSuiteHel df.queryExecution.optimizedPlan.collect { case _: DataSourceV2ScanRelation => val expected_plan_fragment = - "PushedAggregates: [CountStar()]" + "PushedAggregates: [COUNT(*)]" checkKeywordsExistsInExplain(df, expected_plan_fragment) } checkAnswer(df, Seq(Row(5))) @@ -300,7 +300,7 @@ class JDBCV2Suite extends QueryTest with SharedSparkSession with ExplainSuiteHel df.queryExecution.optimizedPlan.collect { case _: DataSourceV2ScanRelation => val expected_plan_fragment = - "PushedAggregates: [Count(DEPT,false)]" + "PushedAggregates: [COUNT(DEPT)]" checkKeywordsExistsInExplain(df, expected_plan_fragment) } checkAnswer(df, Seq(Row(5))) @@ -311,7 +311,7 @@ class JDBCV2Suite extends QueryTest with SharedSparkSession with ExplainSuiteHel df.queryExecution.optimizedPlan.collect { case _: DataSourceV2ScanRelation => val expected_plan_fragment = - "PushedAggregates: [Count(DEPT,true)]" + "PushedAggregates: [COUNT(DISTINCT DEPT)]" checkKeywordsExistsInExplain(df, expected_plan_fragment) } checkAnswer(df, Seq(Row(3))) @@ -322,7 +322,7 @@ class JDBCV2Suite extends QueryTest with SharedSparkSession with ExplainSuiteHel df.queryExecution.optimizedPlan.collect { case _: DataSourceV2ScanRelation => val expected_plan_fragment = - "PushedAggregates: [Sum(SALARY,DecimalType(30,2),false)]" + "PushedAggregates: [SUM(SALARY)]" checkKeywordsExistsInExplain(df, expected_plan_fragment) } checkAnswer(df, Seq(Row(53000))) @@ -333,7 +333,7 @@ class JDBCV2Suite extends QueryTest with SharedSparkSession with ExplainSuiteHel df.queryExecution.optimizedPlan.collect { case _: DataSourceV2ScanRelation => val expected_plan_fragment = - "PushedAggregates: [Sum(SALARY,DecimalType(30,2),true)]" + "PushedAggregates: [SUM(DISTINCT SALARY)]" checkKeywordsExistsInExplain(df, expected_plan_fragment) } checkAnswer(df, Seq(Row(31000))) @@ -344,7 +344,7 @@ class JDBCV2Suite extends QueryTest with SharedSparkSession with ExplainSuiteHel df.queryExecution.optimizedPlan.collect { case _: DataSourceV2ScanRelation => val expected_plan_fragment = - "PushedAggregates: [Sum(SALARY,DecimalType(30,2),false)], " + + "PushedAggregates: [SUM(SALARY)], " + "PushedFilters: [], " + "PushedGroupby: [DEPT]" checkKeywordsExistsInExplain(df, expected_plan_fragment) @@ -357,7 +357,7 @@ class JDBCV2Suite extends QueryTest with SharedSparkSession with ExplainSuiteHel df.queryExecution.optimizedPlan.collect { case _: DataSourceV2ScanRelation => val expected_plan_fragment = - "PushedAggregates: [Sum(SALARY,DecimalType(30,2),true)], " + + "PushedAggregates: [SUM(DISTINCT SALARY)], " + "PushedFilters: [], " + "PushedGroupby: [DEPT]" checkKeywordsExistsInExplain(df, expected_plan_fragment) @@ -375,7 +375,7 @@ class JDBCV2Suite extends QueryTest with SharedSparkSession with ExplainSuiteHel df.queryExecution.optimizedPlan.collect { case _: DataSourceV2ScanRelation => val expected_plan_fragment = - "PushedAggregates: [Max(SALARY), Min(BONUS)], " + + "PushedAggregates: [MAX(SALARY), MIN(BONUS)], " + "PushedFilters: [IsNotNull(DEPT), GreaterThan(DEPT,0)], " + "PushedGroupby: [DEPT, NAME]" checkKeywordsExistsInExplain(df, expected_plan_fragment) @@ -394,7 +394,7 @@ class JDBCV2Suite extends QueryTest with SharedSparkSession with ExplainSuiteHel df.queryExecution.optimizedPlan.collect { case _: DataSourceV2ScanRelation => val expected_plan_fragment = - "PushedAggregates: [Max(SALARY), Min(BONUS)], " + + "PushedAggregates: [MAX(SALARY), MIN(BONUS)], " + "PushedFilters: [IsNotNull(DEPT), GreaterThan(DEPT,0)], " + "PushedGroupby: [DEPT]" checkKeywordsExistsInExplain(df, expected_plan_fragment) @@ -409,7 +409,7 @@ class JDBCV2Suite extends QueryTest with SharedSparkSession with ExplainSuiteHel df.queryExecution.optimizedPlan.collect { case _: DataSourceV2ScanRelation => val expected_plan_fragment = - "PushedAggregates: [Min(SALARY)], " + + "PushedAggregates: [MIN(SALARY)], " + "PushedFilters: [], " + "PushedGroupby: [DEPT]" checkKeywordsExistsInExplain(df, expected_plan_fragment) @@ -432,7 +432,7 @@ class JDBCV2Suite extends QueryTest with SharedSparkSession with ExplainSuiteHel query.queryExecution.optimizedPlan.collect { case _: DataSourceV2ScanRelation => val expected_plan_fragment = - "PushedAggregates: [Sum(SALARY,DecimalType(30,2),false)], " + + "PushedAggregates: [SUM(SALARY)], " + "PushedFilters: [IsNotNull(DEPT), GreaterThan(DEPT,0)], " + "PushedGroupby: [DEPT]" checkKeywordsExistsInExplain(query, expected_plan_fragment) @@ -447,7 +447,7 @@ class JDBCV2Suite extends QueryTest with SharedSparkSession with ExplainSuiteHel query.queryExecution.optimizedPlan.collect { case _: DataSourceV2ScanRelation => val expected_plan_fragment = - "PushedAggregates: [Sum(SALARY,DecimalType(30,2),false), Sum(BONUS,DoubleType,false)" + "PushedAggregates: [SUM(SALARY), SUM(BONUS)" checkKeywordsExistsInExplain(query, expected_plan_fragment) } checkAnswer(query, Seq(Row(47100.0)))