diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/AggregateFunc.java b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/AggregateFunc.java new file mode 100644 index 0000000000000..eea8c3152e602 --- /dev/null +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/AggregateFunc.java @@ -0,0 +1,31 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.connector.expressions; + +import org.apache.spark.annotation.Evolving; + +import java.io.Serializable; + +/** + * Base class of the Aggregate Functions. + * + * @since 3.2.0 + */ +@Evolving +public interface AggregateFunc extends Expression, Serializable { +} diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/Aggregation.java b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/Aggregation.java new file mode 100644 index 0000000000000..fdf30312f156f --- /dev/null +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/Aggregation.java @@ -0,0 +1,46 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.connector.expressions; + +import org.apache.spark.annotation.Evolving; + +import java.io.Serializable; + +/** + * Aggregation in SQL statement. + * + * @since 3.2.0 + */ +@Evolving +public final class Aggregation implements Serializable { + private AggregateFunc[] aggregateExpressions; + private FieldReference[] groupByColumns; + + public Aggregation(AggregateFunc[] aggregateExpressions, FieldReference[] groupByColumns) { + this.aggregateExpressions = aggregateExpressions; + this.groupByColumns = groupByColumns; + } + + public AggregateFunc[] aggregateExpressions() { + return aggregateExpressions; + } + + public FieldReference[] groupByColumns() { + return groupByColumns; + } +} diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/Count.java b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/Count.java new file mode 100644 index 0000000000000..17562a1aa1763 --- /dev/null +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/Count.java @@ -0,0 +1,49 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.connector.expressions; + +import org.apache.spark.annotation.Evolving; + +/** + * An aggregate function that returns the number of the specific row in a group. + * + * @since 3.2.0 + */ +@Evolving +public final class Count implements AggregateFunc { + private FieldReference column; + private boolean isDistinct; + + public Count(FieldReference column, boolean isDistinct) { + this.column = column; + this.isDistinct = isDistinct; + } + + public FieldReference column() { + return column; + } + public boolean isDinstinct() { + return isDistinct; + } + + @Override + public String toString() { return "Count(" + column.describe() + "," + isDistinct + ")"; } + + @Override + public String describe() { return this.toString(); } +} diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/CountStar.java b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/CountStar.java new file mode 100644 index 0000000000000..777a99d58e4d8 --- /dev/null +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/CountStar.java @@ -0,0 +1,40 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.connector.expressions; + +import org.apache.spark.annotation.Evolving; + +/** + * An aggregate function that returns the number of rows in a group. + * + * @since 3.2.0 + */ +@Evolving +public final class CountStar implements AggregateFunc { + + public CountStar() { + } + + @Override + public String toString() { + return "CountStar()"; + } + + @Override + public String describe() { return this.toString(); } +} diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/Max.java b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/Max.java new file mode 100644 index 0000000000000..fe7689c18ac66 --- /dev/null +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/Max.java @@ -0,0 +1,44 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.connector.expressions; + +import org.apache.spark.annotation.Evolving; + +/** + * An aggregate function that returns the maximum value in a group. + * + * @since 3.2.0 + */ +@Evolving +public final class Max implements AggregateFunc { + private FieldReference column; + + public Max(FieldReference column) { + this.column = column; + } + + public FieldReference column() { return column; } + + @Override + public String toString() { + return "Max(" + column.describe() + ")"; + } + + @Override + public String describe() { return this.toString(); } +} diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/Min.java b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/Min.java new file mode 100644 index 0000000000000..f528b0bedfd67 --- /dev/null +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/Min.java @@ -0,0 +1,46 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.connector.expressions; + +import org.apache.spark.annotation.Evolving; + +/** + * An aggregate function that returns the minimum value in a group. + * + * @since 3.2.0 + */ +@Evolving +public final class Min implements AggregateFunc { + private FieldReference column; + + public Min(FieldReference column) { + this.column = column; + } + + public FieldReference column() { + return column; + } + + @Override + public String toString() { + return "Min(" + column.describe() + ")"; + } + + @Override + public String describe() { return this.toString(); } +} diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/Sum.java b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/Sum.java new file mode 100644 index 0000000000000..4cb34bee28d9b --- /dev/null +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/Sum.java @@ -0,0 +1,57 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.connector.expressions; + +import org.apache.spark.annotation.Evolving; +import org.apache.spark.sql.types.DataType; + +/** + * An aggregate function that returns the summation of all the values in a group. + * + * @since 3.2.0 + */ +@Evolving +public final class Sum implements AggregateFunc { + private FieldReference column; + private DataType dataType; + private boolean isDistinct; + + public Sum(FieldReference column, DataType dataType, boolean isDistinct) { + this.column = column; + this.dataType = dataType; + this.isDistinct = isDistinct; + } + + public FieldReference column() { + return column; + } + public DataType dataType() { + return dataType; + } + public boolean isDinstinct() { + return isDistinct; + } + + @Override + public String toString() { + return "Sum(" + column.describe() + "," + dataType + "," + isDistinct + ")"; + } + + @Override + public String describe() { return this.toString(); } +} diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/read/ScanBuilder.java b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/read/ScanBuilder.java index cb3eea7680058..b46f620d4fedb 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/read/ScanBuilder.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/read/ScanBuilder.java @@ -22,7 +22,8 @@ /** * An interface for building the {@link Scan}. Implementations can mixin SupportsPushDownXYZ * interfaces to do operator pushdown, and keep the operator pushdown result in the returned - * {@link Scan}. + * {@link Scan}. When pushing down operators, Spark pushes down filters first, then pushes down + * aggregates or applies column pruning. * * @since 3.0.0 */ diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/read/SupportsPushDownAggregates.java b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/read/SupportsPushDownAggregates.java new file mode 100644 index 0000000000000..7efa333bdaa2d --- /dev/null +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/read/SupportsPushDownAggregates.java @@ -0,0 +1,56 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.connector.read; + +import org.apache.spark.annotation.Evolving; +import org.apache.spark.sql.connector.expressions.Aggregation; + +/** + * A mix-in interface for {@link ScanBuilder}. Data sources can implement this interface to + * push down aggregates. Spark assumes that the data source can't fully complete the + * grouping work, and will group the data source output again. For queries like + * "SELECT min(value) AS m FROM t GROUP BY key", after pushing down the aggregate + * to the data source, the data source can still output data with duplicated keys, which is OK + * as Spark will do GROUP BY key again. The final query plan can be something like this: + * {{{ + * Aggregate [key#1], [min(min(value)#2) AS m#3] + * +- RelationV2[key#1, min(value)#2] + * }}} + * + *
+ * Similarly, if there is no grouping expression, the data source can still output more than one + * rows. + * + *
+ * When pushing down operators, Spark pushes down filters to the data source first, then push down + * aggregates or apply column pruning. Depends on data source implementation, aggregates may or + * may not be able to be pushed down with filters. If pushed filters still need to be evaluated + * after scanning, aggregates can't be pushed down. + * + * @since 3.2.0 + */ +@Evolving +public interface SupportsPushDownAggregates extends ScanBuilder { + + /** + * Pushes down Aggregation to datasource. The order of the datasource scan output columns should + * be: grouping columns, aggregate columns (in the same order as the aggregate functions in + * the given Aggregation). + */ + boolean pushAggregation(Aggregation aggregation); +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/DataSourceScanExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/DataSourceScanExec.scala index de991fc1fd1bb..603d53a295609 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/DataSourceScanExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/DataSourceScanExec.scala @@ -31,6 +31,7 @@ import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.plans.QueryPlan import org.apache.spark.sql.catalyst.plans.physical.{HashPartitioning, Partitioning, UnknownPartitioning} import org.apache.spark.sql.catalyst.util.truncatedString +import org.apache.spark.sql.connector.expressions.Aggregation import org.apache.spark.sql.execution.datasources._ import org.apache.spark.sql.execution.datasources.parquet.{ParquetFileFormat => ParquetSource} import org.apache.spark.sql.execution.metric.{SQLMetric, SQLMetrics} @@ -102,6 +103,7 @@ case class RowDataSourceScanExec( requiredSchema: StructType, filters: Set[Filter], handledFilters: Set[Filter], + aggregation: Option[Aggregation], rdd: RDD[InternalRow], @transient relation: BaseRelation, tableIdentifier: Option[TableIdentifier]) @@ -129,12 +131,29 @@ case class RowDataSourceScanExec( override def inputRDD: RDD[InternalRow] = rdd override val metadata: Map[String, String] = { - val markedFilters = for (filter <- filters) yield { - if (handledFilters.contains(filter)) s"*$filter" else s"$filter" + + def seqToString(seq: Seq[Any]): String = seq.mkString("[", ", ", "]") + + val (aggString, groupByString) = if (aggregation.nonEmpty) { + (seqToString(aggregation.get.aggregateExpressions), + seqToString(aggregation.get.groupByColumns)) + } else { + ("[]", "[]") + } + + val markedFilters = if (filters.nonEmpty) { + for (filter <- filters) yield { + if (handledFilters.contains(filter)) s"*$filter" else s"$filter" + } + } else { + handledFilters } + Map( "ReadSchema" -> requiredSchema.catalogString, - "PushedFilters" -> markedFilters.mkString("[", ", ", "]")) + "PushedFilters" -> seqToString(markedFilters.toSeq), + "PushedAggregates" -> aggString, + "PushedGroupby" -> groupByString) } // Don't care about `rdd` and `tableIdentifier` when canonicalizing. diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala index 9e33723e5fabd..2f334deebc8f2 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala @@ -33,12 +33,14 @@ import org.apache.spark.sql.catalyst.catalog._ import org.apache.spark.sql.catalyst.encoders.RowEncoder import org.apache.spark.sql.catalyst.expressions import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.expressions.aggregate.AggregateExpression import org.apache.spark.sql.catalyst.planning.ScanOperation import org.apache.spark.sql.catalyst.plans.logical.{InsertIntoDir, InsertIntoStatement, LogicalPlan, Project} import org.apache.spark.sql.catalyst.rules.Rule import org.apache.spark.sql.catalyst.streaming.StreamingRelationV2 import org.apache.spark.sql.connector.catalog.SupportsRead import org.apache.spark.sql.connector.catalog.TableCapability._ +import org.apache.spark.sql.connector.expressions.{AggregateFunc, Count, CountStar, FieldReference, Max, Min, Sum} import org.apache.spark.sql.errors.QueryCompilationErrors import org.apache.spark.sql.execution.{InSubqueryExec, RowDataSourceScanExec, SparkPlan} import org.apache.spark.sql.execution.command._ @@ -332,6 +334,7 @@ object DataSourceStrategy l.output.toStructType, Set.empty, Set.empty, + None, toCatalystRDD(l, baseRelation.buildScan()), baseRelation, None) :: Nil @@ -405,6 +408,7 @@ object DataSourceStrategy requestedColumns.toStructType, pushedFilters.toSet, handledFilters, + None, scanBuilder(requestedColumns, candidatePredicates, pushedFilters), relation.relation, relation.catalogTable.map(_.identifier)) @@ -427,6 +431,7 @@ object DataSourceStrategy requestedColumns.toStructType, pushedFilters.toSet, handledFilters, + None, scanBuilder(requestedColumns, candidatePredicates, pushedFilters), relation.relation, relation.catalogTable.map(_.identifier)) @@ -692,6 +697,32 @@ object DataSourceStrategy (nonconvertiblePredicates ++ unhandledPredicates, pushedFilters, handledFilters) } + protected[sql] def translateAggregate(aggregates: AggregateExpression): Option[AggregateFunc] = { + if (aggregates.filter.isEmpty) { + aggregates.aggregateFunction match { + case aggregate.Min(PushableColumnWithoutNestedColumn(name)) => + Some(new Min(FieldReference(name).asInstanceOf[FieldReference])) + case aggregate.Max(PushableColumnWithoutNestedColumn(name)) => + Some(new Max(FieldReference(name).asInstanceOf[FieldReference])) + case count: aggregate.Count if count.children.length == 1 => + count.children.head match { + // SELECT COUNT(*) FROM table is translated to SELECT 1 FROM table + case Literal(_, _) => Some(new CountStar()) + case PushableColumnWithoutNestedColumn(name) => + Some(new Count(FieldReference(name).asInstanceOf[FieldReference], + aggregates.isDistinct)) + case _ => None + } + case sum @ aggregate.Sum(PushableColumnWithoutNestedColumn(name), _) => + Some(new Sum(FieldReference(name).asInstanceOf[FieldReference], + sum.dataType, aggregates.isDistinct)) + case _ => None + } + } else { + None + } + } + /** * Convert RDD of Row into RDD of InternalRow with objects in catalyst types */ diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCOptions.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCOptions.scala index 97d4f2d97622e..8b2ae2beb6d4a 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCOptions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCOptions.scala @@ -188,6 +188,9 @@ class JDBCOptions( // An option to allow/disallow pushing down predicate into JDBC data source val pushDownPredicate = parameters.getOrElse(JDBC_PUSHDOWN_PREDICATE, "true").toBoolean + // An option to allow/disallow pushing down aggregate into JDBC data source + val pushDownAggregate = parameters.getOrElse(JDBC_PUSHDOWN_AGGREGATE, "false").toBoolean + // The local path of user's keytab file, which is assumed to be pre-uploaded to all nodes either // by --files option of spark-submit or manually val keytab = { @@ -259,6 +262,7 @@ object JDBCOptions { val JDBC_TXN_ISOLATION_LEVEL = newOption("isolationLevel") val JDBC_SESSION_INIT_STATEMENT = newOption("sessionInitStatement") val JDBC_PUSHDOWN_PREDICATE = newOption("pushDownPredicate") + val JDBC_PUSHDOWN_AGGREGATE = newOption("pushDownAggregate") val JDBC_KEYTAB = newOption("keytab") val JDBC_PRINCIPAL = newOption("principal") val JDBC_TABLE_COMMENT = newOption("tableComment") diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCRDD.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCRDD.scala index 87ca78db59b29..c22ca1502b512 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCRDD.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCRDD.scala @@ -25,6 +25,7 @@ import org.apache.spark.{InterruptibleIterator, Partition, SparkContext, TaskCon import org.apache.spark.internal.Logging import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.connector.expressions.{AggregateFunc, Count, CountStar, FieldReference, Max, Min, Sum} import org.apache.spark.sql.jdbc.{JdbcDialect, JdbcDialects} import org.apache.spark.sql.sources._ import org.apache.spark.sql.types._ @@ -133,6 +134,34 @@ object JDBCRDD extends Logging { }) } + def compileAggregates( + aggregates: Seq[AggregateFunc], + dialect: JdbcDialect): Seq[String] = { + def quote(colName: String): String = dialect.quoteIdentifier(colName) + + aggregates.map { + case min: Min => + assert(min.column.fieldNames.length == 1) + s"MIN(${quote(min.column.fieldNames.head)})" + case max: Max => + assert(max.column.fieldNames.length == 1) + s"MAX(${quote(max.column.fieldNames.head)})" + case count: Count => + assert(count.column.fieldNames.length == 1) + val distinct = if (count.isDinstinct) "DISTINCT" else "" + val column = quote(count.column.fieldNames.head) + s"COUNT($distinct $column)" + case sum: Sum => + assert(sum.column.fieldNames.length == 1) + val distinct = if (sum.isDinstinct) "DISTINCT" else "" + val column = quote(sum.column.fieldNames.head) + s"SUM($distinct $column)" + case _: CountStar => + s"COUNT(1)" + case _ => "" + } + } + /** * Build and return JDBCRDD from the given information. * @@ -143,6 +172,8 @@ object JDBCRDD extends Logging { * @param parts - An array of JDBCPartitions specifying partition ids and * per-partition WHERE clauses. * @param options - JDBC options that contains url, table and other information. + * @param requiredSchema - The schema of the columns to SELECT. + * @param aggregation - The pushed down aggregation * * @return An RDD representing "SELECT requiredColumns FROM fqTable". */ @@ -152,19 +183,27 @@ object JDBCRDD extends Logging { requiredColumns: Array[String], filters: Array[Filter], parts: Array[Partition], - options: JDBCOptions): RDD[InternalRow] = { + options: JDBCOptions, + outputSchema: Option[StructType] = None, + groupByColumns: Option[Array[FieldReference]] = None): RDD[InternalRow] = { val url = options.url val dialect = JdbcDialects.get(url) - val quotedColumns = requiredColumns.map(colName => dialect.quoteIdentifier(colName)) + val quotedColumns = if (groupByColumns.isEmpty) { + requiredColumns.map(colName => dialect.quoteIdentifier(colName)) + } else { + // these are already quoted in JDBCScanBuilder + requiredColumns + } new JDBCRDD( sc, JdbcUtils.createConnectionFactory(options), - pruneSchema(schema, requiredColumns), + outputSchema.getOrElse(pruneSchema(schema, requiredColumns)), quotedColumns, filters, parts, url, - options) + options, + groupByColumns) } } @@ -181,7 +220,8 @@ private[jdbc] class JDBCRDD( filters: Array[Filter], partitions: Array[Partition], url: String, - options: JDBCOptions) + options: JDBCOptions, + groupByColumns: Option[Array[FieldReference]]) extends RDD[InternalRow](sc, Nil) { /** @@ -221,6 +261,20 @@ private[jdbc] class JDBCRDD( } } + /** + * A GROUP BY clause representing pushed-down grouping columns. + */ + private def getGroupByClause: String = { + if (groupByColumns.nonEmpty && groupByColumns.get.nonEmpty) { + assert(groupByColumns.get.forall(_.fieldNames.length == 1)) + val dialect = JdbcDialects.get(url) + val quotedColumns = groupByColumns.get.map(c => dialect.quoteIdentifier(c.fieldNames.head)) + s"GROUP BY ${quotedColumns.mkString(", ")}" + } else { + "" + } + } + /** * Runs the SQL query against the JDBC driver. * @@ -296,7 +350,8 @@ private[jdbc] class JDBCRDD( val myWhereClause = getWhereClause(part) - val sqlText = s"SELECT $columnList FROM ${options.tableOrQuery} $myWhereClause" + val sqlText = s"SELECT $columnList FROM ${options.tableOrQuery} $myWhereClause" + + s" $getGroupByClause" stmt = conn.prepareStatement(sqlText, ResultSet.TYPE_FORWARD_ONLY, ResultSet.CONCUR_READ_ONLY) stmt.setFetchSize(options.fetchSize) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCRelation.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCRelation.scala index 4ec9a4f98c6d6..5fb26d2f5e79b 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCRelation.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCRelation.scala @@ -27,6 +27,7 @@ import org.apache.spark.sql.{DataFrame, Row, SaveMode, SparkSession, SQLContext} import org.apache.spark.sql.catalyst.analysis._ import org.apache.spark.sql.catalyst.util.{DateFormatter, DateTimeUtils, TimestampFormatter} import org.apache.spark.sql.catalyst.util.DateTimeUtils.{getZoneId, stringToDate, stringToTimestamp} +import org.apache.spark.sql.connector.expressions.FieldReference import org.apache.spark.sql.errors.QueryCompilationErrors import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.jdbc.JdbcDialects @@ -288,6 +289,23 @@ private[sql] case class JDBCRelation( jdbcOptions).asInstanceOf[RDD[Row]] } + def buildScan( + requiredColumns: Array[String], + requireSchema: Option[StructType], + filters: Array[Filter], + groupByColumns: Option[Array[FieldReference]]): RDD[Row] = { + // Rely on a type erasure hack to pass RDD[InternalRow] back as RDD[Row] + JDBCRDD.scanTable( + sparkSession.sparkContext, + schema, + requiredColumns, + filters, + parts, + jdbcOptions, + requireSchema, + groupByColumns).asInstanceOf[RDD[Row]] + } + override def insert(data: DataFrame, overwrite: Boolean): Unit = { data.write .mode(if (overwrite) SaveMode.Overwrite else SaveMode.Append) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Strategy.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Strategy.scala index 7be13791ce2f0..70b81d8e99ea1 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Strategy.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Strategy.scala @@ -87,7 +87,7 @@ class DataSourceV2Strategy(session: SparkSession) extends Strategy with Predicat override def apply(plan: LogicalPlan): Seq[SparkPlan] = plan match { case PhysicalOperation(project, filters, - relation @ DataSourceV2ScanRelation(_, V1ScanWrapper(scan, translated, pushed), output)) => + DataSourceV2ScanRelation(_, V1ScanWrapper(scan, pushed, aggregate), output)) => val v1Relation = scan.toV1TableScan[BaseRelation with TableScan](session.sqlContext) if (v1Relation.schema != scan.readSchema()) { throw QueryExecutionErrors.fallbackV1RelationReportsInconsistentSchemaError( @@ -98,8 +98,9 @@ class DataSourceV2Strategy(session: SparkSession) extends Strategy with Predicat val dsScan = RowDataSourceScanExec( output, output.toStructType, - translated.toSet, + Set.empty, pushed.toSet, + aggregate, unsafeRowRDD, v1Relation, tableIdentifier = None) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/PushDownUtils.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/PushDownUtils.scala index 1f57f17911457..ab5c5da43a79f 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/PushDownUtils.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/PushDownUtils.scala @@ -20,9 +20,13 @@ package org.apache.spark.sql.execution.datasources.v2 import scala.collection.mutable import org.apache.spark.sql.catalyst.expressions.{AttributeReference, AttributeSet, Expression, NamedExpression, PredicateHelper, SchemaPruning} +import org.apache.spark.sql.catalyst.expressions.aggregate.AggregateExpression import org.apache.spark.sql.catalyst.util.CharVarcharUtils +import org.apache.spark.sql.connector.expressions.{Aggregation, FieldReference} +import org.apache.spark.sql.connector.read.{Scan, ScanBuilder, SupportsPushDownAggregates, SupportsPushDownFilters, SupportsPushDownRequiredColumns} import org.apache.spark.sql.connector.read.{Scan, ScanBuilder, SupportsPushDownFilters, SupportsPushDownRequiredColumns} import org.apache.spark.sql.execution.datasources.DataSourceStrategy +import org.apache.spark.sql.execution.datasources.PushableColumnWithoutNestedColumn import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.sources import org.apache.spark.sql.types.StructType @@ -70,6 +74,42 @@ object PushDownUtils extends PredicateHelper { } } + /** + * Pushes down aggregates to the data source reader + * + * @return pushed aggregation. + */ + def pushAggregates( + scanBuilder: ScanBuilder, + aggregates: Seq[AggregateExpression], + groupBy: Seq[Expression]): Option[Aggregation] = { + + def columnAsString(e: Expression): Option[FieldReference] = e match { + case PushableColumnWithoutNestedColumn(name) => + Some(FieldReference(name).asInstanceOf[FieldReference]) + case _ => None + } + + scanBuilder match { + case r: SupportsPushDownAggregates => + val translatedAggregates = aggregates.map(DataSourceStrategy.translateAggregate).flatten + val translatedGroupBys = groupBy.map(columnAsString).flatten + + if (translatedAggregates.length != aggregates.length || + translatedGroupBys.length != groupBy.length) { + return None + } + + val agg = new Aggregation(translatedAggregates.toArray, translatedGroupBys.toArray) + if (r.pushAggregation(agg)) { + Some(agg) + } else { + None + } + case _ => None + } + } + /** * Applies column pruning to the data source, w.r.t. the references of the given expressions. * diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/V2ScanRelationPushDown.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/V2ScanRelationPushDown.scala index d2180566790ac..445ff033d4987 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/V2ScanRelationPushDown.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/V2ScanRelationPushDown.scala @@ -17,23 +17,36 @@ package org.apache.spark.sql.execution.datasources.v2 -import org.apache.spark.sql.catalyst.expressions.{And, Expression, NamedExpression, ProjectionOverSchema, SubqueryExpression} +import org.apache.spark.sql.catalyst.expressions.{And, Attribute, AttributeReference, Expression, NamedExpression, PredicateHelper, ProjectionOverSchema, SubqueryExpression} +import org.apache.spark.sql.catalyst.expressions.aggregate +import org.apache.spark.sql.catalyst.expressions.aggregate.AggregateExpression import org.apache.spark.sql.catalyst.planning.ScanOperation -import org.apache.spark.sql.catalyst.plans.logical.{Filter, LogicalPlan, Project} +import org.apache.spark.sql.catalyst.plans.logical.{Aggregate, Filter, LeafNode, LogicalPlan, Project} import org.apache.spark.sql.catalyst.rules.Rule -import org.apache.spark.sql.connector.read.{Scan, V1Scan} +import org.apache.spark.sql.connector.expressions.Aggregation +import org.apache.spark.sql.connector.read.{Scan, ScanBuilder, SupportsPushDownAggregates, SupportsPushDownFilters, V1Scan} import org.apache.spark.sql.execution.datasources.DataSourceStrategy import org.apache.spark.sql.sources import org.apache.spark.sql.types.StructType -object V2ScanRelationPushDown extends Rule[LogicalPlan] { +object V2ScanRelationPushDown extends Rule[LogicalPlan] with PredicateHelper { import DataSourceV2Implicits._ - override def apply(plan: LogicalPlan): LogicalPlan = plan transformDown { - case ScanOperation(project, filters, relation: DataSourceV2Relation) => - val scanBuilder = relation.table.asReadable.newScanBuilder(relation.options) + def apply(plan: LogicalPlan): LogicalPlan = { + applyColumnPruning(pushdownAggregate(pushDownFilters(createScanBuilder(plan)))) + } + + private def createScanBuilder(plan: LogicalPlan) = plan.transform { + case r: DataSourceV2Relation => + ScanBuilderHolder(r.output, r, r.table.asReadable.newScanBuilder(r.options)) + } - val normalizedFilters = DataSourceStrategy.normalizeExprs(filters, relation.output) + private def pushDownFilters(plan: LogicalPlan) = plan.transform { + // update the scan builder with filter push down and return a new plan with filter pushed + case Filter(condition, sHolder: ScanBuilderHolder) => + val filters = splitConjunctivePredicates(condition) + val normalizedFilters = + DataSourceStrategy.normalizeExprs(filters, sHolder.relation.output) val (normalizedFiltersWithSubquery, normalizedFiltersWithoutSubquery) = normalizedFilters.partition(SubqueryExpression.hasSubquery) @@ -41,37 +54,142 @@ object V2ScanRelationPushDown extends Rule[LogicalPlan] { // `postScanFilters` need to be evaluated after the scan. // `postScanFilters` and `pushedFilters` can overlap, e.g. the parquet row group filter. val (pushedFilters, postScanFiltersWithoutSubquery) = PushDownUtils.pushFilters( - scanBuilder, normalizedFiltersWithoutSubquery) + sHolder.builder, normalizedFiltersWithoutSubquery) val postScanFilters = postScanFiltersWithoutSubquery ++ normalizedFiltersWithSubquery + logInfo( + s""" + |Pushing operators to ${sHolder.relation.name} + |Pushed Filters: ${pushedFilters.mkString(", ")} + |Post-Scan Filters: ${postScanFilters.mkString(",")} + """.stripMargin) + + val filterCondition = postScanFilters.reduceLeftOption(And) + filterCondition.map(Filter(_, sHolder)).getOrElse(sHolder) + } + + def pushdownAggregate(plan: LogicalPlan): LogicalPlan = plan.transform { + // update the scan builder with agg pushdown and return a new plan with agg pushed + case aggNode @ Aggregate(groupingExpressions, resultExpressions, child) => + child match { + case ScanOperation(project, filters, sHolder: ScanBuilderHolder) + if filters.isEmpty && project.forall(_.isInstanceOf[AttributeReference]) => + sHolder.builder match { + case _: SupportsPushDownAggregates => + val aggregates = resultExpressions.flatMap { expr => + expr.collect { + case agg: AggregateExpression => agg + } + } + val pushedAggregates = PushDownUtils + .pushAggregates(sHolder.builder, aggregates, groupingExpressions) + if (pushedAggregates.isEmpty) { + aggNode // return original plan node + } else { + // No need to do column pruning because only the aggregate columns are used as + // DataSourceV2ScanRelation output columns. All the other columns are not + // included in the output. + val scan = sHolder.builder.build() + + // scalastyle:off + // use the group by columns and aggregate columns as the output columns + // e.g. TABLE t (c1 INT, c2 INT, c3 INT) + // SELECT min(c1), max(c1) FROM t GROUP BY c2; + // Use c2, min(c1), max(c1) as output for DataSourceV2ScanRelation + // We want to have the following logical plan: + // == Optimized Logical Plan == + // Aggregate [c2#10], [min(min(c1)#21) AS min(c1)#17, max(max(c1)#22) AS max(c1)#18] + // +- RelationV2[c2#10, min(c1)#21, max(c1)#22] + // scalastyle:on + val newOutput = scan.readSchema().toAttributes + assert(newOutput.length == groupingExpressions.length + aggregates.length) + val groupAttrs = groupingExpressions.zip(newOutput).map { + case (a: Attribute, b: Attribute) => b.withExprId(a.exprId) + case (_, b) => b + } + val output = groupAttrs ++ newOutput.drop(groupAttrs.length) + + logInfo( + s""" + |Pushing operators to ${sHolder.relation.name} + |Pushed Aggregate Functions: + | ${pushedAggregates.get.aggregateExpressions.mkString(", ")} + |Pushed Group by: + | ${pushedAggregates.get.groupByColumns.mkString(", ")} + |Output: ${output.mkString(", ")} + """.stripMargin) + + val wrappedScan = getWrappedScan(scan, sHolder, pushedAggregates) + + val scanRelation = DataSourceV2ScanRelation(sHolder.relation, wrappedScan, output) + + val plan = Aggregate( + output.take(groupingExpressions.length), resultExpressions, scanRelation) + + // scalastyle:off + // Change the optimized logical plan to reflect the pushed down aggregate + // e.g. TABLE t (c1 INT, c2 INT, c3 INT) + // SELECT min(c1), max(c1) FROM t GROUP BY c2; + // The original logical plan is + // Aggregate [c2#10],[min(c1#9) AS min(c1)#17, max(c1#9) AS max(c1)#18] + // +- RelationV2[c1#9, c2#10] ... + // + // After change the V2ScanRelation output to [c2#10, min(c1)#21, max(c1)#22] + // we have the following + // !Aggregate [c2#10], [min(c1#9) AS min(c1)#17, max(c1#9) AS max(c1)#18] + // +- RelationV2[c2#10, min(c1)#21, max(c1)#22] ... + // + // We want to change it to + // == Optimized Logical Plan == + // Aggregate [c2#10], [min(min(c1)#21) AS min(c1)#17, max(max(c1)#22) AS max(c1)#18] + // +- RelationV2[c2#10, min(c1)#21, max(c1)#22] ... + // scalastyle:on + var i = 0 + val aggOutput = output.drop(groupAttrs.length) + plan.transformExpressions { + case agg: AggregateExpression => + val aggFunction: aggregate.AggregateFunction = + agg.aggregateFunction match { + case max: aggregate.Max => max.copy(child = aggOutput(i)) + case min: aggregate.Min => min.copy(child = aggOutput(i)) + case sum: aggregate.Sum => sum.copy(child = aggOutput(i)) + case _: aggregate.Count => aggregate.Sum(aggOutput(i)) + case other => other + } + i += 1 + agg.copy(aggregateFunction = aggFunction) + } + } + case _ => aggNode + } + case _ => aggNode + } + } + + def applyColumnPruning(plan: LogicalPlan): LogicalPlan = plan.transform { + case ScanOperation(project, filters, sHolder: ScanBuilderHolder) => + // column pruning val normalizedProjects = DataSourceStrategy - .normalizeExprs(project, relation.output) + .normalizeExprs(project, sHolder.output) .asInstanceOf[Seq[NamedExpression]] val (scan, output) = PushDownUtils.pruneColumns( - scanBuilder, relation, normalizedProjects, postScanFilters) + sHolder.builder, sHolder.relation, normalizedProjects, filters) + logInfo( s""" - |Pushing operators to ${relation.name} - |Pushed Filters: ${pushedFilters.mkString(", ")} - |Post-Scan Filters: ${postScanFilters.mkString(",")} |Output: ${output.mkString(", ")} """.stripMargin) - val wrappedScan = scan match { - case v1: V1Scan => - val translated = filters.flatMap(DataSourceStrategy.translateFilter(_, true)) - V1ScanWrapper(v1, translated, pushedFilters) - case _ => scan - } + val wrappedScan = getWrappedScan(scan, sHolder, Option.empty[Aggregation]) - val scanRelation = DataSourceV2ScanRelation(relation, wrappedScan, output) + val scanRelation = DataSourceV2ScanRelation(sHolder.relation, wrappedScan, output) val projectionOverSchema = ProjectionOverSchema(output.toStructType) val projectionFunc = (expr: Expression) => expr transformDown { case projectionOverSchema(newExpr) => newExpr } - val filterCondition = postScanFilters.reduceLeftOption(And) + val filterCondition = filters.reduceLeftOption(And) val newFilterCondition = filterCondition.map(projectionFunc) val withFilter = newFilterCondition.map(Filter(_, scanRelation)).getOrElse(scanRelation) @@ -83,16 +201,36 @@ object V2ScanRelationPushDown extends Rule[LogicalPlan] { } else { withFilter } - withProjection } + + private def getWrappedScan( + scan: Scan, + sHolder: ScanBuilderHolder, + aggregation: Option[Aggregation]): Scan = { + scan match { + case v1: V1Scan => + val pushedFilters = sHolder.builder match { + case f: SupportsPushDownFilters => + f.pushedFilters() + case _ => Array.empty[sources.Filter] + } + V1ScanWrapper(v1, pushedFilters, aggregation) + case _ => scan + } + } } +case class ScanBuilderHolder( + output: Seq[AttributeReference], + relation: DataSourceV2Relation, + builder: ScanBuilder) extends LeafNode + // A wrapper for v1 scan to carry the translated filters and the handled ones. This is required by // the physical v1 scan node. case class V1ScanWrapper( v1Scan: V1Scan, - translatedFilters: Seq[sources.Filter], - handledFilters: Seq[sources.Filter]) extends Scan { + handledFilters: Seq[sources.Filter], + pushedAggregate: Option[Aggregation]) extends Scan { override def readSchema(): StructType = v1Scan.readSchema() } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/jdbc/JDBCScan.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/jdbc/JDBCScan.scala index 860232ba84f39..d6ae7c893aeef 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/jdbc/JDBCScan.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/jdbc/JDBCScan.scala @@ -18,6 +18,7 @@ package org.apache.spark.sql.execution.datasources.v2.jdbc import org.apache.spark.rdd.RDD import org.apache.spark.sql.{Row, SQLContext} +import org.apache.spark.sql.connector.expressions.FieldReference import org.apache.spark.sql.connector.read.V1Scan import org.apache.spark.sql.execution.datasources.jdbc.JDBCRelation import org.apache.spark.sql.sources.{BaseRelation, Filter, TableScan} @@ -26,7 +27,9 @@ import org.apache.spark.sql.types.StructType case class JDBCScan( relation: JDBCRelation, prunedSchema: StructType, - pushedFilters: Array[Filter]) extends V1Scan { + pushedFilters: Array[Filter], + pushedAggregateColumn: Array[String] = Array(), + groupByColumns: Option[Array[FieldReference]]) extends V1Scan { override def readSchema(): StructType = prunedSchema @@ -36,14 +39,28 @@ case class JDBCScan( override def schema: StructType = prunedSchema override def needConversion: Boolean = relation.needConversion override def buildScan(): RDD[Row] = { - relation.buildScan(prunedSchema.map(_.name).toArray, pushedFilters) + if (groupByColumns.isEmpty) { + relation.buildScan( + prunedSchema.map(_.name).toArray, Some(prunedSchema), pushedFilters, groupByColumns) + } else { + relation.buildScan( + pushedAggregateColumn, Some(prunedSchema), pushedFilters, groupByColumns) + } } }.asInstanceOf[T] } override def description(): String = { + val (aggString, groupByString) = if (groupByColumns.nonEmpty) { + val groupByColumnsLength = groupByColumns.get.length + (seqToString(pushedAggregateColumn.drop(groupByColumnsLength)), + seqToString(pushedAggregateColumn.take(groupByColumnsLength))) + } else { + ("[]", "[]") + } super.description() + ", prunedSchema: " + seqToString(prunedSchema) + - ", PushedFilters: " + seqToString(pushedFilters) + ", PushedFilters: " + seqToString(pushedFilters) + + ", PushedAggregates: " + aggString + ", PushedGroupBy: " + groupByString } private def seqToString(seq: Seq[Any]): String = seq.mkString("[", ", ", "]") diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/jdbc/JDBCScanBuilder.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/jdbc/JDBCScanBuilder.scala index 270c5b6d92e32..7442edaafd67a 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/jdbc/JDBCScanBuilder.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/jdbc/JDBCScanBuilder.scala @@ -17,18 +17,20 @@ package org.apache.spark.sql.execution.datasources.v2.jdbc import org.apache.spark.sql.SparkSession -import org.apache.spark.sql.connector.read.{Scan, ScanBuilder, SupportsPushDownFilters, SupportsPushDownRequiredColumns} +import org.apache.spark.sql.connector.expressions.{Aggregation, Count, CountStar, FieldReference, Max, Min, Sum} +import org.apache.spark.sql.connector.read.{Scan, ScanBuilder, SupportsPushDownAggregates, SupportsPushDownFilters, SupportsPushDownRequiredColumns} import org.apache.spark.sql.execution.datasources.PartitioningUtils import org.apache.spark.sql.execution.datasources.jdbc.{JDBCOptions, JDBCRDD, JDBCRelation} import org.apache.spark.sql.jdbc.JdbcDialects import org.apache.spark.sql.sources.Filter -import org.apache.spark.sql.types.StructType +import org.apache.spark.sql.types.{LongType, StructField, StructType} case class JDBCScanBuilder( session: SparkSession, schema: StructType, jdbcOptions: JDBCOptions) - extends ScanBuilder with SupportsPushDownFilters with SupportsPushDownRequiredColumns { + extends ScanBuilder with SupportsPushDownFilters with SupportsPushDownRequiredColumns + with SupportsPushDownAggregates{ private val isCaseSensitive = session.sessionState.conf.caseSensitiveAnalysis @@ -49,6 +51,58 @@ case class JDBCScanBuilder( override def pushedFilters(): Array[Filter] = pushedFilter + private var pushedAggregations = Option.empty[Aggregation] + + private var pushedAggregateColumn: Array[String] = Array() + + private def getStructFieldForCol(col: FieldReference): StructField = + schema.fields(schema.fieldNames.toList.indexOf(col.fieldNames.head)) + + override def pushAggregation(aggregation: Aggregation): Boolean = { + if (!jdbcOptions.pushDownAggregate) return false + + val dialect = JdbcDialects.get(jdbcOptions.url) + val compiledAgg = JDBCRDD.compileAggregates(aggregation.aggregateExpressions, dialect) + + var outputSchema = new StructType() + aggregation.groupByColumns.foreach { col => + val structField = getStructFieldForCol(col) + outputSchema = outputSchema.add(structField) + pushedAggregateColumn = pushedAggregateColumn :+ dialect.quoteIdentifier(structField.name) + } + + // The column names here are already quoted and can be used to build sql string directly. + // e.g. "DEPT","NAME",MAX("SALARY"),MIN("BONUS") => + // SELECT "DEPT","NAME",MAX("SALARY"),MIN("BONUS") FROM "test"."employee" + // GROUP BY "DEPT", "NAME" + pushedAggregateColumn = pushedAggregateColumn ++ compiledAgg + + aggregation.aggregateExpressions.foreach { + case max: Max => + val structField = getStructFieldForCol(max.column) + outputSchema = outputSchema.add(structField.copy("max(" + structField.name + ")")) + case min: Min => + val structField = getStructFieldForCol(min.column) + outputSchema = outputSchema.add(structField.copy("min(" + structField.name + ")")) + case count: Count => + val distinct = if (count.isDinstinct) "DISTINCT " else "" + val structField = getStructFieldForCol(count.column) + outputSchema = + outputSchema.add(StructField(s"count($distinct" + structField.name + ")", LongType)) + case _: CountStar => + outputSchema = outputSchema.add(StructField("count(*)", LongType)) + case sum: Sum => + val distinct = if (sum.isDinstinct) "DISTINCT " else "" + val structField = getStructFieldForCol(sum.column) + outputSchema = + outputSchema.add(StructField(s"sum($distinct" + structField.name + ")", sum.dataType)) + case _ => return false + } + this.pushedAggregations = Some(aggregation) + prunedSchema = outputSchema + true + } + override def pruneColumns(requiredSchema: StructType): Unit = { // JDBC doesn't support nested column pruning. // TODO (SPARK-32593): JDBC support nested column and nested column pruning. @@ -65,6 +119,20 @@ case class JDBCScanBuilder( val resolver = session.sessionState.conf.resolver val timeZoneId = session.sessionState.conf.sessionLocalTimeZone val parts = JDBCRelation.columnPartition(schema, resolver, timeZoneId, jdbcOptions) - JDBCScan(JDBCRelation(schema, parts, jdbcOptions)(session), prunedSchema, pushedFilter) + + // in prunedSchema, the schema is either pruned in pushAggregation (if aggregates are + // pushed down), or pruned in pruneColumns (in regular column pruning). These + // two are mutual exclusive. + // For aggregate push down case, we want to pass down the quoted column lists such as + // "DEPT","NAME",MAX("SALARY"),MIN("BONUS"), instead of getting column names from + // prunedSchema and quote them (will become "MAX(SALARY)", "MIN(BONUS)" and can't + // be used in sql string. + val groupByColumns = if (pushedAggregations.nonEmpty) { + Some(pushedAggregations.get.groupByColumns) + } else { + Option.empty[Array[FieldReference]] + } + JDBCScan(JDBCRelation(schema, parts, jdbcOptions)(session), prunedSchema, pushedFilter, + pushedAggregateColumn, groupByColumns) } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCV2Suite.scala b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCV2Suite.scala index a3a3f47280952..c1f8f5f00e5f2 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCV2Suite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCV2Suite.scala @@ -21,16 +21,16 @@ import java.sql.{Connection, DriverManager} import java.util.Properties import org.apache.spark.SparkConf -import org.apache.spark.sql.{QueryTest, Row} +import org.apache.spark.sql.{ExplainSuiteHelper, QueryTest, Row} import org.apache.spark.sql.catalyst.analysis.CannotReplaceMissingTableException import org.apache.spark.sql.catalyst.plans.logical.Filter import org.apache.spark.sql.execution.datasources.v2.DataSourceV2ScanRelation import org.apache.spark.sql.execution.datasources.v2.jdbc.JDBCTableCatalog -import org.apache.spark.sql.functions.lit +import org.apache.spark.sql.functions.{lit, sum, udf} import org.apache.spark.sql.test.SharedSparkSession import org.apache.spark.util.Utils -class JDBCV2Suite extends QueryTest with SharedSparkSession { +class JDBCV2Suite extends QueryTest with SharedSparkSession with ExplainSuiteHelper { import testImplicits._ val tempDir = Utils.createTempDir() @@ -41,6 +41,7 @@ class JDBCV2Suite extends QueryTest with SharedSparkSession { .set("spark.sql.catalog.h2", classOf[JDBCTableCatalog].getName) .set("spark.sql.catalog.h2.url", url) .set("spark.sql.catalog.h2.driver", "org.h2.Driver") + .set("spark.sql.catalog.h2.pushDownAggregate", "true") private def withConnection[T](f: Connection => T): T = { val conn = DriverManager.getConnection(url, new Properties()) @@ -64,6 +65,19 @@ class JDBCV2Suite extends QueryTest with SharedSparkSession { .executeUpdate() conn.prepareStatement("INSERT INTO \"test\".\"people\" VALUES ('fred', 1)").executeUpdate() conn.prepareStatement("INSERT INTO \"test\".\"people\" VALUES ('mary', 2)").executeUpdate() + conn.prepareStatement( + "CREATE TABLE \"test\".\"employee\" (dept INTEGER, name TEXT(32), salary NUMERIC(20, 2)," + + " bonus DOUBLE)").executeUpdate() + conn.prepareStatement("INSERT INTO \"test\".\"employee\" VALUES (1, 'amy', 10000, 1000)") + .executeUpdate() + conn.prepareStatement("INSERT INTO \"test\".\"employee\" VALUES (2, 'alex', 12000, 1200)") + .executeUpdate() + conn.prepareStatement("INSERT INTO \"test\".\"employee\" VALUES (1, 'cathy', 9000, 1200)") + .executeUpdate() + conn.prepareStatement("INSERT INTO \"test\".\"employee\" VALUES (2, 'david', 10000, 1300)") + .executeUpdate() + conn.prepareStatement("INSERT INTO \"test\".\"employee\" VALUES (6, 'jen', 12000, 1200)") + .executeUpdate() } } @@ -84,6 +98,14 @@ class JDBCV2Suite extends QueryTest with SharedSparkSession { case f: Filter => f } assert(filters.isEmpty) + + df.queryExecution.optimizedPlan.collect { + case _: DataSourceV2ScanRelation => + val expected_plan_fragment = + "PushedFilters: [IsNotNull(ID), GreaterThan(ID,1)]" + checkKeywordsExistsInExplain(df, expected_plan_fragment) + } + checkAnswer(df, Row("mary", 2)) } @@ -145,7 +167,8 @@ class JDBCV2Suite extends QueryTest with SharedSparkSession { test("show tables") { checkAnswer(sql("SHOW TABLES IN h2.test"), - Seq(Row("test", "people", false), Row("test", "empty_table", false))) + Seq(Row("test", "people", false), Row("test", "empty_table", false), + Row("test", "employee", false))) } test("SQL API: create table as select") { @@ -214,4 +237,232 @@ class JDBCV2Suite extends QueryTest with SharedSparkSession { checkAnswer(sql("SELECT name, id FROM h2.test.abc"), Row("bob", 4)) } } + + test("scan with aggregate push-down: MAX MIN with filter and group by") { + val df = sql("select MAX(SALARY), MIN(BONUS) FROM h2.test.employee where dept > 0" + + " group by DEPT") + val filters = df.queryExecution.optimizedPlan.collect { + case f: Filter => f + } + assert(filters.isEmpty) + df.queryExecution.optimizedPlan.collect { + case _: DataSourceV2ScanRelation => + val expected_plan_fragment = + "PushedAggregates: [Max(SALARY), Min(BONUS)], " + + "PushedFilters: [IsNotNull(DEPT), GreaterThan(DEPT,0)], " + + "PushedGroupby: [DEPT]" + checkKeywordsExistsInExplain(df, expected_plan_fragment) + } + checkAnswer(df, Seq(Row(10000, 1000), Row(12000, 1200), Row(12000, 1200))) + } + + test("scan with aggregate push-down: MAX MIN with filter without group by") { + val df = sql("select MAX(ID), MIN(ID) FROM h2.test.people where id > 0") + val filters = df.queryExecution.optimizedPlan.collect { + case f: Filter => f + } + assert(filters.isEmpty) + df.queryExecution.optimizedPlan.collect { + case _: DataSourceV2ScanRelation => + val expected_plan_fragment = + "PushedAggregates: [Max(ID), Min(ID)], " + + "PushedFilters: [IsNotNull(ID), GreaterThan(ID,0)], " + + "PushedGroupby: []" + checkKeywordsExistsInExplain(df, expected_plan_fragment) + } + checkAnswer(df, Seq(Row(2, 1))) + } + + test("scan with aggregate push-down: aggregate + number") { + val df = sql("select MAX(SALARY) + 1 FROM h2.test.employee") + df.queryExecution.optimizedPlan.collect { + case _: DataSourceV2ScanRelation => + val expected_plan_fragment = + "PushedAggregates: [Max(SALARY)]" + checkKeywordsExistsInExplain(df, expected_plan_fragment) + } + checkAnswer(df, Seq(Row(12001))) + } + + test("scan with aggregate push-down: COUNT(*)") { + val df = sql("select COUNT(*) FROM h2.test.employee") + df.queryExecution.optimizedPlan.collect { + case _: DataSourceV2ScanRelation => + val expected_plan_fragment = + "PushedAggregates: [CountStar()]" + checkKeywordsExistsInExplain(df, expected_plan_fragment) + } + checkAnswer(df, Seq(Row(5))) + } + + test("scan with aggregate push-down: COUNT(col)") { + val df = sql("select COUNT(DEPT) FROM h2.test.employee") + df.queryExecution.optimizedPlan.collect { + case _: DataSourceV2ScanRelation => + val expected_plan_fragment = + "PushedAggregates: [Count(DEPT,false)]" + checkKeywordsExistsInExplain(df, expected_plan_fragment) + } + checkAnswer(df, Seq(Row(5))) + } + + test("scan with aggregate push-down: COUNT(DISTINCT col)") { + val df = sql("select COUNT(DISTINCT DEPT) FROM h2.test.employee") + df.queryExecution.optimizedPlan.collect { + case _: DataSourceV2ScanRelation => + val expected_plan_fragment = + "PushedAggregates: [Count(DEPT,true)]" + checkKeywordsExistsInExplain(df, expected_plan_fragment) + } + checkAnswer(df, Seq(Row(3))) + } + + test("scan with aggregate push-down: SUM without filer and group by") { + val df = sql("SELECT SUM(SALARY) FROM h2.test.employee") + df.queryExecution.optimizedPlan.collect { + case _: DataSourceV2ScanRelation => + val expected_plan_fragment = + "PushedAggregates: [Sum(SALARY,DecimalType(30,2),false)]" + checkKeywordsExistsInExplain(df, expected_plan_fragment) + } + checkAnswer(df, Seq(Row(53000))) + } + + test("scan with aggregate push-down: DISTINCT SUM without filer and group by") { + val df = sql("SELECT SUM(DISTINCT SALARY) FROM h2.test.employee") + df.queryExecution.optimizedPlan.collect { + case _: DataSourceV2ScanRelation => + val expected_plan_fragment = + "PushedAggregates: [Sum(SALARY,DecimalType(30,2),true)]" + checkKeywordsExistsInExplain(df, expected_plan_fragment) + } + checkAnswer(df, Seq(Row(31000))) + } + + test("scan with aggregate push-down: SUM with group by") { + val df = sql("SELECT SUM(SALARY) FROM h2.test.employee GROUP BY DEPT") + df.queryExecution.optimizedPlan.collect { + case _: DataSourceV2ScanRelation => + val expected_plan_fragment = + "PushedAggregates: [Sum(SALARY,DecimalType(30,2),false)], " + + "PushedFilters: [], " + + "PushedGroupby: [DEPT]" + checkKeywordsExistsInExplain(df, expected_plan_fragment) + } + checkAnswer(df, Seq(Row(19000), Row(22000), Row(12000))) + } + + test("scan with aggregate push-down: DISTINCT SUM with group by") { + val df = sql("SELECT SUM(DISTINCT SALARY) FROM h2.test.employee GROUP BY DEPT") + df.queryExecution.optimizedPlan.collect { + case _: DataSourceV2ScanRelation => + val expected_plan_fragment = + "PushedAggregates: [Sum(SALARY,DecimalType(30,2),true)], " + + "PushedFilters: [], " + + "PushedGroupby: [DEPT]" + checkKeywordsExistsInExplain(df, expected_plan_fragment) + } + checkAnswer(df, Seq(Row(19000), Row(22000), Row(12000))) + } + + test("scan with aggregate push-down: with multiple group by columns") { + val df = sql("select MAX(SALARY), MIN(BONUS) FROM h2.test.employee where dept > 0" + + " group by DEPT, NAME") + val filters11 = df.queryExecution.optimizedPlan.collect { + case f: Filter => f + } + assert(filters11.isEmpty) + df.queryExecution.optimizedPlan.collect { + case _: DataSourceV2ScanRelation => + val expected_plan_fragment = + "PushedAggregates: [Max(SALARY), Min(BONUS)], " + + "PushedFilters: [IsNotNull(DEPT), GreaterThan(DEPT,0)], " + + "PushedGroupby: [DEPT, NAME]" + checkKeywordsExistsInExplain(df, expected_plan_fragment) + } + checkAnswer(df, Seq(Row(9000, 1200), Row(12000, 1200), Row(10000, 1300), + Row(10000, 1000), Row(12000, 1200))) + } + + test("scan with aggregate push-down: with having clause") { + val df = sql("select MAX(SALARY), MIN(BONUS) FROM h2.test.employee where dept > 0" + + " group by DEPT having MIN(BONUS) > 1000") + val filters = df.queryExecution.optimizedPlan.collect { + case f: Filter => f // filter over aggregate not push down + } + assert(filters.nonEmpty) + df.queryExecution.optimizedPlan.collect { + case _: DataSourceV2ScanRelation => + val expected_plan_fragment = + "PushedAggregates: [Max(SALARY), Min(BONUS)], " + + "PushedFilters: [IsNotNull(DEPT), GreaterThan(DEPT,0)], " + + "PushedGroupby: [DEPT]" + checkKeywordsExistsInExplain(df, expected_plan_fragment) + } + checkAnswer(df, Seq(Row(12000, 1200), Row(12000, 1200))) + } + + test("scan with aggregate push-down: alias over aggregate") { + val df = sql("select * from h2.test.employee") + .groupBy($"DEPT") + .min("SALARY").as("total") + df.queryExecution.optimizedPlan.collect { + case _: DataSourceV2ScanRelation => + val expected_plan_fragment = + "PushedAggregates: [Min(SALARY)], " + + "PushedFilters: [], " + + "PushedGroupby: [DEPT]" + checkKeywordsExistsInExplain(df, expected_plan_fragment) + } + checkAnswer(df, Seq(Row(1, 9000), Row(2, 10000), Row(6, 12000))) + } + + test("scan with aggregate push-down: order by alias over aggregate") { + val df = spark.table("h2.test.employee") + val query = df.select($"DEPT", $"SALARY") + .filter($"DEPT" > 0) + .groupBy($"DEPT") + .agg(sum($"SALARY").as("total")) + .filter($"total" > 1000) + .orderBy($"total") + val filters = query.queryExecution.optimizedPlan.collect { + case f: Filter => f + } + assert(filters.nonEmpty) // filter over aggregate not pushed down + query.queryExecution.optimizedPlan.collect { + case _: DataSourceV2ScanRelation => + val expected_plan_fragment = + "PushedAggregates: [Sum(SALARY,DecimalType(30,2),false)], " + + "PushedFilters: [IsNotNull(DEPT), GreaterThan(DEPT,0)], " + + "PushedGroupby: [DEPT]" + checkKeywordsExistsInExplain(query, expected_plan_fragment) + } + checkAnswer(query, Seq(Row(6, 12000), Row(1, 19000), Row(2, 22000))) + } + + test("scan with aggregate push-down: udf over aggregate") { + val df = spark.table("h2.test.employee") + val decrease = udf { (x: Double, y: Double) => x - y } + val query = df.select(decrease(sum($"SALARY"), sum($"BONUS")).as("value")) + query.queryExecution.optimizedPlan.collect { + case _: DataSourceV2ScanRelation => + val expected_plan_fragment = + "PushedAggregates: [Sum(SALARY,DecimalType(30,2),false), Sum(BONUS,DoubleType,false)" + checkKeywordsExistsInExplain(query, expected_plan_fragment) + } + checkAnswer(query, Seq(Row(47100.0))) + } + + test("scan with aggregate push-down: aggregate over alias") { + val cols = Seq("a", "b", "c", "d") + val df1 = sql("select * from h2.test.employee").toDF(cols: _*) + val df2 = df1.groupBy().sum("c") + df2.queryExecution.optimizedPlan.collect { + case _: DataSourceV2ScanRelation => + val expected_plan_fragment = + "PushedAggregates: []" // aggregate over alias not push down + checkKeywordsExistsInExplain(df2, expected_plan_fragment) + } + checkAnswer(df2, Seq(Row(53000.00))) + } }