From 8d0256de7c7fc4685c08caf5f1ca2e53fbeff700 Mon Sep 17 00:00:00 2001 From: Huaxin Gao Date: Thu, 25 Mar 2021 18:20:01 -0700 Subject: [PATCH 01/30] init work for parquet aggregate push down --- .../read/SupportsPushDownAggregates.java | 44 ++++ .../expressions/aggregate/Count.scala | 22 +- .../apache/spark/sql/internal/SQLConf.scala | 8 + .../apache/spark/sql/sources/aggregates.scala | 38 +++ .../sql/execution/DataSourceScanExec.scala | 13 +- .../datasources/DataSourceStrategy.scala | 51 ++++ .../datasources/parquet/ParquetUtils.scala | 172 ++++++++++++++ .../datasources/v2/DataSourceV2Strategy.scala | 4 +- .../datasources/v2/PushDownUtils.scala | 41 +++- .../v2/V2ScanRelationPushDown.scala | 217 ++++++++++++++---- .../ParquetPartitionReaderFactory.scala | 118 ++++++++-- .../datasources/v2/parquet/ParquetScan.scala | 4 +- .../v2/parquet/ParquetScanBuilder.scala | 33 ++- .../datasources/v2/parquet/ParquetTable.scala | 2 +- .../parquet/ParquetQuerySuite.scala | 48 ++++ 15 files changed, 740 insertions(+), 75 deletions(-) create mode 100644 sql/catalyst/src/main/java/org/apache/spark/sql/connector/read/SupportsPushDownAggregates.java create mode 100644 sql/catalyst/src/main/scala/org/apache/spark/sql/sources/aggregates.scala diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/read/SupportsPushDownAggregates.java b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/read/SupportsPushDownAggregates.java new file mode 100644 index 0000000000000..40ed146114ffe --- /dev/null +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/read/SupportsPushDownAggregates.java @@ -0,0 +1,44 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.connector.read; + +import org.apache.spark.annotation.Evolving; +import org.apache.spark.sql.sources.Aggregation; + +/** + * A mix-in interface for {@link ScanBuilder}. Data sources can implement this interface to + * push down aggregates to the data source. + * + * @since 3.2.0 + */ +@Evolving +public interface SupportsPushDownAggregates extends ScanBuilder { + + /** + * Pushes down Aggregation to datasource. + * The Aggregation can be pushed down only if all the Aggregate Functions can + * be pushed down. + */ + void pushAggregation(Aggregation aggregation); + + /** + * Returns the aggregation that are pushed to the data source via + * {@link #pushAggregation(Aggregation aggregation)}. + */ + Aggregation pushedAggregation(); +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Count.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Count.scala index dfdd828d10d03..472436f3884f5 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Count.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Count.scala @@ -80,15 +80,25 @@ case class Count(children: Seq[Expression]) extends DeclarativeAggregate { override def defaultResult: Option[Literal] = Option(Literal(0L)) + private[sql] var pushDown: Boolean = false + override lazy val updateExpressions = { - val nullableChildren = children.filter(_.nullable) - if (nullableChildren.isEmpty) { - Seq( - /* count = */ count + 1L - ) + if (!pushDown) { + val nullableChildren = children.filter(_.nullable) + if (nullableChildren.isEmpty) { + Seq( + /* count = */ count + 1L + ) + } else { + Seq( + /* count = */ If(nullableChildren.map(IsNull).reduce(Or), count, count + 1L) + ) + } } else { Seq( - /* count = */ If(nullableChildren.map(IsNull).reduce(Or), count, count + 1L) + // if count is pushed down to Data Source layer, add the count result retrieved from + // Data Source + /* count = */ count + children.head ) } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala index 04e740039f005..d979cd1ef966c 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala @@ -772,6 +772,12 @@ object SQLConf { .checkValue(threshold => threshold >= 0, "The threshold must not be negative.") .createWithDefault(10) + val PARQUET_AGGREGATE_PUSHDOWN_ENABLED = buildConf("spark.sql.parquet.aggregatePushdown") + .doc("Enables Parquet aggregate push-down optimization when set to true.") + .version("3.2.0") + .booleanConf + .createWithDefault(false) + val PARQUET_WRITE_LEGACY_FORMAT = buildConf("spark.sql.parquet.writeLegacyFormat") .doc("If true, data will be written in a way of Spark 1.4 and earlier. For example, decimal " + "values will be written in Apache Parquet's fixed-length byte array format, which other " + @@ -3423,6 +3429,8 @@ class SQLConf extends Serializable with Logging { def parquetFilterPushDownInFilterThreshold: Int = getConf(PARQUET_FILTER_PUSHDOWN_INFILTERTHRESHOLD) + def parquetAggregatePushDown: Boolean = getConf(PARQUET_AGGREGATE_PUSHDOWN_ENABLED) + def orcFilterPushDown: Boolean = getConf(ORC_FILTER_PUSHDOWN_ENABLED) def isOrcSchemaMergingEnabled: Boolean = getConf(ORC_SCHEMA_MERGING_ENABLED) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/sources/aggregates.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/sources/aggregates.scala new file mode 100644 index 0000000000000..3bc263ff49f1f --- /dev/null +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/sources/aggregates.scala @@ -0,0 +1,38 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.sources + +import org.apache.spark.sql.types.DataType + +// groupBy only used by JDBC agg pushdown, not supported by parquet agg pushdown yet +case class Aggregation(aggregateExpressions: Seq[AggregateFunc], + groupByExpressions: Seq[String]) + +abstract class AggregateFunc + +// Avg and Sum are only supported by JDBC agg pushdown, not supported by parquet agg pushdown yet +case class Avg(column: String, dataType: DataType, isDistinct: Boolean) extends AggregateFunc +case class Min(column: String, dataType: DataType) extends AggregateFunc +case class Max(column: String, dataType: DataType) extends AggregateFunc +case class Sum(column: String, dataType: DataType, isDistinct: Boolean) extends AggregateFunc +case class Count(column: String, dataType: DataType, isDistinct: Boolean) extends AggregateFunc + +object Aggregation { + // Returns an empty Aggregate + def empty: Aggregation = Aggregation(Seq.empty[AggregateFunc], Seq.empty[String]) +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/DataSourceScanExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/DataSourceScanExec.scala index 6fa4167384925..29f9bb8be0900 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/DataSourceScanExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/DataSourceScanExec.scala @@ -35,7 +35,7 @@ import org.apache.spark.sql.execution.datasources._ import org.apache.spark.sql.execution.datasources.parquet.{ParquetFileFormat => ParquetSource} import org.apache.spark.sql.execution.metric.{SQLMetric, SQLMetrics} import org.apache.spark.sql.internal.SQLConf -import org.apache.spark.sql.sources.{BaseRelation, Filter} +import org.apache.spark.sql.sources.{Aggregation, BaseRelation, Filter} import org.apache.spark.sql.types.StructType import org.apache.spark.sql.vectorized.ColumnarBatch import org.apache.spark.util.Utils @@ -102,6 +102,7 @@ case class RowDataSourceScanExec( requiredSchema: StructType, filters: Set[Filter], handledFilters: Set[Filter], + aggregation: Aggregation, rdd: RDD[InternalRow], @transient relation: BaseRelation, tableIdentifier: Option[TableIdentifier]) @@ -132,9 +133,17 @@ case class RowDataSourceScanExec( val markedFilters = for (filter <- filters) yield { if (handledFilters.contains(filter)) s"*$filter" else s"$filter" } + val markedAggregates = for (aggregate <- aggregation.aggregateExpressions) yield { + s"*$aggregate" + } + val markedGroupby = for (groupby <- aggregation.groupByExpressions) yield { + s"*$groupby" + } Map( "ReadSchema" -> requiredSchema.catalogString, - "PushedFilters" -> markedFilters.mkString("[", ", ", "]")) + "PushedFilters" -> markedFilters.mkString("[", ", ", "]"), + "PushedAggregates" -> markedAggregates.mkString("[", ", ", "]"), + "PushedGroupby" -> markedGroupby.mkString("[", ", ", "]")) } // Don't care about `rdd` and `tableIdentifier` when canonicalizing. diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala index 58ac924a1d36c..b8b6c29ae77e1 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala @@ -33,6 +33,7 @@ import org.apache.spark.sql.catalyst.catalog._ import org.apache.spark.sql.catalyst.encoders.RowEncoder import org.apache.spark.sql.catalyst.expressions import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.expressions.aggregate.AggregateExpression import org.apache.spark.sql.catalyst.planning.ScanOperation import org.apache.spark.sql.catalyst.plans.logical.{InsertIntoDir, InsertIntoStatement, LogicalPlan, Project} import org.apache.spark.sql.catalyst.rules.Rule @@ -332,6 +333,7 @@ object DataSourceStrategy l.output.toStructType, Set.empty, Set.empty, + Aggregation.empty, toCatalystRDD(l, baseRelation.buildScan()), baseRelation, None) :: Nil @@ -405,6 +407,7 @@ object DataSourceStrategy requestedColumns.toStructType, pushedFilters.toSet, handledFilters, + Aggregation.empty, scanBuilder(requestedColumns, candidatePredicates, pushedFilters), relation.relation, relation.catalogTable.map(_.identifier)) @@ -427,6 +430,7 @@ object DataSourceStrategy requestedColumns.toStructType, pushedFilters.toSet, handledFilters, + Aggregation.empty, scanBuilder(requestedColumns, candidatePredicates, pushedFilters), relation.relation, relation.catalogTable.map(_.identifier)) @@ -673,6 +677,53 @@ object DataSourceStrategy (nonconvertiblePredicates ++ unhandledPredicates, pushedFilters, handledFilters) } + private def columnAsString(e: Expression): String = e match { + case AttributeReference(name, _, _, _) => name + case Cast(child, _, _) => columnAsString (child) + // Add, Subtract, Multiply and Divide are only supported by JDBC agg pushdown + case Add(left, right, _) => + columnAsString(left) + " + " + columnAsString(right) + case Subtract(left, right, _) => + columnAsString(left) + " - " + columnAsString(right) + case Multiply(left, right, _) => + columnAsString(left) + " * " + columnAsString(right) + case Divide(left, right, _) => + columnAsString(left) + " / " + columnAsString(right) + + case CheckOverflow(child, _, _) => columnAsString (child) + case PromotePrecision(child) => columnAsString (child) + case _ => "" + } + + protected[sql] def translateAggregate(aggregates: AggregateExpression): Option[AggregateFunc] = { + + aggregates.aggregateFunction match { + case min: aggregate.Min => + val colName = columnAsString(min.child) + if (colName.nonEmpty) Some(Min(colName, min.dataType)) else None + case max: aggregate.Max => + val colName = columnAsString(max.child) + if (colName.nonEmpty) Some(Max(colName, max.dataType)) else None + case avg: aggregate.Average => + val colName = columnAsString(avg.child) + if (colName.nonEmpty) Some(Avg(colName, avg.dataType, aggregates.isDistinct)) else None + case sum: aggregate.Sum => + val colName = columnAsString(sum.child) + if (colName.nonEmpty) Some(Sum(colName, sum.dataType, aggregates.isDistinct)) else None + case count: aggregate.Count => + val columnName = count.children.head match { + case Literal(_, _) => + "1" + case _ => columnAsString(count.children.head) + } + if (columnName.nonEmpty) { + Some(Count(columnName, count.dataType, aggregates.isDistinct)) + } + else None + case _ => None + } + } + /** * Convert RDD of Row into RDD of InternalRow with objects in catalyst types */ diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetUtils.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetUtils.scala index b91d75c55c513..9afa2974a49b7 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetUtils.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetUtils.scala @@ -16,11 +16,26 @@ */ package org.apache.spark.sql.execution.datasources.parquet +import java.util + +import scala.collection.JavaConverters._ +import scala.collection.mutable.ArrayBuilder +import scala.language.existentials + +import org.apache.hadoop.conf.Configuration import org.apache.hadoop.fs.{FileStatus, Path} +import org.apache.parquet.format.converter.ParquetMetadataConverter.NO_FILTER import org.apache.parquet.hadoop.ParquetFileWriter +import org.apache.parquet.hadoop.metadata.{ColumnChunkMetaData, ParquetMetadata} +import org.apache.parquet.schema.PrimitiveType import org.apache.spark.sql.SparkSession +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.expressions.SpecificInternalRow +import org.apache.spark.sql.execution.vectorized.{OffHeapColumnVector, OnHeapColumnVector} +import org.apache.spark.sql.sources.{Aggregation, Count, Max, Min} import org.apache.spark.sql.types.StructType +import org.apache.spark.sql.vectorized.{ColumnarBatch, ColumnVector} object ParquetUtils { def inferSchema( @@ -127,4 +142,161 @@ object ParquetUtils { file.getName == ParquetFileWriter.PARQUET_COMMON_METADATA_FILE || file.getName == ParquetFileWriter.PARQUET_METADATA_FILE } + + private[sql] def aggResultToSparkInternalRows( + parquetTypes: Seq[PrimitiveType.PrimitiveTypeName], + values: Seq[Any], + dataSchema: StructType): InternalRow = { + val mutableRow = new SpecificInternalRow(dataSchema.fields.map(x => x.dataType)) + + parquetTypes.zipWithIndex.map { + case (PrimitiveType.PrimitiveTypeName.INT32, i) => + mutableRow.setInt(i, values(i).asInstanceOf[Int]) + case (PrimitiveType.PrimitiveTypeName.INT64, i) => + mutableRow.setLong(i, values(i).asInstanceOf[Long]) + case (PrimitiveType.PrimitiveTypeName.INT96, i) => + mutableRow.setLong(i, values(i).asInstanceOf[Long]) + case (PrimitiveType.PrimitiveTypeName.FLOAT, i) => + mutableRow.setFloat(i, values(i).asInstanceOf[Float]) + case (PrimitiveType.PrimitiveTypeName.DOUBLE, i) => + mutableRow.setDouble(i, values(i).asInstanceOf[Double]) + case (PrimitiveType.PrimitiveTypeName.BINARY, i) => + mutableRow.update(i, values(i).asInstanceOf[Array[Byte]]) + case (PrimitiveType.PrimitiveTypeName.FIXED_LEN_BYTE_ARRAY, i) => + mutableRow.update(i, values(i).asInstanceOf[Array[Byte]]) + case _ => + throw new IllegalArgumentException("Unexpected parquet type name") + } + mutableRow + } + + private[sql] def aggResultToSparkColumnarBatch( + parquetTypes: Seq[PrimitiveType.PrimitiveTypeName], + values: Seq[Any], + readDataSchema: StructType, + offHeap: Boolean): ColumnarBatch = { + val capacity = 4 * 1024 + val columnVectors = if (offHeap) { + OffHeapColumnVector.allocateColumns(capacity, readDataSchema) + } else { + OnHeapColumnVector.allocateColumns(capacity, readDataSchema) + } + + parquetTypes.zipWithIndex.map { + case (PrimitiveType.PrimitiveTypeName.INT32, i) => + columnVectors(i).appendInt(values(i).asInstanceOf[Int]) + case (PrimitiveType.PrimitiveTypeName.INT64, i) => + columnVectors(i).appendLong(values(i).asInstanceOf[Long]) + case (PrimitiveType.PrimitiveTypeName.INT96, i) => + columnVectors(i).appendLong(values(i).asInstanceOf[Long]) + case (PrimitiveType.PrimitiveTypeName.FLOAT, i) => + columnVectors(i).appendFloat(values(i).asInstanceOf[Float]) + case (PrimitiveType.PrimitiveTypeName.DOUBLE, i) => + columnVectors(i).appendDouble(values(i).asInstanceOf[Double]) + case (PrimitiveType.PrimitiveTypeName.BINARY, i) => + val byteArray = values(i).asInstanceOf[Array[Byte]] + columnVectors(i).appendBytes(byteArray.length, byteArray, 0) + case (PrimitiveType.PrimitiveTypeName.FIXED_LEN_BYTE_ARRAY, i) => + val byteArray = values(i).asInstanceOf[Array[Byte]] + columnVectors(i).appendBytes(byteArray.length, byteArray, 0) + case _ => + throw new IllegalArgumentException("Unexpected parquet type name") + } + new ColumnarBatch(columnVectors.asInstanceOf[Array[ColumnVector]], 1) + } + + private[sql] def getPushedDownAggResult( + conf: Configuration, + file: Path, + dataSchema: StructType, + aggregation: Aggregation) + : (Array[PrimitiveType.PrimitiveTypeName], Array[Any]) = { + + val footer = ParquetFooterReader.readFooter(conf, file, NO_FILTER) + val fields = footer.getFileMetaData.getSchema.getFields + val typesBuilder = ArrayBuilder.make[PrimitiveType.PrimitiveTypeName] + val valuesBuilder = ArrayBuilder.make[Any] + val blocks = footer.getBlocks() + + blocks.forEach { block => + val columns = block.getColumns() + for (i <- 0 until aggregation.aggregateExpressions.size) { + var index = 0 + aggregation.aggregateExpressions(i) match { + case Max(col, _) => + index = dataSchema.fieldNames.toList.indexOf(col) + valuesBuilder += getPushedDownMaxMin(footer, columns, index, true) + typesBuilder += fields.get(index).asPrimitiveType.getPrimitiveTypeName + case Min(col, _) => + index = dataSchema.fieldNames.toList.indexOf(col) + valuesBuilder += getPushedDownMaxMin(footer, columns, index, false) + typesBuilder += fields.get(index).asPrimitiveType.getPrimitiveTypeName + case Count(col, _, _) => + index = dataSchema.fieldNames.toList.indexOf(col) + var rowCount = getRowCountFromParquetMetadata(footer) + if (!col.equals("1")) { // count(*) + rowCount -= getNumNulls(footer, columns, index) + } + valuesBuilder += rowCount + typesBuilder += PrimitiveType.PrimitiveTypeName.INT96 + case _ => + } + } + } + (typesBuilder.result(), valuesBuilder.result()) + } + + private def getPushedDownMaxMin( + footer: ParquetMetadata, + columnChunkMetaData: util.List[ColumnChunkMetaData], + i: Int, + isMax: Boolean) = { + val parquetType = footer.getFileMetaData.getSchema.getType(i) + if (!parquetType.isPrimitive) { + throw new IllegalArgumentException("Unsupported type : " + parquetType.toString) + } + var value: Any = None + val statistics = columnChunkMetaData.get(i).getStatistics() + if (isMax) { + val currentMax = statistics.genericGetMax() + if (currentMax != None && + (value == None || currentMax.asInstanceOf[Comparable[Any]].compareTo(value) > 0)) { + value = currentMax + } + } else { + val currentMin = statistics.genericGetMin() + if (currentMin != None && + (value == None || currentMin.asInstanceOf[Comparable[Any]].compareTo(value) < 0)) { + value = currentMin + } + } + value + } + + private def getRowCountFromParquetMetadata(footer: ParquetMetadata): Long = { + var rowCount: Long = 0 + for (blockMetaData <- footer.getBlocks.asScala) { + rowCount += blockMetaData.getRowCount + } + rowCount + } + + private def getNumNulls( + footer: ParquetMetadata, + columnChunkMetaData: util.List[ColumnChunkMetaData], + i: Int): Long = { + val parquetType = footer.getFileMetaData.getSchema.getType(i) + if (!parquetType.isPrimitive) { + throw new IllegalArgumentException("Unsupported type : " + parquetType.toString) + } + var numNulls: Long = 0; + val statistics = columnChunkMetaData.get(i).getStatistics() + if (!statistics.isNumNullsSet()) { + throw new UnsupportedOperationException("Number of nulls not set for parquet file." + + " Set session property hive.pushdown_partial_aggregations_into_scan=false and execute" + + " query again"); + } + numNulls += statistics.getNumNulls(); + numNulls + } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Strategy.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Strategy.scala index 811f41832d159..6d69c8071988a 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Strategy.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Strategy.scala @@ -86,7 +86,8 @@ class DataSourceV2Strategy(session: SparkSession) extends Strategy with Predicat override def apply(plan: LogicalPlan): Seq[SparkPlan] = plan match { case PhysicalOperation(project, filters, - relation @ DataSourceV2ScanRelation(_, V1ScanWrapper(scan, translated, pushed), output)) => + relation @ DataSourceV2ScanRelation(_, V1ScanWrapper(scan, translated, pushed, + aggregation), output)) => val v1Relation = scan.toV1TableScan[BaseRelation with TableScan](session.sqlContext) if (v1Relation.schema != scan.readSchema()) { throw QueryExecutionErrors.fallbackV1RelationReportsInconsistentSchemaError( @@ -99,6 +100,7 @@ class DataSourceV2Strategy(session: SparkSession) extends Strategy with Predicat output.toStructType, translated.toSet, pushed.toSet, + aggregation, unsafeRowRDD, v1Relation, tableIdentifier = None) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/PushDownUtils.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/PushDownUtils.scala index 1f57f17911457..b6192e058e46e 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/PushDownUtils.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/PushDownUtils.scala @@ -20,11 +20,13 @@ package org.apache.spark.sql.execution.datasources.v2 import scala.collection.mutable import org.apache.spark.sql.catalyst.expressions.{AttributeReference, AttributeSet, Expression, NamedExpression, PredicateHelper, SchemaPruning} +import org.apache.spark.sql.catalyst.expressions.aggregate.AggregateExpression import org.apache.spark.sql.catalyst.util.CharVarcharUtils -import org.apache.spark.sql.connector.read.{Scan, ScanBuilder, SupportsPushDownFilters, SupportsPushDownRequiredColumns} +import org.apache.spark.sql.connector.read.{Scan, ScanBuilder, SupportsPushDownAggregates, SupportsPushDownFilters, SupportsPushDownRequiredColumns} import org.apache.spark.sql.execution.datasources.DataSourceStrategy import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.sources +import org.apache.spark.sql.sources.Aggregation import org.apache.spark.sql.types.StructType object PushDownUtils extends PredicateHelper { @@ -70,6 +72,43 @@ object PushDownUtils extends PredicateHelper { } } + /** + * Pushes down aggregates to the data source reader + * + * @return pushed aggregation. + */ + def pushAggregates( + scanBuilder: ScanBuilder, + aggregates: Seq[AggregateExpression], + groupBy: Seq[Expression]): Aggregation = { + + def columnAsString(e: Expression): String = e match { + case AttributeReference(name, _, _, _) => name + case _ => "" + } + + scanBuilder match { + case r: SupportsPushDownAggregates => + val translatedAggregates = mutable.ArrayBuffer.empty[sources.AggregateFunc] + + for (aggregateExpr <- aggregates) { + val translated = DataSourceStrategy.translateAggregate(aggregateExpr) + if (translated.isEmpty) { + return Aggregation.empty + } else { + translatedAggregates += translated.get + } + } + val groupByCols = groupBy.map(columnAsString(_)) + if (!groupByCols.exists(_.isEmpty)) { + r.pushAggregation(Aggregation(translatedAggregates, groupByCols)) + } + r.pushedAggregation + + case _ => Aggregation.empty + } + } + /** * Applies column pruning to the data source, w.r.t. the references of the given expressions. * diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/V2ScanRelationPushDown.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/V2ScanRelationPushDown.scala index d2180566790ac..dc2aaaca55895 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/V2ScanRelationPushDown.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/V2ScanRelationPushDown.scala @@ -17,38 +17,133 @@ package org.apache.spark.sql.execution.datasources.v2 -import org.apache.spark.sql.catalyst.expressions.{And, Expression, NamedExpression, ProjectionOverSchema, SubqueryExpression} +import scala.collection.mutable.ArrayBuilder + +import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.expressions.aggregate.AggregateExpression import org.apache.spark.sql.catalyst.planning.ScanOperation -import org.apache.spark.sql.catalyst.plans.logical.{Filter, LogicalPlan, Project} +import org.apache.spark.sql.catalyst.plans.logical.{Aggregate, Filter, LogicalPlan, Project} import org.apache.spark.sql.catalyst.rules.Rule -import org.apache.spark.sql.connector.read.{Scan, V1Scan} +import org.apache.spark.sql.connector.read.{Scan, ScanBuilder, V1Scan} import org.apache.spark.sql.execution.datasources.DataSourceStrategy import org.apache.spark.sql.sources +import org.apache.spark.sql.sources.{AggregateFunc, Aggregation} import org.apache.spark.sql.types.StructType -object V2ScanRelationPushDown extends Rule[LogicalPlan] { +object V2ScanRelationPushDown extends Rule[LogicalPlan] with AliasHelper { import DataSourceV2Implicits._ override def apply(plan: LogicalPlan): LogicalPlan = plan transformDown { - case ScanOperation(project, filters, relation: DataSourceV2Relation) => - val scanBuilder = relation.table.asReadable.newScanBuilder(relation.options) + case Aggregate(groupingExpressions, resultExpressions, child) => + child match { + case ScanOperation(project, filters, relation: DataSourceV2Relation) => + val scanBuilder = relation.table.asReadable.newScanBuilder(relation.options) - val normalizedFilters = DataSourceStrategy.normalizeExprs(filters, relation.output) - val (normalizedFiltersWithSubquery, normalizedFiltersWithoutSubquery) = - normalizedFilters.partition(SubqueryExpression.hasSubquery) + val aliasMap = getAliasMap(project) + var aggregates = resultExpressions.flatMap { expr => + expr.collect { + case agg: AggregateExpression => + replaceAlias(agg, aliasMap).asInstanceOf[AggregateExpression] + } + } + aggregates = DataSourceStrategy.normalizeExprs(aggregates, relation.output) + .asInstanceOf[Seq[AggregateExpression]] - // `pushedFilters` will be pushed down and evaluated in the underlying data sources. - // `postScanFilters` need to be evaluated after the scan. - // `postScanFilters` and `pushedFilters` can overlap, e.g. the parquet row group filter. - val (pushedFilters, postScanFiltersWithoutSubquery) = PushDownUtils.pushFilters( - scanBuilder, normalizedFiltersWithoutSubquery) - val postScanFilters = postScanFiltersWithoutSubquery ++ normalizedFiltersWithSubquery + val groupingExpressionsWithoutAlias = groupingExpressions.flatMap{ expr => + expr.collect { + case e: Expression => replaceAlias(e, aliasMap) + } + } + val normalizedGroupingExpressions = + DataSourceStrategy.normalizeExprs(groupingExpressionsWithoutAlias, relation.output) + + var newFilters = filters + aggregates.foreach(agg => + if (agg.filter.nonEmpty) { + // handle agg filter the same way as other filters + newFilters = newFilters :+ agg.filter.get + } + ) + + val (pushedFilters, postScanFilters) = pushDownFilter(scanBuilder, newFilters, relation) + if (postScanFilters.nonEmpty) { + Aggregate(groupingExpressions, resultExpressions, child) + } else { // only push down aggregate if all the filers can be push down + val aggregation = PushDownUtils.pushAggregates(scanBuilder, aggregates, + normalizedGroupingExpressions) + + val (scan, output, normalizedProjects) = + processFilterAndColumn(scanBuilder, project, postScanFilters, relation) + + logInfo( + s""" + |Pushing operators to ${relation.name} + |Pushed Filters: ${pushedFilters.mkString(", ")} + |Post-Scan Filters: ${postScanFilters.mkString(",")} + |Pushed Aggregate Functions: ${aggregation.aggregateExpressions.mkString(", ")} + |Pushed Groupby: ${aggregation.groupByExpressions.mkString(", ")} + |Output: ${output.mkString(", ")} + """.stripMargin) + + val wrappedScan = scan match { + case v1: V1Scan => + val translated = newFilters.flatMap(DataSourceStrategy.translateFilter(_, true)) + V1ScanWrapper(v1, translated, pushedFilters, aggregation) + case _ => scan + } + + if (aggregation.aggregateExpressions.isEmpty) { + Aggregate(groupingExpressions, resultExpressions, child) + } else { + val aggOutputBuilder = ArrayBuilder.make[AttributeReference] + for (i <- 0 until aggregates.length) { + aggOutputBuilder += AttributeReference( + aggregation.aggregateExpressions(i).toString, aggregates(i).dataType)() + } + groupingExpressions.foreach{ + case a@AttributeReference(_, _, _, _) => aggOutputBuilder += a + case _ => + } + val aggOutput = aggOutputBuilder.result + + val r = buildLogicalPlan(aggOutput, relation, wrappedScan, aggOutput, + normalizedProjects, postScanFilters) + val plan = Aggregate(groupingExpressions, resultExpressions, r) + + var i = 0 + plan.transformExpressions { + case agg: AggregateExpression => + i += 1 + val aggFunction: aggregate.AggregateFunction = { + if (agg.aggregateFunction.isInstanceOf[aggregate.Max]) { + aggregate.Max(aggOutput(i - 1)) + } else if (agg.aggregateFunction.isInstanceOf[aggregate.Min]) { + aggregate.Min(aggOutput(i - 1)) + } else if (agg.aggregateFunction.isInstanceOf[aggregate.Average]) { + aggregate.Average(aggOutput(i - 1)) + } else if (agg.aggregateFunction.isInstanceOf[aggregate.Sum]) { + aggregate.Sum(aggOutput(i - 1)) + } else if (agg.aggregateFunction.isInstanceOf[aggregate.Count]) { + val count = aggregate.Count(aggOutput(i - 1)) + count.pushDown = true + count + } else { + agg.aggregateFunction + } + } + agg.copy(aggregateFunction = aggFunction, filter = None) + } + } + } + + case _ => Aggregate(groupingExpressions, resultExpressions, child) + } + case ScanOperation(project, filters, relation: DataSourceV2Relation) => + val scanBuilder = relation.table.asReadable.newScanBuilder(relation.options) + val (pushedFilters, postScanFilters) = pushDownFilter (scanBuilder, filters, relation) + val (scan, output, normalizedProjects) = + processFilterAndColumn(scanBuilder, project, postScanFilters, relation) - val normalizedProjects = DataSourceStrategy - .normalizeExprs(project, relation.output) - .asInstanceOf[Seq[NamedExpression]] - val (scan, output) = PushDownUtils.pruneColumns( - scanBuilder, relation, normalizedProjects, postScanFilters) logInfo( s""" |Pushing operators to ${relation.name} @@ -60,31 +155,72 @@ object V2ScanRelationPushDown extends Rule[LogicalPlan] { val wrappedScan = scan match { case v1: V1Scan => val translated = filters.flatMap(DataSourceStrategy.translateFilter(_, true)) - V1ScanWrapper(v1, translated, pushedFilters) + V1ScanWrapper(v1, translated, pushedFilters, + Aggregation(Seq.empty[AggregateFunc], Seq.empty[String])) + case _ => scan } - val scanRelation = DataSourceV2ScanRelation(relation, wrappedScan, output) + buildLogicalPlan(project, relation, wrappedScan, output, normalizedProjects, postScanFilters) + } - val projectionOverSchema = ProjectionOverSchema(output.toStructType) - val projectionFunc = (expr: Expression) => expr transformDown { - case projectionOverSchema(newExpr) => newExpr - } + private def pushDownFilter( + scanBuilder: ScanBuilder, + filters: Seq[Expression], + relation: DataSourceV2Relation): (Seq[sources.Filter], Seq[Expression]) = { + val normalizedFilters = DataSourceStrategy.normalizeExprs(filters, relation.output) + val (normalizedFiltersWithSubquery, normalizedFiltersWithoutSubquery) = + normalizedFilters.partition(SubqueryExpression.hasSubquery) - val filterCondition = postScanFilters.reduceLeftOption(And) - val newFilterCondition = filterCondition.map(projectionFunc) - val withFilter = newFilterCondition.map(Filter(_, scanRelation)).getOrElse(scanRelation) - - val withProjection = if (withFilter.output != project) { - val newProjects = normalizedProjects - .map(projectionFunc) - .asInstanceOf[Seq[NamedExpression]] - Project(newProjects, withFilter) - } else { - withFilter - } + // `pushedFilters` will be pushed down and evaluated in the underlying data sources. + // `postScanFilters` need to be evaluated after the scan. + // `postScanFilters` and `pushedFilters` can overlap, e.g. the parquet row group filter. + val (pushedFilters, postScanFiltersWithoutSubquery) = PushDownUtils.pushFilters( + scanBuilder, normalizedFiltersWithoutSubquery) + val postScanFilters = postScanFiltersWithoutSubquery ++ normalizedFiltersWithSubquery + (pushedFilters, postScanFilters) + } + + private def processFilterAndColumn( + scanBuilder: ScanBuilder, + project: Seq[NamedExpression], + postScanFilters: Seq[Expression], + relation: DataSourceV2Relation): + (Scan, Seq[AttributeReference], Seq[NamedExpression]) = { + val normalizedProjects = DataSourceStrategy + .normalizeExprs(project, relation.output) + .asInstanceOf[Seq[NamedExpression]] + val (scan, output) = PushDownUtils.pruneColumns( + scanBuilder, relation, normalizedProjects, postScanFilters) + (scan, output, normalizedProjects) + } - withProjection + private def buildLogicalPlan( + project: Seq[NamedExpression], + relation: DataSourceV2Relation, + wrappedScan: Scan, + output: Seq[AttributeReference], + normalizedProjects: Seq[NamedExpression], + postScanFilters: Seq[Expression]): LogicalPlan = { + val scanRelation = DataSourceV2ScanRelation(relation, wrappedScan, output) + val projectionOverSchema = ProjectionOverSchema(output.toStructType) + val projectionFunc = (expr: Expression) => expr transformDown { + case projectionOverSchema(newExpr) => newExpr + } + + val filterCondition = postScanFilters.reduceLeftOption(And) + val newFilterCondition = filterCondition.map(projectionFunc) + val withFilter = newFilterCondition.map(Filter(_, scanRelation)).getOrElse(scanRelation) + + val withProjection = if (withFilter.output != project) { + val newProjects = normalizedProjects + .map(projectionFunc) + .asInstanceOf[Seq[NamedExpression]] + Project(newProjects, withFilter) + } else { + withFilter + } + withProjection } } @@ -93,6 +229,7 @@ object V2ScanRelationPushDown extends Rule[LogicalPlan] { case class V1ScanWrapper( v1Scan: V1Scan, translatedFilters: Seq[sources.Filter], - handledFilters: Seq[sources.Filter]) extends Scan { + handledFilters: Seq[sources.Filter], + pushedAggregates: sources.Aggregation) extends Scan { override def readSchema(): StructType = v1Scan.readSchema() } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/parquet/ParquetPartitionReaderFactory.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/parquet/ParquetPartitionReaderFactory.scala index 78076040e7cf5..b1d144a8f013f 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/parquet/ParquetPartitionReaderFactory.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/parquet/ParquetPartitionReaderFactory.scala @@ -39,8 +39,8 @@ import org.apache.spark.sql.execution.datasources.parquet._ import org.apache.spark.sql.execution.datasources.v2._ import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.internal.SQLConf.LegacyBehaviorPolicy -import org.apache.spark.sql.sources.Filter -import org.apache.spark.sql.types.{AtomicType, StructType} +import org.apache.spark.sql.sources.{Aggregation, Count, Filter, Max, Min} +import org.apache.spark.sql.types.{AtomicType, LongType, StructField, StructType} import org.apache.spark.sql.vectorized.ColumnarBatch import org.apache.spark.util.SerializableConfiguration @@ -53,6 +53,7 @@ import org.apache.spark.util.SerializableConfiguration * @param readDataSchema Required schema of Parquet files. * @param partitionSchema Schema of partitions. * @param filters Filters to be pushed down in the batch scan. + * @param aggregation Aggregation to be pushed down in the batch scan. * @param parquetOptions The options of Parquet datasource that are set for the read. */ case class ParquetPartitionReaderFactory( @@ -62,9 +63,16 @@ case class ParquetPartitionReaderFactory( readDataSchema: StructType, partitionSchema: StructType, filters: Array[Filter], + aggregation: Aggregation, parquetOptions: ParquetOptions) extends FilePartitionReaderFactory with Logging { private val isCaseSensitive = sqlConf.caseSensitiveAnalysis - private val resultSchema = StructType(partitionSchema.fields ++ readDataSchema.fields) + private val aggSchema = buildAggSchema + private val newReadDataSchema = if (aggregation.aggregateExpressions.isEmpty) { + readDataSchema + } else { + aggSchema + } + private val resultSchema = StructType(partitionSchema.fields ++ newReadDataSchema.fields) private val enableOffHeapColumnVector = sqlConf.offHeapColumnVectorEnabled private val enableVectorizedReader: Boolean = sqlConf.parquetVectorizedReaderEnabled && resultSchema.forall(_.dataType.isInstanceOf[AtomicType]) @@ -80,6 +88,31 @@ case class ParquetPartitionReaderFactory( private val datetimeRebaseModeInRead = parquetOptions.datetimeRebaseModeInRead private val int96RebaseModeInRead = parquetOptions.int96RebaseModeInRead + private def buildAggSchema: StructType = { + var aggSchema = new StructType() + for (i <- 0 until aggregation.aggregateExpressions.size) { + var index = 0 + aggregation.aggregateExpressions(i) match { + case Max(col, _) => + index = dataSchema.fieldNames.toList.indexOf(col) + val field = dataSchema.fields(index) + aggSchema = aggSchema.add(field.copy("max(" + field.name + ")")) + case Min(col, _) => + index = dataSchema.fieldNames.toList.indexOf(col) + val field = dataSchema.fields(index) + aggSchema = aggSchema.add(field.copy("min(" + field.name + ")")) + case Count(col, _, _) => + if (col.equals("1")) { + aggSchema = aggSchema.add(new StructField("count(*)", LongType)) + } else { + aggSchema = aggSchema.add(new StructField("count(" + col + ")", LongType)) + } + case _ => + } + } + aggSchema + } + override def supportColumnarReads(partition: InputPartition): Boolean = { sqlConf.parquetVectorizedReaderEnabled && sqlConf.wholeStageEnabled && resultSchema.length <= sqlConf.wholeStageMaxNumFields && @@ -87,36 +120,83 @@ case class ParquetPartitionReaderFactory( } override def buildReader(file: PartitionedFile): PartitionReader[InternalRow] = { - val reader = if (enableVectorizedReader) { - createVectorizedReader(file) + val fileReader = if (aggregation.aggregateExpressions.isEmpty) { + + val reader = if (enableVectorizedReader) { + createVectorizedReader(file) + } else { + createRowBaseReader(file) + } + + new PartitionReader[InternalRow] { + override def next(): Boolean = reader.nextKeyValue() + + override def get(): InternalRow = reader.getCurrentValue.asInstanceOf[InternalRow] + + override def close(): Unit = reader.close() + } } else { - createRowBaseReader(file) - } + new PartitionReader[InternalRow] { + var count = 0 - val fileReader = new PartitionReader[InternalRow] { - override def next(): Boolean = reader.nextKeyValue() + override def next(): Boolean = { + val hasNext = if (count == 0) true else false + count += 1 + hasNext + } - override def get(): InternalRow = reader.getCurrentValue.asInstanceOf[InternalRow] + override def get(): InternalRow = { + val conf = broadcastedConf.value.value + val filePath = new Path(new URI(file.filePath)) + val (parquetTypes, values) = + ParquetUtils.getPushedDownAggResult(conf, filePath, dataSchema, aggregation) + ParquetUtils.aggResultToSparkInternalRows(parquetTypes, values, aggSchema) + } - override def close(): Unit = reader.close() + override def close(): Unit = return + } } - new PartitionReaderWithPartitionValues(fileReader, readDataSchema, + new PartitionReaderWithPartitionValues(fileReader, newReadDataSchema, partitionSchema, file.partitionValues) } override def buildColumnarReader(file: PartitionedFile): PartitionReader[ColumnarBatch] = { - val vectorizedReader = createVectorizedReader(file) - vectorizedReader.enableReturningBatches() + val fileReader = if (aggregation.aggregateExpressions.isEmpty) { + val vectorizedReader = createVectorizedReader(file) + vectorizedReader.enableReturningBatches() + + new PartitionReader[ColumnarBatch] { + override def next(): Boolean = vectorizedReader.nextKeyValue() - new PartitionReader[ColumnarBatch] { - override def next(): Boolean = vectorizedReader.nextKeyValue() + override def get(): ColumnarBatch = + vectorizedReader.getCurrentValue.asInstanceOf[ColumnarBatch] - override def get(): ColumnarBatch = - vectorizedReader.getCurrentValue.asInstanceOf[ColumnarBatch] + override def close(): Unit = vectorizedReader.close() + } + } else { + new PartitionReader[ColumnarBatch] { + var count = 0 + + override def next(): Boolean = { + val hasNext = if (count == 0) true else false + count += 1 + hasNext + } - override def close(): Unit = vectorizedReader.close() + override def get(): ColumnarBatch = { + val conf = broadcastedConf.value.value + val filePath = new Path(new URI(file.filePath)) + val (parquetTypes, values) = + ParquetUtils.getPushedDownAggResult(conf, filePath, dataSchema, aggregation) + ParquetUtils.aggResultToSparkColumnarBatch(parquetTypes, values, aggSchema, + enableOffHeapColumnVector) + } + + override def close(): Unit = return + } } + fileReader } private def buildReaderBase[T]( diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/parquet/ParquetScan.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/parquet/ParquetScan.scala index 60573ba10ccb6..100d9d60c06c7 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/parquet/ParquetScan.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/parquet/ParquetScan.scala @@ -29,7 +29,7 @@ import org.apache.spark.sql.execution.datasources.PartitioningAwareFileIndex import org.apache.spark.sql.execution.datasources.parquet.{ParquetOptions, ParquetReadSupport, ParquetWriteSupport} import org.apache.spark.sql.execution.datasources.v2.FileScan import org.apache.spark.sql.internal.SQLConf -import org.apache.spark.sql.sources.Filter +import org.apache.spark.sql.sources.{Aggregation, Filter} import org.apache.spark.sql.types.StructType import org.apache.spark.sql.util.CaseInsensitiveStringMap import org.apache.spark.util.SerializableConfiguration @@ -42,6 +42,7 @@ case class ParquetScan( readDataSchema: StructType, readPartitionSchema: StructType, pushedFilters: Array[Filter], + pushedAggregations: Aggregation = Aggregation.empty, options: CaseInsensitiveStringMap, partitionFilters: Seq[Expression] = Seq.empty, dataFilters: Seq[Expression] = Seq.empty) extends FileScan { @@ -86,6 +87,7 @@ case class ParquetScan( readDataSchema, readPartitionSchema, pushedFilters, + pushedAggregations, new ParquetOptions(options.asCaseSensitiveMap.asScala.toMap, sqlConf)) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/parquet/ParquetScanBuilder.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/parquet/ParquetScanBuilder.scala index 44053830defe5..c0a40e611bf56 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/parquet/ParquetScanBuilder.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/parquet/ParquetScanBuilder.scala @@ -20,11 +20,11 @@ package org.apache.spark.sql.execution.datasources.v2.parquet import scala.collection.JavaConverters._ import org.apache.spark.sql.SparkSession -import org.apache.spark.sql.connector.read.{Scan, SupportsPushDownFilters} +import org.apache.spark.sql.connector.read.{Scan, SupportsPushDownAggregates, SupportsPushDownFilters} import org.apache.spark.sql.execution.datasources.PartitioningAwareFileIndex import org.apache.spark.sql.execution.datasources.parquet.{ParquetFilters, SparkToParquetSchemaConverter} import org.apache.spark.sql.execution.datasources.v2.FileScanBuilder -import org.apache.spark.sql.sources.Filter +import org.apache.spark.sql.sources.{Aggregation, Count, Filter, Min, Max} import org.apache.spark.sql.types.StructType import org.apache.spark.sql.util.CaseInsensitiveStringMap @@ -34,7 +34,8 @@ case class ParquetScanBuilder( schema: StructType, dataSchema: StructType, options: CaseInsensitiveStringMap) - extends FileScanBuilder(sparkSession, fileIndex, dataSchema) with SupportsPushDownFilters { + extends FileScanBuilder(sparkSession, fileIndex, dataSchema) with SupportsPushDownFilters + with SupportsPushDownAggregates { lazy val hadoopConf = { val caseSensitiveMap = options.asCaseSensitiveMap.asScala.toMap // Hadoop Configurations are case sensitive. @@ -70,8 +71,32 @@ case class ParquetScanBuilder( // All filters that can be converted to Parquet are pushed down. override def pushedFilters(): Array[Filter] = pushedParquetFilters + private var pushedAggregations = Aggregation.empty + + override def pushAggregation(aggregation: Aggregation): Unit = { + if (!sparkSession.sessionState.conf.parquetAggregatePushDown || + aggregation.groupByExpressions.nonEmpty) { + Aggregation.empty + return + } + + aggregation.aggregateExpressions.foreach { agg => + if (!agg.isInstanceOf[Max] && !agg.isInstanceOf[Min] && !agg.isInstanceOf[Count]) { + Aggregation.empty + return + } else if (agg.isInstanceOf[Count] && agg.asInstanceOf[Count].isDistinct) { + // parquet's statistics doesn't have distinct count info + Aggregation.empty + return + } + } + this.pushedAggregations = aggregation + } + + override def pushedAggregation(): Aggregation = pushedAggregations + override def build(): Scan = { ParquetScan(sparkSession, hadoopConf, fileIndex, dataSchema, readDataSchema(), - readPartitionSchema(), pushedParquetFilters, options) + readPartitionSchema(), pushedParquetFilters, pushedAggregations, options) } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/parquet/ParquetTable.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/parquet/ParquetTable.scala index e9f9bf8df35e6..261774a3e1e49 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/parquet/ParquetTable.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/parquet/ParquetTable.scala @@ -38,7 +38,7 @@ case class ParquetTable( extends FileTable(sparkSession, options, paths, userSpecifiedSchema) { override def newScanBuilder(options: CaseInsensitiveStringMap): ParquetScanBuilder = - new ParquetScanBuilder(sparkSession, fileIndex, schema, dataSchema, options) + ParquetScanBuilder(sparkSession, fileIndex, schema, dataSchema, options) override def inferSchema(files: Seq[FileStatus]): Option[StructType] = ParquetUtils.inferSchema(sparkSession, options.asScala.toMap, files) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetQuerySuite.scala index 9ef43995467c6..373af907a9a78 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetQuerySuite.scala @@ -50,6 +50,54 @@ abstract class ParquetQuerySuite extends QueryTest with ParquetTest with SharedS } } + test("test aggregate pushdown") { + spark.conf.set(SQLConf.PARQUET_AGGREGATE_PUSHDOWN_ENABLED.key, "true") + val data = Seq((-2, "abc", 2), (3, "def", 4), (6, "ghi", 2), (0, null, 19), + (9, "mno", 7), (2, null, 6)) + spark.createDataFrame(data).toDF("c1", "c2", "c3").createOrReplaceTempView("tmp") + withParquetTable(data, "t") { + sql("INSERT OVERWRITE TABLE t SELECT * FROM tmp") + val selectAgg = sql("SELECT min(_3), min(_3), max(_3), min(_1), max(_1), max(_1)," + + " count(*), count(_1), count(_2), count(_3) FROM t") + // selectAgg.explain(true) + // scalastyle:off line.size.limit + // == Parsed Logical Plan == + // 'Project [unresolvedalias('min('_3), None), unresolvedalias('min('_3), None), unresolvedalias('max('_3), None), unresolvedalias('min('_1), None), unresolvedalias('max('_1), None), unresolvedalias('max('_1), None), unresolvedalias('count(1), None), unresolvedalias('count('_1), None), unresolvedalias('count('_2), None), unresolvedalias('count('_3), None)] + // +- 'UnresolvedRelation [t], [], false + // + // == Analyzed Logical Plan == + // min(_3): int, min(_3): int, max(_3): int, min(_1): int, max(_1): int, max(_1): int, count(1): bigint, count(_1): bigint, count(_2): bigint, count(_3): bigint + // Aggregate [min(_3#23) AS min(_3)#43, min(_3#23) AS min(_3)#44, max(_3#23) AS max(_3)#45, min(_1#21) AS min(_1)#46, max(_1#21) AS max(_1)#47, max(_1#21) AS max(_1)#48, count(1) AS count(1)#49L, count(_1#21) AS count(_1)#50L, count(_2#22) AS count(_2)#51L, count(_3#23) AS count(_3)#52L] + // +- SubqueryAlias t + // +- View (`t`, [_1#21,_2#22,_3#23]) + // +- RelationV2[_1#21, _2#22, _3#23] parquet file:/private/var/folders/hm/dghdj3hn791fd9bfmwnl12km0000gn/T/spark-6609ad86-7ff5-4f83-96e2-fea5f2c85646 + // + // == Optimized Logical Plan == + // Aggregate [min(Min(_3,IntegerType)#66) AS min(_3)#43, min(Min(_3,IntegerType)#67) AS min(_3)#44, max(Max(_3,IntegerType)#68) AS max(_3)#45, min(Min(_1,IntegerType)#69) AS min(_1)#46, max(Max(_1,IntegerType)#70) AS max(_1)#47, max(Max(_1,IntegerType)#71) AS max(_1)#48, count(Count(1,LongType,false)#72L) AS count(1)#49L, count(Count(_1,LongType,false)#73L) AS count(_1)#50L, count(Count(_2,LongType,false)#74L) AS count(_2)#51L, count(Count(_3,LongType,false)#75L) AS count(_3)#52L] + // +- RelationV2[Min(_3,IntegerType)#66, Min(_3,IntegerType)#67, Max(_3,IntegerType)#68, Min(_1,IntegerType)#69, Max(_1,IntegerType)#70, Max(_1,IntegerType)#71, Count(1,LongType,false)#72L, Count(_1,LongType,false)#73L, Count(_2,LongType,false)#74L, Count(_3,LongType,false)#75L] parquet file:/private/var/folders/hm/dghdj3hn791fd9bfmwnl12km0000gn/T/spark-6609ad86-7ff5-4f83-96e2-fea5f2c85646 + // + // == Physical Plan == + // AdaptiveSparkPlan isFinalPlan=false + // +- HashAggregate(keys=[], functions=[min(Min(_3,IntegerType)#66), min(Min(_3,IntegerType)#67), max(Max(_3,IntegerType)#68), min(Min(_1,IntegerType)#69), max(Max(_1,IntegerType)#70), max(Max(_1,IntegerType)#71), count(Count(1,LongType,false)#72L), count(Count(_1,LongType,false)#73L), count(Count(_2,LongType,false)#74L), count(Count(_3,LongType,false)#75L)], output=[min(_3)#43, min(_3)#44, max(_3)#45, min(_1)#46, max(_1)#47, max(_1)#48, count(1)#49L, count(_1)#50L, count(_2)#51L, count(_3)#52L]) + // +- HashAggregate(keys=[], functions=[partial_min(Min(_3,IntegerType)#66), partial_min(Min(_3,IntegerType)#67), partial_max(Max(_3,IntegerType)#68), partial_min(Min(_1,IntegerType)#69), partial_max(Max(_1,IntegerType)#70), partial_max(Max(_1,IntegerType)#71), partial_count(Count(1,LongType,false)#72L), partial_count(Count(_1,LongType,false)#73L), partial_count(Count(_2,LongType,false)#74L), partial_count(Count(_3,LongType,false)#75L)], output=[min#86, min#87, max#88, min#89, max#90, max#91, count#92L, count#93L, count#94L, count#95L]) + // +- Project [Min(_3,IntegerType)#66, Min(_3,IntegerType)#67, Max(_3,IntegerType)#68, Min(_1,IntegerType)#69, Max(_1,IntegerType)#70, Max(_1,IntegerType)#71, Count(1,LongType,false)#72L, Count(_1,LongType,false)#73L, Count(_2,LongType,false)#74L, Count(_3,LongType,false)#75L] + // +- BatchScan[Min(_3,IntegerType)#66, Min(_3,IntegerType)#67, Max(_3,IntegerType)#68, Min(_1,IntegerType)#69, Max(_1,IntegerType)#70, Max(_1,IntegerType)#71, Count(1,LongType,false)#72L, Count(_1,LongType,false)#73L, Count(_2,LongType,false)#74L, Count(_3,LongType,false)#75L] ParquetScan DataFilters: [], Format: parquet, Location: InMemoryFileIndex(1 paths)[file:/private/var/folders/hm/dghdj3hn791fd9bfmwnl12km0000gn/T/spark-66..., PartitionFilters: [], PushedFilters: [], ReadSchema: struct<_1:int,_2:string,_3:int>, PushedFilters: [] + // scalastyle:on line.size.limit + + // selectAgg.show() + // +-------+-------+-------+-------+-------+-------+--------+---------+---------+---------+ + // |min(_3)|min(_3)|max(_3)|min(_1)|max(_1)|max(_1)|count(1)|count(_1)|count(_2)|count(_3)| + // +-------+-------+-------+-------+-------+-------+--------+---------+---------+---------+ + // | 2| 2| 19| -2| 9| 9| 6| 6| 4| 6| + // +-------+-------+-------+-------+-------+-------+--------+---------+---------+---------+ + + checkAnswer(selectAgg, Seq(Row(2, 2, 19, -2, 9, 9, 6, 6, 4, 6))) + } + spark.sessionState.catalog.dropTable( + TableIdentifier("tmp"), ignoreIfNotExists = true, purge = false) + spark.conf.unset(SQLConf.PARQUET_AGGREGATE_PUSHDOWN_ENABLED.key) + } + test("appending") { val data = (0 until 10).map(i => (i, i.toString)) spark.createDataFrame(data).toDF("c1", "c2").createOrReplaceTempView("tmp") From 3b7a1704fb13030bb725ec5d413726bbc202e277 Mon Sep 17 00:00:00 2001 From: Huaxin Gao Date: Mon, 29 Mar 2021 12:44:58 -0700 Subject: [PATCH 02/30] change count --- .../expressions/aggregate/CountBase.scala | 75 +++++++++++++++++++ .../expressions/aggregate/PushDownCount.scala | 41 ++++++++++ .../datasources/DataSourceStrategy.scala | 4 +- .../v2/V2ScanRelationPushDown.scala | 3 +- 4 files changed, 118 insertions(+), 5 deletions(-) create mode 100644 sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/CountBase.scala create mode 100644 sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/PushDownCount.scala diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/CountBase.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/CountBase.scala new file mode 100644 index 0000000000000..003bf5b51a204 --- /dev/null +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/CountBase.scala @@ -0,0 +1,75 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.expressions.aggregate + +import org.apache.spark.sql.catalyst.analysis.TypeCheckResult +import org.apache.spark.sql.catalyst.dsl.expressions._ +import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.internal.SQLConf +import org.apache.spark.sql.types._ + +abstract class CountBase(children: Seq[Expression]) extends DeclarativeAggregate { + + override def nullable: Boolean = false + + // Return data type. + override def dataType: DataType = LongType + + override def checkInputDataTypes(): TypeCheckResult = { + if (children.isEmpty && !SQLConf.get.getConf(SQLConf.ALLOW_PARAMETERLESS_COUNT)) { + TypeCheckResult.TypeCheckFailure(s"$prettyName requires at least one argument. " + + s"If you have to call the function $prettyName without arguments, set the legacy " + + s"configuration `${SQLConf.ALLOW_PARAMETERLESS_COUNT.key}` as true") + } else { + TypeCheckResult.TypeCheckSuccess + } + } + + protected lazy val count = AttributeReference("count", LongType, nullable = false)() + + override lazy val aggBufferAttributes = count :: Nil + + override lazy val initialValues = Seq( + /* count = */ Literal(0L) + ) + + override lazy val mergeExpressions = Seq( + /* count = */ count.left + count.right + ) + + override lazy val evaluateExpression = count + + override def defaultResult: Option[Literal] = Option(Literal(0L)) + + private[sql] var pushDown: Boolean = false + + override lazy val updateExpressions = { + val nullableChildren = children.filter(_.nullable) + if (nullableChildren.isEmpty) { + Seq( + /* count = */ count + 1L + ) + } else { + Seq( + /* count = */ If(nullableChildren.map(IsNull).reduce(Or), count, count + 1L) + ) + } + } +} + + diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/PushDownCount.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/PushDownCount.scala new file mode 100644 index 0000000000000..537efdf38ce2d --- /dev/null +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/PushDownCount.scala @@ -0,0 +1,41 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.expressions.aggregate + +import org.apache.spark.sql.catalyst.dsl.expressions._ +import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.types.LongType + +case class PushDownCount(children: Seq[Expression], pushdown: Boolean) extends CountBase(children) { + + protected override lazy val count = + AttributeReference("PushDownCount", LongType, nullable = false)() + + override lazy val updateExpressions = { + Seq( + // if count is pushed down to Data Source layer, add the count result retrieved from + // Data Source + /* count = */ count + children.head + ) + } +} + +object PushDownCount { + def apply(child: Expression, pushdown: Boolean): PushDownCount = + PushDownCount(child :: Nil, pushdown) +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala index b8b6c29ae77e1..11d98899602f8 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala @@ -696,7 +696,6 @@ object DataSourceStrategy } protected[sql] def translateAggregate(aggregates: AggregateExpression): Option[AggregateFunc] = { - aggregates.aggregateFunction match { case min: aggregate.Min => val colName = columnAsString(min.child) @@ -712,8 +711,7 @@ object DataSourceStrategy if (colName.nonEmpty) Some(Sum(colName, sum.dataType, aggregates.isDistinct)) else None case count: aggregate.Count => val columnName = count.children.head match { - case Literal(_, _) => - "1" + case Literal(_, _) => "1" case _ => columnAsString(count.children.head) } if (columnName.nonEmpty) { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/V2ScanRelationPushDown.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/V2ScanRelationPushDown.scala index dc2aaaca55895..45b3be0316623 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/V2ScanRelationPushDown.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/V2ScanRelationPushDown.scala @@ -125,8 +125,7 @@ object V2ScanRelationPushDown extends Rule[LogicalPlan] with AliasHelper { aggregate.Sum(aggOutput(i - 1)) } else if (agg.aggregateFunction.isInstanceOf[aggregate.Count]) { val count = aggregate.Count(aggOutput(i - 1)) - count.pushDown = true - count + aggregate.PushDownCount(aggOutput(i - 1), true) } else { agg.aggregateFunction } From a4b054ad05a51a4800f8317de79ebd9c141eadee Mon Sep 17 00:00:00 2001 From: Huaxin Gao Date: Tue, 30 Mar 2021 09:09:05 -0700 Subject: [PATCH 03/30] change countpwd --- .../sql/execution/CollapseAggregates.scala | 116 ++++++++++++++++++ 1 file changed, 116 insertions(+) create mode 100644 sql/core/src/main/scala/org/apache/spark/sql/execution/CollapseAggregates.scala diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/CollapseAggregates.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/CollapseAggregates.scala new file mode 100644 index 0000000000000..1d6433d2e9b80 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/CollapseAggregates.scala @@ -0,0 +1,116 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution + +import org.apache.spark.sql.catalyst.expressions.aggregate.{Complete, Final, Partial} +import org.apache.spark.sql.catalyst.rules.Rule +import org.apache.spark.sql.execution.aggregate.{BaseAggregateExec, HashAggregateExec, ObjectHashAggregateExec, SortAggregateExec} +// import org.apache.spark.sql.execution.datasources.v2.DataSourceV2ScanRelation +import org.apache.spark.sql.execution.exchange.ShuffleExchangeExec + +/** + * Collapse Physical aggregate exec nodes together if there is no exchange between them and they + * correspond to Partial and Final Aggregation for same + * [[org.apache.spark.sql.catalyst.plans.logical.Aggregate]] logical node. + */ +object CollapseAggregates extends Rule[SparkPlan] { + + override def apply(plan: SparkPlan): SparkPlan = { + collapseAggregates(plan) + } + + private def collapseAggregates(plan: SparkPlan): SparkPlan = { + plan transform { + case parent@HashAggregateExec(_, _, _, _, _, _, child: HashAggregateExec) + if checkIfAggregatesCanBeCollapsed(parent, child) => + val completeAggregateExpressions = child.aggregateExpressions.map(_.copy(mode = Complete)) + HashAggregateExec( + requiredChildDistributionExpressions = Some(child.groupingExpressions), + groupingExpressions = child.groupingExpressions, + aggregateExpressions = completeAggregateExpressions, + aggregateAttributes = completeAggregateExpressions.map(_.resultAttribute), + initialInputBufferOffset = 0, + resultExpressions = parent.resultExpressions, + child = child.child) + + case parent@SortAggregateExec(_, _, _, _, _, _, child: SortAggregateExec) + if checkIfAggregatesCanBeCollapsed(parent, child) => + val completeAggregateExpressions = child.aggregateExpressions.map(_.copy(mode = Complete)) + SortAggregateExec( + requiredChildDistributionExpressions = Some(child.groupingExpressions), + groupingExpressions = child.groupingExpressions, + aggregateExpressions = completeAggregateExpressions, + aggregateAttributes = completeAggregateExpressions.map(_.resultAttribute), + initialInputBufferOffset = 0, + resultExpressions = parent.resultExpressions, + child = child.child) + + case parent@ObjectHashAggregateExec(_, _, _, _, _, _, child: ObjectHashAggregateExec) + if checkIfAggregatesCanBeCollapsed(parent, child) => + val completeAggregateExpressions = child.aggregateExpressions.map(_.copy(mode = Complete)) + ObjectHashAggregateExec( + requiredChildDistributionExpressions = Some(child.groupingExpressions), + groupingExpressions = child.groupingExpressions, + aggregateExpressions = completeAggregateExpressions, + aggregateAttributes = completeAggregateExpressions.map(_.resultAttribute), + initialInputBufferOffset = 0, + resultExpressions = parent.resultExpressions, + child = child.child) + case parent@HashAggregateExec(_, _, _, _, _, _, child) => + // if checkIfAggregatesCanBeCollapsed(parent, child) => + child match { + case ShuffleExchangeExec(_, c, _) => + c match { + case agg@HashAggregateExec(_, _, _, _, _, _, c2) => + c2 match { + case r: SparkPlan => + val completeAggregateExpressions = + agg.aggregateExpressions.map(_.copy(mode = Complete)) + HashAggregateExec( + requiredChildDistributionExpressions = Some(agg.groupingExpressions), + groupingExpressions = agg.groupingExpressions, + aggregateExpressions = completeAggregateExpressions, + aggregateAttributes = completeAggregateExpressions.map(_.resultAttribute), + initialInputBufferOffset = 0, + resultExpressions = parent.resultExpressions, + child = agg.child) + } + } + } + } + } + + private def checkIfAggregatesCanBeCollapsed( + parent: BaseAggregateExec, + child: BaseAggregateExec): Boolean = { + val parentHasFinalMode = parent.aggregateExpressions.forall(_.mode == Final) + if (!parentHasFinalMode) { + return false + } + val childHasPartialMode = child.aggregateExpressions.forall(_.mode == Partial) + if (!childHasPartialMode) { + return false + } + val parentChildAggExpressionsSame = parent.aggregateExpressions.map( + _.copy(mode = Partial)) == child.aggregateExpressions + if (!parentChildAggExpressionsSame) { + return false + } + true + } +} \ No newline at end of file From d576cf64e9ee7f7349f5039f5a9377cda28393e8 Mon Sep 17 00:00:00 2001 From: Huaxin Gao Date: Sat, 3 Apr 2021 23:40:11 -0700 Subject: [PATCH 04/30] [SPARK-34952][SQL] Aggregate (Min/Max/Count) push down for Parquet --- .../expressions/aggregate/PushDownCount.scala | 19 +- .../apache/spark/sql/sources/aggregates.scala | 6 +- .../sql/execution/CollapseAggregates.scala | 116 --------- .../sql/execution/DataSourceScanExec.scala | 2 +- .../datasources/DataSourceStrategy.scala | 27 +- .../datasources/parquet/ParquetUtils.scala | 234 +++++++++++++++--- .../datasources/v2/PushDownUtils.scala | 21 +- .../v2/V2ScanRelationPushDown.scala | 45 ++-- .../ParquetPartitionReaderFactory.scala | 32 ++- .../v2/parquet/ParquetScanBuilder.scala | 4 +- 10 files changed, 286 insertions(+), 220 deletions(-) delete mode 100644 sql/core/src/main/scala/org/apache/spark/sql/execution/CollapseAggregates.scala diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/PushDownCount.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/PushDownCount.scala index 537efdf38ce2d..79de67063c8b4 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/PushDownCount.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/PushDownCount.scala @@ -19,19 +19,20 @@ package org.apache.spark.sql.catalyst.expressions.aggregate import org.apache.spark.sql.catalyst.dsl.expressions._ import org.apache.spark.sql.catalyst.expressions._ -import org.apache.spark.sql.types.LongType case class PushDownCount(children: Seq[Expression], pushdown: Boolean) extends CountBase(children) { - protected override lazy val count = - AttributeReference("PushDownCount", LongType, nullable = false)() - override lazy val updateExpressions = { - Seq( - // if count is pushed down to Data Source layer, add the count result retrieved from - // Data Source - /* count = */ count + children.head - ) + val nullableChildren = children.filter(_.nullable) + if (nullableChildren.isEmpty) { + Seq( + /* count = */ count + children.head + ) + } else { + Seq( + /* count = */ If(nullableChildren.map(IsNull).reduce(Or), count, count + children.head) + ) + } } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/sources/aggregates.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/sources/aggregates.scala index 3bc263ff49f1f..2b98454046a47 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/sources/aggregates.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/sources/aggregates.scala @@ -19,9 +19,11 @@ package org.apache.spark.sql.sources import org.apache.spark.sql.types.DataType -// groupBy only used by JDBC agg pushdown, not supported by parquet agg pushdown yet +// Aggregate Functions in SQL statement. +// e.g. SELECT COUNT(EmployeeID), AVG(salary), deptID FROM dept GROUP BY deptID +// aggregateExpressions are (COUNT(EmployeeID), AVG(salary)), groupByColumns are (deptID) case class Aggregation(aggregateExpressions: Seq[AggregateFunc], - groupByExpressions: Seq[String]) + groupByColumns: Seq[String]) abstract class AggregateFunc diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/CollapseAggregates.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/CollapseAggregates.scala deleted file mode 100644 index 1d6433d2e9b80..0000000000000 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/CollapseAggregates.scala +++ /dev/null @@ -1,116 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql.execution - -import org.apache.spark.sql.catalyst.expressions.aggregate.{Complete, Final, Partial} -import org.apache.spark.sql.catalyst.rules.Rule -import org.apache.spark.sql.execution.aggregate.{BaseAggregateExec, HashAggregateExec, ObjectHashAggregateExec, SortAggregateExec} -// import org.apache.spark.sql.execution.datasources.v2.DataSourceV2ScanRelation -import org.apache.spark.sql.execution.exchange.ShuffleExchangeExec - -/** - * Collapse Physical aggregate exec nodes together if there is no exchange between them and they - * correspond to Partial and Final Aggregation for same - * [[org.apache.spark.sql.catalyst.plans.logical.Aggregate]] logical node. - */ -object CollapseAggregates extends Rule[SparkPlan] { - - override def apply(plan: SparkPlan): SparkPlan = { - collapseAggregates(plan) - } - - private def collapseAggregates(plan: SparkPlan): SparkPlan = { - plan transform { - case parent@HashAggregateExec(_, _, _, _, _, _, child: HashAggregateExec) - if checkIfAggregatesCanBeCollapsed(parent, child) => - val completeAggregateExpressions = child.aggregateExpressions.map(_.copy(mode = Complete)) - HashAggregateExec( - requiredChildDistributionExpressions = Some(child.groupingExpressions), - groupingExpressions = child.groupingExpressions, - aggregateExpressions = completeAggregateExpressions, - aggregateAttributes = completeAggregateExpressions.map(_.resultAttribute), - initialInputBufferOffset = 0, - resultExpressions = parent.resultExpressions, - child = child.child) - - case parent@SortAggregateExec(_, _, _, _, _, _, child: SortAggregateExec) - if checkIfAggregatesCanBeCollapsed(parent, child) => - val completeAggregateExpressions = child.aggregateExpressions.map(_.copy(mode = Complete)) - SortAggregateExec( - requiredChildDistributionExpressions = Some(child.groupingExpressions), - groupingExpressions = child.groupingExpressions, - aggregateExpressions = completeAggregateExpressions, - aggregateAttributes = completeAggregateExpressions.map(_.resultAttribute), - initialInputBufferOffset = 0, - resultExpressions = parent.resultExpressions, - child = child.child) - - case parent@ObjectHashAggregateExec(_, _, _, _, _, _, child: ObjectHashAggregateExec) - if checkIfAggregatesCanBeCollapsed(parent, child) => - val completeAggregateExpressions = child.aggregateExpressions.map(_.copy(mode = Complete)) - ObjectHashAggregateExec( - requiredChildDistributionExpressions = Some(child.groupingExpressions), - groupingExpressions = child.groupingExpressions, - aggregateExpressions = completeAggregateExpressions, - aggregateAttributes = completeAggregateExpressions.map(_.resultAttribute), - initialInputBufferOffset = 0, - resultExpressions = parent.resultExpressions, - child = child.child) - case parent@HashAggregateExec(_, _, _, _, _, _, child) => - // if checkIfAggregatesCanBeCollapsed(parent, child) => - child match { - case ShuffleExchangeExec(_, c, _) => - c match { - case agg@HashAggregateExec(_, _, _, _, _, _, c2) => - c2 match { - case r: SparkPlan => - val completeAggregateExpressions = - agg.aggregateExpressions.map(_.copy(mode = Complete)) - HashAggregateExec( - requiredChildDistributionExpressions = Some(agg.groupingExpressions), - groupingExpressions = agg.groupingExpressions, - aggregateExpressions = completeAggregateExpressions, - aggregateAttributes = completeAggregateExpressions.map(_.resultAttribute), - initialInputBufferOffset = 0, - resultExpressions = parent.resultExpressions, - child = agg.child) - } - } - } - } - } - - private def checkIfAggregatesCanBeCollapsed( - parent: BaseAggregateExec, - child: BaseAggregateExec): Boolean = { - val parentHasFinalMode = parent.aggregateExpressions.forall(_.mode == Final) - if (!parentHasFinalMode) { - return false - } - val childHasPartialMode = child.aggregateExpressions.forall(_.mode == Partial) - if (!childHasPartialMode) { - return false - } - val parentChildAggExpressionsSame = parent.aggregateExpressions.map( - _.copy(mode = Partial)) == child.aggregateExpressions - if (!parentChildAggExpressionsSame) { - return false - } - true - } -} \ No newline at end of file diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/DataSourceScanExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/DataSourceScanExec.scala index 29f9bb8be0900..c50819893812c 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/DataSourceScanExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/DataSourceScanExec.scala @@ -136,7 +136,7 @@ case class RowDataSourceScanExec( val markedAggregates = for (aggregate <- aggregation.aggregateExpressions) yield { s"*$aggregate" } - val markedGroupby = for (groupby <- aggregation.groupByExpressions) yield { + val markedGroupby = for (groupby <- aggregation.groupByColumns) yield { s"*$groupby" } Map( diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala index 11d98899602f8..2c1a2616c67e6 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala @@ -677,25 +677,14 @@ object DataSourceStrategy (nonconvertiblePredicates ++ unhandledPredicates, pushedFilters, handledFilters) } - private def columnAsString(e: Expression): String = e match { - case AttributeReference(name, _, _, _) => name - case Cast(child, _, _) => columnAsString (child) - // Add, Subtract, Multiply and Divide are only supported by JDBC agg pushdown - case Add(left, right, _) => - columnAsString(left) + " + " + columnAsString(right) - case Subtract(left, right, _) => - columnAsString(left) + " - " + columnAsString(right) - case Multiply(left, right, _) => - columnAsString(left) + " * " + columnAsString(right) - case Divide(left, right, _) => - columnAsString(left) + " / " + columnAsString(right) - - case CheckOverflow(child, _, _) => columnAsString (child) - case PromotePrecision(child) => columnAsString (child) - case _ => "" - } - protected[sql] def translateAggregate(aggregates: AggregateExpression): Option[AggregateFunc] = { + + def columnAsString(e: Expression): String = e match { + case AttributeReference(name, _, _, _) => name + case Cast(child, _, _) => columnAsString(child) + case _ => "" + } + aggregates.aggregateFunction match { case min: aggregate.Min => val colName = columnAsString(min.child) @@ -711,7 +700,7 @@ object DataSourceStrategy if (colName.nonEmpty) Some(Sum(colName, sum.dataType, aggregates.isDistinct)) else None case count: aggregate.Count => val columnName = count.children.head match { - case Literal(_, _) => "1" + case Literal(_, _) => "1" // SELECT (*) FROM table is translated to SELECT 1 FROM table case _ => columnAsString(count.children.head) } if (columnName.nonEmpty) { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetUtils.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetUtils.scala index 9afa2974a49b7..db0137c5bf2f6 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetUtils.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetUtils.scala @@ -16,26 +16,31 @@ */ package org.apache.spark.sql.execution.datasources.parquet +import java.math.{BigDecimal, BigInteger} +import java.time.{ZoneId, ZoneOffset} import java.util import scala.collection.JavaConverters._ import scala.collection.mutable.ArrayBuilder import scala.language.existentials -import org.apache.hadoop.conf.Configuration import org.apache.hadoop.fs.{FileStatus, Path} -import org.apache.parquet.format.converter.ParquetMetadataConverter.NO_FILTER import org.apache.parquet.hadoop.ParquetFileWriter import org.apache.parquet.hadoop.metadata.{ColumnChunkMetaData, ParquetMetadata} +import org.apache.parquet.io.api.Binary import org.apache.parquet.schema.PrimitiveType import org.apache.spark.sql.SparkSession import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.SpecificInternalRow +import org.apache.spark.sql.catalyst.util.DateTimeUtils +import org.apache.spark.sql.execution.datasources.DataSourceUtils import org.apache.spark.sql.execution.vectorized.{OffHeapColumnVector, OnHeapColumnVector} import org.apache.spark.sql.sources.{Aggregation, Count, Max, Min} -import org.apache.spark.sql.types.StructType +import org.apache.spark.sql.types.{BinaryType, ByteType, DateType, Decimal, DecimalType, IntegerType, LongType, ShortType, StringType, StructType, TimestampType} import org.apache.spark.sql.vectorized.{ColumnarBatch, ColumnVector} +import org.apache.spark.unsafe.types.UTF8String + object ParquetUtils { def inferSchema( @@ -144,26 +149,124 @@ object ParquetUtils { } private[sql] def aggResultToSparkInternalRows( + footer: ParquetMetadata, parquetTypes: Seq[PrimitiveType.PrimitiveTypeName], values: Seq[Any], - dataSchema: StructType): InternalRow = { + dataSchema: StructType, + datetimeRebaseModeInRead: String, + int96RebaseModeInRead: String, + convertTz: Option[ZoneId]): InternalRow = { val mutableRow = new SpecificInternalRow(dataSchema.fields.map(x => x.dataType)) - + val footerFileMetaData = footer.getFileMetaData + val datetimeRebaseMode = DataSourceUtils.datetimeRebaseMode( + footerFileMetaData.getKeyValueMetaData.get, + datetimeRebaseModeInRead) + val int96RebaseMode = DataSourceUtils.int96RebaseMode( + footerFileMetaData.getKeyValueMetaData.get, + int96RebaseModeInRead) parquetTypes.zipWithIndex.map { case (PrimitiveType.PrimitiveTypeName.INT32, i) => - mutableRow.setInt(i, values(i).asInstanceOf[Int]) + if (values(i) == null) { + mutableRow.setNullAt(i) + } else { + dataSchema.fields(i).dataType match { + case b: ByteType => + mutableRow.setByte(i, values(i).asInstanceOf[Integer].toByte) + case s: ShortType => + mutableRow.setShort(i, values(i).asInstanceOf[Integer].toShort) + case int: IntegerType => + mutableRow.setInt(i, values(i).asInstanceOf[Integer]) + case d: DateType => + val dateRebaseFunc = DataSourceUtils.creteDateRebaseFuncInRead( + datetimeRebaseMode, "Parquet") + mutableRow.update(i, dateRebaseFunc(values(i).asInstanceOf[Integer])) + case d: DecimalType => + val decimal = Decimal(values(i).asInstanceOf[Integer].toLong, d.precision, d.scale) + mutableRow.setDecimal(i, decimal, d.precision) + case _ => throw new IllegalArgumentException("Unexpected type for INT32") + } + } case (PrimitiveType.PrimitiveTypeName.INT64, i) => - mutableRow.setLong(i, values(i).asInstanceOf[Long]) + if (values(i) == null) { + mutableRow.setNullAt(i) + } else { + dataSchema.fields(i).dataType match { + case long: LongType => + mutableRow.setLong(i, values(i).asInstanceOf[Long]) + case d: DecimalType => + val decimal = Decimal(values(i).asInstanceOf[Integer].toLong, d.precision, d.scale) + mutableRow.setDecimal(i, decimal, d.precision) + case _ => throw new IllegalArgumentException("Unexpected type for INT64") + } + } case (PrimitiveType.PrimitiveTypeName.INT96, i) => - mutableRow.setLong(i, values(i).asInstanceOf[Long]) + if (values(i) == null) { + mutableRow.setNullAt(i) + } else { + dataSchema.fields(i).dataType match { + case l: LongType => + mutableRow.setLong(i, values(i).asInstanceOf[Long]) + case d: TimestampType => + val int96RebaseFunc = DataSourceUtils.creteTimestampRebaseFuncInRead( + int96RebaseMode, "Parquet INT96") + val julianMicros = + ParquetRowConverter.binaryToSQLTimestamp(values(i).asInstanceOf[Binary]) + val gregorianMicros = int96RebaseFunc(julianMicros) + val adjTime = + convertTz.map(DateTimeUtils.convertTz(gregorianMicros, _, ZoneOffset.UTC)) + .getOrElse(gregorianMicros) + mutableRow.setLong(i, adjTime) + case _ => + } + } case (PrimitiveType.PrimitiveTypeName.FLOAT, i) => - mutableRow.setFloat(i, values(i).asInstanceOf[Float]) + if (values(i) == null) { + mutableRow.setNullAt(i) + } else { + mutableRow.setFloat(i, values(i).asInstanceOf[Float]) + } case (PrimitiveType.PrimitiveTypeName.DOUBLE, i) => - mutableRow.setDouble(i, values(i).asInstanceOf[Double]) + if (values(i) == null) { + mutableRow.setNullAt(i) + } else { + mutableRow.setDouble(i, values(i).asInstanceOf[Double]) + } + case (PrimitiveType.PrimitiveTypeName.BOOLEAN, i) => + if (values(i) == null) { + mutableRow.setNullAt(i) + } else { + mutableRow.setBoolean(i, values(i).asInstanceOf[Boolean]) + } case (PrimitiveType.PrimitiveTypeName.BINARY, i) => - mutableRow.update(i, values(i).asInstanceOf[Array[Byte]]) + if (values(i) == null) { + mutableRow.setNullAt(i) + } else { + val bytes = values(i).asInstanceOf[Binary].getBytes + dataSchema.fields(i).dataType match { + case s: StringType => + mutableRow.update(i, UTF8String.fromBytes(bytes)) + case b: BinaryType => + mutableRow.update(i, bytes) + case d: DecimalType => + val decimal = + Decimal(new BigDecimal(new BigInteger(bytes), d.scale), d.precision, d.scale) + mutableRow.setDecimal(i, decimal, d.precision) + case _ => throw new IllegalArgumentException("Unexpected type for Binary") + } + } case (PrimitiveType.PrimitiveTypeName.FIXED_LEN_BYTE_ARRAY, i) => - mutableRow.update(i, values(i).asInstanceOf[Array[Byte]]) + if (values(i) == null) { + mutableRow.setNullAt(i) + } else { + val bytes = values(i).asInstanceOf[Binary].getBytes + dataSchema.fields(i).dataType match { + case d: DecimalType => + val decimal = + Decimal(new BigDecimal(new BigInteger(bytes), d.scale), d.precision, d.scale) + mutableRow.setDecimal(i, decimal, d.precision) + case _ => throw new IllegalArgumentException("Unexpected type for FIXED_LEN_BYTE_ARRAY") + } + } case _ => throw new IllegalArgumentException("Unexpected parquet type name") } @@ -171,34 +274,105 @@ object ParquetUtils { } private[sql] def aggResultToSparkColumnarBatch( + footer: ParquetMetadata, parquetTypes: Seq[PrimitiveType.PrimitiveTypeName], values: Seq[Any], - readDataSchema: StructType, - offHeap: Boolean): ColumnarBatch = { + dataSchema: StructType, + offHeap: Boolean, + datetimeRebaseModeInRead: String, + int96RebaseModeInRead: String, + convertTz: Option[ZoneId]): ColumnarBatch = { val capacity = 4 * 1024 + val footerFileMetaData = footer.getFileMetaData + val datetimeRebaseMode = DataSourceUtils.datetimeRebaseMode( + footerFileMetaData.getKeyValueMetaData.get, + datetimeRebaseModeInRead) + val int96RebaseMode = DataSourceUtils.int96RebaseMode( + footerFileMetaData.getKeyValueMetaData.get, + int96RebaseModeInRead) val columnVectors = if (offHeap) { - OffHeapColumnVector.allocateColumns(capacity, readDataSchema) + OffHeapColumnVector.allocateColumns(capacity, dataSchema) } else { - OnHeapColumnVector.allocateColumns(capacity, readDataSchema) + OnHeapColumnVector.allocateColumns(capacity, dataSchema) } parquetTypes.zipWithIndex.map { case (PrimitiveType.PrimitiveTypeName.INT32, i) => - columnVectors(i).appendInt(values(i).asInstanceOf[Int]) + if (values(i) == null) { + columnVectors(i).appendNull() + } else { + dataSchema.fields(i).dataType match { + case b: ByteType => + columnVectors(i).appendByte(values(i).asInstanceOf[Integer].toByte) + case s: ShortType => + columnVectors(i).appendShort(values(i).asInstanceOf[Integer].toShort) + case int: IntegerType => + columnVectors(i).appendInt(values(i).asInstanceOf[Integer]) + case d: DateType => + val dateRebaseFunc = DataSourceUtils.creteDateRebaseFuncInRead( + datetimeRebaseMode, "Parquet") + columnVectors(i).appendInt(dateRebaseFunc(values(i).asInstanceOf[Integer])) + case _ => throw new IllegalArgumentException("Unexpected type for INT32") + } + } case (PrimitiveType.PrimitiveTypeName.INT64, i) => - columnVectors(i).appendLong(values(i).asInstanceOf[Long]) + if (values(i) == null) { + columnVectors(i).appendNull() + } else { + columnVectors(i).appendLong(values(i).asInstanceOf[Long]) + } case (PrimitiveType.PrimitiveTypeName.INT96, i) => - columnVectors(i).appendLong(values(i).asInstanceOf[Long]) + if (values(i) == null) { + columnVectors(i).appendNull() + } else { + dataSchema.fields(i).dataType match { + case l: LongType => + columnVectors(i).appendLong(values(i).asInstanceOf[Long]) + case d: TimestampType => + val int96RebaseFunc = DataSourceUtils.creteTimestampRebaseFuncInRead( + int96RebaseMode, "Parquet INT96") + val julianMicros = + ParquetRowConverter.binaryToSQLTimestamp(values(i).asInstanceOf[Binary]) + val gregorianMicros = int96RebaseFunc(julianMicros) + val adjTime = + convertTz.map(DateTimeUtils.convertTz(gregorianMicros, _, ZoneOffset.UTC)) + .getOrElse(gregorianMicros) + columnVectors(i).appendLong(adjTime) + case _ => throw new IllegalArgumentException("Unexpected type for INT96") + } + } case (PrimitiveType.PrimitiveTypeName.FLOAT, i) => - columnVectors(i).appendFloat(values(i).asInstanceOf[Float]) + if (values(i) == null) { + columnVectors(i).appendNull() + } else { + columnVectors(i).appendFloat(values(i).asInstanceOf[Float]) + } case (PrimitiveType.PrimitiveTypeName.DOUBLE, i) => - columnVectors(i).appendDouble(values(i).asInstanceOf[Double]) + if (values(i) == null) { + columnVectors(i).appendNull() + } else { + columnVectors(i).appendDouble(values(i).asInstanceOf[Double]) + } case (PrimitiveType.PrimitiveTypeName.BINARY, i) => - val byteArray = values(i).asInstanceOf[Array[Byte]] - columnVectors(i).appendBytes(byteArray.length, byteArray, 0) + if (values(i) == null) { + columnVectors(i).appendNull() + } else { + val bytes = values(i).asInstanceOf[Binary].getBytes + columnVectors(i).putByteArray(0, bytes, 0, bytes.length) + } case (PrimitiveType.PrimitiveTypeName.FIXED_LEN_BYTE_ARRAY, i) => - val byteArray = values(i).asInstanceOf[Array[Byte]] - columnVectors(i).appendBytes(byteArray.length, byteArray, 0) + if (values(i) == null) { + columnVectors(i).appendNull() + } else { + val bytes = values(i).asInstanceOf[Binary].getBytes + columnVectors(i).putByteArray(0, bytes, 0, bytes.length) + } + case (PrimitiveType.PrimitiveTypeName.BOOLEAN, i) => + if (values(i) == null) { + columnVectors(i).appendNull() + } else { + columnVectors(i).appendBoolean(values(i).asInstanceOf[Boolean]) + } case _ => throw new IllegalArgumentException("Unexpected parquet type name") } @@ -206,17 +380,15 @@ object ParquetUtils { } private[sql] def getPushedDownAggResult( - conf: Configuration, - file: Path, + footer: ParquetMetadata, dataSchema: StructType, aggregation: Aggregation) : (Array[PrimitiveType.PrimitiveTypeName], Array[Any]) = { - - val footer = ParquetFooterReader.readFooter(conf, file, NO_FILTER) - val fields = footer.getFileMetaData.getSchema.getFields + val footerFileMetaData = footer.getFileMetaData + val fields = footerFileMetaData.getSchema.getFields + val blocks = footer.getBlocks() val typesBuilder = ArrayBuilder.make[PrimitiveType.PrimitiveTypeName] val valuesBuilder = ArrayBuilder.make[Any] - val blocks = footer.getBlocks() blocks.forEach { block => val columns = block.getColumns() diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/PushDownUtils.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/PushDownUtils.scala index b6192e058e46e..3e8885cbfa9df 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/PushDownUtils.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/PushDownUtils.scala @@ -89,22 +89,15 @@ object PushDownUtils extends PredicateHelper { scanBuilder match { case r: SupportsPushDownAggregates => - val translatedAggregates = mutable.ArrayBuffer.empty[sources.AggregateFunc] + val translatedAggregates = aggregates.map(DataSourceStrategy.translateAggregate) + val translatedGroupBys = groupBy.map(columnAsString) - for (aggregateExpr <- aggregates) { - val translated = DataSourceStrategy.translateAggregate(aggregateExpr) - if (translated.isEmpty) { - return Aggregation.empty - } else { - translatedAggregates += translated.get - } - } - val groupByCols = groupBy.map(columnAsString(_)) - if (!groupByCols.exists(_.isEmpty)) { - r.pushAggregation(Aggregation(translatedAggregates, groupByCols)) + if (translatedAggregates.exists(_.isEmpty) || translatedGroupBys.exists(_.isEmpty)) { + Aggregation.empty + } else { + r.pushAggregation(Aggregation(translatedAggregates.flatten, translatedGroupBys)) + r.pushedAggregation } - r.pushedAggregation - case _ => Aggregation.empty } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/V2ScanRelationPushDown.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/V2ScanRelationPushDown.scala index 45b3be0316623..00c398b24a70d 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/V2ScanRelationPushDown.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/V2ScanRelationPushDown.scala @@ -34,7 +34,7 @@ object V2ScanRelationPushDown extends Rule[LogicalPlan] with AliasHelper { import DataSourceV2Implicits._ override def apply(plan: LogicalPlan): LogicalPlan = plan transformDown { - case Aggregate(groupingExpressions, resultExpressions, child) => + case aggNode@Aggregate(groupingExpressions, resultExpressions, child) => child match { case ScanOperation(project, filters, relation: DataSourceV2Relation) => val scanBuilder = relation.table.asReadable.newScanBuilder(relation.options) @@ -67,7 +67,7 @@ object V2ScanRelationPushDown extends Rule[LogicalPlan] with AliasHelper { val (pushedFilters, postScanFilters) = pushDownFilter(scanBuilder, newFilters, relation) if (postScanFilters.nonEmpty) { - Aggregate(groupingExpressions, resultExpressions, child) + aggNode // return original plan node } else { // only push down aggregate if all the filers can be push down val aggregation = PushDownUtils.pushAggregates(scanBuilder, aggregates, normalizedGroupingExpressions) @@ -81,7 +81,7 @@ object V2ScanRelationPushDown extends Rule[LogicalPlan] with AliasHelper { |Pushed Filters: ${pushedFilters.mkString(", ")} |Post-Scan Filters: ${postScanFilters.mkString(",")} |Pushed Aggregate Functions: ${aggregation.aggregateExpressions.mkString(", ")} - |Pushed Groupby: ${aggregation.groupByExpressions.mkString(", ")} + |Pushed Groupby: ${aggregation.groupByColumns.mkString(", ")} |Output: ${output.mkString(", ")} """.stripMargin) @@ -93,8 +93,9 @@ object V2ScanRelationPushDown extends Rule[LogicalPlan] with AliasHelper { } if (aggregation.aggregateExpressions.isEmpty) { - Aggregate(groupingExpressions, resultExpressions, child) + aggNode // return original plan node } else { + // build the aggregate expressions + groupby expressions val aggOutputBuilder = ArrayBuilder.make[AttributeReference] for (i <- 0 until aggregates.length) { aggOutputBuilder += AttributeReference( @@ -111,31 +112,35 @@ object V2ScanRelationPushDown extends Rule[LogicalPlan] with AliasHelper { val plan = Aggregate(groupingExpressions, resultExpressions, r) var i = 0 + // scalastyle:off line.size.limit + // change the original optimized logical plan to reflect the pushed down aggregate + // e.g. sql("select max(id), min(id) FROM h2.test.people") + // the the original optimized logical plan is + // == Optimized Logical Plan == + // Aggregate [max(ID#35) AS max(ID)#38, min(ID#35) AS min(ID)#39] + // +- RelationV2[ID#35] test.people + // We want to change it to the following + // == Optimized Logical Plan == + // Aggregate [max(Max(ID,IntegerType)#298) AS max(ID)#293, min(Min(ID,IntegerType)#299) AS min(ID)#294] + // +- RelationV2[Max(ID,IntegerType)#298, Min(ID,IntegerType)#299] test.people + // scalastyle:on line.size.limit plan.transformExpressions { case agg: AggregateExpression => i += 1 - val aggFunction: aggregate.AggregateFunction = { - if (agg.aggregateFunction.isInstanceOf[aggregate.Max]) { - aggregate.Max(aggOutput(i - 1)) - } else if (agg.aggregateFunction.isInstanceOf[aggregate.Min]) { - aggregate.Min(aggOutput(i - 1)) - } else if (agg.aggregateFunction.isInstanceOf[aggregate.Average]) { - aggregate.Average(aggOutput(i - 1)) - } else if (agg.aggregateFunction.isInstanceOf[aggregate.Sum]) { - aggregate.Sum(aggOutput(i - 1)) - } else if (agg.aggregateFunction.isInstanceOf[aggregate.Count]) { - val count = aggregate.Count(aggOutput(i - 1)) - aggregate.PushDownCount(aggOutput(i - 1), true) - } else { - agg.aggregateFunction - } + val aggFunction: aggregate.AggregateFunction = agg.aggregateFunction match { + case max: aggregate.Max => aggregate.Max(aggOutput(i - 1)) + case min: aggregate.Min => aggregate.Min(aggOutput(i - 1)) + case sum: aggregate.Sum => aggregate.Sum(aggOutput(i - 1)) + case avg: aggregate.Average => aggregate.Average(aggOutput(i - 1)) + case count: aggregate.Count => aggregate.PushDownCount(aggOutput(i - 1), true) + case _ => agg.aggregateFunction } agg.copy(aggregateFunction = aggFunction, filter = None) } } } - case _ => Aggregate(groupingExpressions, resultExpressions, child) + case _ => aggNode // return original plan node } case ScanOperation(project, filters, relation: DataSourceV2Relation) => val scanBuilder = relation.table.asReadable.newScanBuilder(relation.options) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/parquet/ParquetPartitionReaderFactory.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/parquet/ParquetPartitionReaderFactory.scala index b1d144a8f013f..62a3a4ddc6167 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/parquet/ParquetPartitionReaderFactory.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/parquet/ParquetPartitionReaderFactory.scala @@ -25,7 +25,7 @@ import org.apache.hadoop.mapreduce._ import org.apache.hadoop.mapreduce.task.TaskAttemptContextImpl import org.apache.parquet.filter2.compat.FilterCompat import org.apache.parquet.filter2.predicate.{FilterApi, FilterPredicate} -import org.apache.parquet.format.converter.ParquetMetadataConverter.SKIP_ROW_GROUPS +import org.apache.parquet.format.converter.ParquetMetadataConverter.{NO_FILTER, SKIP_ROW_GROUPS} import org.apache.parquet.hadoop.{ParquetInputFormat, ParquetRecordReader} import org.apache.spark.TaskContext @@ -148,9 +148,19 @@ case class ParquetPartitionReaderFactory( override def get(): InternalRow = { val conf = broadcastedConf.value.value val filePath = new Path(new URI(file.filePath)) + val footer = ParquetFooterReader.readFooter(conf, filePath, NO_FILTER) + def isCreatedByParquetMr: Boolean = + footer.getFileMetaData.getCreatedBy().startsWith("parquet-mr") + val convertTz = + if (timestampConversion && !isCreatedByParquetMr) { + Some(DateTimeUtils.getZoneId(conf.get(SQLConf.SESSION_LOCAL_TIMEZONE.key))) + } else { + None + } val (parquetTypes, values) = - ParquetUtils.getPushedDownAggResult(conf, filePath, dataSchema, aggregation) - ParquetUtils.aggResultToSparkInternalRows(parquetTypes, values, aggSchema) + ParquetUtils.getPushedDownAggResult(footer, dataSchema, aggregation) + ParquetUtils.aggResultToSparkInternalRows(footer, parquetTypes, values, aggSchema, + datetimeRebaseModeInRead, int96RebaseModeInRead, convertTz) } override def close(): Unit = return @@ -185,12 +195,22 @@ case class ParquetPartitionReaderFactory( } override def get(): ColumnarBatch = { + val conf = broadcastedConf.value.value val filePath = new Path(new URI(file.filePath)) + val footer = ParquetFooterReader.readFooter(conf, filePath, NO_FILTER) + def isCreatedByParquetMr: Boolean = + footer.getFileMetaData.getCreatedBy().startsWith("parquet-mr") + val convertTz = + if (timestampConversion && !isCreatedByParquetMr) { + Some(DateTimeUtils.getZoneId(conf.get(SQLConf.SESSION_LOCAL_TIMEZONE.key))) + } else { + None + } val (parquetTypes, values) = - ParquetUtils.getPushedDownAggResult(conf, filePath, dataSchema, aggregation) - ParquetUtils.aggResultToSparkColumnarBatch(parquetTypes, values, aggSchema, - enableOffHeapColumnVector) + ParquetUtils.getPushedDownAggResult(footer, dataSchema, aggregation) + ParquetUtils.aggResultToSparkColumnarBatch(footer, parquetTypes, values, aggSchema, + enableOffHeapColumnVector, datetimeRebaseModeInRead, int96RebaseModeInRead, convertTz) } override def close(): Unit = return diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/parquet/ParquetScanBuilder.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/parquet/ParquetScanBuilder.scala index c0a40e611bf56..d7fa474e1133c 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/parquet/ParquetScanBuilder.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/parquet/ParquetScanBuilder.scala @@ -24,7 +24,7 @@ import org.apache.spark.sql.connector.read.{Scan, SupportsPushDownAggregates, Su import org.apache.spark.sql.execution.datasources.PartitioningAwareFileIndex import org.apache.spark.sql.execution.datasources.parquet.{ParquetFilters, SparkToParquetSchemaConverter} import org.apache.spark.sql.execution.datasources.v2.FileScanBuilder -import org.apache.spark.sql.sources.{Aggregation, Count, Filter, Min, Max} +import org.apache.spark.sql.sources.{Aggregation, Count, Filter, Max, Min} import org.apache.spark.sql.types.StructType import org.apache.spark.sql.util.CaseInsensitiveStringMap @@ -75,7 +75,7 @@ case class ParquetScanBuilder( override def pushAggregation(aggregation: Aggregation): Unit = { if (!sparkSession.sessionState.conf.parquetAggregatePushDown || - aggregation.groupByExpressions.nonEmpty) { + aggregation.groupByColumns.nonEmpty) { Aggregation.empty return } From a66a87c3a1570d8e8c1dcd0ce68bc18d46b59ce1 Mon Sep 17 00:00:00 2001 From: Huaxin Gao Date: Sun, 4 Apr 2021 00:46:29 -0700 Subject: [PATCH 05/30] remove blank --- .../spark/sql/catalyst/expressions/aggregate/CountBase.scala | 2 -- 1 file changed, 2 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/CountBase.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/CountBase.scala index 003bf5b51a204..4a30611727b2e 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/CountBase.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/CountBase.scala @@ -71,5 +71,3 @@ abstract class CountBase(children: Seq[Expression]) extends DeclarativeAggregate } } } - - From e561cc297be62eaec6b5f65b91f6c6e2d4cf8996 Mon Sep 17 00:00:00 2001 From: Huaxin Gao Date: Sun, 4 Apr 2021 09:02:49 -0700 Subject: [PATCH 06/30] fix build failure --- .../datasources/v2/parquet/ParquetScan.scala | 15 ++++++++++++--- .../org/apache/spark/sql/FileScanSuite.scala | 5 +++-- 2 files changed, 15 insertions(+), 5 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/parquet/ParquetScan.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/parquet/ParquetScan.scala index 100d9d60c06c7..d634f85564f1e 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/parquet/ParquetScan.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/parquet/ParquetScan.scala @@ -94,21 +94,30 @@ case class ParquetScan( override def equals(obj: Any): Boolean = obj match { case p: ParquetScan => super.equals(p) && dataSchema == p.dataSchema && options == p.options && - equivalentFilters(pushedFilters, p.pushedFilters) + equivalentFilters(pushedFilters, p.pushedFilters) && + equivalentAggregations(pushedAggregations, p.pushedAggregations) case _ => false } override def hashCode(): Int = getClass.hashCode() override def description(): String = { - super.description() + ", PushedFilters: " + seqToString(pushedFilters) + super.description() + ", PushedFilters: " + seqToString(pushedFilters) + + ", PushedAggregation: " + seqToString(pushedAggregations.aggregateExpressions) } override def getMetaData(): Map[String, String] = { - super.getMetaData() ++ Map("PushedFilters" -> seqToString(pushedFilters)) + super.getMetaData() ++ Map("PushedFilters" -> seqToString(pushedFilters)) ++ + Map("PushedAggregation" -> seqToString(pushedAggregations.aggregateExpressions)) } override def withFilters( partitionFilters: Seq[Expression], dataFilters: Seq[Expression]): FileScan = this.copy(partitionFilters = partitionFilters, dataFilters = dataFilters) + + // Returns whether the two given [[Aggregation]]s are equivalent. + private def equivalentAggregations(a: Aggregation, b: Aggregation): Boolean = { + a.aggregateExpressions.sortBy(_.hashCode()) + .sameElements(b.aggregateExpressions.sortBy(_.hashCode())) + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/FileScanSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/FileScanSuite.scala index 4e7fe8455ff93..fcc95de2d213d 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/FileScanSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/FileScanSuite.scala @@ -31,7 +31,7 @@ import org.apache.spark.sql.execution.datasources.v2.json.JsonScan import org.apache.spark.sql.execution.datasources.v2.orc.OrcScan import org.apache.spark.sql.execution.datasources.v2.parquet.ParquetScan import org.apache.spark.sql.execution.datasources.v2.text.TextScan -import org.apache.spark.sql.sources.Filter +import org.apache.spark.sql.sources.{Aggregation, Filter} import org.apache.spark.sql.test.SharedSparkSession import org.apache.spark.sql.types.StructType import org.apache.spark.sql.util.CaseInsensitiveStringMap @@ -354,7 +354,8 @@ class FileScanSuite extends FileScanSuiteBase { val scanBuilders = Seq[(String, ScanBuilder, Seq[String])]( ("ParquetScan", (s, fi, ds, rds, rps, f, o, pf, df) => - ParquetScan(s, s.sessionState.newHadoopConf(), fi, ds, rds, rps, f, o, pf, df), + ParquetScan(s, s.sessionState.newHadoopConf(), fi, ds, rds, rps, f, Aggregation.empty, + o, pf, df), Seq.empty), ("OrcScan", (s, fi, ds, rds, rps, f, o, pf, df) => From d52a72c0e663c5eb408efe8c781a111c4f8b1a2c Mon Sep 17 00:00:00 2001 From: Huaxin Gao Date: Sun, 4 Apr 2021 23:05:50 -0700 Subject: [PATCH 07/30] address comments --- .../datasources/DataSourceStrategy.scala | 39 ++----- .../datasources/parquet/ParquetUtils.scala | 104 ++++++++++++------ .../datasources/v2/PushDownUtils.scala | 4 +- .../ParquetPartitionReaderFactory.scala | 1 - 4 files changed, 83 insertions(+), 65 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala index 2c1a2616c67e6..3fbca206874b2 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala @@ -677,36 +677,21 @@ object DataSourceStrategy (nonconvertiblePredicates ++ unhandledPredicates, pushedFilters, handledFilters) } - protected[sql] def translateAggregate(aggregates: AggregateExpression): Option[AggregateFunc] = { - - def columnAsString(e: Expression): String = e match { - case AttributeReference(name, _, _, _) => name - case Cast(child, _, _) => columnAsString(child) - case _ => "" - } - + protected[sql] def translateAggregate( + aggregates: AggregateExpression, + pushableColumn: PushableColumnBase): Option[AggregateFunc] = { aggregates.aggregateFunction match { - case min: aggregate.Min => - val colName = columnAsString(min.child) - if (colName.nonEmpty) Some(Min(colName, min.dataType)) else None - case max: aggregate.Max => - val colName = columnAsString(max.child) - if (colName.nonEmpty) Some(Max(colName, max.dataType)) else None - case avg: aggregate.Average => - val colName = columnAsString(avg.child) - if (colName.nonEmpty) Some(Avg(colName, avg.dataType, aggregates.isDistinct)) else None - case sum: aggregate.Sum => - val colName = columnAsString(sum.child) - if (colName.nonEmpty) Some(Sum(colName, sum.dataType, aggregates.isDistinct)) else None - case count: aggregate.Count => + case min@aggregate.Min(pushableColumn(name)) => + Some(Min(name, min.dataType)) + case max@aggregate.Max(pushableColumn(name)) => + Some(Max(name, max.dataType)) + case count@aggregate.Count(pushableColumn(name)) => val columnName = count.children.head match { - case Literal(_, _) => "1" // SELECT (*) FROM table is translated to SELECT 1 FROM table - case _ => columnAsString(count.children.head) - } - if (columnName.nonEmpty) { - Some(Count(columnName, count.dataType, aggregates.isDistinct)) + // SELECT COUNT(*) FROM table is translated to SELECT 1 FROM table + case Literal(_, _) => "1" + case _ => name } - else None + Some(Count(columnName, count.dataType, aggregates.isDistinct)) case _ => None } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetUtils.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetUtils.scala index db0137c5bf2f6..ea1d272ba95ad 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetUtils.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetUtils.scala @@ -148,6 +148,14 @@ object ParquetUtils { file.getName == ParquetFileWriter.PARQUET_METADATA_FILE } + /** + * When the Aggregates (Max/Min/Count) are pushed down to parquet, we don't need to + * createRowBaseReader to read data from parquet and aggregate at spark layer. Instead we want + * to calculate the Aggregates (Max/Min/Count) result using the statistics information + * from parquet footer file, and then construct an InternalRow from these Aggregate results. + * + * @return Aggregate results in the format of InternalRow + */ private[sql] def aggResultToSparkInternalRows( footer: ParquetMetadata, parquetTypes: Seq[PrimitiveType.PrimitiveTypeName], @@ -194,7 +202,7 @@ object ParquetUtils { case long: LongType => mutableRow.setLong(i, values(i).asInstanceOf[Long]) case d: DecimalType => - val decimal = Decimal(values(i).asInstanceOf[Integer].toLong, d.precision, d.scale) + val decimal = Decimal(values(i).asInstanceOf[Long], d.precision, d.scale) mutableRow.setDecimal(i, decimal, d.precision) case _ => throw new IllegalArgumentException("Unexpected type for INT64") } @@ -216,7 +224,7 @@ object ParquetUtils { convertTz.map(DateTimeUtils.convertTz(gregorianMicros, _, ZoneOffset.UTC)) .getOrElse(gregorianMicros) mutableRow.setLong(i, adjTime) - case _ => + case _ => throw new IllegalArgumentException("Unexpected type for INT96") } } case (PrimitiveType.PrimitiveTypeName.FLOAT, i) => @@ -273,6 +281,15 @@ object ParquetUtils { mutableRow } + /** + * When the Aggregates (Max/Min/Count) are pushed down to parquet, in the case of + * PARQUET_VECTORIZED_READER_ENABLED sets to true, we don't need buildColumnarReader + * to read data from parquet and aggregate at spark layer. Instead we want + * to calculate the Aggregates (Max/Min/Count) result using the statistics information + * from parquet footer file, and then construct a ColumnarBatch from these Aggregate results. + * + * @return Aggregate results in the format of ColumnarBatch + */ private[sql] def aggResultToSparkColumnarBatch( footer: ParquetMetadata, parquetTypes: Seq[PrimitiveType.PrimitiveTypeName], @@ -379,6 +396,14 @@ object ParquetUtils { new ColumnarBatch(columnVectors.asInstanceOf[Array[ColumnVector]], 1) } + /** + * Calculate the pushed down Aggregates (Max/Min/Count) result using the statistics + * information from parquet footer file. + * + * @return A tuple of `Array[PrimitiveType.PrimitiveTypeName]` and Array[Any]. + * The first element is the PrimitiveTypeName of the Aggregate column, + * and the second element is the aggregated value. + */ private[sql] def getPushedDownAggResult( footer: ParquetMetadata, dataSchema: StructType, @@ -390,59 +415,68 @@ object ParquetUtils { val typesBuilder = ArrayBuilder.make[PrimitiveType.PrimitiveTypeName] val valuesBuilder = ArrayBuilder.make[Any] - blocks.forEach { block => - val columns = block.getColumns() - for (i <- 0 until aggregation.aggregateExpressions.size) { - var index = 0 + for (i <- 0 until aggregation.aggregateExpressions.size) { + var value: Any = None + var rowCount = 0L + var isCount = false + var index = 0 + blocks.forEach { block => + val blockMetaData = block.getColumns() aggregation.aggregateExpressions(i) match { case Max(col, _) => index = dataSchema.fieldNames.toList.indexOf(col) - valuesBuilder += getPushedDownMaxMin(footer, columns, index, true) - typesBuilder += fields.get(index).asPrimitiveType.getPrimitiveTypeName + val currentMax = getCurrentBlockMaxOrMin(footer, blockMetaData, index, true) + if (currentMax != None && + (value == None || currentMax.asInstanceOf[Comparable[Any]].compareTo(value) > 0)) { + value = currentMax + } + case Min(col, _) => index = dataSchema.fieldNames.toList.indexOf(col) - valuesBuilder += getPushedDownMaxMin(footer, columns, index, false) - typesBuilder += fields.get(index).asPrimitiveType.getPrimitiveTypeName + val currentMin = getCurrentBlockMaxOrMin(footer, blockMetaData, index, false) + if (currentMin != None && + (value == None || currentMin.asInstanceOf[Comparable[Any]].compareTo(value) < 0)) { + value = currentMin + } + case Count(col, _, _) => index = dataSchema.fieldNames.toList.indexOf(col) - var rowCount = getRowCountFromParquetMetadata(footer) - if (!col.equals("1")) { // count(*) - rowCount -= getNumNulls(footer, columns, index) + rowCount = getRowCountFromParquetMetadata(footer) + if (!col.equals("1")) { // "1" is for count(*) + rowCount -= getNumNulls(footer, blockMetaData, index) } - valuesBuilder += rowCount - typesBuilder += PrimitiveType.PrimitiveTypeName.INT96 + isCount = true + case _ => } } + if (isCount) { + valuesBuilder += rowCount + typesBuilder += PrimitiveType.PrimitiveTypeName.INT96 + } else { + valuesBuilder += value + typesBuilder += fields.get(index).asPrimitiveType.getPrimitiveTypeName + } } (typesBuilder.result(), valuesBuilder.result()) } - private def getPushedDownMaxMin( + /** + * get the Max or Min value for ith column in the current block + * + * @return the Max or Min value + */ + private def getCurrentBlockMaxOrMin( footer: ParquetMetadata, columnChunkMetaData: util.List[ColumnChunkMetaData], i: Int, - isMax: Boolean) = { + isMax: Boolean): Any = { val parquetType = footer.getFileMetaData.getSchema.getType(i) if (!parquetType.isPrimitive) { throw new IllegalArgumentException("Unsupported type : " + parquetType.toString) } - var value: Any = None val statistics = columnChunkMetaData.get(i).getStatistics() - if (isMax) { - val currentMax = statistics.genericGetMax() - if (currentMax != None && - (value == None || currentMax.asInstanceOf[Comparable[Any]].compareTo(value) > 0)) { - value = currentMax - } - } else { - val currentMin = statistics.genericGetMin() - if (currentMin != None && - (value == None || currentMin.asInstanceOf[Comparable[Any]].compareTo(value) < 0)) { - value = currentMin - } - } - value + if (isMax) statistics.genericGetMax() else statistics.genericGetMin() } private def getRowCountFromParquetMetadata(footer: ParquetMetadata): Long = { @@ -459,14 +493,12 @@ object ParquetUtils { i: Int): Long = { val parquetType = footer.getFileMetaData.getSchema.getType(i) if (!parquetType.isPrimitive) { - throw new IllegalArgumentException("Unsupported type : " + parquetType.toString) + throw new IllegalArgumentException("Unsupported type: " + parquetType.toString) } var numNulls: Long = 0; val statistics = columnChunkMetaData.get(i).getStatistics() if (!statistics.isNumNullsSet()) { - throw new UnsupportedOperationException("Number of nulls not set for parquet file." + - " Set session property hive.pushdown_partial_aggregations_into_scan=false and execute" + - " query again"); + throw new UnsupportedOperationException("Number of nulls not set for parquet file."); } numNulls += statistics.getNumNulls(); numNulls diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/PushDownUtils.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/PushDownUtils.scala index 3e8885cbfa9df..208ebc1198ec9 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/PushDownUtils.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/PushDownUtils.scala @@ -24,6 +24,7 @@ import org.apache.spark.sql.catalyst.expressions.aggregate.AggregateExpression import org.apache.spark.sql.catalyst.util.CharVarcharUtils import org.apache.spark.sql.connector.read.{Scan, ScanBuilder, SupportsPushDownAggregates, SupportsPushDownFilters, SupportsPushDownRequiredColumns} import org.apache.spark.sql.execution.datasources.DataSourceStrategy +import org.apache.spark.sql.execution.datasources.PushableColumn import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.sources import org.apache.spark.sql.sources.Aggregation @@ -89,7 +90,8 @@ object PushDownUtils extends PredicateHelper { scanBuilder match { case r: SupportsPushDownAggregates => - val translatedAggregates = aggregates.map(DataSourceStrategy.translateAggregate) + val translatedAggregates = aggregates.map(DataSourceStrategy + .translateAggregate(_, PushableColumn(false))) val translatedGroupBys = groupBy.map(columnAsString) if (translatedAggregates.exists(_.isEmpty) || translatedGroupBys.exists(_.isEmpty)) { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/parquet/ParquetPartitionReaderFactory.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/parquet/ParquetPartitionReaderFactory.scala index 62a3a4ddc6167..5898017bfce98 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/parquet/ParquetPartitionReaderFactory.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/parquet/ParquetPartitionReaderFactory.scala @@ -195,7 +195,6 @@ case class ParquetPartitionReaderFactory( } override def get(): ColumnarBatch = { - val conf = broadcastedConf.value.value val filePath = new Path(new URI(file.filePath)) val footer = ParquetFooterReader.readFooter(conf, filePath, NO_FILTER) From 60b495db61c2367ad7c57cc8c6d220da5f11441a Mon Sep 17 00:00:00 2001 From: Huaxin Gao Date: Sun, 4 Apr 2021 23:17:47 -0700 Subject: [PATCH 08/30] minor --- .../spark/sql/execution/datasources/parquet/ParquetUtils.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetUtils.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetUtils.scala index ea1d272ba95ad..dc853615ea5fc 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetUtils.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetUtils.scala @@ -441,7 +441,7 @@ object ParquetUtils { case Count(col, _, _) => index = dataSchema.fieldNames.toList.indexOf(col) - rowCount = getRowCountFromParquetMetadata(footer) + rowCount += getRowCountFromParquetMetadata(footer) if (!col.equals("1")) { // "1" is for count(*) rowCount -= getNumNulls(footer, blockMetaData, index) } From d954727db2da909d463f2c08e1456e5c7cfb7059 Mon Sep 17 00:00:00 2001 From: Huaxin Gao Date: Sun, 11 Apr 2021 17:01:08 -0700 Subject: [PATCH 09/30] throw Exception if statics is not available in parquet + code clean up --- .../datasources/DataSourceStrategy.scala | 4 +- .../datasources/parquet/ParquetUtils.scala | 274 +++++++----------- .../datasources/v2/PushDownUtils.scala | 10 +- .../v2/V2ScanRelationPushDown.scala | 248 +++++++--------- .../ParquetPartitionReaderFactory.scala | 108 +++---- 5 files changed, 258 insertions(+), 386 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala index 3fbca206874b2..c72f64e0b3d35 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala @@ -685,11 +685,11 @@ object DataSourceStrategy Some(Min(name, min.dataType)) case max@aggregate.Max(pushableColumn(name)) => Some(Max(name, max.dataType)) - case count@aggregate.Count(pushableColumn(name)) => + case count: aggregate.Count => val columnName = count.children.head match { // SELECT COUNT(*) FROM table is translated to SELECT 1 FROM table case Literal(_, _) => "1" - case _ => name + case pushableColumn(name) => name } Some(Count(columnName, count.dataType, aggregates.isDistinct)) case _ => None diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetUtils.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetUtils.scala index dc853615ea5fc..5af1e564070cc 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetUtils.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetUtils.scala @@ -20,7 +20,6 @@ import java.math.{BigDecimal, BigInteger} import java.time.{ZoneId, ZoneOffset} import java.util -import scala.collection.JavaConverters._ import scala.collection.mutable.ArrayBuilder import scala.language.existentials @@ -41,7 +40,6 @@ import org.apache.spark.sql.types.{BinaryType, ByteType, DateType, Decimal, Deci import org.apache.spark.sql.vectorized.{ColumnarBatch, ColumnVector} import org.apache.spark.unsafe.types.UTF8String - object ParquetUtils { def inferSchema( sparkSession: SparkSession, @@ -149,9 +147,9 @@ object ParquetUtils { } /** - * When the Aggregates (Max/Min/Count) are pushed down to parquet, we don't need to + * When the partial Aggregates (Max/Min/Count) are pushed down to parquet, we don't need to * createRowBaseReader to read data from parquet and aggregate at spark layer. Instead we want - * to calculate the Aggregates (Max/Min/Count) result using the statistics information + * to calculate the partial Aggregates (Max/Min/Count) result using the statistics information * from parquet footer file, and then construct an InternalRow from these Aggregate results. * * @return Aggregate results in the format of InternalRow @@ -174,106 +172,74 @@ object ParquetUtils { int96RebaseModeInRead) parquetTypes.zipWithIndex.map { case (PrimitiveType.PrimitiveTypeName.INT32, i) => - if (values(i) == null) { - mutableRow.setNullAt(i) - } else { - dataSchema.fields(i).dataType match { - case b: ByteType => - mutableRow.setByte(i, values(i).asInstanceOf[Integer].toByte) - case s: ShortType => - mutableRow.setShort(i, values(i).asInstanceOf[Integer].toShort) - case int: IntegerType => - mutableRow.setInt(i, values(i).asInstanceOf[Integer]) - case d: DateType => - val dateRebaseFunc = DataSourceUtils.creteDateRebaseFuncInRead( - datetimeRebaseMode, "Parquet") - mutableRow.update(i, dateRebaseFunc(values(i).asInstanceOf[Integer])) - case d: DecimalType => - val decimal = Decimal(values(i).asInstanceOf[Integer].toLong, d.precision, d.scale) - mutableRow.setDecimal(i, decimal, d.precision) - case _ => throw new IllegalArgumentException("Unexpected type for INT32") - } + dataSchema.fields(i).dataType match { + case b: ByteType => + mutableRow.setByte(i, values(i).asInstanceOf[Integer].toByte) + case s: ShortType => + mutableRow.setShort(i, values(i).asInstanceOf[Integer].toShort) + case int: IntegerType => + mutableRow.setInt(i, values(i).asInstanceOf[Integer]) + case d: DateType => + val dateRebaseFunc = DataSourceUtils.creteDateRebaseFuncInRead( + datetimeRebaseMode, "Parquet") + mutableRow.update(i, dateRebaseFunc(values(i).asInstanceOf[Integer])) + case d: DecimalType => + val decimal = Decimal(values(i).asInstanceOf[Integer].toLong, d.precision, d.scale) + mutableRow.setDecimal(i, decimal, d.precision) + case _ => throw new IllegalArgumentException("Unexpected type for INT32") } case (PrimitiveType.PrimitiveTypeName.INT64, i) => - if (values(i) == null) { - mutableRow.setNullAt(i) - } else { - dataSchema.fields(i).dataType match { - case long: LongType => - mutableRow.setLong(i, values(i).asInstanceOf[Long]) - case d: DecimalType => - val decimal = Decimal(values(i).asInstanceOf[Long], d.precision, d.scale) - mutableRow.setDecimal(i, decimal, d.precision) - case _ => throw new IllegalArgumentException("Unexpected type for INT64") - } + dataSchema.fields(i).dataType match { + case long: LongType => + mutableRow.setLong(i, values(i).asInstanceOf[Long]) + case d: DecimalType => + val decimal = Decimal(values(i).asInstanceOf[Long], d.precision, d.scale) + mutableRow.setDecimal(i, decimal, d.precision) + case _ => throw new IllegalArgumentException("Unexpected type for INT64") } case (PrimitiveType.PrimitiveTypeName.INT96, i) => - if (values(i) == null) { - mutableRow.setNullAt(i) - } else { - dataSchema.fields(i).dataType match { - case l: LongType => - mutableRow.setLong(i, values(i).asInstanceOf[Long]) - case d: TimestampType => - val int96RebaseFunc = DataSourceUtils.creteTimestampRebaseFuncInRead( - int96RebaseMode, "Parquet INT96") - val julianMicros = - ParquetRowConverter.binaryToSQLTimestamp(values(i).asInstanceOf[Binary]) - val gregorianMicros = int96RebaseFunc(julianMicros) - val adjTime = - convertTz.map(DateTimeUtils.convertTz(gregorianMicros, _, ZoneOffset.UTC)) - .getOrElse(gregorianMicros) - mutableRow.setLong(i, adjTime) - case _ => throw new IllegalArgumentException("Unexpected type for INT96") - } + dataSchema.fields(i).dataType match { + case l: LongType => + mutableRow.setLong(i, values(i).asInstanceOf[Long]) + case d: TimestampType => + val int96RebaseFunc = DataSourceUtils.creteTimestampRebaseFuncInRead( + int96RebaseMode, "Parquet INT96") + val julianMicros = + ParquetRowConverter.binaryToSQLTimestamp(values(i).asInstanceOf[Binary]) + val gregorianMicros = int96RebaseFunc(julianMicros) + val adjTime = + convertTz.map(DateTimeUtils.convertTz(gregorianMicros, _, ZoneOffset.UTC)) + .getOrElse(gregorianMicros) + mutableRow.setLong(i, adjTime) + case _ => throw new IllegalArgumentException("Unexpected type for INT96") } case (PrimitiveType.PrimitiveTypeName.FLOAT, i) => - if (values(i) == null) { - mutableRow.setNullAt(i) - } else { - mutableRow.setFloat(i, values(i).asInstanceOf[Float]) - } + mutableRow.setFloat(i, values(i).asInstanceOf[Float]) case (PrimitiveType.PrimitiveTypeName.DOUBLE, i) => - if (values(i) == null) { - mutableRow.setNullAt(i) - } else { - mutableRow.setDouble(i, values(i).asInstanceOf[Double]) - } + mutableRow.setDouble(i, values(i).asInstanceOf[Double]) case (PrimitiveType.PrimitiveTypeName.BOOLEAN, i) => - if (values(i) == null) { - mutableRow.setNullAt(i) - } else { - mutableRow.setBoolean(i, values(i).asInstanceOf[Boolean]) - } + mutableRow.setBoolean(i, values(i).asInstanceOf[Boolean]) case (PrimitiveType.PrimitiveTypeName.BINARY, i) => - if (values(i) == null) { - mutableRow.setNullAt(i) - } else { - val bytes = values(i).asInstanceOf[Binary].getBytes - dataSchema.fields(i).dataType match { - case s: StringType => - mutableRow.update(i, UTF8String.fromBytes(bytes)) - case b: BinaryType => - mutableRow.update(i, bytes) - case d: DecimalType => - val decimal = - Decimal(new BigDecimal(new BigInteger(bytes), d.scale), d.precision, d.scale) - mutableRow.setDecimal(i, decimal, d.precision) - case _ => throw new IllegalArgumentException("Unexpected type for Binary") - } + val bytes = values(i).asInstanceOf[Binary].getBytes + dataSchema.fields(i).dataType match { + case s: StringType => + mutableRow.update(i, UTF8String.fromBytes(bytes)) + case b: BinaryType => + mutableRow.update(i, bytes) + case d: DecimalType => + val decimal = + Decimal(new BigDecimal(new BigInteger(bytes), d.scale), d.precision, d.scale) + mutableRow.setDecimal(i, decimal, d.precision) + case _ => throw new IllegalArgumentException("Unexpected type for Binary") } case (PrimitiveType.PrimitiveTypeName.FIXED_LEN_BYTE_ARRAY, i) => - if (values(i) == null) { - mutableRow.setNullAt(i) - } else { - val bytes = values(i).asInstanceOf[Binary].getBytes - dataSchema.fields(i).dataType match { - case d: DecimalType => - val decimal = - Decimal(new BigDecimal(new BigInteger(bytes), d.scale), d.precision, d.scale) - mutableRow.setDecimal(i, decimal, d.precision) - case _ => throw new IllegalArgumentException("Unexpected type for FIXED_LEN_BYTE_ARRAY") - } + val bytes = values(i).asInstanceOf[Binary].getBytes + dataSchema.fields(i).dataType match { + case d: DecimalType => + val decimal = + Decimal(new BigDecimal(new BigInteger(bytes), d.scale), d.precision, d.scale) + mutableRow.setDecimal(i, decimal, d.precision) + case _ => throw new IllegalArgumentException("Unexpected type for FIXED_LEN_BYTE_ARRAY") } case _ => throw new IllegalArgumentException("Unexpected parquet type name") @@ -315,81 +281,49 @@ object ParquetUtils { parquetTypes.zipWithIndex.map { case (PrimitiveType.PrimitiveTypeName.INT32, i) => - if (values(i) == null) { - columnVectors(i).appendNull() - } else { - dataSchema.fields(i).dataType match { - case b: ByteType => - columnVectors(i).appendByte(values(i).asInstanceOf[Integer].toByte) - case s: ShortType => - columnVectors(i).appendShort(values(i).asInstanceOf[Integer].toShort) - case int: IntegerType => - columnVectors(i).appendInt(values(i).asInstanceOf[Integer]) - case d: DateType => - val dateRebaseFunc = DataSourceUtils.creteDateRebaseFuncInRead( - datetimeRebaseMode, "Parquet") - columnVectors(i).appendInt(dateRebaseFunc(values(i).asInstanceOf[Integer])) - case _ => throw new IllegalArgumentException("Unexpected type for INT32") - } + dataSchema.fields(i).dataType match { + case b: ByteType => + columnVectors(i).appendByte(values(i).asInstanceOf[Integer].toByte) + case s: ShortType => + columnVectors(i).appendShort(values(i).asInstanceOf[Integer].toShort) + case int: IntegerType => + columnVectors(i).appendInt(values(i).asInstanceOf[Integer]) + case d: DateType => + val dateRebaseFunc = DataSourceUtils.creteDateRebaseFuncInRead( + datetimeRebaseMode, "Parquet") + columnVectors(i).appendInt(dateRebaseFunc(values(i).asInstanceOf[Integer])) + case _ => throw new IllegalArgumentException("Unexpected type for INT32") } case (PrimitiveType.PrimitiveTypeName.INT64, i) => - if (values(i) == null) { - columnVectors(i).appendNull() - } else { - columnVectors(i).appendLong(values(i).asInstanceOf[Long]) - } + columnVectors(i).appendLong(values(i).asInstanceOf[Long]) case (PrimitiveType.PrimitiveTypeName.INT96, i) => - if (values(i) == null) { - columnVectors(i).appendNull() - } else { - dataSchema.fields(i).dataType match { - case l: LongType => - columnVectors(i).appendLong(values(i).asInstanceOf[Long]) - case d: TimestampType => - val int96RebaseFunc = DataSourceUtils.creteTimestampRebaseFuncInRead( - int96RebaseMode, "Parquet INT96") - val julianMicros = - ParquetRowConverter.binaryToSQLTimestamp(values(i).asInstanceOf[Binary]) - val gregorianMicros = int96RebaseFunc(julianMicros) - val adjTime = - convertTz.map(DateTimeUtils.convertTz(gregorianMicros, _, ZoneOffset.UTC)) - .getOrElse(gregorianMicros) - columnVectors(i).appendLong(adjTime) - case _ => throw new IllegalArgumentException("Unexpected type for INT96") - } + dataSchema.fields(i).dataType match { + case l: LongType => + columnVectors(i).appendLong(values(i).asInstanceOf[Long]) + case d: TimestampType => + val int96RebaseFunc = DataSourceUtils.creteTimestampRebaseFuncInRead( + int96RebaseMode, "Parquet INT96") + val julianMicros = + ParquetRowConverter.binaryToSQLTimestamp(values(i).asInstanceOf[Binary]) + val gregorianMicros = int96RebaseFunc(julianMicros) + val adjTime = + convertTz.map(DateTimeUtils.convertTz(gregorianMicros, _, ZoneOffset.UTC)) + .getOrElse(gregorianMicros) + columnVectors(i).appendLong(adjTime) + case _ => throw new IllegalArgumentException("Unexpected type for INT96") } case (PrimitiveType.PrimitiveTypeName.FLOAT, i) => - if (values(i) == null) { - columnVectors(i).appendNull() - } else { - columnVectors(i).appendFloat(values(i).asInstanceOf[Float]) - } + columnVectors(i).appendFloat(values(i).asInstanceOf[Float]) case (PrimitiveType.PrimitiveTypeName.DOUBLE, i) => - if (values(i) == null) { - columnVectors(i).appendNull() - } else { - columnVectors(i).appendDouble(values(i).asInstanceOf[Double]) - } + columnVectors(i).appendDouble(values(i).asInstanceOf[Double]) case (PrimitiveType.PrimitiveTypeName.BINARY, i) => - if (values(i) == null) { - columnVectors(i).appendNull() - } else { - val bytes = values(i).asInstanceOf[Binary].getBytes - columnVectors(i).putByteArray(0, bytes, 0, bytes.length) - } + val bytes = values(i).asInstanceOf[Binary].getBytes + columnVectors(i).putByteArray(0, bytes, 0, bytes.length) case (PrimitiveType.PrimitiveTypeName.FIXED_LEN_BYTE_ARRAY, i) => - if (values(i) == null) { - columnVectors(i).appendNull() - } else { - val bytes = values(i).asInstanceOf[Binary].getBytes - columnVectors(i).putByteArray(0, bytes, 0, bytes.length) - } + val bytes = values(i).asInstanceOf[Binary].getBytes + columnVectors(i).putByteArray(0, bytes, 0, bytes.length) case (PrimitiveType.PrimitiveTypeName.BOOLEAN, i) => - if (values(i) == null) { - columnVectors(i).appendNull() - } else { - columnVectors(i).appendBoolean(values(i).asInstanceOf[Boolean]) - } + columnVectors(i).appendBoolean(values(i).asInstanceOf[Boolean]) case _ => throw new IllegalArgumentException("Unexpected parquet type name") } @@ -441,7 +375,7 @@ object ParquetUtils { case Count(col, _, _) => index = dataSchema.fieldNames.toList.indexOf(col) - rowCount += getRowCountFromParquetMetadata(footer) + rowCount += block.getRowCount if (!col.equals("1")) { // "1" is for count(*) rowCount -= getNumNulls(footer, blockMetaData, index) } @@ -476,15 +410,12 @@ object ParquetUtils { throw new IllegalArgumentException("Unsupported type : " + parquetType.toString) } val statistics = columnChunkMetaData.get(i).getStatistics() - if (isMax) statistics.genericGetMax() else statistics.genericGetMin() - } - - private def getRowCountFromParquetMetadata(footer: ParquetMetadata): Long = { - var rowCount: Long = 0 - for (blockMetaData <- footer.getBlocks.asScala) { - rowCount += blockMetaData.getRowCount + if (!statistics.hasNonNullValue) { + throw new UnsupportedOperationException("No min/max found for parquet file, Set SQLConf" + + " PARQUET_AGGREGATE_PUSHDOWN_ENABLED to false and execute again") + } else { + if (isMax) statistics.genericGetMax() else statistics.genericGetMin() } - rowCount } private def getNumNulls( @@ -495,12 +426,11 @@ object ParquetUtils { if (!parquetType.isPrimitive) { throw new IllegalArgumentException("Unsupported type: " + parquetType.toString) } - var numNulls: Long = 0; val statistics = columnChunkMetaData.get(i).getStatistics() if (!statistics.isNumNullsSet()) { - throw new UnsupportedOperationException("Number of nulls not set for parquet file."); + throw new UnsupportedOperationException("Number of nulls not set for parquet file." + + " Set SQLConf PARQUET_AGGREGATE_PUSHDOWN_ENABLED to false and execute again") } - numNulls += statistics.getNumNulls(); - numNulls + statistics.getNumNulls(); } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/PushDownUtils.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/PushDownUtils.scala index 208ebc1198ec9..0745deb82c998 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/PushDownUtils.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/PushDownUtils.scala @@ -83,21 +83,15 @@ object PushDownUtils extends PredicateHelper { aggregates: Seq[AggregateExpression], groupBy: Seq[Expression]): Aggregation = { - def columnAsString(e: Expression): String = e match { - case AttributeReference(name, _, _, _) => name - case _ => "" - } - scanBuilder match { case r: SupportsPushDownAggregates => val translatedAggregates = aggregates.map(DataSourceStrategy .translateAggregate(_, PushableColumn(false))) - val translatedGroupBys = groupBy.map(columnAsString) - if (translatedAggregates.exists(_.isEmpty) || translatedGroupBys.exists(_.isEmpty)) { + if (translatedAggregates.exists(_.isEmpty)) { Aggregation.empty } else { - r.pushAggregation(Aggregation(translatedAggregates.flatten, translatedGroupBys)) + r.pushAggregation(Aggregation(translatedAggregates.flatten, Seq.empty)) r.pushedAggregation } case _ => Aggregation.empty diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/V2ScanRelationPushDown.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/V2ScanRelationPushDown.scala index 00c398b24a70d..aeff0ae4c0015 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/V2ScanRelationPushDown.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/V2ScanRelationPushDown.scala @@ -17,14 +17,13 @@ package org.apache.spark.sql.execution.datasources.v2 -import scala.collection.mutable.ArrayBuilder - import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.aggregate.AggregateExpression import org.apache.spark.sql.catalyst.planning.ScanOperation import org.apache.spark.sql.catalyst.plans.logical.{Aggregate, Filter, LogicalPlan, Project} import org.apache.spark.sql.catalyst.rules.Rule -import org.apache.spark.sql.connector.read.{Scan, ScanBuilder, V1Scan} +import org.apache.spark.sql.catalyst.util.toPrettySQL +import org.apache.spark.sql.connector.read.{Scan, V1Scan} import org.apache.spark.sql.execution.datasources.DataSourceStrategy import org.apache.spark.sql.sources import org.apache.spark.sql.sources.{AggregateFunc, Aggregation} @@ -34,120 +33,120 @@ object V2ScanRelationPushDown extends Rule[LogicalPlan] with AliasHelper { import DataSourceV2Implicits._ override def apply(plan: LogicalPlan): LogicalPlan = plan transformDown { - case aggNode@Aggregate(groupingExpressions, resultExpressions, child) => + // Pattern matching for partial aggregate push down + // For parquet, footer only has statistics information for max/min/count. + // It doesn't handle max/min/count associated with filter or group by. + // ORC is similar. If JDBC partial aggregate push down is added later, + // these condition checks need to be changed. + case aggNode@Aggregate(groupingExpressions, resultExpressions, child) + if (groupingExpressions.isEmpty) => child match { - case ScanOperation(project, filters, relation: DataSourceV2Relation) => + case ScanOperation(project, filters, relation: DataSourceV2Relation) + if (filters.isEmpty) => val scanBuilder = relation.table.asReadable.newScanBuilder(relation.options) - - val aliasMap = getAliasMap(project) var aggregates = resultExpressions.flatMap { expr => expr.collect { case agg: AggregateExpression => - replaceAlias(agg, aliasMap).asInstanceOf[AggregateExpression] + replaceAlias(agg, getAliasMap(project)).asInstanceOf[AggregateExpression] } } aggregates = DataSourceStrategy.normalizeExprs(aggregates, relation.output) - .asInstanceOf[Seq[AggregateExpression]] + .asInstanceOf[Seq[AggregateExpression]] - val groupingExpressionsWithoutAlias = groupingExpressions.flatMap{ expr => - expr.collect { - case e: Expression => replaceAlias(e, aliasMap) - } - } - val normalizedGroupingExpressions = - DataSourceStrategy.normalizeExprs(groupingExpressionsWithoutAlias, relation.output) - - var newFilters = filters - aggregates.foreach(agg => - if (agg.filter.nonEmpty) { - // handle agg filter the same way as other filters - newFilters = newFilters :+ agg.filter.get - } - ) - - val (pushedFilters, postScanFilters) = pushDownFilter(scanBuilder, newFilters, relation) - if (postScanFilters.nonEmpty) { + val aggregation = PushDownUtils + .pushAggregates(scanBuilder, aggregates, groupingExpressions) + if (aggregation.aggregateExpressions.isEmpty) { aggNode // return original plan node - } else { // only push down aggregate if all the filers can be push down - val aggregation = PushDownUtils.pushAggregates(scanBuilder, aggregates, - normalizedGroupingExpressions) + } else { + // use the aggregate columns as the output columns + // e.g. TABLE t (c1 INT, c2 INT, c3 INT) + // SELECT min(c1), max(c1) FROM t; + // Use min(c1), max(c1) as output for DataSourceV2ScanRelation + // We want to have the following logical plan: + // == Optimized Logical Plan == + // Aggregate [min(min(c1)#21) AS min(c1)#17, max(max(c1)#22) AS max(c1)#18] + // +- RelationV2[min(c1)#21, max(c1)#22] parquet file ... + val output = aggregates.map { + case agg: AggregateExpression => + AttributeReference(toPrettySQL(agg), agg.dataType)() + } - val (scan, output, normalizedProjects) = - processFilterAndColumn(scanBuilder, project, postScanFilters, relation) + // No need to do column pruning because only the aggregate columns are used as + // DataSourceV2ScanRelation output columns. All the other columns are not + // included in the output. Since PushDownUtils.pruneColumns is not called, + // ScanBuilder.requiredSchema is not pruned, but ScanBuilder.requiredSchema is + // not used anyways. The schema for aggregate columns will be built in Scan. + val scan = scanBuilder.build() logInfo( s""" |Pushing operators to ${relation.name} - |Pushed Filters: ${pushedFilters.mkString(", ")} - |Post-Scan Filters: ${postScanFilters.mkString(",")} |Pushed Aggregate Functions: ${aggregation.aggregateExpressions.mkString(", ")} - |Pushed Groupby: ${aggregation.groupByColumns.mkString(", ")} |Output: ${output.mkString(", ")} - """.stripMargin) - + """.stripMargin) val wrappedScan = scan match { case v1: V1Scan => - val translated = newFilters.flatMap(DataSourceStrategy.translateFilter(_, true)) - V1ScanWrapper(v1, translated, pushedFilters, aggregation) + V1ScanWrapper(v1, Seq.empty[sources.Filter], Seq.empty[sources.Filter], aggregation) case _ => scan } - if (aggregation.aggregateExpressions.isEmpty) { - aggNode // return original plan node - } else { - // build the aggregate expressions + groupby expressions - val aggOutputBuilder = ArrayBuilder.make[AttributeReference] - for (i <- 0 until aggregates.length) { - aggOutputBuilder += AttributeReference( - aggregation.aggregateExpressions(i).toString, aggregates(i).dataType)() - } - groupingExpressions.foreach{ - case a@AttributeReference(_, _, _, _) => aggOutputBuilder += a - case _ => - } - val aggOutput = aggOutputBuilder.result - - val r = buildLogicalPlan(aggOutput, relation, wrappedScan, aggOutput, - normalizedProjects, postScanFilters) - val plan = Aggregate(groupingExpressions, resultExpressions, r) - - var i = 0 - // scalastyle:off line.size.limit - // change the original optimized logical plan to reflect the pushed down aggregate - // e.g. sql("select max(id), min(id) FROM h2.test.people") - // the the original optimized logical plan is - // == Optimized Logical Plan == - // Aggregate [max(ID#35) AS max(ID)#38, min(ID#35) AS min(ID)#39] - // +- RelationV2[ID#35] test.people - // We want to change it to the following - // == Optimized Logical Plan == - // Aggregate [max(Max(ID,IntegerType)#298) AS max(ID)#293, min(Min(ID,IntegerType)#299) AS min(ID)#294] - // +- RelationV2[Max(ID,IntegerType)#298, Min(ID,IntegerType)#299] test.people - // scalastyle:on line.size.limit - plan.transformExpressions { - case agg: AggregateExpression => - i += 1 - val aggFunction: aggregate.AggregateFunction = agg.aggregateFunction match { - case max: aggregate.Max => aggregate.Max(aggOutput(i - 1)) - case min: aggregate.Min => aggregate.Min(aggOutput(i - 1)) - case sum: aggregate.Sum => aggregate.Sum(aggOutput(i - 1)) - case avg: aggregate.Average => aggregate.Average(aggOutput(i - 1)) - case count: aggregate.Count => aggregate.PushDownCount(aggOutput(i - 1), true) - case _ => agg.aggregateFunction - } - agg.copy(aggregateFunction = aggFunction, filter = None) - } + val scanRelation = DataSourceV2ScanRelation(relation, wrappedScan, output) + val plan = Aggregate(groupingExpressions, resultExpressions, scanRelation) + + // Change the optimized logical plan to reflect the pushed down aggregate + // e.g. TABLE t (c1 INT, c2 INT, c3 INT) + // SELECT min(c1), max(c1) FROM t; + // The original logical plan is + // Aggregate [min(c1#9) AS min(c1)#17, max(c1#9) AS max(c1)#18] + // +- RelationV2[c1#9] parquet ... + // + // After change the V2ScanRelation output to [min(_1)#21, max(_1)#22] + // we have the following + // !Aggregate [min(_1#9) AS min(_1)#17, max(_1#9) AS max(_1)#18] + // +- RelationV2[min(_1)#21, max(_1)#22] parquet ... + // + // We want to change it to + // == Optimized Logical Plan == + // Aggregate [min(min(c1)#21) AS min(c1)#17, max(max(c1)#22) AS max(c1)#18] + // +- RelationV2[min(c1)#21, max(c1)#22] parquet file ... + var i = 0 + plan.transformExpressions { + case agg: AggregateExpression => + i += 1 + val aggFunction: aggregate.AggregateFunction = agg.aggregateFunction match { + case max: aggregate.Max => aggregate.Max(output(i - 1)) + case min: aggregate.Min => aggregate.Min(output(i - 1)) + case sum: aggregate.Sum => aggregate.Sum(output(i - 1)) + case avg: aggregate.Average => aggregate.Average(output(i - 1)) + case count: aggregate.Count => aggregate.PushDownCount(output(i - 1), true) + case _ => agg.aggregateFunction + } + agg.copy(aggregateFunction = aggFunction, filter = None) } } case _ => aggNode // return original plan node } + case ScanOperation(project, filters, relation: DataSourceV2Relation) => val scanBuilder = relation.table.asReadable.newScanBuilder(relation.options) - val (pushedFilters, postScanFilters) = pushDownFilter (scanBuilder, filters, relation) - val (scan, output, normalizedProjects) = - processFilterAndColumn(scanBuilder, project, postScanFilters, relation) + val normalizedFilters = DataSourceStrategy.normalizeExprs(filters, relation.output) + val (normalizedFiltersWithSubquery, normalizedFiltersWithoutSubquery) = + normalizedFilters.partition(SubqueryExpression.hasSubquery) + + // `pushedFilters` will be pushed down and evaluated in the underlying data sources. + // `postScanFilters` need to be evaluated after the scan. + // `postScanFilters` and `pushedFilters` can overlap, e.g. the parquet row group filter. + val (pushedFilters, postScanFiltersWithoutSubquery) = PushDownUtils.pushFilters( + scanBuilder, normalizedFiltersWithoutSubquery) + val postScanFilters = postScanFiltersWithoutSubquery ++ normalizedFiltersWithSubquery + + val normalizedProjects = DataSourceStrategy + .normalizeExprs(project, relation.output) + .asInstanceOf[Seq[NamedExpression]] + val (scan, output) = PushDownUtils.pruneColumns( + scanBuilder, relation, normalizedProjects, postScanFilters) logInfo( s""" |Pushing operators to ${relation.name} @@ -165,66 +164,27 @@ object V2ScanRelationPushDown extends Rule[LogicalPlan] with AliasHelper { case _ => scan } - buildLogicalPlan(project, relation, wrappedScan, output, normalizedProjects, postScanFilters) - } + val scanRelation = DataSourceV2ScanRelation(relation, wrappedScan, output) - private def pushDownFilter( - scanBuilder: ScanBuilder, - filters: Seq[Expression], - relation: DataSourceV2Relation): (Seq[sources.Filter], Seq[Expression]) = { - val normalizedFilters = DataSourceStrategy.normalizeExprs(filters, relation.output) - val (normalizedFiltersWithSubquery, normalizedFiltersWithoutSubquery) = - normalizedFilters.partition(SubqueryExpression.hasSubquery) - - // `pushedFilters` will be pushed down and evaluated in the underlying data sources. - // `postScanFilters` need to be evaluated after the scan. - // `postScanFilters` and `pushedFilters` can overlap, e.g. the parquet row group filter. - val (pushedFilters, postScanFiltersWithoutSubquery) = PushDownUtils.pushFilters( - scanBuilder, normalizedFiltersWithoutSubquery) - val postScanFilters = postScanFiltersWithoutSubquery ++ normalizedFiltersWithSubquery - (pushedFilters, postScanFilters) - } + val projectionOverSchema = ProjectionOverSchema(output.toStructType) + val projectionFunc = (expr: Expression) => expr transformDown { + case projectionOverSchema(newExpr) => newExpr + } - private def processFilterAndColumn( - scanBuilder: ScanBuilder, - project: Seq[NamedExpression], - postScanFilters: Seq[Expression], - relation: DataSourceV2Relation): - (Scan, Seq[AttributeReference], Seq[NamedExpression]) = { - val normalizedProjects = DataSourceStrategy - .normalizeExprs(project, relation.output) - .asInstanceOf[Seq[NamedExpression]] - val (scan, output) = PushDownUtils.pruneColumns( - scanBuilder, relation, normalizedProjects, postScanFilters) - (scan, output, normalizedProjects) - } + val filterCondition = postScanFilters.reduceLeftOption(And) + val newFilterCondition = filterCondition.map(projectionFunc) + val withFilter = newFilterCondition.map(Filter(_, scanRelation)).getOrElse(scanRelation) + + val withProjection = if (withFilter.output != project) { + val newProjects = normalizedProjects + .map(projectionFunc) + .asInstanceOf[Seq[NamedExpression]] + Project(newProjects, withFilter) + } else { + withFilter + } - private def buildLogicalPlan( - project: Seq[NamedExpression], - relation: DataSourceV2Relation, - wrappedScan: Scan, - output: Seq[AttributeReference], - normalizedProjects: Seq[NamedExpression], - postScanFilters: Seq[Expression]): LogicalPlan = { - val scanRelation = DataSourceV2ScanRelation(relation, wrappedScan, output) - val projectionOverSchema = ProjectionOverSchema(output.toStructType) - val projectionFunc = (expr: Expression) => expr transformDown { - case projectionOverSchema(newExpr) => newExpr - } - - val filterCondition = postScanFilters.reduceLeftOption(And) - val newFilterCondition = filterCondition.map(projectionFunc) - val withFilter = newFilterCondition.map(Filter(_, scanRelation)).getOrElse(scanRelation) - - val withProjection = if (withFilter.output != project) { - val newProjects = normalizedProjects - .map(projectionFunc) - .asInstanceOf[Seq[NamedExpression]] - Project(newProjects, withFilter) - } else { - withFilter - } - withProjection + withProjection } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/parquet/ParquetPartitionReaderFactory.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/parquet/ParquetPartitionReaderFactory.scala index 5898017bfce98..bc1ac8e985835 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/parquet/ParquetPartitionReaderFactory.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/parquet/ParquetPartitionReaderFactory.scala @@ -27,6 +27,7 @@ import org.apache.parquet.filter2.compat.FilterCompat import org.apache.parquet.filter2.predicate.{FilterApi, FilterPredicate} import org.apache.parquet.format.converter.ParquetMetadataConverter.{NO_FILTER, SKIP_ROW_GROUPS} import org.apache.parquet.hadoop.{ParquetInputFormat, ParquetRecordReader} +import org.apache.parquet.hadoop.metadata.ParquetMetadata import org.apache.spark.TaskContext import org.apache.spark.broadcast.Broadcast @@ -66,7 +67,6 @@ case class ParquetPartitionReaderFactory( aggregation: Aggregation, parquetOptions: ParquetOptions) extends FilePartitionReaderFactory with Logging { private val isCaseSensitive = sqlConf.caseSensitiveAnalysis - private val aggSchema = buildAggSchema private val newReadDataSchema = if (aggregation.aggregateExpressions.isEmpty) { readDataSchema } else { @@ -87,32 +87,50 @@ case class ParquetPartitionReaderFactory( private val pushDownInFilterThreshold = sqlConf.parquetFilterPushDownInFilterThreshold private val datetimeRebaseModeInRead = parquetOptions.datetimeRebaseModeInRead private val int96RebaseModeInRead = parquetOptions.int96RebaseModeInRead + private lazy val aggSchema = { + var schema = new StructType() + aggregation.aggregateExpressions.map { + case Max(col, _) => + val field = dataSchema.fields(dataSchema.fieldNames.toList.indexOf(col)) + schema = schema.add(field.copy("max(" + field.name + ")")) + case Min(col, _) => + val field = dataSchema.fields(dataSchema.fieldNames.toList.indexOf(col)) + schema = schema.add(field.copy("min(" + field.name + ")")) + case Count(col, _, _) => + if (col.equals("1")) { + schema = schema.add(new StructField("count(*)", LongType)) + } else { + schema = schema.add(new StructField("count(" + col + ")", LongType)) + } + case _ => + } + schema + } - private def buildAggSchema: StructType = { - var aggSchema = new StructType() - for (i <- 0 until aggregation.aggregateExpressions.size) { - var index = 0 - aggregation.aggregateExpressions(i) match { - case Max(col, _) => - index = dataSchema.fieldNames.toList.indexOf(col) - val field = dataSchema.fields(index) - aggSchema = aggSchema.add(field.copy("max(" + field.name + ")")) - case Min(col, _) => - index = dataSchema.fieldNames.toList.indexOf(col) - val field = dataSchema.fields(index) - aggSchema = aggSchema.add(field.copy("min(" + field.name + ")")) - case Count(col, _, _) => - if (col.equals("1")) { - aggSchema = aggSchema.add(new StructField("count(*)", LongType)) - } else { - aggSchema = aggSchema.add(new StructField("count(" + col + ")", LongType)) - } - case _ => - } + private def getFooter(file: PartitionedFile): ParquetMetadata = { + val conf = broadcastedConf.value.value + + val filePath = new Path(new URI(file.filePath)) + + if (aggregation.aggregateExpressions.isEmpty) { + ParquetFooterReader.readFooter(conf, filePath, SKIP_ROW_GROUPS) + } else { + ParquetFooterReader.readFooter(conf, filePath, NO_FILTER) } - aggSchema } + // Define isCreatedByParquetMr as function to avoid unnecessary parquet footer reads. + private def isCreatedByParquetMr(file: PartitionedFile): Boolean = + getFooter(file).getFileMetaData.getCreatedBy().startsWith("parquet-mr") + + private def convertTz(isCreatedByParquetMr: Boolean): Option[ZoneId] = + if (timestampConversion && !isCreatedByParquetMr) { + Some(DateTimeUtils + .getZoneId(broadcastedConf.value.value.get(SQLConf.SESSION_LOCAL_TIMEZONE.key))) + } else { + None + } + override def supportColumnarReads(partition: InputPartition): Boolean = { sqlConf.parquetVectorizedReaderEnabled && sqlConf.wholeStageEnabled && resultSchema.length <= sqlConf.wholeStageMaxNumFields && @@ -146,21 +164,11 @@ case class ParquetPartitionReaderFactory( } override def get(): InternalRow = { - val conf = broadcastedConf.value.value - val filePath = new Path(new URI(file.filePath)) - val footer = ParquetFooterReader.readFooter(conf, filePath, NO_FILTER) - def isCreatedByParquetMr: Boolean = - footer.getFileMetaData.getCreatedBy().startsWith("parquet-mr") - val convertTz = - if (timestampConversion && !isCreatedByParquetMr) { - Some(DateTimeUtils.getZoneId(conf.get(SQLConf.SESSION_LOCAL_TIMEZONE.key))) - } else { - None - } + val footer = getFooter(file) val (parquetTypes, values) = ParquetUtils.getPushedDownAggResult(footer, dataSchema, aggregation) ParquetUtils.aggResultToSparkInternalRows(footer, parquetTypes, values, aggSchema, - datetimeRebaseModeInRead, int96RebaseModeInRead, convertTz) + datetimeRebaseModeInRead, int96RebaseModeInRead, convertTz(isCreatedByParquetMr(file))) } override def close(): Unit = return @@ -195,21 +203,12 @@ case class ParquetPartitionReaderFactory( } override def get(): ColumnarBatch = { - val conf = broadcastedConf.value.value - val filePath = new Path(new URI(file.filePath)) - val footer = ParquetFooterReader.readFooter(conf, filePath, NO_FILTER) - def isCreatedByParquetMr: Boolean = - footer.getFileMetaData.getCreatedBy().startsWith("parquet-mr") - val convertTz = - if (timestampConversion && !isCreatedByParquetMr) { - Some(DateTimeUtils.getZoneId(conf.get(SQLConf.SESSION_LOCAL_TIMEZONE.key))) - } else { - None - } + val footer = getFooter(file) val (parquetTypes, values) = ParquetUtils.getPushedDownAggResult(footer, dataSchema, aggregation) ParquetUtils.aggResultToSparkColumnarBatch(footer, parquetTypes, values, aggSchema, - enableOffHeapColumnVector, datetimeRebaseModeInRead, int96RebaseModeInRead, convertTz) + enableOffHeapColumnVector, datetimeRebaseModeInRead, int96RebaseModeInRead, + convertTz(isCreatedByParquetMr(file))) } override def close(): Unit = return @@ -230,8 +229,7 @@ case class ParquetPartitionReaderFactory( val filePath = new Path(new URI(file.filePath)) val split = new FileSplit(filePath, file.start, file.length, Array.empty[String]) - lazy val footerFileMetaData = - ParquetFooterReader.readFooter(conf, filePath, SKIP_ROW_GROUPS).getFileMetaData + lazy val footerFileMetaData = getFooter(file).getFileMetaData // Try to push down filters when filter push-down is enabled. val pushed = if (enableParquetFilterPushDown) { val parquetSchema = footerFileMetaData.getSchema @@ -250,16 +248,6 @@ case class ParquetPartitionReaderFactory( // *only* if the file was created by something other than "parquet-mr", so check the actual // writer here for this file. We have to do this per-file, as each file in the table may // have different writers. - // Define isCreatedByParquetMr as function to avoid unnecessary parquet footer reads. - def isCreatedByParquetMr: Boolean = - footerFileMetaData.getCreatedBy().startsWith("parquet-mr") - - val convertTz = - if (timestampConversion && !isCreatedByParquetMr) { - Some(DateTimeUtils.getZoneId(conf.get(SQLConf.SESSION_LOCAL_TIMEZONE.key))) - } else { - None - } val attemptId = new TaskAttemptID(new TaskID(new JobID(), TaskType.MAP, 0), 0) val hadoopAttemptContext = new TaskAttemptContextImpl(conf, attemptId) @@ -280,7 +268,7 @@ case class ParquetPartitionReaderFactory( file.partitionValues, hadoopAttemptContext, pushed, - convertTz, + convertTz(isCreatedByParquetMr(file)), datetimeRebaseMode, int96RebaseMode) reader.initialize(split, hadoopAttemptContext) From a505e7f79b233ea47e22399ef16969684b8dfe57 Mon Sep 17 00:00:00 2001 From: Huaxin Gao Date: Sun, 11 Apr 2021 20:10:01 -0700 Subject: [PATCH 10/30] fix build failure --- .../sql/catalyst/expressions/aggregate/PushDownCount.scala | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/PushDownCount.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/PushDownCount.scala index 79de67063c8b4..610b780d5d6fe 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/PushDownCount.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/PushDownCount.scala @@ -34,6 +34,10 @@ case class PushDownCount(children: Seq[Expression], pushdown: Boolean) extends C ) } } + + override protected def withNewChildrenInternal(newChildren: IndexedSeq[Expression]) + : PushDownCount = + copy(children = newChildren) } object PushDownCount { From d118eaf183814d735fe8499b1d82af20f28d8c9e Mon Sep 17 00:00:00 2001 From: Huaxin Gao Date: Mon, 26 Apr 2021 10:04:07 -0700 Subject: [PATCH 11/30] rewrite count to sum for pushed down count --- .../expressions/aggregate/CountBase.scala | 73 ----- .../expressions/aggregate/PushDownCount.scala | 46 ---- .../datasources/parquet/ParquetUtils.scala | 30 +-- .../v2/V2ScanRelationPushDown.scala | 10 +- .../parquet/ParquetQuerySuite.scala | 253 ++++++++++++++---- 5 files changed, 223 insertions(+), 189 deletions(-) delete mode 100644 sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/CountBase.scala delete mode 100644 sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/PushDownCount.scala diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/CountBase.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/CountBase.scala deleted file mode 100644 index 4a30611727b2e..0000000000000 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/CountBase.scala +++ /dev/null @@ -1,73 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql.catalyst.expressions.aggregate - -import org.apache.spark.sql.catalyst.analysis.TypeCheckResult -import org.apache.spark.sql.catalyst.dsl.expressions._ -import org.apache.spark.sql.catalyst.expressions._ -import org.apache.spark.sql.internal.SQLConf -import org.apache.spark.sql.types._ - -abstract class CountBase(children: Seq[Expression]) extends DeclarativeAggregate { - - override def nullable: Boolean = false - - // Return data type. - override def dataType: DataType = LongType - - override def checkInputDataTypes(): TypeCheckResult = { - if (children.isEmpty && !SQLConf.get.getConf(SQLConf.ALLOW_PARAMETERLESS_COUNT)) { - TypeCheckResult.TypeCheckFailure(s"$prettyName requires at least one argument. " + - s"If you have to call the function $prettyName without arguments, set the legacy " + - s"configuration `${SQLConf.ALLOW_PARAMETERLESS_COUNT.key}` as true") - } else { - TypeCheckResult.TypeCheckSuccess - } - } - - protected lazy val count = AttributeReference("count", LongType, nullable = false)() - - override lazy val aggBufferAttributes = count :: Nil - - override lazy val initialValues = Seq( - /* count = */ Literal(0L) - ) - - override lazy val mergeExpressions = Seq( - /* count = */ count.left + count.right - ) - - override lazy val evaluateExpression = count - - override def defaultResult: Option[Literal] = Option(Literal(0L)) - - private[sql] var pushDown: Boolean = false - - override lazy val updateExpressions = { - val nullableChildren = children.filter(_.nullable) - if (nullableChildren.isEmpty) { - Seq( - /* count = */ count + 1L - ) - } else { - Seq( - /* count = */ If(nullableChildren.map(IsNull).reduce(Or), count, count + 1L) - ) - } - } -} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/PushDownCount.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/PushDownCount.scala deleted file mode 100644 index 610b780d5d6fe..0000000000000 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/PushDownCount.scala +++ /dev/null @@ -1,46 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql.catalyst.expressions.aggregate - -import org.apache.spark.sql.catalyst.dsl.expressions._ -import org.apache.spark.sql.catalyst.expressions._ - -case class PushDownCount(children: Seq[Expression], pushdown: Boolean) extends CountBase(children) { - - override lazy val updateExpressions = { - val nullableChildren = children.filter(_.nullable) - if (nullableChildren.isEmpty) { - Seq( - /* count = */ count + children.head - ) - } else { - Seq( - /* count = */ If(nullableChildren.map(IsNull).reduce(Or), count, count + children.head) - ) - } - } - - override protected def withNewChildrenInternal(newChildren: IndexedSeq[Expression]) - : PushDownCount = - copy(children = newChildren) -} - -object PushDownCount { - def apply(child: Expression, pushdown: Boolean): PushDownCount = - PushDownCount(child :: Nil, pushdown) -} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetUtils.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetUtils.scala index 5af1e564070cc..73a514432f06a 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetUtils.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetUtils.scala @@ -173,13 +173,13 @@ object ParquetUtils { parquetTypes.zipWithIndex.map { case (PrimitiveType.PrimitiveTypeName.INT32, i) => dataSchema.fields(i).dataType match { - case b: ByteType => + case ByteType => mutableRow.setByte(i, values(i).asInstanceOf[Integer].toByte) - case s: ShortType => + case ShortType => mutableRow.setShort(i, values(i).asInstanceOf[Integer].toShort) - case int: IntegerType => + case IntegerType => mutableRow.setInt(i, values(i).asInstanceOf[Integer]) - case d: DateType => + case DateType => val dateRebaseFunc = DataSourceUtils.creteDateRebaseFuncInRead( datetimeRebaseMode, "Parquet") mutableRow.update(i, dateRebaseFunc(values(i).asInstanceOf[Integer])) @@ -190,7 +190,7 @@ object ParquetUtils { } case (PrimitiveType.PrimitiveTypeName.INT64, i) => dataSchema.fields(i).dataType match { - case long: LongType => + case LongType => mutableRow.setLong(i, values(i).asInstanceOf[Long]) case d: DecimalType => val decimal = Decimal(values(i).asInstanceOf[Long], d.precision, d.scale) @@ -199,9 +199,9 @@ object ParquetUtils { } case (PrimitiveType.PrimitiveTypeName.INT96, i) => dataSchema.fields(i).dataType match { - case l: LongType => + case LongType => mutableRow.setLong(i, values(i).asInstanceOf[Long]) - case d: TimestampType => + case TimestampType => val int96RebaseFunc = DataSourceUtils.creteTimestampRebaseFuncInRead( int96RebaseMode, "Parquet INT96") val julianMicros = @@ -222,9 +222,9 @@ object ParquetUtils { case (PrimitiveType.PrimitiveTypeName.BINARY, i) => val bytes = values(i).asInstanceOf[Binary].getBytes dataSchema.fields(i).dataType match { - case s: StringType => + case StringType => mutableRow.update(i, UTF8String.fromBytes(bytes)) - case b: BinaryType => + case BinaryType => mutableRow.update(i, bytes) case d: DecimalType => val decimal = @@ -282,13 +282,13 @@ object ParquetUtils { parquetTypes.zipWithIndex.map { case (PrimitiveType.PrimitiveTypeName.INT32, i) => dataSchema.fields(i).dataType match { - case b: ByteType => + case ByteType => columnVectors(i).appendByte(values(i).asInstanceOf[Integer].toByte) - case s: ShortType => + case ShortType => columnVectors(i).appendShort(values(i).asInstanceOf[Integer].toShort) - case int: IntegerType => + case IntegerType => columnVectors(i).appendInt(values(i).asInstanceOf[Integer]) - case d: DateType => + case DateType => val dateRebaseFunc = DataSourceUtils.creteDateRebaseFuncInRead( datetimeRebaseMode, "Parquet") columnVectors(i).appendInt(dateRebaseFunc(values(i).asInstanceOf[Integer])) @@ -298,9 +298,9 @@ object ParquetUtils { columnVectors(i).appendLong(values(i).asInstanceOf[Long]) case (PrimitiveType.PrimitiveTypeName.INT96, i) => dataSchema.fields(i).dataType match { - case l: LongType => + case LongType => columnVectors(i).appendLong(values(i).asInstanceOf[Long]) - case d: TimestampType => + case TimestampType => val int96RebaseFunc = DataSourceUtils.creteTimestampRebaseFuncInRead( int96RebaseMode, "Parquet INT96") val julianMicros = diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/V2ScanRelationPushDown.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/V2ScanRelationPushDown.scala index aeff0ae4c0015..64951e43052fe 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/V2ScanRelationPushDown.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/V2ScanRelationPushDown.scala @@ -114,11 +114,11 @@ object V2ScanRelationPushDown extends Rule[LogicalPlan] with AliasHelper { case agg: AggregateExpression => i += 1 val aggFunction: aggregate.AggregateFunction = agg.aggregateFunction match { - case max: aggregate.Max => aggregate.Max(output(i - 1)) - case min: aggregate.Min => aggregate.Min(output(i - 1)) - case sum: aggregate.Sum => aggregate.Sum(output(i - 1)) - case avg: aggregate.Average => aggregate.Average(output(i - 1)) - case count: aggregate.Count => aggregate.PushDownCount(output(i - 1), true) + case _: aggregate.Max => aggregate.Max(output(i - 1)) + case _: aggregate.Min => aggregate.Min(output(i - 1)) + case _: aggregate.Sum => aggregate.Sum(output(i - 1)) + case _: aggregate.Average => aggregate.Average(output(i - 1)) + case _: aggregate.Count => aggregate.Sum(output(i - 1)) case _ => agg.aggregateFunction } agg.copy(aggregateFunction = aggFunction, filter = None) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetQuerySuite.scala index 373af907a9a78..2a67162c86ca4 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetQuerySuite.scala @@ -18,6 +18,7 @@ package org.apache.spark.sql.execution.datasources.parquet import java.io.File +import java.sql.{Date, Timestamp} import java.util.concurrent.TimeUnit import org.apache.hadoop.fs.{FileSystem, Path} @@ -30,7 +31,7 @@ import org.apache.spark.sql.catalyst.expressions.SpecificInternalRow import org.apache.spark.sql.execution.FileSourceScanExec import org.apache.spark.sql.execution.datasources.{SchemaColumnConvertNotSupportedException, SQLHadoopMapReduceCommitProtocol} import org.apache.spark.sql.execution.datasources.parquet.TestingUDT.{NestedStruct, NestedStructUDT, SingleElement} -import org.apache.spark.sql.execution.datasources.v2.BatchScanExec +import org.apache.spark.sql.execution.datasources.v2.{BatchScanExec, DataSourceV2ScanRelation} import org.apache.spark.sql.execution.datasources.v2.parquet.ParquetScan import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.test.SharedSparkSession @@ -40,7 +41,8 @@ import org.apache.spark.util.Utils /** * A test suite that tests various Parquet queries. */ -abstract class ParquetQuerySuite extends QueryTest with ParquetTest with SharedSparkSession { +abstract class ParquetQuerySuite extends QueryTest with ParquetTest with SharedSparkSession + with ExplainSuiteHelper { import testImplicits._ test("simple select queries") { @@ -50,54 +52,6 @@ abstract class ParquetQuerySuite extends QueryTest with ParquetTest with SharedS } } - test("test aggregate pushdown") { - spark.conf.set(SQLConf.PARQUET_AGGREGATE_PUSHDOWN_ENABLED.key, "true") - val data = Seq((-2, "abc", 2), (3, "def", 4), (6, "ghi", 2), (0, null, 19), - (9, "mno", 7), (2, null, 6)) - spark.createDataFrame(data).toDF("c1", "c2", "c3").createOrReplaceTempView("tmp") - withParquetTable(data, "t") { - sql("INSERT OVERWRITE TABLE t SELECT * FROM tmp") - val selectAgg = sql("SELECT min(_3), min(_3), max(_3), min(_1), max(_1), max(_1)," + - " count(*), count(_1), count(_2), count(_3) FROM t") - // selectAgg.explain(true) - // scalastyle:off line.size.limit - // == Parsed Logical Plan == - // 'Project [unresolvedalias('min('_3), None), unresolvedalias('min('_3), None), unresolvedalias('max('_3), None), unresolvedalias('min('_1), None), unresolvedalias('max('_1), None), unresolvedalias('max('_1), None), unresolvedalias('count(1), None), unresolvedalias('count('_1), None), unresolvedalias('count('_2), None), unresolvedalias('count('_3), None)] - // +- 'UnresolvedRelation [t], [], false - // - // == Analyzed Logical Plan == - // min(_3): int, min(_3): int, max(_3): int, min(_1): int, max(_1): int, max(_1): int, count(1): bigint, count(_1): bigint, count(_2): bigint, count(_3): bigint - // Aggregate [min(_3#23) AS min(_3)#43, min(_3#23) AS min(_3)#44, max(_3#23) AS max(_3)#45, min(_1#21) AS min(_1)#46, max(_1#21) AS max(_1)#47, max(_1#21) AS max(_1)#48, count(1) AS count(1)#49L, count(_1#21) AS count(_1)#50L, count(_2#22) AS count(_2)#51L, count(_3#23) AS count(_3)#52L] - // +- SubqueryAlias t - // +- View (`t`, [_1#21,_2#22,_3#23]) - // +- RelationV2[_1#21, _2#22, _3#23] parquet file:/private/var/folders/hm/dghdj3hn791fd9bfmwnl12km0000gn/T/spark-6609ad86-7ff5-4f83-96e2-fea5f2c85646 - // - // == Optimized Logical Plan == - // Aggregate [min(Min(_3,IntegerType)#66) AS min(_3)#43, min(Min(_3,IntegerType)#67) AS min(_3)#44, max(Max(_3,IntegerType)#68) AS max(_3)#45, min(Min(_1,IntegerType)#69) AS min(_1)#46, max(Max(_1,IntegerType)#70) AS max(_1)#47, max(Max(_1,IntegerType)#71) AS max(_1)#48, count(Count(1,LongType,false)#72L) AS count(1)#49L, count(Count(_1,LongType,false)#73L) AS count(_1)#50L, count(Count(_2,LongType,false)#74L) AS count(_2)#51L, count(Count(_3,LongType,false)#75L) AS count(_3)#52L] - // +- RelationV2[Min(_3,IntegerType)#66, Min(_3,IntegerType)#67, Max(_3,IntegerType)#68, Min(_1,IntegerType)#69, Max(_1,IntegerType)#70, Max(_1,IntegerType)#71, Count(1,LongType,false)#72L, Count(_1,LongType,false)#73L, Count(_2,LongType,false)#74L, Count(_3,LongType,false)#75L] parquet file:/private/var/folders/hm/dghdj3hn791fd9bfmwnl12km0000gn/T/spark-6609ad86-7ff5-4f83-96e2-fea5f2c85646 - // - // == Physical Plan == - // AdaptiveSparkPlan isFinalPlan=false - // +- HashAggregate(keys=[], functions=[min(Min(_3,IntegerType)#66), min(Min(_3,IntegerType)#67), max(Max(_3,IntegerType)#68), min(Min(_1,IntegerType)#69), max(Max(_1,IntegerType)#70), max(Max(_1,IntegerType)#71), count(Count(1,LongType,false)#72L), count(Count(_1,LongType,false)#73L), count(Count(_2,LongType,false)#74L), count(Count(_3,LongType,false)#75L)], output=[min(_3)#43, min(_3)#44, max(_3)#45, min(_1)#46, max(_1)#47, max(_1)#48, count(1)#49L, count(_1)#50L, count(_2)#51L, count(_3)#52L]) - // +- HashAggregate(keys=[], functions=[partial_min(Min(_3,IntegerType)#66), partial_min(Min(_3,IntegerType)#67), partial_max(Max(_3,IntegerType)#68), partial_min(Min(_1,IntegerType)#69), partial_max(Max(_1,IntegerType)#70), partial_max(Max(_1,IntegerType)#71), partial_count(Count(1,LongType,false)#72L), partial_count(Count(_1,LongType,false)#73L), partial_count(Count(_2,LongType,false)#74L), partial_count(Count(_3,LongType,false)#75L)], output=[min#86, min#87, max#88, min#89, max#90, max#91, count#92L, count#93L, count#94L, count#95L]) - // +- Project [Min(_3,IntegerType)#66, Min(_3,IntegerType)#67, Max(_3,IntegerType)#68, Min(_1,IntegerType)#69, Max(_1,IntegerType)#70, Max(_1,IntegerType)#71, Count(1,LongType,false)#72L, Count(_1,LongType,false)#73L, Count(_2,LongType,false)#74L, Count(_3,LongType,false)#75L] - // +- BatchScan[Min(_3,IntegerType)#66, Min(_3,IntegerType)#67, Max(_3,IntegerType)#68, Min(_1,IntegerType)#69, Max(_1,IntegerType)#70, Max(_1,IntegerType)#71, Count(1,LongType,false)#72L, Count(_1,LongType,false)#73L, Count(_2,LongType,false)#74L, Count(_3,LongType,false)#75L] ParquetScan DataFilters: [], Format: parquet, Location: InMemoryFileIndex(1 paths)[file:/private/var/folders/hm/dghdj3hn791fd9bfmwnl12km0000gn/T/spark-66..., PartitionFilters: [], PushedFilters: [], ReadSchema: struct<_1:int,_2:string,_3:int>, PushedFilters: [] - // scalastyle:on line.size.limit - - // selectAgg.show() - // +-------+-------+-------+-------+-------+-------+--------+---------+---------+---------+ - // |min(_3)|min(_3)|max(_3)|min(_1)|max(_1)|max(_1)|count(1)|count(_1)|count(_2)|count(_3)| - // +-------+-------+-------+-------+-------+-------+--------+---------+---------+---------+ - // | 2| 2| 19| -2| 9| 9| 6| 6| 4| 6| - // +-------+-------+-------+-------+-------+-------+--------+---------+---------+---------+ - - checkAnswer(selectAgg, Seq(Row(2, 2, 19, -2, 9, 9, 6, 6, 4, 6))) - } - spark.sessionState.catalog.dropTable( - TableIdentifier("tmp"), ignoreIfNotExists = true, purge = false) - spark.conf.unset(SQLConf.PARQUET_AGGREGATE_PUSHDOWN_ENABLED.key) - } - test("appending") { val data = (0 until 10).map(i => (i, i.toString)) spark.createDataFrame(data).toDF("c1", "c2").createOrReplaceTempView("tmp") @@ -949,6 +903,205 @@ abstract class ParquetQuerySuite extends QueryTest with ParquetTest with SharedS } } } + + test("test aggregate push down") { + val data = Seq((-2, "abc", 2), (3, "def", 4), (6, "ghi", 2), (0, null, 19), + (9, "mno", 7), (2, null, 6)) + withParquetTable(data, "t") { + withSQLConf( + SQLConf.PARQUET_AGGREGATE_PUSHDOWN_ENABLED.key -> "true") { + + // This is not pushed down since aggregates have arithmetic operation + val selectAgg1 = sql("SELECT min(_3 + _1), max(_3 + _1) FROM t") + checkAnswer(selectAgg1, Seq(Row(0, 19))) + + // sum is not pushed down + val selectAgg2 = sql("SELECT sum(_3) FROM t") + checkAnswer(selectAgg2, Seq(Row(40))) + + val selectAgg3 = sql("SELECT min(_3), min(_3), max(_3), min(_1), max(_1), max(_1)," + + " count(*), count(_1), count(_2), count(_3) FROM t") + + selectAgg3.queryExecution.optimizedPlan.collect { + case _: DataSourceV2ScanRelation => + val expected_plan_fragment = + "PushedAggregation: [Min(_3,IntegerType), " + + "Min(_3,IntegerType), " + + "Max(_3,IntegerType), " + + "Min(_1,IntegerType), " + + "Max(_1,IntegerType), " + + "Max(_1,IntegerType), " + + "Count(1,LongType,false), " + + "Count(_1,LongType,false), " + + "Count(_2,LongType,false), " + + "Count(_3,LongType,false)]" + checkKeywordsExistsInExplain(selectAgg3, expected_plan_fragment) + } + + checkAnswer(selectAgg3, Seq(Row(2, 2, 19, -2, 9, 9, 6, 6, 4, 6))) + } + } + spark.sessionState.catalog.dropTable( + TableIdentifier("tmp"), ignoreIfNotExists = true, purge = false) + } + + test("aggregate pushdown for different data types") { + implicit class StringToDate(s: String) { + def date: Date = Date.valueOf(s) + } + + implicit class StringToTs(s: String) { + def ts: Timestamp = Timestamp.valueOf(s) + } + + val rows = + Seq( + Row( + "a string", + true, + 10.toByte, + "Spark SQL".getBytes, + 12.toShort, + 3, + Long.MaxValue, + 0.15.toFloat, + 0.75D, + Decimal("12.345678"), + ("2021-01-01").date, + ("2015-01-01 23:50:59.123").ts), + Row( + "test string", + false, + 1.toByte, + "Parquet".getBytes, + 2.toShort, + null, + Long.MinValue, + 0.25.toFloat, + 0.85D, + Decimal("1.2345678"), + ("2015-01-01").date, + ("2021-01-01 23:50:59.123").ts), + Row( + null, + true, + 10000.toByte, + "Spark ML".getBytes, + 222.toShort, + 113, + 11111111L, + 0.25.toFloat, + 0.75D, + Decimal("12345.678"), + ("2004-06-19").date, + ("1999-08-26 10:43:59.123").ts) + ) + + val schema = StructType(List(StructField("StringCol", StringType, true), + StructField("BooleanCol", BooleanType, false), + StructField("ByteCol", ByteType, false), + StructField("BinaryCol", BinaryType, false), + StructField("ShortCol", ShortType, false), + StructField("IntegerCol", IntegerType, true), + StructField("LongCol", LongType, false), + StructField("FloatCol", FloatType, false), + StructField("DoubleCol", DoubleType, false), + StructField("DecimalCol", DecimalType(25, 5), true), + StructField("DateCol", DateType, false), + StructField("TimestampCol", TimestampType, false)).toArray) + + val rdd = sparkContext.parallelize(rows) + spark.createDataFrame(rdd, schema).createOrReplaceTempView("test") + val enableVectorizedReader = Seq("false", "true") + for (testVectorizedReader <- enableVectorizedReader) { + withSQLConf( + SQLConf.PARQUET_AGGREGATE_PUSHDOWN_ENABLED.key -> "true", + vectorizedReaderEnabledKey -> testVectorizedReader) { + withTempPath { file => + val df1 = spark.createDataFrame(rdd, schema) + df1.write.parquet(file.getCanonicalPath) + val df2 = spark.read.parquet(file.getCanonicalPath) + df2.createOrReplaceTempView("test") + + val testMin = sql("SELECT min(StringCol), min(BooleanCol), min(ByteCol), " + + "min(BinaryCol), min(ShortCol), min(IntegerCol), min(LongCol), min(FloatCol), " + + "min(DoubleCol), min(DecimalCol), min(DateCol) FROM test") + + testMin.queryExecution.optimizedPlan.collect { + case _: DataSourceV2ScanRelation => + val expected_plan_fragment = + "PushedAggregation: [Min(StringCol,StringType), " + + "Min(BooleanCol,BooleanType), " + + "Min(ByteCol,ByteType), " + + "Min(BinaryCol,BinaryType), " + + "Min(ShortCol,ShortType), " + + "Min(IntegerCol,IntegerType), " + + "Min(LongCol,LongType), " + + "Min(FloatCol,FloatType), " + + "Min(DoubleCol,DoubleType), " + + "Min(DecimalCol,DecimalType(25,5)), " + + "Min(DateCol,DateType)]" + checkKeywordsExistsInExplain(testMin, expected_plan_fragment) + } + + checkAnswer(testMin, Seq(Row("a string", false, 1.toByte, "Parquet".getBytes, + 2.toShort, 3, -9223372036854775808L, 0.15.toFloat, 0.75D, 1.23457, + ("2004-06-19").date))) + + val testMax = sql("SELECT max(StringCol), max(BooleanCol), max(ByteCol), " + + "max(BinaryCol), max(ShortCol), max(IntegerCol), max(LongCol), max(FloatCol), " + + "max(DoubleCol), max(DecimalCol), max(DateCol) FROM test") + + testMax.queryExecution.optimizedPlan.collect { + case _: DataSourceV2ScanRelation => + val expected_plan_fragment = + "PushedAggregation: [Max(StringCol,StringType), " + + "Max(BooleanCol,BooleanType)," + + " Max(ByteCol,ByteType), " + + "Max(BinaryCol,BinaryType)," + + " Max(ShortCol,ShortType), " + + "Max(IntegerCol,IntegerType)," + + " Max(LongCol,LongType), " + + "Max(FloatCol,FloatType)," + + " Max(DoubleCol,DoubleType), " + + "Max(DecimalCol,DecimalType(25,5)), " + + "Max(DateCol,DateType)]" + checkKeywordsExistsInExplain(testMax, expected_plan_fragment) + } + + checkAnswer(testMax, Seq(Row("test string", true, 16.toByte, "Spark SQL".getBytes, + 222.toShort, 113, 9223372036854775807L, 0.25.toFloat, 0.85D, 12345.678, + ("2021-01-01").date))) + + val testCount = sql("SELECT count(*), count(StringCol), count(BooleanCol)," + + " count(ByteCol), count(BinaryCol), count(ShortCol), count(IntegerCol)," + + " count(LongCol), count(FloatCol), count(DoubleCol)," + + " count(DecimalCol), count(DateCol), count(TimestampCol) FROM test") + + testCount.queryExecution.optimizedPlan.collect { + case _: DataSourceV2ScanRelation => + val expected_plan_fragment = + "PushedAggregation: [Count(1,LongType,false), " + + "Count(StringCol,LongType,false), " + + "Count(BooleanCol,LongType,false), " + + "Count(ByteCol,LongType,false), " + + "Count(BinaryCol,LongType,false), " + + "Count(ShortCol,LongType,false), " + + "Count(IntegerCol,LongType,false), " + + "Count(LongCol,LongType,false), " + + "Count(FloatCol,LongType,false), " + + "Count(DoubleCol,LongType,false), " + + "Count(DecimalCol,LongType,false), " + + "Count(DateCol,LongType,false), " + + "Count(TimestampCol,LongType,false)]" + checkKeywordsExistsInExplain(testCount, expected_plan_fragment) + } + + checkAnswer(testCount, Seq(Row(3, 2, 3, 3, 3, 3, 2, 3, 3, 3, 3, 3, 3))) + } + } + } + } } class ParquetV1QuerySuite extends ParquetQuerySuite { From 9af611f4ed707db3c21a907da958379b01094ffd Mon Sep 17 00:00:00 2001 From: Huaxin Gao Date: Mon, 26 Apr 2021 10:19:21 -0700 Subject: [PATCH 12/30] remove unnessary change in Count --- .../expressions/aggregate/Count.scala | 22 +++++-------------- 1 file changed, 6 insertions(+), 16 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Count.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Count.scala index 472436f3884f5..dfdd828d10d03 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Count.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Count.scala @@ -80,25 +80,15 @@ case class Count(children: Seq[Expression]) extends DeclarativeAggregate { override def defaultResult: Option[Literal] = Option(Literal(0L)) - private[sql] var pushDown: Boolean = false - override lazy val updateExpressions = { - if (!pushDown) { - val nullableChildren = children.filter(_.nullable) - if (nullableChildren.isEmpty) { - Seq( - /* count = */ count + 1L - ) - } else { - Seq( - /* count = */ If(nullableChildren.map(IsNull).reduce(Or), count, count + 1L) - ) - } + val nullableChildren = children.filter(_.nullable) + if (nullableChildren.isEmpty) { + Seq( + /* count = */ count + 1L + ) } else { Seq( - // if count is pushed down to Data Source layer, add the count result retrieved from - // Data Source - /* count = */ count + children.head + /* count = */ If(nullableChildren.map(IsNull).reduce(Or), count, count + 1L) ) } } From d277701eb888c2703fa49c5cd6b8246bce2c00a6 Mon Sep 17 00:00:00 2001 From: Huaxin Gao Date: Tue, 27 Apr 2021 18:53:37 -0700 Subject: [PATCH 13/30] add a rule for PartialAggregatePushDown --- .../apache/spark/sql/connector/read/Scan.java | 11 ++ .../read/SupportsPushDownAggregates.java | 44 ------ .../spark/sql/execution/SparkOptimizer.scala | 4 +- .../PartialAggregatePushDown.scala | 148 ++++++++++++++++++ .../datasources/v2/PushDownUtils.scala | 30 +--- .../v2/V2ScanRelationPushDown.scala | 99 +----------- .../datasources/v2/parquet/ParquetScan.scala | 25 ++- .../v2/parquet/ParquetScanBuilder.scala | 33 +--- .../org/apache/spark/sql/FileScanSuite.scala | 5 +- 9 files changed, 193 insertions(+), 206 deletions(-) delete mode 100644 sql/catalyst/src/main/java/org/apache/spark/sql/connector/read/SupportsPushDownAggregates.java create mode 100644 sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/PartialAggregatePushDown.scala diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/read/Scan.java b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/read/Scan.java index b70a656c492a8..dd09c37cada90 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/read/Scan.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/read/Scan.java @@ -24,6 +24,7 @@ import org.apache.spark.sql.types.StructType; import org.apache.spark.sql.connector.catalog.Table; import org.apache.spark.sql.connector.catalog.TableCapability; +import org.apache.spark.sql.sources.Aggregation; /** * A logical representation of a data source scan. This interface is used to provide logical @@ -112,4 +113,14 @@ default CustomMetric[] supportedCustomMetrics() { CustomMetric[] NO_METRICS = {}; return NO_METRICS; } + + /** + * Pushes down Aggregation to scan. + * The Aggregation can be pushed down only if all the Aggregate Functions can + * be pushed down. + */ + default void pushAggregation(Aggregation aggregation) { + throw new UnsupportedOperationException(description() + + ": Push down Aggregation is not supported"); + } } diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/read/SupportsPushDownAggregates.java b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/read/SupportsPushDownAggregates.java deleted file mode 100644 index 40ed146114ffe..0000000000000 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/read/SupportsPushDownAggregates.java +++ /dev/null @@ -1,44 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql.connector.read; - -import org.apache.spark.annotation.Evolving; -import org.apache.spark.sql.sources.Aggregation; - -/** - * A mix-in interface for {@link ScanBuilder}. Data sources can implement this interface to - * push down aggregates to the data source. - * - * @since 3.2.0 - */ -@Evolving -public interface SupportsPushDownAggregates extends ScanBuilder { - - /** - * Pushes down Aggregation to datasource. - * The Aggregation can be pushed down only if all the Aggregate Functions can - * be pushed down. - */ - void pushAggregation(Aggregation aggregation); - - /** - * Returns the aggregation that are pushed to the data source via - * {@link #pushAggregation(Aggregation aggregation)}. - */ - Aggregation pushedAggregation(); -} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkOptimizer.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkOptimizer.scala index dde5dc2be0556..f2aa5365cc7ce 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkOptimizer.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkOptimizer.scala @@ -23,6 +23,7 @@ import org.apache.spark.sql.catalyst.optimizer._ import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan import org.apache.spark.sql.catalyst.rules.Rule import org.apache.spark.sql.connector.catalog.CatalogManager +import org.apache.spark.sql.execution.datasources.PartialAggregatePushDown import org.apache.spark.sql.execution.datasources.PruneFileSourcePartitions import org.apache.spark.sql.execution.datasources.SchemaPruning import org.apache.spark.sql.execution.datasources.v2.{V2ScanRelationPushDown, V2Writes} @@ -37,7 +38,8 @@ class SparkOptimizer( override def earlyScanPushDownRules: Seq[Rule[LogicalPlan]] = // TODO: move SchemaPruning into catalyst - SchemaPruning :: V2ScanRelationPushDown :: V2Writes :: PruneFileSourcePartitions :: Nil + SchemaPruning :: PartialAggregatePushDown :: V2ScanRelationPushDown :: V2Writes :: + PruneFileSourcePartitions :: Nil override def defaultBatches: Seq[Batch] = (preOptimizationBatches ++ super.defaultBatches :+ Batch("Optimize Metadata Only Query", Once, OptimizeMetadataOnlyQuery(catalog)) :+ diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/PartialAggregatePushDown.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/PartialAggregatePushDown.scala new file mode 100644 index 0000000000000..60322278215e2 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/PartialAggregatePushDown.scala @@ -0,0 +1,148 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.datasources + +import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.expressions.aggregate.AggregateExpression +import org.apache.spark.sql.catalyst.planning.ScanOperation +import org.apache.spark.sql.catalyst.plans.logical.{Aggregate, LogicalPlan} +import org.apache.spark.sql.catalyst.rules.Rule +import org.apache.spark.sql.catalyst.util.toPrettySQL +import org.apache.spark.sql.connector.read.{Scan, V1Scan} +import org.apache.spark.sql.execution.datasources.v2.{DataSourceV2Implicits, DataSourceV2Relation, DataSourceV2ScanRelation} +import org.apache.spark.sql.execution.datasources.v2.parquet.ParquetTable +import org.apache.spark.sql.sources +import org.apache.spark.sql.sources.Aggregation +import org.apache.spark.sql.types.StructType + +/** + * Push down partial Aggregate to datasource for better performance + */ +object PartialAggregatePushDown extends Rule[LogicalPlan] with AliasHelper { + import DataSourceV2Implicits._ + + def apply(plan: LogicalPlan): LogicalPlan = plan transform { + // Pattern matching for partial aggregate push down + // For parquet, footer only has statistics information for max/min/count. + // It doesn't handle max/min/count associated with filter or group by. + // ORC is similar. If JDBC partial aggregate push down is added later, + // these condition checks need to be changed. + case aggNode@Aggregate(groupingExpressions, resultExpressions, child) + if (groupingExpressions.isEmpty) => + child match { + case ScanOperation(project, filters, relation@DataSourceV2Relation(table, _, _, _, _)) + if (filters.isEmpty) && table.isInstanceOf[ParquetTable] => + val scanBuilder = relation.table.asReadable.newScanBuilder(relation.options) + var aggregates = resultExpressions.flatMap { expr => + expr.collect { + case agg: AggregateExpression => + replaceAlias(agg, getAliasMap(project)).asInstanceOf[AggregateExpression] + } + } + aggregates = DataSourceStrategy.normalizeExprs(aggregates, relation.output) + .asInstanceOf[Seq[AggregateExpression]] + + val scan = scanBuilder.build() + val translatedAggregates = aggregates.map(DataSourceStrategy + .translateAggregate(_, PushableColumn(false))) + if (translatedAggregates.exists(_.isEmpty)) { + aggNode // return original plan node + } else { + val aggregation = Aggregation(translatedAggregates.flatten, Seq.empty) + scan.pushAggregation(aggregation) + // use the aggregate columns as the output columns + // e.g. TABLE t (c1 INT, c2 INT, c3 INT) + // SELECT min(c1), max(c1) FROM t; + // Use min(c1), max(c1) as output for DataSourceV2ScanRelation + // We want to have the following logical plan: + // == Optimized Logical Plan == + // Aggregate [min(min(c1)#21) AS min(c1)#17, max(max(c1)#22) AS max(c1)#18] + // +- RelationV2[min(c1)#21, max(c1)#22] parquet file ... + val output = aggregates.map { + case agg: AggregateExpression => + AttributeReference(toPrettySQL(agg), agg.dataType)() + } + + // No need to do column pruning because only the aggregate columns are used as + // DataSourceV2ScanRelation output columns. All the other columns are not + // included in the output. Since PushDownUtils.pruneColumns is not called, + // ScanBuilder.requiredSchema is not pruned, but ScanBuilder.requiredSchema is + // not used anyways. The schema for aggregate columns will be built in Scan. + + + logInfo( + s""" + |Pushing operators to ${relation.name} + |Pushed Aggregate Functions: ${aggregation.aggregateExpressions.mkString(", ")} + |Output: ${output.mkString(", ")} + """.stripMargin) + val wrappedScan = scan match { + case v1: V1Scan => + V1ScanWrapper(v1, Seq.empty[sources.Filter], Seq.empty[sources.Filter], aggregation) + case _ => scan + } + + val scanRelation = DataSourceV2ScanRelation(relation, wrappedScan, output) + val plan = Aggregate(groupingExpressions, resultExpressions, scanRelation) + + // Change the optimized logical plan to reflect the pushed down aggregate + // e.g. TABLE t (c1 INT, c2 INT, c3 INT) + // SELECT min(c1), max(c1) FROM t; + // The original logical plan is + // Aggregate [min(c1#9) AS min(c1)#17, max(c1#9) AS max(c1)#18] + // +- RelationV2[c1#9] parquet ... + // + // After change the V2ScanRelation output to [min(_1)#21, max(_1)#22] + // we have the following + // !Aggregate [min(_1#9) AS min(_1)#17, max(_1#9) AS max(_1)#18] + // +- RelationV2[min(_1)#21, max(_1)#22] parquet ... + // + // We want to change it to + // == Optimized Logical Plan == + // Aggregate [min(min(c1)#21) AS min(c1)#17, max(max(c1)#22) AS max(c1)#18] + // +- RelationV2[min(c1)#21, max(c1)#22] parquet file ... + var i = 0 + plan.transformExpressions { + case agg: AggregateExpression => + i += 1 + val aggFunction: aggregate.AggregateFunction = agg.aggregateFunction match { + case _: aggregate.Max => aggregate.Max(output(i - 1)) + case _: aggregate.Min => aggregate.Min(output(i - 1)) + case _: aggregate.Sum => aggregate.Sum(output(i - 1)) + case _: aggregate.Average => aggregate.Average(output(i - 1)) + case _: aggregate.Count => aggregate.Sum(output(i - 1)) + case _ => agg.aggregateFunction + } + agg.copy(aggregateFunction = aggFunction, filter = None) + } + } + + case _ => aggNode // return original plan node + } + } +} + +// A wrapper for v1 scan to carry the translated filters and the handled ones. This is required by +// the physical v1 scan node. +case class V1ScanWrapper( + v1Scan: V1Scan, + translatedFilters: Seq[sources.Filter], + handledFilters: Seq[sources.Filter], + pushedAggregates: sources.Aggregation) extends Scan { + override def readSchema(): StructType = v1Scan.readSchema() +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/PushDownUtils.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/PushDownUtils.scala index 0745deb82c998..1f57f17911457 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/PushDownUtils.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/PushDownUtils.scala @@ -20,14 +20,11 @@ package org.apache.spark.sql.execution.datasources.v2 import scala.collection.mutable import org.apache.spark.sql.catalyst.expressions.{AttributeReference, AttributeSet, Expression, NamedExpression, PredicateHelper, SchemaPruning} -import org.apache.spark.sql.catalyst.expressions.aggregate.AggregateExpression import org.apache.spark.sql.catalyst.util.CharVarcharUtils -import org.apache.spark.sql.connector.read.{Scan, ScanBuilder, SupportsPushDownAggregates, SupportsPushDownFilters, SupportsPushDownRequiredColumns} +import org.apache.spark.sql.connector.read.{Scan, ScanBuilder, SupportsPushDownFilters, SupportsPushDownRequiredColumns} import org.apache.spark.sql.execution.datasources.DataSourceStrategy -import org.apache.spark.sql.execution.datasources.PushableColumn import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.sources -import org.apache.spark.sql.sources.Aggregation import org.apache.spark.sql.types.StructType object PushDownUtils extends PredicateHelper { @@ -73,31 +70,6 @@ object PushDownUtils extends PredicateHelper { } } - /** - * Pushes down aggregates to the data source reader - * - * @return pushed aggregation. - */ - def pushAggregates( - scanBuilder: ScanBuilder, - aggregates: Seq[AggregateExpression], - groupBy: Seq[Expression]): Aggregation = { - - scanBuilder match { - case r: SupportsPushDownAggregates => - val translatedAggregates = aggregates.map(DataSourceStrategy - .translateAggregate(_, PushableColumn(false))) - - if (translatedAggregates.exists(_.isEmpty)) { - Aggregation.empty - } else { - r.pushAggregation(Aggregation(translatedAggregates.flatten, Seq.empty)) - r.pushedAggregation - } - case _ => Aggregation.empty - } - } - /** * Applies column pruning to the data source, w.r.t. the references of the given expressions. * diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/V2ScanRelationPushDown.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/V2ScanRelationPushDown.scala index 64951e43052fe..673a1dd462337 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/V2ScanRelationPushDown.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/V2ScanRelationPushDown.scala @@ -18,11 +18,9 @@ package org.apache.spark.sql.execution.datasources.v2 import org.apache.spark.sql.catalyst.expressions._ -import org.apache.spark.sql.catalyst.expressions.aggregate.AggregateExpression import org.apache.spark.sql.catalyst.planning.ScanOperation -import org.apache.spark.sql.catalyst.plans.logical.{Aggregate, Filter, LogicalPlan, Project} +import org.apache.spark.sql.catalyst.plans.logical.{Filter, LogicalPlan, Project} import org.apache.spark.sql.catalyst.rules.Rule -import org.apache.spark.sql.catalyst.util.toPrettySQL import org.apache.spark.sql.connector.read.{Scan, V1Scan} import org.apache.spark.sql.execution.datasources.DataSourceStrategy import org.apache.spark.sql.sources @@ -33,101 +31,6 @@ object V2ScanRelationPushDown extends Rule[LogicalPlan] with AliasHelper { import DataSourceV2Implicits._ override def apply(plan: LogicalPlan): LogicalPlan = plan transformDown { - // Pattern matching for partial aggregate push down - // For parquet, footer only has statistics information for max/min/count. - // It doesn't handle max/min/count associated with filter or group by. - // ORC is similar. If JDBC partial aggregate push down is added later, - // these condition checks need to be changed. - case aggNode@Aggregate(groupingExpressions, resultExpressions, child) - if (groupingExpressions.isEmpty) => - child match { - case ScanOperation(project, filters, relation: DataSourceV2Relation) - if (filters.isEmpty) => - val scanBuilder = relation.table.asReadable.newScanBuilder(relation.options) - var aggregates = resultExpressions.flatMap { expr => - expr.collect { - case agg: AggregateExpression => - replaceAlias(agg, getAliasMap(project)).asInstanceOf[AggregateExpression] - } - } - aggregates = DataSourceStrategy.normalizeExprs(aggregates, relation.output) - .asInstanceOf[Seq[AggregateExpression]] - - val aggregation = PushDownUtils - .pushAggregates(scanBuilder, aggregates, groupingExpressions) - if (aggregation.aggregateExpressions.isEmpty) { - aggNode // return original plan node - } else { - // use the aggregate columns as the output columns - // e.g. TABLE t (c1 INT, c2 INT, c3 INT) - // SELECT min(c1), max(c1) FROM t; - // Use min(c1), max(c1) as output for DataSourceV2ScanRelation - // We want to have the following logical plan: - // == Optimized Logical Plan == - // Aggregate [min(min(c1)#21) AS min(c1)#17, max(max(c1)#22) AS max(c1)#18] - // +- RelationV2[min(c1)#21, max(c1)#22] parquet file ... - val output = aggregates.map { - case agg: AggregateExpression => - AttributeReference(toPrettySQL(agg), agg.dataType)() - } - - // No need to do column pruning because only the aggregate columns are used as - // DataSourceV2ScanRelation output columns. All the other columns are not - // included in the output. Since PushDownUtils.pruneColumns is not called, - // ScanBuilder.requiredSchema is not pruned, but ScanBuilder.requiredSchema is - // not used anyways. The schema for aggregate columns will be built in Scan. - val scan = scanBuilder.build() - - logInfo( - s""" - |Pushing operators to ${relation.name} - |Pushed Aggregate Functions: ${aggregation.aggregateExpressions.mkString(", ")} - |Output: ${output.mkString(", ")} - """.stripMargin) - val wrappedScan = scan match { - case v1: V1Scan => - V1ScanWrapper(v1, Seq.empty[sources.Filter], Seq.empty[sources.Filter], aggregation) - case _ => scan - } - - val scanRelation = DataSourceV2ScanRelation(relation, wrappedScan, output) - val plan = Aggregate(groupingExpressions, resultExpressions, scanRelation) - - // Change the optimized logical plan to reflect the pushed down aggregate - // e.g. TABLE t (c1 INT, c2 INT, c3 INT) - // SELECT min(c1), max(c1) FROM t; - // The original logical plan is - // Aggregate [min(c1#9) AS min(c1)#17, max(c1#9) AS max(c1)#18] - // +- RelationV2[c1#9] parquet ... - // - // After change the V2ScanRelation output to [min(_1)#21, max(_1)#22] - // we have the following - // !Aggregate [min(_1#9) AS min(_1)#17, max(_1#9) AS max(_1)#18] - // +- RelationV2[min(_1)#21, max(_1)#22] parquet ... - // - // We want to change it to - // == Optimized Logical Plan == - // Aggregate [min(min(c1)#21) AS min(c1)#17, max(max(c1)#22) AS max(c1)#18] - // +- RelationV2[min(c1)#21, max(c1)#22] parquet file ... - var i = 0 - plan.transformExpressions { - case agg: AggregateExpression => - i += 1 - val aggFunction: aggregate.AggregateFunction = agg.aggregateFunction match { - case _: aggregate.Max => aggregate.Max(output(i - 1)) - case _: aggregate.Min => aggregate.Min(output(i - 1)) - case _: aggregate.Sum => aggregate.Sum(output(i - 1)) - case _: aggregate.Average => aggregate.Average(output(i - 1)) - case _: aggregate.Count => aggregate.Sum(output(i - 1)) - case _ => agg.aggregateFunction - } - agg.copy(aggregateFunction = aggFunction, filter = None) - } - } - - case _ => aggNode // return original plan node - } - case ScanOperation(project, filters, relation: DataSourceV2Relation) => val scanBuilder = relation.table.asReadable.newScanBuilder(relation.options) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/parquet/ParquetScan.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/parquet/ParquetScan.scala index d634f85564f1e..d55e26e9e56a7 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/parquet/ParquetScan.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/parquet/ParquetScan.scala @@ -29,7 +29,7 @@ import org.apache.spark.sql.execution.datasources.PartitioningAwareFileIndex import org.apache.spark.sql.execution.datasources.parquet.{ParquetOptions, ParquetReadSupport, ParquetWriteSupport} import org.apache.spark.sql.execution.datasources.v2.FileScan import org.apache.spark.sql.internal.SQLConf -import org.apache.spark.sql.sources.{Aggregation, Filter} +import org.apache.spark.sql.sources.{Aggregation, Count, Filter, Max, Min} import org.apache.spark.sql.types.StructType import org.apache.spark.sql.util.CaseInsensitiveStringMap import org.apache.spark.util.SerializableConfiguration @@ -42,11 +42,11 @@ case class ParquetScan( readDataSchema: StructType, readPartitionSchema: StructType, pushedFilters: Array[Filter], - pushedAggregations: Aggregation = Aggregation.empty, options: CaseInsensitiveStringMap, partitionFilters: Seq[Expression] = Seq.empty, dataFilters: Seq[Expression] = Seq.empty) extends FileScan { override def isSplitable(path: Path): Boolean = true + private var pushedAggregations = Aggregation.empty override def createReaderFactory(): PartitionReaderFactory = { val readDataSchemaAsJson = readDataSchema.json @@ -120,4 +120,25 @@ case class ParquetScan( a.aggregateExpressions.sortBy(_.hashCode()) .sameElements(b.aggregateExpressions.sortBy(_.hashCode())) } + + override def pushAggregation(aggregation: Aggregation): Unit = { + if (!sparkSession.sessionState.conf.parquetAggregatePushDown || + aggregation.groupByColumns.nonEmpty) { + Aggregation.empty + return + } + + aggregation.aggregateExpressions.foreach { agg => + if (!agg.isInstanceOf[Max] && !agg.isInstanceOf[Min] && !agg.isInstanceOf[Count]) { + Aggregation.empty + return + } else if (agg.isInstanceOf[Count] && agg.asInstanceOf[Count].isDistinct) { + // parquet's statistics doesn't have distinct count info + Aggregation.empty + return + } + } + pushedAggregations = aggregation + } + } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/parquet/ParquetScanBuilder.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/parquet/ParquetScanBuilder.scala index d7fa474e1133c..44053830defe5 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/parquet/ParquetScanBuilder.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/parquet/ParquetScanBuilder.scala @@ -20,11 +20,11 @@ package org.apache.spark.sql.execution.datasources.v2.parquet import scala.collection.JavaConverters._ import org.apache.spark.sql.SparkSession -import org.apache.spark.sql.connector.read.{Scan, SupportsPushDownAggregates, SupportsPushDownFilters} +import org.apache.spark.sql.connector.read.{Scan, SupportsPushDownFilters} import org.apache.spark.sql.execution.datasources.PartitioningAwareFileIndex import org.apache.spark.sql.execution.datasources.parquet.{ParquetFilters, SparkToParquetSchemaConverter} import org.apache.spark.sql.execution.datasources.v2.FileScanBuilder -import org.apache.spark.sql.sources.{Aggregation, Count, Filter, Max, Min} +import org.apache.spark.sql.sources.Filter import org.apache.spark.sql.types.StructType import org.apache.spark.sql.util.CaseInsensitiveStringMap @@ -34,8 +34,7 @@ case class ParquetScanBuilder( schema: StructType, dataSchema: StructType, options: CaseInsensitiveStringMap) - extends FileScanBuilder(sparkSession, fileIndex, dataSchema) with SupportsPushDownFilters - with SupportsPushDownAggregates { + extends FileScanBuilder(sparkSession, fileIndex, dataSchema) with SupportsPushDownFilters { lazy val hadoopConf = { val caseSensitiveMap = options.asCaseSensitiveMap.asScala.toMap // Hadoop Configurations are case sensitive. @@ -71,32 +70,8 @@ case class ParquetScanBuilder( // All filters that can be converted to Parquet are pushed down. override def pushedFilters(): Array[Filter] = pushedParquetFilters - private var pushedAggregations = Aggregation.empty - - override def pushAggregation(aggregation: Aggregation): Unit = { - if (!sparkSession.sessionState.conf.parquetAggregatePushDown || - aggregation.groupByColumns.nonEmpty) { - Aggregation.empty - return - } - - aggregation.aggregateExpressions.foreach { agg => - if (!agg.isInstanceOf[Max] && !agg.isInstanceOf[Min] && !agg.isInstanceOf[Count]) { - Aggregation.empty - return - } else if (agg.isInstanceOf[Count] && agg.asInstanceOf[Count].isDistinct) { - // parquet's statistics doesn't have distinct count info - Aggregation.empty - return - } - } - this.pushedAggregations = aggregation - } - - override def pushedAggregation(): Aggregation = pushedAggregations - override def build(): Scan = { ParquetScan(sparkSession, hadoopConf, fileIndex, dataSchema, readDataSchema(), - readPartitionSchema(), pushedParquetFilters, pushedAggregations, options) + readPartitionSchema(), pushedParquetFilters, options) } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/FileScanSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/FileScanSuite.scala index fcc95de2d213d..4e7fe8455ff93 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/FileScanSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/FileScanSuite.scala @@ -31,7 +31,7 @@ import org.apache.spark.sql.execution.datasources.v2.json.JsonScan import org.apache.spark.sql.execution.datasources.v2.orc.OrcScan import org.apache.spark.sql.execution.datasources.v2.parquet.ParquetScan import org.apache.spark.sql.execution.datasources.v2.text.TextScan -import org.apache.spark.sql.sources.{Aggregation, Filter} +import org.apache.spark.sql.sources.Filter import org.apache.spark.sql.test.SharedSparkSession import org.apache.spark.sql.types.StructType import org.apache.spark.sql.util.CaseInsensitiveStringMap @@ -354,8 +354,7 @@ class FileScanSuite extends FileScanSuiteBase { val scanBuilders = Seq[(String, ScanBuilder, Seq[String])]( ("ParquetScan", (s, fi, ds, rds, rps, f, o, pf, df) => - ParquetScan(s, s.sessionState.newHadoopConf(), fi, ds, rds, rps, f, Aggregation.empty, - o, pf, df), + ParquetScan(s, s.sessionState.newHadoopConf(), fi, ds, rds, rps, f, o, pf, df), Seq.empty), ("OrcScan", (s, fi, ds, rds, rps, f, o, pf, df) => From 743bc8aad9055941b4b6d02639f429659a60522f Mon Sep 17 00:00:00 2001 From: Huaxin Gao Date: Tue, 27 Apr 2021 23:20:57 -0700 Subject: [PATCH 14/30] fix tests failure --- .../java/org/apache/spark/sql/connector/read/Scan.java | 8 ++++++++ .../execution/datasources/PartialAggregatePushDown.scala | 6 +++--- .../execution/datasources/v2/parquet/ParquetScan.scala | 1 + 3 files changed, 12 insertions(+), 3 deletions(-) diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/read/Scan.java b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/read/Scan.java index dd09c37cada90..f594983e947c1 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/read/Scan.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/read/Scan.java @@ -123,4 +123,12 @@ default void pushAggregation(Aggregation aggregation) { throw new UnsupportedOperationException(description() + ": Push down Aggregation is not supported"); } + + /* + * Returns the aggregation that is pushed to the Scan + */ + default Aggregation pushedAggregation() { + throw new UnsupportedOperationException(description() + + ": pushedAggregation is not supported"); + } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/PartialAggregatePushDown.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/PartialAggregatePushDown.scala index 60322278215e2..e92632b99ea58 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/PartialAggregatePushDown.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/PartialAggregatePushDown.scala @@ -60,11 +60,11 @@ object PartialAggregatePushDown extends Rule[LogicalPlan] with AliasHelper { val scan = scanBuilder.build() val translatedAggregates = aggregates.map(DataSourceStrategy .translateAggregate(_, PushableColumn(false))) - if (translatedAggregates.exists(_.isEmpty)) { + val aggregation = Aggregation(translatedAggregates.flatten, Seq.empty) + scan.pushAggregation(aggregation) + if (scan.pushedAggregation().aggregateExpressions.isEmpty) { aggNode // return original plan node } else { - val aggregation = Aggregation(translatedAggregates.flatten, Seq.empty) - scan.pushAggregation(aggregation) // use the aggregate columns as the output columns // e.g. TABLE t (c1 INT, c2 INT, c3 INT) // SELECT min(c1), max(c1) FROM t; diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/parquet/ParquetScan.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/parquet/ParquetScan.scala index d55e26e9e56a7..38259adb07162 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/parquet/ParquetScan.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/parquet/ParquetScan.scala @@ -141,4 +141,5 @@ case class ParquetScan( pushedAggregations = aggregation } + override def pushedAggregation(): Aggregation = pushedAggregations } From 346485e2e220f5d301ee36bdb44415f09c055521 Mon Sep 17 00:00:00 2001 From: Huaxin Gao Date: Sat, 15 May 2021 13:39:28 -0700 Subject: [PATCH 15/30] add interface SupportsPushDownAggregates --- .../apache/spark/sql/connector/read/Scan.java | 18 --- .../spark/sql/connector/read/ScanBuilder.java | 6 + .../read/SupportsPushDownAggregates.java | 60 +++++++ .../apache/spark/sql/sources/aggregates.scala | 10 +- .../spark/sql/execution/SparkOptimizer.scala | 3 +- .../datasources/DataSourceStrategy.scala | 8 +- .../PartialAggregatePushDown.scala | 148 ------------------ .../datasources/parquet/ParquetUtils.scala | 6 +- .../datasources/v2/PushDownUtils.scala | 37 ++++- .../v2/V2ScanRelationPushDown.scala | 115 +++++++++++++- .../ParquetPartitionReaderFactory.scala | 39 +---- .../datasources/v2/parquet/ParquetScan.scala | 28 +--- .../v2/parquet/ParquetScanBuilder.scala | 56 ++++++- .../org/apache/spark/sql/FileScanSuite.scala | 5 +- .../parquet/ParquetQuerySuite.scala | 90 +++++------ 15 files changed, 336 insertions(+), 293 deletions(-) create mode 100644 sql/catalyst/src/main/java/org/apache/spark/sql/connector/read/SupportsPushDownAggregates.java delete mode 100644 sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/PartialAggregatePushDown.scala diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/read/Scan.java b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/read/Scan.java index f594983e947c1..71e36bb2dd915 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/read/Scan.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/read/Scan.java @@ -113,22 +113,4 @@ default CustomMetric[] supportedCustomMetrics() { CustomMetric[] NO_METRICS = {}; return NO_METRICS; } - - /** - * Pushes down Aggregation to scan. - * The Aggregation can be pushed down only if all the Aggregate Functions can - * be pushed down. - */ - default void pushAggregation(Aggregation aggregation) { - throw new UnsupportedOperationException(description() + - ": Push down Aggregation is not supported"); - } - - /* - * Returns the aggregation that is pushed to the Scan - */ - default Aggregation pushedAggregation() { - throw new UnsupportedOperationException(description() + - ": pushedAggregation is not supported"); - } } diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/read/ScanBuilder.java b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/read/ScanBuilder.java index cb3eea7680058..c7655d77e6f4d 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/read/ScanBuilder.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/read/ScanBuilder.java @@ -28,5 +28,11 @@ */ @Evolving public interface ScanBuilder { + enum orders { FILTER, AGGREGATE, COLUMNS }; + + // Orders of operators push down. Spark will push down filters first, then aggregates, and finally + // column pruning (if applicable). + static orders[] PUSH_DOWN_ORDERS = {orders.FILTER, orders.AGGREGATE, orders.COLUMNS}; + Scan build(); } diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/read/SupportsPushDownAggregates.java b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/read/SupportsPushDownAggregates.java new file mode 100644 index 0000000000000..7304f6797792f --- /dev/null +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/read/SupportsPushDownAggregates.java @@ -0,0 +1,60 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.connector.read; + +import org.apache.spark.annotation.Evolving; +import org.apache.spark.sql.sources.Aggregation; +import org.apache.spark.sql.types.StructType; + +/** + * A mix-in interface for {@link ScanBuilder}. Data source can implement this interface to + * push down aggregates to the data source. + * + * @since 3.2.0 + */ +@Evolving +public interface SupportsPushDownAggregates extends ScanBuilder { + + /** + * Pushes down Aggregation to datasource. + * The Aggregation can be pushed down only if all the Aggregate Functions can + * be pushed down. + */ + void pushAggregation(Aggregation aggregation); + + /** + * Returns the aggregation that are pushed to the data source via + * {@link #pushAggregation(Aggregation aggregation)}. + */ + Aggregation pushedAggregation(); + + /** + * Returns the schema of the pushed down aggregates + */ + StructType getPushDownAggSchema(); + + /** + * Indicate if the data source only supports global aggregated push down + */ + boolean supportsGlobalAggregatePushDownOnly(); + + /** + * Indicate if the data source supports push down aggregates along with filters + */ + boolean supportsPushDownAggregateWithFilter(); +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/sources/aggregates.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/sources/aggregates.scala index 2b98454046a47..646082a2d44f7 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/sources/aggregates.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/sources/aggregates.scala @@ -20,15 +20,13 @@ package org.apache.spark.sql.sources import org.apache.spark.sql.types.DataType // Aggregate Functions in SQL statement. -// e.g. SELECT COUNT(EmployeeID), AVG(salary), deptID FROM dept GROUP BY deptID -// aggregateExpressions are (COUNT(EmployeeID), AVG(salary)), groupByColumns are (deptID) -case class Aggregation(aggregateExpressions: Seq[AggregateFunc], +// e.g. SELECT COUNT(EmployeeID), Max(salary), deptID FROM dept GROUP BY deptID +// aggregateExpressions are (COUNT(EmployeeID), Max(salary)), groupByColumns are (deptID) +case class Aggregation(aggregateExpressions: Seq[Seq[AggregateFunc]], groupByColumns: Seq[String]) abstract class AggregateFunc -// Avg and Sum are only supported by JDBC agg pushdown, not supported by parquet agg pushdown yet -case class Avg(column: String, dataType: DataType, isDistinct: Boolean) extends AggregateFunc case class Min(column: String, dataType: DataType) extends AggregateFunc case class Max(column: String, dataType: DataType) extends AggregateFunc case class Sum(column: String, dataType: DataType, isDistinct: Boolean) extends AggregateFunc @@ -36,5 +34,5 @@ case class Count(column: String, dataType: DataType, isDistinct: Boolean) extend object Aggregation { // Returns an empty Aggregate - def empty: Aggregation = Aggregation(Seq.empty[AggregateFunc], Seq.empty[String]) + def empty: Aggregation = Aggregation(Seq.empty[Seq[AggregateFunc]], Seq.empty[String]) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkOptimizer.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkOptimizer.scala index f2aa5365cc7ce..9c7a05c36a8f7 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkOptimizer.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkOptimizer.scala @@ -23,7 +23,6 @@ import org.apache.spark.sql.catalyst.optimizer._ import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan import org.apache.spark.sql.catalyst.rules.Rule import org.apache.spark.sql.connector.catalog.CatalogManager -import org.apache.spark.sql.execution.datasources.PartialAggregatePushDown import org.apache.spark.sql.execution.datasources.PruneFileSourcePartitions import org.apache.spark.sql.execution.datasources.SchemaPruning import org.apache.spark.sql.execution.datasources.v2.{V2ScanRelationPushDown, V2Writes} @@ -38,7 +37,7 @@ class SparkOptimizer( override def earlyScanPushDownRules: Seq[Rule[LogicalPlan]] = // TODO: move SchemaPruning into catalyst - SchemaPruning :: PartialAggregatePushDown :: V2ScanRelationPushDown :: V2Writes :: + SchemaPruning :: V2ScanRelationPushDown :: V2Writes :: PruneFileSourcePartitions :: Nil override def defaultBatches: Seq[Batch] = (preOptimizationBatches ++ super.defaultBatches :+ diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala index c72f64e0b3d35..954ef7b1fc704 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala @@ -679,19 +679,19 @@ object DataSourceStrategy protected[sql] def translateAggregate( aggregates: AggregateExpression, - pushableColumn: PushableColumnBase): Option[AggregateFunc] = { + pushableColumn: PushableColumnBase): Option[Seq[AggregateFunc]] = { aggregates.aggregateFunction match { case min@aggregate.Min(pushableColumn(name)) => - Some(Min(name, min.dataType)) + Some(Seq(Min(name, min.dataType))) case max@aggregate.Max(pushableColumn(name)) => - Some(Max(name, max.dataType)) + Some(Seq(Max(name, max.dataType))) case count: aggregate.Count => val columnName = count.children.head match { // SELECT COUNT(*) FROM table is translated to SELECT 1 FROM table case Literal(_, _) => "1" case pushableColumn(name) => name } - Some(Count(columnName, count.dataType, aggregates.isDistinct)) + Some(Seq(Count(columnName, count.dataType, aggregates.isDistinct))) case _ => None } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/PartialAggregatePushDown.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/PartialAggregatePushDown.scala deleted file mode 100644 index e92632b99ea58..0000000000000 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/PartialAggregatePushDown.scala +++ /dev/null @@ -1,148 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql.execution.datasources - -import org.apache.spark.sql.catalyst.expressions._ -import org.apache.spark.sql.catalyst.expressions.aggregate.AggregateExpression -import org.apache.spark.sql.catalyst.planning.ScanOperation -import org.apache.spark.sql.catalyst.plans.logical.{Aggregate, LogicalPlan} -import org.apache.spark.sql.catalyst.rules.Rule -import org.apache.spark.sql.catalyst.util.toPrettySQL -import org.apache.spark.sql.connector.read.{Scan, V1Scan} -import org.apache.spark.sql.execution.datasources.v2.{DataSourceV2Implicits, DataSourceV2Relation, DataSourceV2ScanRelation} -import org.apache.spark.sql.execution.datasources.v2.parquet.ParquetTable -import org.apache.spark.sql.sources -import org.apache.spark.sql.sources.Aggregation -import org.apache.spark.sql.types.StructType - -/** - * Push down partial Aggregate to datasource for better performance - */ -object PartialAggregatePushDown extends Rule[LogicalPlan] with AliasHelper { - import DataSourceV2Implicits._ - - def apply(plan: LogicalPlan): LogicalPlan = plan transform { - // Pattern matching for partial aggregate push down - // For parquet, footer only has statistics information for max/min/count. - // It doesn't handle max/min/count associated with filter or group by. - // ORC is similar. If JDBC partial aggregate push down is added later, - // these condition checks need to be changed. - case aggNode@Aggregate(groupingExpressions, resultExpressions, child) - if (groupingExpressions.isEmpty) => - child match { - case ScanOperation(project, filters, relation@DataSourceV2Relation(table, _, _, _, _)) - if (filters.isEmpty) && table.isInstanceOf[ParquetTable] => - val scanBuilder = relation.table.asReadable.newScanBuilder(relation.options) - var aggregates = resultExpressions.flatMap { expr => - expr.collect { - case agg: AggregateExpression => - replaceAlias(agg, getAliasMap(project)).asInstanceOf[AggregateExpression] - } - } - aggregates = DataSourceStrategy.normalizeExprs(aggregates, relation.output) - .asInstanceOf[Seq[AggregateExpression]] - - val scan = scanBuilder.build() - val translatedAggregates = aggregates.map(DataSourceStrategy - .translateAggregate(_, PushableColumn(false))) - val aggregation = Aggregation(translatedAggregates.flatten, Seq.empty) - scan.pushAggregation(aggregation) - if (scan.pushedAggregation().aggregateExpressions.isEmpty) { - aggNode // return original plan node - } else { - // use the aggregate columns as the output columns - // e.g. TABLE t (c1 INT, c2 INT, c3 INT) - // SELECT min(c1), max(c1) FROM t; - // Use min(c1), max(c1) as output for DataSourceV2ScanRelation - // We want to have the following logical plan: - // == Optimized Logical Plan == - // Aggregate [min(min(c1)#21) AS min(c1)#17, max(max(c1)#22) AS max(c1)#18] - // +- RelationV2[min(c1)#21, max(c1)#22] parquet file ... - val output = aggregates.map { - case agg: AggregateExpression => - AttributeReference(toPrettySQL(agg), agg.dataType)() - } - - // No need to do column pruning because only the aggregate columns are used as - // DataSourceV2ScanRelation output columns. All the other columns are not - // included in the output. Since PushDownUtils.pruneColumns is not called, - // ScanBuilder.requiredSchema is not pruned, but ScanBuilder.requiredSchema is - // not used anyways. The schema for aggregate columns will be built in Scan. - - - logInfo( - s""" - |Pushing operators to ${relation.name} - |Pushed Aggregate Functions: ${aggregation.aggregateExpressions.mkString(", ")} - |Output: ${output.mkString(", ")} - """.stripMargin) - val wrappedScan = scan match { - case v1: V1Scan => - V1ScanWrapper(v1, Seq.empty[sources.Filter], Seq.empty[sources.Filter], aggregation) - case _ => scan - } - - val scanRelation = DataSourceV2ScanRelation(relation, wrappedScan, output) - val plan = Aggregate(groupingExpressions, resultExpressions, scanRelation) - - // Change the optimized logical plan to reflect the pushed down aggregate - // e.g. TABLE t (c1 INT, c2 INT, c3 INT) - // SELECT min(c1), max(c1) FROM t; - // The original logical plan is - // Aggregate [min(c1#9) AS min(c1)#17, max(c1#9) AS max(c1)#18] - // +- RelationV2[c1#9] parquet ... - // - // After change the V2ScanRelation output to [min(_1)#21, max(_1)#22] - // we have the following - // !Aggregate [min(_1#9) AS min(_1)#17, max(_1#9) AS max(_1)#18] - // +- RelationV2[min(_1)#21, max(_1)#22] parquet ... - // - // We want to change it to - // == Optimized Logical Plan == - // Aggregate [min(min(c1)#21) AS min(c1)#17, max(max(c1)#22) AS max(c1)#18] - // +- RelationV2[min(c1)#21, max(c1)#22] parquet file ... - var i = 0 - plan.transformExpressions { - case agg: AggregateExpression => - i += 1 - val aggFunction: aggregate.AggregateFunction = agg.aggregateFunction match { - case _: aggregate.Max => aggregate.Max(output(i - 1)) - case _: aggregate.Min => aggregate.Min(output(i - 1)) - case _: aggregate.Sum => aggregate.Sum(output(i - 1)) - case _: aggregate.Average => aggregate.Average(output(i - 1)) - case _: aggregate.Count => aggregate.Sum(output(i - 1)) - case _ => agg.aggregateFunction - } - agg.copy(aggregateFunction = aggFunction, filter = None) - } - } - - case _ => aggNode // return original plan node - } - } -} - -// A wrapper for v1 scan to carry the translated filters and the handled ones. This is required by -// the physical v1 scan node. -case class V1ScanWrapper( - v1Scan: V1Scan, - translatedFilters: Seq[sources.Filter], - handledFilters: Seq[sources.Filter], - pushedAggregates: sources.Aggregation) extends Scan { - override def readSchema(): StructType = v1Scan.readSchema() -} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetUtils.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetUtils.scala index 73a514432f06a..c2951f9af0e6f 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetUtils.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetUtils.scala @@ -357,7 +357,7 @@ object ParquetUtils { blocks.forEach { block => val blockMetaData = block.getColumns() aggregation.aggregateExpressions(i) match { - case Max(col, _) => + case Seq(Max(col, _)) => index = dataSchema.fieldNames.toList.indexOf(col) val currentMax = getCurrentBlockMaxOrMin(footer, blockMetaData, index, true) if (currentMax != None && @@ -365,7 +365,7 @@ object ParquetUtils { value = currentMax } - case Min(col, _) => + case Seq(Min(col, _)) => index = dataSchema.fieldNames.toList.indexOf(col) val currentMin = getCurrentBlockMaxOrMin(footer, blockMetaData, index, false) if (currentMin != None && @@ -373,7 +373,7 @@ object ParquetUtils { value = currentMin } - case Count(col, _, _) => + case Seq(Count(col, _, _)) => index = dataSchema.fieldNames.toList.indexOf(col) rowCount += block.getRowCount if (!col.equals("1")) { // "1" is for count(*) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/PushDownUtils.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/PushDownUtils.scala index 1f57f17911457..f778f8e6e2295 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/PushDownUtils.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/PushDownUtils.scala @@ -20,11 +20,13 @@ package org.apache.spark.sql.execution.datasources.v2 import scala.collection.mutable import org.apache.spark.sql.catalyst.expressions.{AttributeReference, AttributeSet, Expression, NamedExpression, PredicateHelper, SchemaPruning} +import org.apache.spark.sql.catalyst.expressions.aggregate.AggregateExpression import org.apache.spark.sql.catalyst.util.CharVarcharUtils -import org.apache.spark.sql.connector.read.{Scan, ScanBuilder, SupportsPushDownFilters, SupportsPushDownRequiredColumns} -import org.apache.spark.sql.execution.datasources.DataSourceStrategy +import org.apache.spark.sql.connector.read.{Scan, ScanBuilder, SupportsPushDownAggregates, SupportsPushDownFilters, SupportsPushDownRequiredColumns} +import org.apache.spark.sql.execution.datasources.{DataSourceStrategy, PushableColumn} import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.sources +import org.apache.spark.sql.sources.Aggregation import org.apache.spark.sql.types.StructType object PushDownUtils extends PredicateHelper { @@ -70,6 +72,37 @@ object PushDownUtils extends PredicateHelper { } } + /** + * Pushes down aggregates to the data source reader + * + * @return pushed aggregation. + */ + def pushAggregates( + scanBuilder: ScanBuilder, + aggregates: Seq[AggregateExpression], + groupBy: Seq[Expression]): Aggregation = { + + def columnAsString(e: Expression): String = e match { + case AttributeReference(name, _, _, _) => name + case _ => "" + } + + scanBuilder match { + case r: SupportsPushDownAggregates => + val translatedAggregates = aggregates.map(DataSourceStrategy + .translateAggregate(_, PushableColumn(false))) + val translatedGroupBys = groupBy.map(columnAsString) + + if (translatedAggregates.exists(_.isEmpty) || translatedGroupBys.exists(_.isEmpty)) { + Aggregation.empty + } else { + r.pushAggregation(Aggregation(translatedAggregates.flatten, translatedGroupBys)) + r.pushedAggregation + } + case _ => Aggregation.empty + } + } + /** * Applies column pruning to the data source, w.r.t. the references of the given expressions. * diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/V2ScanRelationPushDown.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/V2ScanRelationPushDown.scala index 673a1dd462337..14bfd321c923b 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/V2ScanRelationPushDown.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/V2ScanRelationPushDown.scala @@ -18,10 +18,12 @@ package org.apache.spark.sql.execution.datasources.v2 import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.expressions.aggregate.AggregateExpression import org.apache.spark.sql.catalyst.planning.ScanOperation -import org.apache.spark.sql.catalyst.plans.logical.{Filter, LogicalPlan, Project} +import org.apache.spark.sql.catalyst.plans.logical.{Aggregate, Filter, LogicalPlan, Project} import org.apache.spark.sql.catalyst.rules.Rule -import org.apache.spark.sql.connector.read.{Scan, V1Scan} +import org.apache.spark.sql.catalyst.util.toPrettySQL +import org.apache.spark.sql.connector.read.{Scan, SupportsPushDownAggregates, V1Scan} import org.apache.spark.sql.execution.datasources.DataSourceStrategy import org.apache.spark.sql.sources import org.apache.spark.sql.sources.{AggregateFunc, Aggregation} @@ -31,6 +33,113 @@ object V2ScanRelationPushDown extends Rule[LogicalPlan] with AliasHelper { import DataSourceV2Implicits._ override def apply(plan: LogicalPlan): LogicalPlan = plan transformDown { + case aggNode@Aggregate(groupingExpressions, resultExpressions, child) => + child match { + case ScanOperation(project, filters, relation: DataSourceV2Relation) => + val scanBuilder = relation.table.asReadable.newScanBuilder(relation.options) + scanBuilder match { + case r: SupportsPushDownAggregates if r.supportsPushDownAggregateWithFilter => + // todo: need to change this for JDBC aggregate push down + // for (pushDownOrder <- ScanBuilder.PUSH_DOWN_ORDERS) { + // if (pushDownOrder == orders.FILTER) { + // pushdown filter + // } else if(pushDownOrder == orders.AGGREGATE) { + // pushdown aggregate + // } + aggNode + case r: SupportsPushDownAggregates => + if (filters.isEmpty) { + if (r.supportsGlobalAggregatePushDownOnly() && groupingExpressions.nonEmpty) { + aggNode // return original plan node + } else { + var aggregates = resultExpressions.flatMap { expr => + expr.collect { + case agg: AggregateExpression => + replaceAlias(agg, getAliasMap(project)).asInstanceOf[AggregateExpression] + } + } + aggregates = DataSourceStrategy.normalizeExprs(aggregates, relation.output) + .asInstanceOf[Seq[AggregateExpression]] + val aggregation = PushDownUtils + .pushAggregates(scanBuilder, aggregates, groupingExpressions) + if (aggregation.aggregateExpressions.isEmpty) { + aggNode // return original plan node + } else { + // use the aggregate columns as the output columns + // e.g. TABLE t (c1 INT, c2 INT, c3 INT) + // SELECT min(c1), max(c1) FROM t; + // Use min(c1), max(c1) as output for DataSourceV2ScanRelation + // We want to have the following logical plan: + // == Optimized Logical Plan == + // Aggregate [min(min(c1)#21) AS min(c1)#17, max(max(c1)#22) AS max(c1)#18] + // +- RelationV2[min(c1)#21, max(c1)#22] parquet file ... + val output = aggregates.map { + case agg: AggregateExpression => + AttributeReference(toPrettySQL(agg), agg.dataType)() + } + + // No need to do column pruning because only the aggregate columns are used as + // DataSourceV2ScanRelation output columns. All the other columns are not + // included in the output. Since PushDownUtils.pruneColumns is not called, + // ScanBuilder.requiredSchema is not pruned, but ScanBuilder.requiredSchema is + // not used anyways. The schema for aggregate columns will be built in Scan. + val scan = scanBuilder.build() + + logInfo( + s""" + |Pushing operators to ${relation.name} + |Pushed Aggregates: ${aggregation.aggregateExpressions.mkString(", ")} + |Output: ${output.mkString(", ")} + """.stripMargin) + val wrappedScan = scan match { + case v1: V1Scan => + V1ScanWrapper(v1, Seq.empty[sources.Filter], Seq.empty[sources.Filter], + aggregation) + case _ => scan + } + + val scanRelation = DataSourceV2ScanRelation(relation, wrappedScan, output) + val plan = Aggregate(groupingExpressions, resultExpressions, scanRelation) + + // Change the optimized logical plan to reflect the pushed down aggregate + // e.g. TABLE t (c1 INT, c2 INT, c3 INT) + // SELECT min(c1), max(c1) FROM t; + // The original logical plan is + // Aggregate [min(c1#9) AS min(c1)#17, max(c1#9) AS max(c1)#18] + // +- RelationV2[c1#9] parquet ... + // + // After change the V2ScanRelation output to [min(_1)#21, max(_1)#22] + // we have the following + // !Aggregate [min(_1#9) AS min(_1)#17, max(_1#9) AS max(_1)#18] + // +- RelationV2[min(_1)#21, max(_1)#22] parquet ... + // + // We want to change it to + // == Optimized Logical Plan == + // Aggregate [min(min(c1)#21) AS min(c1)#17, max(max(c1)#22) AS max(c1)#18] + // +- RelationV2[min(c1)#21, max(c1)#22] parquet file ... + var i = 0 + plan.transformExpressions { + case agg: AggregateExpression => + i += 1 + val aggFunction: aggregate.AggregateFunction = + agg.aggregateFunction match { + case _: aggregate.Max => aggregate.Max(output(i - 1)) + case _: aggregate.Min => aggregate.Min(output(i - 1)) + case _: aggregate.Sum => aggregate.Sum(output(i - 1)) + case _: aggregate.Count => aggregate.Sum(output(i - 1)) + case _ => agg.aggregateFunction + } + agg.copy(aggregateFunction = aggFunction, filter = None) + } + } + } + } else { + aggNode + } + } + + case _ => aggNode // return original plan node + } case ScanOperation(project, filters, relation: DataSourceV2Relation) => val scanBuilder = relation.table.asReadable.newScanBuilder(relation.options) @@ -62,7 +171,7 @@ object V2ScanRelationPushDown extends Rule[LogicalPlan] with AliasHelper { case v1: V1Scan => val translated = filters.flatMap(DataSourceStrategy.translateFilter(_, true)) V1ScanWrapper(v1, translated, pushedFilters, - Aggregation(Seq.empty[AggregateFunc], Seq.empty[String])) + Aggregation(Seq.empty[Seq[AggregateFunc]], Seq.empty[String])) case _ => scan } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/parquet/ParquetPartitionReaderFactory.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/parquet/ParquetPartitionReaderFactory.scala index bc1ac8e985835..caf259b69b377 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/parquet/ParquetPartitionReaderFactory.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/parquet/ParquetPartitionReaderFactory.scala @@ -40,8 +40,8 @@ import org.apache.spark.sql.execution.datasources.parquet._ import org.apache.spark.sql.execution.datasources.v2._ import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.internal.SQLConf.LegacyBehaviorPolicy -import org.apache.spark.sql.sources.{Aggregation, Count, Filter, Max, Min} -import org.apache.spark.sql.types.{AtomicType, LongType, StructField, StructType} +import org.apache.spark.sql.sources.{Aggregation, Filter} +import org.apache.spark.sql.types.{AtomicType, StructType} import org.apache.spark.sql.vectorized.ColumnarBatch import org.apache.spark.util.SerializableConfiguration @@ -53,6 +53,7 @@ import org.apache.spark.util.SerializableConfiguration * @param dataSchema Schema of Parquet files. * @param readDataSchema Required schema of Parquet files. * @param partitionSchema Schema of partitions. + * @aggSchema Schema of the pushed down aggregation. * @param filters Filters to be pushed down in the batch scan. * @param aggregation Aggregation to be pushed down in the batch scan. * @param parquetOptions The options of Parquet datasource that are set for the read. @@ -63,6 +64,7 @@ case class ParquetPartitionReaderFactory( dataSchema: StructType, readDataSchema: StructType, partitionSchema: StructType, + aggSchema: StructType, filters: Array[Filter], aggregation: Aggregation, parquetOptions: ParquetOptions) extends FilePartitionReaderFactory with Logging { @@ -87,25 +89,6 @@ case class ParquetPartitionReaderFactory( private val pushDownInFilterThreshold = sqlConf.parquetFilterPushDownInFilterThreshold private val datetimeRebaseModeInRead = parquetOptions.datetimeRebaseModeInRead private val int96RebaseModeInRead = parquetOptions.int96RebaseModeInRead - private lazy val aggSchema = { - var schema = new StructType() - aggregation.aggregateExpressions.map { - case Max(col, _) => - val field = dataSchema.fields(dataSchema.fieldNames.toList.indexOf(col)) - schema = schema.add(field.copy("max(" + field.name + ")")) - case Min(col, _) => - val field = dataSchema.fields(dataSchema.fieldNames.toList.indexOf(col)) - schema = schema.add(field.copy("min(" + field.name + ")")) - case Count(col, _, _) => - if (col.equals("1")) { - schema = schema.add(new StructField("count(*)", LongType)) - } else { - schema = schema.add(new StructField("count(" + col + ")", LongType)) - } - case _ => - } - schema - } private def getFooter(file: PartitionedFile): ParquetMetadata = { val conf = broadcastedConf.value.value @@ -157,13 +140,10 @@ case class ParquetPartitionReaderFactory( new PartitionReader[InternalRow] { var count = 0 - override def next(): Boolean = { - val hasNext = if (count == 0) true else false - count += 1 - hasNext - } + override def next(): Boolean = if (count == 0) true else false override def get(): InternalRow = { + count += 1 val footer = getFooter(file) val (parquetTypes, values) = ParquetUtils.getPushedDownAggResult(footer, dataSchema, aggregation) @@ -196,13 +176,10 @@ case class ParquetPartitionReaderFactory( new PartitionReader[ColumnarBatch] { var count = 0 - override def next(): Boolean = { - val hasNext = if (count == 0) true else false - count += 1 - hasNext - } + override def next(): Boolean = if (count == 0) true else false override def get(): ColumnarBatch = { + count += 1 val footer = getFooter(file) val (parquetTypes, values) = ParquetUtils.getPushedDownAggResult(footer, dataSchema, aggregation) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/parquet/ParquetScan.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/parquet/ParquetScan.scala index 38259adb07162..2b0bfb831f349 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/parquet/ParquetScan.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/parquet/ParquetScan.scala @@ -29,7 +29,7 @@ import org.apache.spark.sql.execution.datasources.PartitioningAwareFileIndex import org.apache.spark.sql.execution.datasources.parquet.{ParquetOptions, ParquetReadSupport, ParquetWriteSupport} import org.apache.spark.sql.execution.datasources.v2.FileScan import org.apache.spark.sql.internal.SQLConf -import org.apache.spark.sql.sources.{Aggregation, Count, Filter, Max, Min} +import org.apache.spark.sql.sources.{Aggregation, Filter} import org.apache.spark.sql.types.StructType import org.apache.spark.sql.util.CaseInsensitiveStringMap import org.apache.spark.util.SerializableConfiguration @@ -42,11 +42,12 @@ case class ParquetScan( readDataSchema: StructType, readPartitionSchema: StructType, pushedFilters: Array[Filter], + pushedAggregations: Aggregation = Aggregation.empty, + pushedDownAggSchema: StructType, options: CaseInsensitiveStringMap, partitionFilters: Seq[Expression] = Seq.empty, dataFilters: Seq[Expression] = Seq.empty) extends FileScan { override def isSplitable(path: Path): Boolean = true - private var pushedAggregations = Aggregation.empty override def createReaderFactory(): PartitionReaderFactory = { val readDataSchemaAsJson = readDataSchema.json @@ -86,6 +87,7 @@ case class ParquetScan( dataSchema, readDataSchema, readPartitionSchema, + pushedDownAggSchema, pushedFilters, pushedAggregations, new ParquetOptions(options.asCaseSensitiveMap.asScala.toMap, sqlConf)) @@ -120,26 +122,4 @@ case class ParquetScan( a.aggregateExpressions.sortBy(_.hashCode()) .sameElements(b.aggregateExpressions.sortBy(_.hashCode())) } - - override def pushAggregation(aggregation: Aggregation): Unit = { - if (!sparkSession.sessionState.conf.parquetAggregatePushDown || - aggregation.groupByColumns.nonEmpty) { - Aggregation.empty - return - } - - aggregation.aggregateExpressions.foreach { agg => - if (!agg.isInstanceOf[Max] && !agg.isInstanceOf[Min] && !agg.isInstanceOf[Count]) { - Aggregation.empty - return - } else if (agg.isInstanceOf[Count] && agg.asInstanceOf[Count].isDistinct) { - // parquet's statistics doesn't have distinct count info - Aggregation.empty - return - } - } - pushedAggregations = aggregation - } - - override def pushedAggregation(): Aggregation = pushedAggregations } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/parquet/ParquetScanBuilder.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/parquet/ParquetScanBuilder.scala index 44053830defe5..ea598f5e942d9 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/parquet/ParquetScanBuilder.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/parquet/ParquetScanBuilder.scala @@ -20,12 +20,12 @@ package org.apache.spark.sql.execution.datasources.v2.parquet import scala.collection.JavaConverters._ import org.apache.spark.sql.SparkSession -import org.apache.spark.sql.connector.read.{Scan, SupportsPushDownFilters} +import org.apache.spark.sql.connector.read.{Scan, SupportsPushDownAggregates, SupportsPushDownFilters} import org.apache.spark.sql.execution.datasources.PartitioningAwareFileIndex import org.apache.spark.sql.execution.datasources.parquet.{ParquetFilters, SparkToParquetSchemaConverter} import org.apache.spark.sql.execution.datasources.v2.FileScanBuilder -import org.apache.spark.sql.sources.Filter -import org.apache.spark.sql.types.StructType +import org.apache.spark.sql.sources.{Aggregation, Count, Filter, Max, Min} +import org.apache.spark.sql.types.{LongType, StructField, StructType} import org.apache.spark.sql.util.CaseInsensitiveStringMap case class ParquetScanBuilder( @@ -34,7 +34,8 @@ case class ParquetScanBuilder( schema: StructType, dataSchema: StructType, options: CaseInsensitiveStringMap) - extends FileScanBuilder(sparkSession, fileIndex, dataSchema) with SupportsPushDownFilters { + extends FileScanBuilder(sparkSession, fileIndex, dataSchema) with SupportsPushDownFilters + with SupportsPushDownAggregates{ lazy val hadoopConf = { val caseSensitiveMap = options.asCaseSensitiveMap.asScala.toMap // Hadoop Configurations are case sensitive. @@ -70,8 +71,53 @@ case class ParquetScanBuilder( // All filters that can be converted to Parquet are pushed down. override def pushedFilters(): Array[Filter] = pushedParquetFilters + private var pushedAggregations = Aggregation.empty + + override def pushAggregation(aggregation: Aggregation): Unit = { + if (!sparkSession.sessionState.conf.parquetAggregatePushDown || + aggregation.groupByColumns.nonEmpty) { + Aggregation.empty + return + } + + aggregation.aggregateExpressions.foreach { _ match { + // parquet's statistics doesn't have distinct count info + case Seq(Max(_, _)) | Seq(Min(_, _)) | Seq(Count(_, _, false)) => + case _ => Aggregation.empty + } + } + this.pushedAggregations = aggregation + } + + override def pushedAggregation(): Aggregation = pushedAggregations + + override def supportsGlobalAggregatePushDownOnly(): Boolean = true + + override def supportsPushDownAggregateWithFilter(): Boolean = false + + override def getPushDownAggSchema: StructType = { + var schema = new StructType() + pushedAggregations.aggregateExpressions.map { + case Seq(Max(col, _)) => + val field = dataSchema.fields(dataSchema.fieldNames.toList.indexOf(col)) + schema = schema.add(field.copy("max(" + field.name + ")")) + case Seq(Min(col, _)) => + val field = dataSchema.fields(dataSchema.fieldNames.toList.indexOf(col)) + schema = schema.add(field.copy("min(" + field.name + ")")) + case Seq(Count(col, _, _)) => + if (col.equals("1")) { + schema = schema.add(new StructField("count(*)", LongType)) + } else { + schema = schema.add(new StructField("count(" + col + ")", LongType)) + } + case _ => + } + schema + } + override def build(): Scan = { ParquetScan(sparkSession, hadoopConf, fileIndex, dataSchema, readDataSchema(), - readPartitionSchema(), pushedParquetFilters, options) + readPartitionSchema(), pushedParquetFilters, pushedAggregations, getPushDownAggSchema, + options) } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/FileScanSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/FileScanSuite.scala index 4e7fe8455ff93..19f47f18d4ff5 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/FileScanSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/FileScanSuite.scala @@ -31,7 +31,7 @@ import org.apache.spark.sql.execution.datasources.v2.json.JsonScan import org.apache.spark.sql.execution.datasources.v2.orc.OrcScan import org.apache.spark.sql.execution.datasources.v2.parquet.ParquetScan import org.apache.spark.sql.execution.datasources.v2.text.TextScan -import org.apache.spark.sql.sources.Filter +import org.apache.spark.sql.sources.{Aggregation, Filter} import org.apache.spark.sql.test.SharedSparkSession import org.apache.spark.sql.types.StructType import org.apache.spark.sql.util.CaseInsensitiveStringMap @@ -354,7 +354,8 @@ class FileScanSuite extends FileScanSuiteBase { val scanBuilders = Seq[(String, ScanBuilder, Seq[String])]( ("ParquetScan", (s, fi, ds, rds, rps, f, o, pf, df) => - ParquetScan(s, s.sessionState.newHadoopConf(), fi, ds, rds, rps, f, o, pf, df), + ParquetScan(s, s.sessionState.newHadoopConf(), fi, ds, rds, rps, f, Aggregation.empty, + null, o, pf, df), Seq.empty), ("OrcScan", (s, fi, ds, rds, rps, f, o, pf, df) => diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetQuerySuite.scala index 2a67162c86ca4..277eeda5f0ce0 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetQuerySuite.scala @@ -925,16 +925,16 @@ abstract class ParquetQuerySuite extends QueryTest with ParquetTest with SharedS selectAgg3.queryExecution.optimizedPlan.collect { case _: DataSourceV2ScanRelation => val expected_plan_fragment = - "PushedAggregation: [Min(_3,IntegerType), " + - "Min(_3,IntegerType), " + - "Max(_3,IntegerType), " + - "Min(_1,IntegerType), " + - "Max(_1,IntegerType), " + - "Max(_1,IntegerType), " + - "Count(1,LongType,false), " + - "Count(_1,LongType,false), " + - "Count(_2,LongType,false), " + - "Count(_3,LongType,false)]" + "PushedAggregation: [List(Min(_3,IntegerType)), " + + "List(Min(_3,IntegerType)), " + + "List(Max(_3,IntegerType)), " + + "List(Min(_1,IntegerType)), " + + "List(Max(_1,IntegerType)), " + + "List(Max(_1,IntegerType)), " + + "List(Count(1,LongType,false)), " + + "List(Count(_1,LongType,false)), " + + "List(Count(_2,LongType,false)), " + + "List(Count(_3,LongType,false))]" checkKeywordsExistsInExplain(selectAgg3, expected_plan_fragment) } @@ -1030,17 +1030,17 @@ abstract class ParquetQuerySuite extends QueryTest with ParquetTest with SharedS testMin.queryExecution.optimizedPlan.collect { case _: DataSourceV2ScanRelation => val expected_plan_fragment = - "PushedAggregation: [Min(StringCol,StringType), " + - "Min(BooleanCol,BooleanType), " + - "Min(ByteCol,ByteType), " + - "Min(BinaryCol,BinaryType), " + - "Min(ShortCol,ShortType), " + - "Min(IntegerCol,IntegerType), " + - "Min(LongCol,LongType), " + - "Min(FloatCol,FloatType), " + - "Min(DoubleCol,DoubleType), " + - "Min(DecimalCol,DecimalType(25,5)), " + - "Min(DateCol,DateType)]" + "PushedAggregation: [List(Min(StringCol,StringType)), " + + "List(Min(BooleanCol,BooleanType)), " + + "List(Min(ByteCol,ByteType)), " + + "List(Min(BinaryCol,BinaryType)), " + + "List(Min(ShortCol,ShortType)), " + + "List(Min(IntegerCol,IntegerType)), " + + "List(Min(LongCol,LongType)), " + + "List(Min(FloatCol,FloatType)), " + + "List(Min(DoubleCol,DoubleType)), " + + "List(Min(DecimalCol,DecimalType(25,5))), " + + "List(Min(DateCol,DateType))]" checkKeywordsExistsInExplain(testMin, expected_plan_fragment) } @@ -1055,17 +1055,17 @@ abstract class ParquetQuerySuite extends QueryTest with ParquetTest with SharedS testMax.queryExecution.optimizedPlan.collect { case _: DataSourceV2ScanRelation => val expected_plan_fragment = - "PushedAggregation: [Max(StringCol,StringType), " + - "Max(BooleanCol,BooleanType)," + - " Max(ByteCol,ByteType), " + - "Max(BinaryCol,BinaryType)," + - " Max(ShortCol,ShortType), " + - "Max(IntegerCol,IntegerType)," + - " Max(LongCol,LongType), " + - "Max(FloatCol,FloatType)," + - " Max(DoubleCol,DoubleType), " + - "Max(DecimalCol,DecimalType(25,5)), " + - "Max(DateCol,DateType)]" + "PushedAggregation: [List(Max(StringCol,StringType)), " + + "List(Max(BooleanCol,BooleanType)), " + + "List(Max(ByteCol,ByteType)), " + + "List(Max(BinaryCol,BinaryType)), " + + "List(Max(ShortCol,ShortType)), " + + "List(Max(IntegerCol,IntegerType)), " + + "List(Max(LongCol,LongType)), " + + "List(Max(FloatCol,FloatType)), " + + "List(Max(DoubleCol,DoubleType)), " + + "List(Max(DecimalCol,DecimalType(25,5))), " + + "List(Max(DateCol,DateType))]" checkKeywordsExistsInExplain(testMax, expected_plan_fragment) } @@ -1081,19 +1081,19 @@ abstract class ParquetQuerySuite extends QueryTest with ParquetTest with SharedS testCount.queryExecution.optimizedPlan.collect { case _: DataSourceV2ScanRelation => val expected_plan_fragment = - "PushedAggregation: [Count(1,LongType,false), " + - "Count(StringCol,LongType,false), " + - "Count(BooleanCol,LongType,false), " + - "Count(ByteCol,LongType,false), " + - "Count(BinaryCol,LongType,false), " + - "Count(ShortCol,LongType,false), " + - "Count(IntegerCol,LongType,false), " + - "Count(LongCol,LongType,false), " + - "Count(FloatCol,LongType,false), " + - "Count(DoubleCol,LongType,false), " + - "Count(DecimalCol,LongType,false), " + - "Count(DateCol,LongType,false), " + - "Count(TimestampCol,LongType,false)]" + "PushedAggregation: [List(Count(1,LongType,false)), " + + "List(Count(StringCol,LongType,false)), " + + "List(Count(BooleanCol,LongType,false)), " + + "List(Count(ByteCol,LongType,false)), " + + "List(Count(BinaryCol,LongType,false)), " + + "List(Count(ShortCol,LongType,false)), " + + "List(Count(IntegerCol,LongType,false)), " + + "List(Count(LongCol,LongType,false)), " + + "List(Count(FloatCol,LongType,false)), " + + "List(Count(DoubleCol,LongType,false)), " + + "List(Count(DecimalCol,LongType,false)), " + + "List(Count(DateCol,LongType,false)), " + + "List(Count(TimestampCol,LongType,false))]" checkKeywordsExistsInExplain(testCount, expected_plan_fragment) } From ed7c19d38ec2d23b18334920bc227a11bb99ba66 Mon Sep 17 00:00:00 2001 From: Huaxin Gao Date: Sat, 15 May 2021 14:33:19 -0700 Subject: [PATCH 16/30] fix java lint error --- .../src/main/java/org/apache/spark/sql/connector/read/Scan.java | 1 - .../java/org/apache/spark/sql/connector/read/ScanBuilder.java | 2 +- 2 files changed, 1 insertion(+), 2 deletions(-) diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/read/Scan.java b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/read/Scan.java index 71e36bb2dd915..b70a656c492a8 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/read/Scan.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/read/Scan.java @@ -24,7 +24,6 @@ import org.apache.spark.sql.types.StructType; import org.apache.spark.sql.connector.catalog.Table; import org.apache.spark.sql.connector.catalog.TableCapability; -import org.apache.spark.sql.sources.Aggregation; /** * A logical representation of a data source scan. This interface is used to provide logical diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/read/ScanBuilder.java b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/read/ScanBuilder.java index c7655d77e6f4d..40c4cbb90f2a2 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/read/ScanBuilder.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/read/ScanBuilder.java @@ -32,7 +32,7 @@ enum orders { FILTER, AGGREGATE, COLUMNS }; // Orders of operators push down. Spark will push down filters first, then aggregates, and finally // column pruning (if applicable). - static orders[] PUSH_DOWN_ORDERS = {orders.FILTER, orders.AGGREGATE, orders.COLUMNS}; + orders[] PUSH_DOWN_ORDERS = {orders.FILTER, orders.AGGREGATE, orders.COLUMNS}; Scan build(); } From 52b0b9822636d14da46727ec323a0df4f364c233 Mon Sep 17 00:00:00 2001 From: Huaxin Gao Date: Sat, 15 May 2021 15:15:29 -0700 Subject: [PATCH 17/30] add default case --- .../sql/execution/datasources/v2/V2ScanRelationPushDown.scala | 1 + 1 file changed, 1 insertion(+) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/V2ScanRelationPushDown.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/V2ScanRelationPushDown.scala index 14bfd321c923b..ebfa6d080f6a1 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/V2ScanRelationPushDown.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/V2ScanRelationPushDown.scala @@ -136,6 +136,7 @@ object V2ScanRelationPushDown extends Rule[LogicalPlan] with AliasHelper { } else { aggNode } + case _ => aggNode } case _ => aggNode // return original plan node From 6df2ae1f4093e23c346270d0ae60205f6d46f1fa Mon Sep 17 00:00:00 2001 From: Huaxin Gao Date: Thu, 27 May 2021 23:36:52 -0700 Subject: [PATCH 18/30] address comments --- .../spark/sql/connector/read/ScanBuilder.java | 10 +- .../read/SupportsPushDownAggregates.java | 2 +- .../sql/catalyst/planning/patterns.scala | 16 +- .../expressions}/aggregates.scala | 18 +- .../sql/execution/DataSourceScanExec.scala | 13 +- .../datasources/DataSourceStrategy.scala | 16 +- .../datasources/parquet/ParquetUtils.scala | 93 +++-- .../datasources/v2/DataSourceV2Strategy.scala | 1 - .../datasources/v2/PushDownUtils.scala | 10 +- .../v2/V2ScanRelationPushDown.scala | 322 +++++++++++------- .../ParquetPartitionReaderFactory.scala | 11 +- .../datasources/v2/parquet/ParquetScan.scala | 3 +- .../v2/parquet/ParquetScanBuilder.scala | 43 ++- .../org/apache/spark/sql/FileScanSuite.scala | 3 +- .../parquet/ParquetQuerySuite.scala | 107 +++--- 15 files changed, 377 insertions(+), 291 deletions(-) rename sql/catalyst/src/main/scala/org/apache/spark/sql/{sources => connector/expressions}/aggregates.scala (64%) diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/read/ScanBuilder.java b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/read/ScanBuilder.java index 40c4cbb90f2a2..565af7cae3ccd 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/read/ScanBuilder.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/read/ScanBuilder.java @@ -24,15 +24,13 @@ * interfaces to do operator pushdown, and keep the operator pushdown result in the returned * {@link Scan}. * + * The operators in the Scan can be pushed down to the data source layer. + * If applicable (the operator is present and the source supports that operator), Spark pushes + * down filters to the source first, then push down aggregation and apply column pruning. + * * @since 3.0.0 */ @Evolving public interface ScanBuilder { - enum orders { FILTER, AGGREGATE, COLUMNS }; - - // Orders of operators push down. Spark will push down filters first, then aggregates, and finally - // column pruning (if applicable). - orders[] PUSH_DOWN_ORDERS = {orders.FILTER, orders.AGGREGATE, orders.COLUMNS}; - Scan build(); } diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/read/SupportsPushDownAggregates.java b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/read/SupportsPushDownAggregates.java index 7304f6797792f..1923826b9adaa 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/read/SupportsPushDownAggregates.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/read/SupportsPushDownAggregates.java @@ -18,7 +18,7 @@ package org.apache.spark.sql.connector.read; import org.apache.spark.annotation.Evolving; -import org.apache.spark.sql.sources.Aggregation; +import org.apache.spark.sql.connector.expressions.Aggregation; import org.apache.spark.sql.types.StructType; /** diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala index c22a874779fca..a53c0121d73bd 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala @@ -48,6 +48,14 @@ trait OperationHelper { .map(Alias(_, a.name)(a.exprId, a.qualifier)).getOrElse(a) } } + + protected def hasCommonNonDeterministic( + expr: Seq[Expression], + aliases: AttributeMap[Expression]): Boolean = { + expr.exists(_.collect { + case a: AttributeReference if aliases.contains(a) => aliases(a) + }.exists(!_.deterministic)) + } } /** @@ -116,14 +124,6 @@ object ScanOperation extends OperationHelper with PredicateHelper { } } - private def hasCommonNonDeterministic( - expr: Seq[Expression], - aliases: AttributeMap[Expression]): Boolean = { - expr.exists(_.collect { - case a: AttributeReference if aliases.contains(a) => aliases(a) - }.exists(!_.deterministic)) - } - private def collectProjectsAndFilters(plan: LogicalPlan): ScanReturnType = { plan match { case Project(fields, child) => diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/sources/aggregates.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/connector/expressions/aggregates.scala similarity index 64% rename from sql/catalyst/src/main/scala/org/apache/spark/sql/sources/aggregates.scala rename to sql/catalyst/src/main/scala/org/apache/spark/sql/connector/expressions/aggregates.scala index 646082a2d44f7..0e3729bf3648f 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/sources/aggregates.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/connector/expressions/aggregates.scala @@ -15,24 +15,26 @@ * limitations under the License. */ -package org.apache.spark.sql.sources +package org.apache.spark.sql.connector.expressions import org.apache.spark.sql.types.DataType // Aggregate Functions in SQL statement. // e.g. SELECT COUNT(EmployeeID), Max(salary), deptID FROM dept GROUP BY deptID // aggregateExpressions are (COUNT(EmployeeID), Max(salary)), groupByColumns are (deptID) -case class Aggregation(aggregateExpressions: Seq[Seq[AggregateFunc]], - groupByColumns: Seq[String]) +case class Aggregation(aggregateExpressions: Seq[AggregateFunc], + groupByColumns: Seq[FieldReference]) abstract class AggregateFunc -case class Min(column: String, dataType: DataType) extends AggregateFunc -case class Max(column: String, dataType: DataType) extends AggregateFunc -case class Sum(column: String, dataType: DataType, isDistinct: Boolean) extends AggregateFunc -case class Count(column: String, dataType: DataType, isDistinct: Boolean) extends AggregateFunc +case class Min(column: FieldReference, dataType: DataType) extends AggregateFunc +case class Max(column: FieldReference, dataType: DataType) extends AggregateFunc +case class Sum(column: FieldReference, dataType: DataType, isDistinct: Boolean) + extends AggregateFunc +case class Count(column: FieldReference, dataType: DataType, isDistinct: Boolean) + extends AggregateFunc object Aggregation { // Returns an empty Aggregate - def empty: Aggregation = Aggregation(Seq.empty[Seq[AggregateFunc]], Seq.empty[String]) + def empty: Aggregation = Aggregation(Seq.empty[AggregateFunc], Seq.empty[FieldReference]) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/DataSourceScanExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/DataSourceScanExec.scala index c50819893812c..6fa4167384925 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/DataSourceScanExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/DataSourceScanExec.scala @@ -35,7 +35,7 @@ import org.apache.spark.sql.execution.datasources._ import org.apache.spark.sql.execution.datasources.parquet.{ParquetFileFormat => ParquetSource} import org.apache.spark.sql.execution.metric.{SQLMetric, SQLMetrics} import org.apache.spark.sql.internal.SQLConf -import org.apache.spark.sql.sources.{Aggregation, BaseRelation, Filter} +import org.apache.spark.sql.sources.{BaseRelation, Filter} import org.apache.spark.sql.types.StructType import org.apache.spark.sql.vectorized.ColumnarBatch import org.apache.spark.util.Utils @@ -102,7 +102,6 @@ case class RowDataSourceScanExec( requiredSchema: StructType, filters: Set[Filter], handledFilters: Set[Filter], - aggregation: Aggregation, rdd: RDD[InternalRow], @transient relation: BaseRelation, tableIdentifier: Option[TableIdentifier]) @@ -133,17 +132,9 @@ case class RowDataSourceScanExec( val markedFilters = for (filter <- filters) yield { if (handledFilters.contains(filter)) s"*$filter" else s"$filter" } - val markedAggregates = for (aggregate <- aggregation.aggregateExpressions) yield { - s"*$aggregate" - } - val markedGroupby = for (groupby <- aggregation.groupByColumns) yield { - s"*$groupby" - } Map( "ReadSchema" -> requiredSchema.catalogString, - "PushedFilters" -> markedFilters.mkString("[", ", ", "]"), - "PushedAggregates" -> markedAggregates.mkString("[", ", ", "]"), - "PushedGroupby" -> markedGroupby.mkString("[", ", ", "]")) + "PushedFilters" -> markedFilters.mkString("[", ", ", "]")) } // Don't care about `rdd` and `tableIdentifier` when canonicalizing. diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala index 954ef7b1fc704..8e347436fa977 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala @@ -40,6 +40,7 @@ import org.apache.spark.sql.catalyst.rules.Rule import org.apache.spark.sql.catalyst.streaming.StreamingRelationV2 import org.apache.spark.sql.connector.catalog.SupportsRead import org.apache.spark.sql.connector.catalog.TableCapability._ +import org.apache.spark.sql.connector.expressions.{AggregateFunc, Count, FieldReference, Max, Min} import org.apache.spark.sql.errors.QueryCompilationErrors import org.apache.spark.sql.execution.{RowDataSourceScanExec, SparkPlan} import org.apache.spark.sql.execution.command._ @@ -333,7 +334,6 @@ object DataSourceStrategy l.output.toStructType, Set.empty, Set.empty, - Aggregation.empty, toCatalystRDD(l, baseRelation.buildScan()), baseRelation, None) :: Nil @@ -407,7 +407,6 @@ object DataSourceStrategy requestedColumns.toStructType, pushedFilters.toSet, handledFilters, - Aggregation.empty, scanBuilder(requestedColumns, candidatePredicates, pushedFilters), relation.relation, relation.catalogTable.map(_.identifier)) @@ -430,7 +429,6 @@ object DataSourceStrategy requestedColumns.toStructType, pushedFilters.toSet, handledFilters, - Aggregation.empty, scanBuilder(requestedColumns, candidatePredicates, pushedFilters), relation.relation, relation.catalogTable.map(_.identifier)) @@ -679,19 +677,19 @@ object DataSourceStrategy protected[sql] def translateAggregate( aggregates: AggregateExpression, - pushableColumn: PushableColumnBase): Option[Seq[AggregateFunc]] = { + pushableColumn: PushableColumnBase): Option[AggregateFunc] = { aggregates.aggregateFunction match { - case min@aggregate.Min(pushableColumn(name)) => - Some(Seq(Min(name, min.dataType))) - case max@aggregate.Max(pushableColumn(name)) => - Some(Seq(Max(name, max.dataType))) + case min @ aggregate.Min(pushableColumn(name)) => + Some(Min(FieldReference(Seq(name)), min.dataType)) + case max @ aggregate.Max(pushableColumn(name)) => + Some(Max(FieldReference(Seq(name)), max.dataType)) case count: aggregate.Count => val columnName = count.children.head match { // SELECT COUNT(*) FROM table is translated to SELECT 1 FROM table case Literal(_, _) => "1" case pushableColumn(name) => name } - Some(Seq(Count(columnName, count.dataType, aggregates.isDistinct))) + Some(Count(FieldReference(Seq(columnName)), count.dataType, aggregates.isDistinct)) case _ => None } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetUtils.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetUtils.scala index c2951f9af0e6f..0edf38eb39675 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetUtils.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetUtils.scala @@ -29,13 +29,14 @@ import org.apache.parquet.hadoop.metadata.{ColumnChunkMetaData, ParquetMetadata} import org.apache.parquet.io.api.Binary import org.apache.parquet.schema.PrimitiveType +import org.apache.spark.SparkException import org.apache.spark.sql.SparkSession import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.SpecificInternalRow import org.apache.spark.sql.catalyst.util.DateTimeUtils +import org.apache.spark.sql.connector.expressions.{Aggregation, Count, Max, Min} import org.apache.spark.sql.execution.datasources.DataSourceUtils import org.apache.spark.sql.execution.vectorized.{OffHeapColumnVector, OnHeapColumnVector} -import org.apache.spark.sql.sources.{Aggregation, Count, Max, Min} import org.apache.spark.sql.types.{BinaryType, ByteType, DateType, Decimal, DecimalType, IntegerType, LongType, ShortType, StringType, StructType, TimestampType} import org.apache.spark.sql.vectorized.{ColumnarBatch, ColumnVector} import org.apache.spark.unsafe.types.UTF8String @@ -154,15 +155,17 @@ object ParquetUtils { * * @return Aggregate results in the format of InternalRow */ - private[sql] def aggResultToSparkInternalRows( + private[sql] def createInternalRowFromAggResult( footer: ParquetMetadata, - parquetTypes: Seq[PrimitiveType.PrimitiveTypeName], - values: Seq[Any], dataSchema: StructType, + aggregation: Aggregation, + aggSchema: StructType, datetimeRebaseModeInRead: String, int96RebaseModeInRead: String, convertTz: Option[ZoneId]): InternalRow = { - val mutableRow = new SpecificInternalRow(dataSchema.fields.map(x => x.dataType)) + val (parquetTypes, values) = + ParquetUtils.getPushedDownAggResult(footer, dataSchema, aggregation) + val mutableRow = new SpecificInternalRow(aggSchema.fields.map(x => x.dataType)) val footerFileMetaData = footer.getFileMetaData val datetimeRebaseMode = DataSourceUtils.datetimeRebaseMode( footerFileMetaData.getKeyValueMetaData.get, @@ -170,9 +173,9 @@ object ParquetUtils { val int96RebaseMode = DataSourceUtils.int96RebaseMode( footerFileMetaData.getKeyValueMetaData.get, int96RebaseModeInRead) - parquetTypes.zipWithIndex.map { + parquetTypes.zipWithIndex.foreach { case (PrimitiveType.PrimitiveTypeName.INT32, i) => - dataSchema.fields(i).dataType match { + aggSchema.fields(i).dataType match { case ByteType => mutableRow.setByte(i, values(i).asInstanceOf[Integer].toByte) case ShortType => @@ -186,19 +189,19 @@ object ParquetUtils { case d: DecimalType => val decimal = Decimal(values(i).asInstanceOf[Integer].toLong, d.precision, d.scale) mutableRow.setDecimal(i, decimal, d.precision) - case _ => throw new IllegalArgumentException("Unexpected type for INT32") + case _ => throw new SparkException("Unexpected type for INT32") } case (PrimitiveType.PrimitiveTypeName.INT64, i) => - dataSchema.fields(i).dataType match { + aggSchema.fields(i).dataType match { case LongType => mutableRow.setLong(i, values(i).asInstanceOf[Long]) case d: DecimalType => val decimal = Decimal(values(i).asInstanceOf[Long], d.precision, d.scale) mutableRow.setDecimal(i, decimal, d.precision) - case _ => throw new IllegalArgumentException("Unexpected type for INT64") + case _ => throw new SparkException("Unexpected type for INT64") } case (PrimitiveType.PrimitiveTypeName.INT96, i) => - dataSchema.fields(i).dataType match { + aggSchema.fields(i).dataType match { case LongType => mutableRow.setLong(i, values(i).asInstanceOf[Long]) case TimestampType => @@ -211,7 +214,7 @@ object ParquetUtils { convertTz.map(DateTimeUtils.convertTz(gregorianMicros, _, ZoneOffset.UTC)) .getOrElse(gregorianMicros) mutableRow.setLong(i, adjTime) - case _ => throw new IllegalArgumentException("Unexpected type for INT96") + case _ => throw new SparkException("Unexpected type for INT96") } case (PrimitiveType.PrimitiveTypeName.FLOAT, i) => mutableRow.setFloat(i, values(i).asInstanceOf[Float]) @@ -221,7 +224,7 @@ object ParquetUtils { mutableRow.setBoolean(i, values(i).asInstanceOf[Boolean]) case (PrimitiveType.PrimitiveTypeName.BINARY, i) => val bytes = values(i).asInstanceOf[Binary].getBytes - dataSchema.fields(i).dataType match { + aggSchema.fields(i).dataType match { case StringType => mutableRow.update(i, UTF8String.fromBytes(bytes)) case BinaryType => @@ -230,19 +233,19 @@ object ParquetUtils { val decimal = Decimal(new BigDecimal(new BigInteger(bytes), d.scale), d.precision, d.scale) mutableRow.setDecimal(i, decimal, d.precision) - case _ => throw new IllegalArgumentException("Unexpected type for Binary") + case _ => throw new SparkException("Unexpected type for Binary") } case (PrimitiveType.PrimitiveTypeName.FIXED_LEN_BYTE_ARRAY, i) => val bytes = values(i).asInstanceOf[Binary].getBytes - dataSchema.fields(i).dataType match { + aggSchema.fields(i).dataType match { case d: DecimalType => val decimal = Decimal(new BigDecimal(new BigInteger(bytes), d.scale), d.precision, d.scale) mutableRow.setDecimal(i, decimal, d.precision) - case _ => throw new IllegalArgumentException("Unexpected type for FIXED_LEN_BYTE_ARRAY") + case _ => throw new SparkException("Unexpected type for FIXED_LEN_BYTE_ARRAY") } case _ => - throw new IllegalArgumentException("Unexpected parquet type name") + throw new SparkException("Unexpected parquet type name") } mutableRow } @@ -256,15 +259,17 @@ object ParquetUtils { * * @return Aggregate results in the format of ColumnarBatch */ - private[sql] def aggResultToSparkColumnarBatch( + private[sql] def createColumnarBatchFromAggResult( footer: ParquetMetadata, - parquetTypes: Seq[PrimitiveType.PrimitiveTypeName], - values: Seq[Any], dataSchema: StructType, + aggregation: Aggregation, + aggSchema: StructType, offHeap: Boolean, datetimeRebaseModeInRead: String, int96RebaseModeInRead: String, convertTz: Option[ZoneId]): ColumnarBatch = { + val (parquetTypes, values) = + ParquetUtils.getPushedDownAggResult(footer, dataSchema, aggregation) val capacity = 4 * 1024 val footerFileMetaData = footer.getFileMetaData val datetimeRebaseMode = DataSourceUtils.datetimeRebaseMode( @@ -274,14 +279,14 @@ object ParquetUtils { footerFileMetaData.getKeyValueMetaData.get, int96RebaseModeInRead) val columnVectors = if (offHeap) { - OffHeapColumnVector.allocateColumns(capacity, dataSchema) + OffHeapColumnVector.allocateColumns(capacity, aggSchema) } else { - OnHeapColumnVector.allocateColumns(capacity, dataSchema) + OnHeapColumnVector.allocateColumns(capacity, aggSchema) } - parquetTypes.zipWithIndex.map { + parquetTypes.zipWithIndex.foreach { case (PrimitiveType.PrimitiveTypeName.INT32, i) => - dataSchema.fields(i).dataType match { + aggSchema.fields(i).dataType match { case ByteType => columnVectors(i).appendByte(values(i).asInstanceOf[Integer].toByte) case ShortType => @@ -292,12 +297,12 @@ object ParquetUtils { val dateRebaseFunc = DataSourceUtils.creteDateRebaseFuncInRead( datetimeRebaseMode, "Parquet") columnVectors(i).appendInt(dateRebaseFunc(values(i).asInstanceOf[Integer])) - case _ => throw new IllegalArgumentException("Unexpected type for INT32") + case _ => throw new SparkException("Unexpected type for INT32") } case (PrimitiveType.PrimitiveTypeName.INT64, i) => columnVectors(i).appendLong(values(i).asInstanceOf[Long]) case (PrimitiveType.PrimitiveTypeName.INT96, i) => - dataSchema.fields(i).dataType match { + aggSchema.fields(i).dataType match { case LongType => columnVectors(i).appendLong(values(i).asInstanceOf[Long]) case TimestampType => @@ -310,7 +315,7 @@ object ParquetUtils { convertTz.map(DateTimeUtils.convertTz(gregorianMicros, _, ZoneOffset.UTC)) .getOrElse(gregorianMicros) columnVectors(i).appendLong(adjTime) - case _ => throw new IllegalArgumentException("Unexpected type for INT96") + case _ => throw new SparkException("Unexpected type for INT96") } case (PrimitiveType.PrimitiveTypeName.FLOAT, i) => columnVectors(i).appendFloat(values(i).asInstanceOf[Float]) @@ -325,7 +330,7 @@ object ParquetUtils { case (PrimitiveType.PrimitiveTypeName.BOOLEAN, i) => columnVectors(i).appendBoolean(values(i).asInstanceOf[Boolean]) case _ => - throw new IllegalArgumentException("Unexpected parquet type name") + throw new SparkException("Unexpected parquet type name") } new ColumnarBatch(columnVectors.asInstanceOf[Array[ColumnVector]], 1) } @@ -349,7 +354,7 @@ object ParquetUtils { val typesBuilder = ArrayBuilder.make[PrimitiveType.PrimitiveTypeName] val valuesBuilder = ArrayBuilder.make[Any] - for (i <- 0 until aggregation.aggregateExpressions.size) { + aggregation.aggregateExpressions.indices.foreach { i => var value: Any = None var rowCount = 0L var isCount = false @@ -357,27 +362,27 @@ object ParquetUtils { blocks.forEach { block => val blockMetaData = block.getColumns() aggregation.aggregateExpressions(i) match { - case Seq(Max(col, _)) => - index = dataSchema.fieldNames.toList.indexOf(col) - val currentMax = getCurrentBlockMaxOrMin(footer, blockMetaData, index, true) + case Max(col, _) => + index = dataSchema.fieldNames.toList.indexOf(col.fieldNames.head) + val currentMax = getCurrentBlockMaxOrMin(blockMetaData, index, true) if (currentMax != None && (value == None || currentMax.asInstanceOf[Comparable[Any]].compareTo(value) > 0)) { value = currentMax } - case Seq(Min(col, _)) => - index = dataSchema.fieldNames.toList.indexOf(col) - val currentMin = getCurrentBlockMaxOrMin(footer, blockMetaData, index, false) + case Min(col, _) => + index = dataSchema.fieldNames.toList.indexOf(col.fieldNames.head) + val currentMin = getCurrentBlockMaxOrMin(blockMetaData, index, false) if (currentMin != None && (value == None || currentMin.asInstanceOf[Comparable[Any]].compareTo(value) < 0)) { value = currentMin } - case Seq(Count(col, _, _)) => - index = dataSchema.fieldNames.toList.indexOf(col) + case Count(col, _, _) => + index = dataSchema.fieldNames.toList.indexOf(col.fieldNames.head) rowCount += block.getRowCount - if (!col.equals("1")) { // "1" is for count(*) - rowCount -= getNumNulls(footer, blockMetaData, index) + if (!col.fieldNames.head.equals("1")) { // "1" is for count(*) + rowCount -= getNumNulls(blockMetaData, index) } isCount = true @@ -401,14 +406,9 @@ object ParquetUtils { * @return the Max or Min value */ private def getCurrentBlockMaxOrMin( - footer: ParquetMetadata, columnChunkMetaData: util.List[ColumnChunkMetaData], i: Int, isMax: Boolean): Any = { - val parquetType = footer.getFileMetaData.getSchema.getType(i) - if (!parquetType.isPrimitive) { - throw new IllegalArgumentException("Unsupported type : " + parquetType.toString) - } val statistics = columnChunkMetaData.get(i).getStatistics() if (!statistics.hasNonNullValue) { throw new UnsupportedOperationException("No min/max found for parquet file, Set SQLConf" + @@ -419,13 +419,8 @@ object ParquetUtils { } private def getNumNulls( - footer: ParquetMetadata, columnChunkMetaData: util.List[ColumnChunkMetaData], i: Int): Long = { - val parquetType = footer.getFileMetaData.getSchema.getType(i) - if (!parquetType.isPrimitive) { - throw new IllegalArgumentException("Unsupported type: " + parquetType.toString) - } val statistics = columnChunkMetaData.get(i).getStatistics() if (!statistics.isNumNullsSet()) { throw new UnsupportedOperationException("Number of nulls not set for parquet file." + diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Strategy.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Strategy.scala index 6d69c8071988a..1585bc040fda1 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Strategy.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Strategy.scala @@ -100,7 +100,6 @@ class DataSourceV2Strategy(session: SparkSession) extends Strategy with Predicat output.toStructType, translated.toSet, pushed.toSet, - aggregation, unsafeRowRDD, v1Relation, tableIdentifier = None) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/PushDownUtils.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/PushDownUtils.scala index f778f8e6e2295..4ef6a984c0d01 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/PushDownUtils.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/PushDownUtils.scala @@ -22,11 +22,11 @@ import scala.collection.mutable import org.apache.spark.sql.catalyst.expressions.{AttributeReference, AttributeSet, Expression, NamedExpression, PredicateHelper, SchemaPruning} import org.apache.spark.sql.catalyst.expressions.aggregate.AggregateExpression import org.apache.spark.sql.catalyst.util.CharVarcharUtils +import org.apache.spark.sql.connector.expressions.{Aggregation, FieldReference} import org.apache.spark.sql.connector.read.{Scan, ScanBuilder, SupportsPushDownAggregates, SupportsPushDownFilters, SupportsPushDownRequiredColumns} import org.apache.spark.sql.execution.datasources.{DataSourceStrategy, PushableColumn} import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.sources -import org.apache.spark.sql.sources.Aggregation import org.apache.spark.sql.types.StructType object PushDownUtils extends PredicateHelper { @@ -82,9 +82,9 @@ object PushDownUtils extends PredicateHelper { aggregates: Seq[AggregateExpression], groupBy: Seq[Expression]): Aggregation = { - def columnAsString(e: Expression): String = e match { - case AttributeReference(name, _, _, _) => name - case _ => "" + def columnAsString(e: Expression): Option[FieldReference] = e match { + case AttributeReference(name, _, _, _) => Some(FieldReference(Seq(name))) + case _ => None } scanBuilder match { @@ -96,7 +96,7 @@ object PushDownUtils extends PredicateHelper { if (translatedAggregates.exists(_.isEmpty) || translatedGroupBys.exists(_.isEmpty)) { Aggregation.empty } else { - r.pushAggregation(Aggregation(translatedAggregates.flatten, translatedGroupBys)) + r.pushAggregation(Aggregation(translatedAggregates.flatten, translatedGroupBys.flatten)) r.pushedAggregation } case _ => Aggregation.empty diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/V2ScanRelationPushDown.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/V2ScanRelationPushDown.scala index ebfa6d080f6a1..ba12c390f70a2 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/V2ScanRelationPushDown.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/V2ScanRelationPushDown.scala @@ -19,172 +19,200 @@ package org.apache.spark.sql.execution.datasources.v2 import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.aggregate.AggregateExpression -import org.apache.spark.sql.catalyst.planning.ScanOperation -import org.apache.spark.sql.catalyst.plans.logical.{Aggregate, Filter, LogicalPlan, Project} +import org.apache.spark.sql.catalyst.planning.{OperationHelper, ScanOperation} +import org.apache.spark.sql.catalyst.plans.logical.{Aggregate, Filter, LeafNode, LogicalPlan, Project} import org.apache.spark.sql.catalyst.rules.Rule import org.apache.spark.sql.catalyst.util.toPrettySQL -import org.apache.spark.sql.connector.read.{Scan, SupportsPushDownAggregates, V1Scan} +import org.apache.spark.sql.connector.expressions.Aggregation +import org.apache.spark.sql.connector.read.{Scan, ScanBuilder, SupportsPushDownAggregates, SupportsPushDownFilters, V1Scan} import org.apache.spark.sql.execution.datasources.DataSourceStrategy import org.apache.spark.sql.sources -import org.apache.spark.sql.sources.{AggregateFunc, Aggregation} import org.apache.spark.sql.types.StructType -object V2ScanRelationPushDown extends Rule[LogicalPlan] with AliasHelper { +object V2ScanRelationPushDown extends Rule[LogicalPlan] with AliasHelper + with OperationHelper with PredicateHelper { import DataSourceV2Implicits._ - override def apply(plan: LogicalPlan): LogicalPlan = plan transformDown { - case aggNode@Aggregate(groupingExpressions, resultExpressions, child) => + def apply(plan: LogicalPlan): LogicalPlan = { + applyColumnPruning(pushdownAggregate(pushDownFilters(createScanBuilder(plan)))) + } + + private def createScanBuilder(plan: LogicalPlan) = plan.transform { + case r: DataSourceV2Relation => + ScanBuilderHolder(r.output, r, r.table.asReadable.newScanBuilder(r.options)) + } + + private def pushDownFilters(plan: LogicalPlan) = plan.transform { + // update the scan builder with filter push down and return a new plan with filter pushed + case filter @ Filter(_, sHolder: ScanBuilderHolder) => + val (filters, _, _) = collectFilters(filter).get + + val normalizedFilters = + DataSourceStrategy.normalizeExprs(filters, sHolder.relation.output) + val (normalizedFiltersWithSubquery, normalizedFiltersWithoutSubquery) = + normalizedFilters.partition(SubqueryExpression.hasSubquery) + + // `pushedFilters` will be pushed down and evaluated in the underlying data sources. + // `postScanFilters` need to be evaluated after the scan. + // `postScanFilters` and `pushedFilters` can overlap, e.g. the parquet row group filter. + val (pushedFilters, postScanFiltersWithoutSubquery) = PushDownUtils.pushFilters( + sHolder.builder, normalizedFiltersWithoutSubquery) + val postScanFilters = postScanFiltersWithoutSubquery ++ normalizedFiltersWithSubquery + val output = sHolder.output + + logInfo( + s""" + |Pushing operators to ${sHolder.relation.asInstanceOf[DataSourceV2Relation].name} + |Pushed Filters: ${pushedFilters.mkString(", ")} + |Post-Scan Filters: ${postScanFilters.mkString(",")} + """.stripMargin) + + val projectionOverSchema = ProjectionOverSchema(output.toStructType) + val projectionFunc = (expr: Expression) => expr transformDown { + case projectionOverSchema(newExpr) => newExpr + } + + val filterCondition = postScanFilters.reduceLeftOption(And) + val newFilterCondition = filterCondition.map(projectionFunc) + newFilterCondition.map(Filter(_, sHolder)).getOrElse(sHolder) + } + + def pushdownAggregate(plan: LogicalPlan): LogicalPlan = plan.transform { + // update the scan builder with agg pushdown and return a new plan with agg pushed + case aggNode @ Aggregate(groupingExpressions, resultExpressions, child) => child match { - case ScanOperation(project, filters, relation: DataSourceV2Relation) => - val scanBuilder = relation.table.asReadable.newScanBuilder(relation.options) - scanBuilder match { - case r: SupportsPushDownAggregates if r.supportsPushDownAggregateWithFilter => - // todo: need to change this for JDBC aggregate push down - // for (pushDownOrder <- ScanBuilder.PUSH_DOWN_ORDERS) { - // if (pushDownOrder == orders.FILTER) { - // pushdown filter - // } else if(pushDownOrder == orders.AGGREGATE) { - // pushdown aggregate - // } - aggNode + case ScanOperation(project, filters, sHolder: ScanBuilderHolder) => + sHolder.builder match { case r: SupportsPushDownAggregates => - if (filters.isEmpty) { + if (filters.isEmpty || r.supportsPushDownAggregateWithFilter()) { if (r.supportsGlobalAggregatePushDownOnly() && groupingExpressions.nonEmpty) { aggNode // return original plan node } else { - var aggregates = resultExpressions.flatMap { expr => - expr.collect { - case agg: AggregateExpression => - replaceAlias(agg, getAliasMap(project)).asInstanceOf[AggregateExpression] - } - } - aggregates = DataSourceStrategy.normalizeExprs(aggregates, relation.output) - .asInstanceOf[Seq[AggregateExpression]] - val aggregation = PushDownUtils - .pushAggregates(scanBuilder, aggregates, groupingExpressions) - if (aggregation.aggregateExpressions.isEmpty) { - aggNode // return original plan node - } else { - // use the aggregate columns as the output columns - // e.g. TABLE t (c1 INT, c2 INT, c3 INT) - // SELECT min(c1), max(c1) FROM t; - // Use min(c1), max(c1) as output for DataSourceV2ScanRelation - // We want to have the following logical plan: - // == Optimized Logical Plan == - // Aggregate [min(min(c1)#21) AS min(c1)#17, max(max(c1)#22) AS max(c1)#18] - // +- RelationV2[min(c1)#21, max(c1)#22] parquet file ... - val output = aggregates.map { - case agg: AggregateExpression => - AttributeReference(toPrettySQL(agg), agg.dataType)() - } - - // No need to do column pruning because only the aggregate columns are used as - // DataSourceV2ScanRelation output columns. All the other columns are not - // included in the output. Since PushDownUtils.pruneColumns is not called, - // ScanBuilder.requiredSchema is not pruned, but ScanBuilder.requiredSchema is - // not used anyways. The schema for aggregate columns will be built in Scan. - val scan = scanBuilder.build() - - logInfo( - s""" - |Pushing operators to ${relation.name} - |Pushed Aggregates: ${aggregation.aggregateExpressions.mkString(", ")} - |Output: ${output.mkString(", ")} - """.stripMargin) - val wrappedScan = scan match { - case v1: V1Scan => - V1ScanWrapper(v1, Seq.empty[sources.Filter], Seq.empty[sources.Filter], - aggregation) - case _ => scan - } - - val scanRelation = DataSourceV2ScanRelation(relation, wrappedScan, output) - val plan = Aggregate(groupingExpressions, resultExpressions, scanRelation) - - // Change the optimized logical plan to reflect the pushed down aggregate - // e.g. TABLE t (c1 INT, c2 INT, c3 INT) - // SELECT min(c1), max(c1) FROM t; - // The original logical plan is - // Aggregate [min(c1#9) AS min(c1)#17, max(c1#9) AS max(c1)#18] - // +- RelationV2[c1#9] parquet ... - // - // After change the V2ScanRelation output to [min(_1)#21, max(_1)#22] - // we have the following - // !Aggregate [min(_1#9) AS min(_1)#17, max(_1#9) AS max(_1)#18] - // +- RelationV2[min(_1)#21, max(_1)#22] parquet ... - // - // We want to change it to - // == Optimized Logical Plan == - // Aggregate [min(min(c1)#21) AS min(c1)#17, max(max(c1)#22) AS max(c1)#18] - // +- RelationV2[min(c1)#21, max(c1)#22] parquet file ... - var i = 0 - plan.transformExpressions { - case agg: AggregateExpression => - i += 1 - val aggFunction: aggregate.AggregateFunction = - agg.aggregateFunction match { - case _: aggregate.Max => aggregate.Max(output(i - 1)) - case _: aggregate.Min => aggregate.Min(output(i - 1)) - case _: aggregate.Sum => aggregate.Sum(output(i - 1)) - case _: aggregate.Count => aggregate.Sum(output(i - 1)) - case _ => agg.aggregateFunction - } - agg.copy(aggregateFunction = aggFunction, filter = None) - } - } + val aggregates = getAggregateExpression(resultExpressions, project, sHolder) + val pushedAggregates = PushDownUtils + .pushAggregates(sHolder.builder, aggregates, groupingExpressions) + logInfo( + s""" + |Pushed Aggregates: ${pushedAggregates.aggregateExpressions.mkString(", ")} + """.stripMargin) + aggNode } } else { aggNode } case _ => aggNode } - - case _ => aggNode // return original plan node + case _ => aggNode } - case ScanOperation(project, filters, relation: DataSourceV2Relation) => - val scanBuilder = relation.table.asReadable.newScanBuilder(relation.options) + } - val normalizedFilters = DataSourceStrategy.normalizeExprs(filters, relation.output) - val (normalizedFiltersWithSubquery, normalizedFiltersWithoutSubquery) = - normalizedFilters.partition(SubqueryExpression.hasSubquery) + def applyColumnPruning(plan: LogicalPlan): LogicalPlan = plan.transform { + case aggNode @ Aggregate(groupingExpressions, resultExpressions, child) => + child match { + case ScanOperation(project, _, sHolder: ScanBuilderHolder) => + sHolder.builder match { + case _: SupportsPushDownAggregates => + if (sHolder.builder.asInstanceOf[SupportsPushDownAggregates] + .pushedAggregation().aggregateExpressions.nonEmpty) { + val aggregates = getAggregateExpression(resultExpressions, project, sHolder) + // use the aggregate columns as the output columns + // e.g. TABLE t (c1 INT, c2 INT, c3 INT) + // SELECT min(c1), max(c1) FROM t; + // Use min(c1), max(c1) as output for DataSourceV2ScanRelation + // We want to have the following logical plan: + // == Optimized Logical Plan == + // Aggregate [min(min(c1)#21) AS min(c1)#17, max(max(c1)#22) AS max(c1)#18] + // +- RelationV2[min(c1)#21, max(c1)#22] parquet file ... + val output = aggregates.map { + case agg: AggregateExpression => + AttributeReference(toPrettySQL(agg), agg.dataType)() + } - // `pushedFilters` will be pushed down and evaluated in the underlying data sources. - // `postScanFilters` need to be evaluated after the scan. - // `postScanFilters` and `pushedFilters` can overlap, e.g. the parquet row group filter. - val (pushedFilters, postScanFiltersWithoutSubquery) = PushDownUtils.pushFilters( - scanBuilder, normalizedFiltersWithoutSubquery) - val postScanFilters = postScanFiltersWithoutSubquery ++ normalizedFiltersWithSubquery + // No need to do column pruning because only the aggregate columns are used as + // DataSourceV2ScanRelation output columns. All the other columns are not + // included in the output. Since PushDownUtils.pruneColumns is not called, + // ScanBuilder.requiredSchema is not pruned, but ScanBuilder.requiredSchema is + // not used anyways. The schema for aggregate columns will be built in Scan. + val scan = sHolder.builder.build() + logInfo( + s""" + |Output: ${output.mkString(", ")} + """.stripMargin) + + val scanRelation = DataSourceV2ScanRelation(sHolder.relation, scan, output) + val plan = Aggregate(groupingExpressions, resultExpressions, scanRelation) + + // Change the optimized logical plan to reflect the pushed down aggregate + // e.g. TABLE t (c1 INT, c2 INT, c3 INT) + // SELECT min(c1), max(c1) FROM t; + // The original logical plan is + // Aggregate [min(c1#9) AS min(c1)#17, max(c1#9) AS max(c1)#18] + // +- RelationV2[c1#9] parquet ... + // + // After change the V2ScanRelation output to [min(_1)#21, max(_1)#22] + // we have the following + // !Aggregate [min(_1#9) AS min(_1)#17, max(_1#9) AS max(_1)#18] + // +- RelationV2[min(_1)#21, max(_1)#22] parquet ... + // + // We want to change it to + // == Optimized Logical Plan == + // Aggregate [min(min(c1)#21) AS min(c1)#17, max(max(c1)#22) AS max(c1)#18] + // +- RelationV2[min(c1)#21, max(c1)#22] parquet file ... + var i = 0 + plan.transformExpressions { + case agg: AggregateExpression => + i += 1 + val aggFunction: aggregate.AggregateFunction = + agg.aggregateFunction match { + case _: aggregate.Max => aggregate.Max(output(i - 1)) + case _: aggregate.Min => aggregate.Min(output(i - 1)) + case _: aggregate.Sum => aggregate.Sum(output(i - 1)) + case _: aggregate.Count => aggregate.Sum(output(i - 1)) + case _ => agg.aggregateFunction + } + agg.copy(aggregateFunction = aggFunction, filter = None) + } + } else { + aggNode + } + case _ => aggNode + } + + case _ => aggNode + } + case ScanOperation(project, filters, sHolder: ScanBuilderHolder) => + // column pruning val normalizedProjects = DataSourceStrategy - .normalizeExprs(project, relation.output) + .normalizeExprs(project, sHolder.output) .asInstanceOf[Seq[NamedExpression]] val (scan, output) = PushDownUtils.pruneColumns( - scanBuilder, relation, normalizedProjects, postScanFilters) + sHolder.builder, sHolder.relation, normalizedProjects, filters) + logInfo( s""" - |Pushing operators to ${relation.name} - |Pushed Filters: ${pushedFilters.mkString(", ")} - |Post-Scan Filters: ${postScanFilters.mkString(",")} |Output: ${output.mkString(", ")} """.stripMargin) val wrappedScan = scan match { case v1: V1Scan => val translated = filters.flatMap(DataSourceStrategy.translateFilter(_, true)) - V1ScanWrapper(v1, translated, pushedFilters, - Aggregation(Seq.empty[Seq[AggregateFunc]], Seq.empty[String])) - + val pushedFilters = sHolder.builder match { + case f: SupportsPushDownFilters => + f.pushedFilters() + } + V1ScanWrapper(v1, translated, pushedFilters, Aggregation.empty) case _ => scan } - val scanRelation = DataSourceV2ScanRelation(relation, wrappedScan, output) + val scanRelation = DataSourceV2ScanRelation(sHolder.relation, wrappedScan, output) val projectionOverSchema = ProjectionOverSchema(output.toStructType) val projectionFunc = (expr: Expression) => expr transformDown { case projectionOverSchema(newExpr) => newExpr } - val filterCondition = postScanFilters.reduceLeftOption(And) + val filterCondition = filters.reduceLeftOption(And) val newFilterCondition = filterCondition.map(projectionFunc) val withFilter = newFilterCondition.map(Filter(_, scanRelation)).getOrElse(scanRelation) @@ -199,14 +227,60 @@ object V2ScanRelationPushDown extends Rule[LogicalPlan] with AliasHelper { withProjection } + + private def getAggregateExpression( + resultExpressions: Seq[NamedExpression], + project: Seq[NamedExpression], + sHolder: ScanBuilderHolder): Seq[AggregateExpression] = { + val aggregates = resultExpressions.flatMap { expr => + expr.collect { + case agg: AggregateExpression => + replaceAlias(agg, getAliasMap(project)).asInstanceOf[AggregateExpression] + } + } + DataSourceStrategy.normalizeExprs(aggregates, sHolder.relation.output) + .asInstanceOf[Seq[AggregateExpression]] + } + + private def collectFilters(plan: LogicalPlan): + Option[(Seq[Expression], LogicalPlan, AttributeMap[Expression])] = { + plan match { + case Filter(condition, child) => + collectFilters(child) match { + case Some((filters, other, aliases)) => + // Follow CombineFilters and only keep going if 1) the collected Filters + // and this filter are all deterministic or 2) if this filter is the first + // collected filter and doesn't have common non-deterministic expressions + // with lower Project. + val substitutedCondition = substitute(aliases)(condition) + val canCombineFilters = (filters.nonEmpty && filters.forall(_.deterministic) && + substitutedCondition.deterministic) || filters.isEmpty + if (canCombineFilters && !hasCommonNonDeterministic(Seq(condition), aliases)) { + Some((filters ++ splitConjunctivePredicates(substitutedCondition), + other, aliases)) + } else { + None + } + case None => None + } + + case other => + Some((Nil, other, AttributeMap(Seq()))) + } + } } +case class ScanBuilderHolder( + output: Seq[AttributeReference], + relation: DataSourceV2Relation, + builder: ScanBuilder) extends LeafNode + // A wrapper for v1 scan to carry the translated filters and the handled ones. This is required by // the physical v1 scan node. case class V1ScanWrapper( v1Scan: V1Scan, translatedFilters: Seq[sources.Filter], handledFilters: Seq[sources.Filter], - pushedAggregates: sources.Aggregation) extends Scan { + pushedAggregates: Aggregation) extends Scan { override def readSchema(): StructType = v1Scan.readSchema() } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/parquet/ParquetPartitionReaderFactory.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/parquet/ParquetPartitionReaderFactory.scala index caf259b69b377..e296027bac5a5 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/parquet/ParquetPartitionReaderFactory.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/parquet/ParquetPartitionReaderFactory.scala @@ -34,13 +34,14 @@ import org.apache.spark.broadcast.Broadcast import org.apache.spark.internal.Logging import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.util.DateTimeUtils +import org.apache.spark.sql.connector.expressions.Aggregation import org.apache.spark.sql.connector.read.{InputPartition, PartitionReader} import org.apache.spark.sql.execution.datasources.{DataSourceUtils, PartitionedFile, RecordReaderIterator} import org.apache.spark.sql.execution.datasources.parquet._ import org.apache.spark.sql.execution.datasources.v2._ import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.internal.SQLConf.LegacyBehaviorPolicy -import org.apache.spark.sql.sources.{Aggregation, Filter} +import org.apache.spark.sql.sources.Filter import org.apache.spark.sql.types.{AtomicType, StructType} import org.apache.spark.sql.vectorized.ColumnarBatch import org.apache.spark.util.SerializableConfiguration @@ -145,9 +146,7 @@ case class ParquetPartitionReaderFactory( override def get(): InternalRow = { count += 1 val footer = getFooter(file) - val (parquetTypes, values) = - ParquetUtils.getPushedDownAggResult(footer, dataSchema, aggregation) - ParquetUtils.aggResultToSparkInternalRows(footer, parquetTypes, values, aggSchema, + ParquetUtils.createInternalRowFromAggResult(footer, dataSchema, aggregation, aggSchema, datetimeRebaseModeInRead, int96RebaseModeInRead, convertTz(isCreatedByParquetMr(file))) } @@ -181,9 +180,7 @@ case class ParquetPartitionReaderFactory( override def get(): ColumnarBatch = { count += 1 val footer = getFooter(file) - val (parquetTypes, values) = - ParquetUtils.getPushedDownAggResult(footer, dataSchema, aggregation) - ParquetUtils.aggResultToSparkColumnarBatch(footer, parquetTypes, values, aggSchema, + ParquetUtils.createColumnarBatchFromAggResult(footer, dataSchema, aggregation, aggSchema, enableOffHeapColumnVector, datetimeRebaseModeInRead, int96RebaseModeInRead, convertTz(isCreatedByParquetMr(file))) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/parquet/ParquetScan.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/parquet/ParquetScan.scala index 2b0bfb831f349..9d4149cd8da9b 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/parquet/ParquetScan.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/parquet/ParquetScan.scala @@ -24,12 +24,13 @@ import org.apache.parquet.hadoop.ParquetInputFormat import org.apache.spark.sql.SparkSession import org.apache.spark.sql.catalyst.expressions.Expression +import org.apache.spark.sql.connector.expressions.Aggregation import org.apache.spark.sql.connector.read.PartitionReaderFactory import org.apache.spark.sql.execution.datasources.PartitioningAwareFileIndex import org.apache.spark.sql.execution.datasources.parquet.{ParquetOptions, ParquetReadSupport, ParquetWriteSupport} import org.apache.spark.sql.execution.datasources.v2.FileScan import org.apache.spark.sql.internal.SQLConf -import org.apache.spark.sql.sources.{Aggregation, Filter} +import org.apache.spark.sql.sources.Filter import org.apache.spark.sql.types.StructType import org.apache.spark.sql.util.CaseInsensitiveStringMap import org.apache.spark.util.SerializableConfiguration diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/parquet/ParquetScanBuilder.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/parquet/ParquetScanBuilder.scala index ea598f5e942d9..7f12d5dbd73bf 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/parquet/ParquetScanBuilder.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/parquet/ParquetScanBuilder.scala @@ -20,12 +20,13 @@ package org.apache.spark.sql.execution.datasources.v2.parquet import scala.collection.JavaConverters._ import org.apache.spark.sql.SparkSession +import org.apache.spark.sql.connector.expressions.{Aggregation, Count, Max, Min} import org.apache.spark.sql.connector.read.{Scan, SupportsPushDownAggregates, SupportsPushDownFilters} import org.apache.spark.sql.execution.datasources.PartitioningAwareFileIndex import org.apache.spark.sql.execution.datasources.parquet.{ParquetFilters, SparkToParquetSchemaConverter} import org.apache.spark.sql.execution.datasources.v2.FileScanBuilder -import org.apache.spark.sql.sources.{Aggregation, Count, Filter, Max, Min} -import org.apache.spark.sql.types.{LongType, StructField, StructType} +import org.apache.spark.sql.sources.Filter +import org.apache.spark.sql.types.{ArrayType, LongType, MapType, StructField, StructType} import org.apache.spark.sql.util.CaseInsensitiveStringMap case class ParquetScanBuilder( @@ -76,15 +77,27 @@ case class ParquetScanBuilder( override def pushAggregation(aggregation: Aggregation): Unit = { if (!sparkSession.sessionState.conf.parquetAggregatePushDown || aggregation.groupByColumns.nonEmpty) { - Aggregation.empty return } - aggregation.aggregateExpressions.foreach { _ match { - // parquet's statistics doesn't have distinct count info - case Seq(Max(_, _)) | Seq(Min(_, _)) | Seq(Count(_, _, false)) => - case _ => Aggregation.empty - } + aggregation.aggregateExpressions.foreach { + case Max(col, _) => + dataSchema.fields(dataSchema.fieldNames.toList.indexOf(col.fieldNames.head)) + .dataType match { + // not push down nested column + case StructType(_) | ArrayType(_, _) | MapType(_, _, _) => return + case _ => + } + case Min(col, _) => + dataSchema.fields(dataSchema.fieldNames.toList.indexOf(col.fieldNames.head)) + .dataType match { + // not push down nested column + case StructType(_) | ArrayType(_, _) | MapType(_, _, _) => return + case _ => + } + // not push down distinct count + case Count(_, _, false) => + case _ => return } this.pushedAggregations = aggregation } @@ -97,15 +110,15 @@ case class ParquetScanBuilder( override def getPushDownAggSchema: StructType = { var schema = new StructType() - pushedAggregations.aggregateExpressions.map { - case Seq(Max(col, _)) => - val field = dataSchema.fields(dataSchema.fieldNames.toList.indexOf(col)) + pushedAggregations.aggregateExpressions.foreach { + case Max(col, _) => + val field = dataSchema.fields(dataSchema.fieldNames.toList.indexOf(col.fieldNames.head)) schema = schema.add(field.copy("max(" + field.name + ")")) - case Seq(Min(col, _)) => - val field = dataSchema.fields(dataSchema.fieldNames.toList.indexOf(col)) + case Min(col, _) => + val field = dataSchema.fields(dataSchema.fieldNames.toList.indexOf(col.fieldNames.head)) schema = schema.add(field.copy("min(" + field.name + ")")) - case Seq(Count(col, _, _)) => - if (col.equals("1")) { + case Count(col, _, _) => + if (col.fieldNames.head.equals("1")) { schema = schema.add(new StructField("count(*)", LongType)) } else { schema = schema.add(new StructField("count(" + col + ")", LongType)) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/FileScanSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/FileScanSuite.scala index 19f47f18d4ff5..2ac1a77fefb45 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/FileScanSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/FileScanSuite.scala @@ -24,6 +24,7 @@ import org.apache.hadoop.fs.{FileStatus, Path} import org.apache.spark.sql.catalyst.dsl.expressions._ import org.apache.spark.sql.catalyst.expressions.{And, Expression, IsNull, LessThan} +import org.apache.spark.sql.connector.expressions.Aggregation import org.apache.spark.sql.execution.datasources.{PartitioningAwareFileIndex, PartitionSpec} import org.apache.spark.sql.execution.datasources.v2.FileScan import org.apache.spark.sql.execution.datasources.v2.csv.CSVScan @@ -31,7 +32,7 @@ import org.apache.spark.sql.execution.datasources.v2.json.JsonScan import org.apache.spark.sql.execution.datasources.v2.orc.OrcScan import org.apache.spark.sql.execution.datasources.v2.parquet.ParquetScan import org.apache.spark.sql.execution.datasources.v2.text.TextScan -import org.apache.spark.sql.sources.{Aggregation, Filter} +import org.apache.spark.sql.sources.Filter import org.apache.spark.sql.test.SharedSparkSession import org.apache.spark.sql.types.StructType import org.apache.spark.sql.util.CaseInsensitiveStringMap diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetQuerySuite.scala index 277eeda5f0ce0..52a3c26c698da 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetQuerySuite.scala @@ -904,6 +904,23 @@ abstract class ParquetQuerySuite extends QueryTest with ParquetTest with SharedS } } + test("test aggregate push down - nested data ") { + val data = (1 to 10).map(i => Tuple1((i, Seq(s"val_$i")))) + withSQLConf( + SQLConf.PARQUET_AGGREGATE_PUSHDOWN_ENABLED.key -> "true") { + withParquetTable(data, "t") { + val count = sql("SELECT Count(_1) FROM t") + count.queryExecution.optimizedPlan.collect { + case _: DataSourceV2ScanRelation => + val expected_plan_fragment = + "PushedAggregation: [Count(_1,LongType,false)]" + checkKeywordsExistsInExplain(count, expected_plan_fragment) + } + checkAnswer(count, Seq(Row(10))) + } + } + } + test("test aggregate push down") { val data = Seq((-2, "abc", 2), (3, "def", 4), (6, "ghi", 2), (0, null, 19), (9, "mno", 7), (2, null, 6)) @@ -925,16 +942,16 @@ abstract class ParquetQuerySuite extends QueryTest with ParquetTest with SharedS selectAgg3.queryExecution.optimizedPlan.collect { case _: DataSourceV2ScanRelation => val expected_plan_fragment = - "PushedAggregation: [List(Min(_3,IntegerType)), " + - "List(Min(_3,IntegerType)), " + - "List(Max(_3,IntegerType)), " + - "List(Min(_1,IntegerType)), " + - "List(Max(_1,IntegerType)), " + - "List(Max(_1,IntegerType)), " + - "List(Count(1,LongType,false)), " + - "List(Count(_1,LongType,false)), " + - "List(Count(_2,LongType,false)), " + - "List(Count(_3,LongType,false))]" + "PushedAggregation: [Min(_3,IntegerType), " + + "Min(_3,IntegerType), " + + "Max(_3,IntegerType), " + + "Min(_1,IntegerType), " + + "Max(_1,IntegerType), " + + "Max(_1,IntegerType), " + + "Count(`1`,LongType,false), " + + "Count(_1,LongType,false), " + + "Count(_2,LongType,false), " + + "Count(_3,LongType,false)]" checkKeywordsExistsInExplain(selectAgg3, expected_plan_fragment) } @@ -1030,17 +1047,17 @@ abstract class ParquetQuerySuite extends QueryTest with ParquetTest with SharedS testMin.queryExecution.optimizedPlan.collect { case _: DataSourceV2ScanRelation => val expected_plan_fragment = - "PushedAggregation: [List(Min(StringCol,StringType)), " + - "List(Min(BooleanCol,BooleanType)), " + - "List(Min(ByteCol,ByteType)), " + - "List(Min(BinaryCol,BinaryType)), " + - "List(Min(ShortCol,ShortType)), " + - "List(Min(IntegerCol,IntegerType)), " + - "List(Min(LongCol,LongType)), " + - "List(Min(FloatCol,FloatType)), " + - "List(Min(DoubleCol,DoubleType)), " + - "List(Min(DecimalCol,DecimalType(25,5))), " + - "List(Min(DateCol,DateType))]" + "PushedAggregation: [Min(StringCol,StringType), " + + "Min(BooleanCol,BooleanType), " + + "Min(ByteCol,ByteType), " + + "Min(BinaryCol,BinaryType), " + + "Min(ShortCol,ShortType), " + + "Min(IntegerCol,IntegerType), " + + "Min(LongCol,LongType), " + + "Min(FloatCol,FloatType), " + + "Min(DoubleCol,DoubleType), " + + "Min(DecimalCol,DecimalType(25,5)), " + + "Min(DateCol,DateType)]" checkKeywordsExistsInExplain(testMin, expected_plan_fragment) } @@ -1055,17 +1072,17 @@ abstract class ParquetQuerySuite extends QueryTest with ParquetTest with SharedS testMax.queryExecution.optimizedPlan.collect { case _: DataSourceV2ScanRelation => val expected_plan_fragment = - "PushedAggregation: [List(Max(StringCol,StringType)), " + - "List(Max(BooleanCol,BooleanType)), " + - "List(Max(ByteCol,ByteType)), " + - "List(Max(BinaryCol,BinaryType)), " + - "List(Max(ShortCol,ShortType)), " + - "List(Max(IntegerCol,IntegerType)), " + - "List(Max(LongCol,LongType)), " + - "List(Max(FloatCol,FloatType)), " + - "List(Max(DoubleCol,DoubleType)), " + - "List(Max(DecimalCol,DecimalType(25,5))), " + - "List(Max(DateCol,DateType))]" + "PushedAggregation: [Max(StringCol,StringType), " + + "Max(BooleanCol,BooleanType), " + + "Max(ByteCol,ByteType), " + + "Max(BinaryCol,BinaryType), " + + "Max(ShortCol,ShortType), " + + "Max(IntegerCol,IntegerType), " + + "Max(LongCol,LongType), " + + "Max(FloatCol,FloatType), " + + "Max(DoubleCol,DoubleType), " + + "Max(DecimalCol,DecimalType(25,5)), " + + "Max(DateCol,DateType)]" checkKeywordsExistsInExplain(testMax, expected_plan_fragment) } @@ -1081,19 +1098,19 @@ abstract class ParquetQuerySuite extends QueryTest with ParquetTest with SharedS testCount.queryExecution.optimizedPlan.collect { case _: DataSourceV2ScanRelation => val expected_plan_fragment = - "PushedAggregation: [List(Count(1,LongType,false)), " + - "List(Count(StringCol,LongType,false)), " + - "List(Count(BooleanCol,LongType,false)), " + - "List(Count(ByteCol,LongType,false)), " + - "List(Count(BinaryCol,LongType,false)), " + - "List(Count(ShortCol,LongType,false)), " + - "List(Count(IntegerCol,LongType,false)), " + - "List(Count(LongCol,LongType,false)), " + - "List(Count(FloatCol,LongType,false)), " + - "List(Count(DoubleCol,LongType,false)), " + - "List(Count(DecimalCol,LongType,false)), " + - "List(Count(DateCol,LongType,false)), " + - "List(Count(TimestampCol,LongType,false))]" + "PushedAggregation: [Count(`1`,LongType,false), " + + "Count(StringCol,LongType,false), " + + "Count(BooleanCol,LongType,false), " + + "Count(ByteCol,LongType,false), " + + "Count(BinaryCol,LongType,false), " + + "Count(ShortCol,LongType,false), " + + "Count(IntegerCol,LongType,false), " + + "Count(LongCol,LongType,false), " + + "Count(FloatCol,LongType,false), " + + "Count(DoubleCol,LongType,false), " + + "Count(DecimalCol,LongType,false), " + + "Count(DateCol,LongType,false), " + + "Count(TimestampCol,LongType,false)]" checkKeywordsExistsInExplain(testCount, expected_plan_fragment) } From d0a61f8a254a25a16d41a79732c6297f2607d77d Mon Sep 17 00:00:00 2001 From: Huaxin Gao Date: Fri, 28 May 2021 08:28:40 -0700 Subject: [PATCH 19/30] add default case --- .../sql/execution/datasources/v2/V2ScanRelationPushDown.scala | 1 + 1 file changed, 1 insertion(+) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/V2ScanRelationPushDown.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/V2ScanRelationPushDown.scala index ba12c390f70a2..40b6be7c4201d 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/V2ScanRelationPushDown.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/V2ScanRelationPushDown.scala @@ -200,6 +200,7 @@ object V2ScanRelationPushDown extends Rule[LogicalPlan] with AliasHelper val pushedFilters = sHolder.builder match { case f: SupportsPushDownFilters => f.pushedFilters() + case _ => Array.empty[sources.Filter] } V1ScanWrapper(v1, translated, pushedFilters, Aggregation.empty) case _ => scan From 91f013c2fa03f2e8904ad3a98f61b8e890722c0a Mon Sep 17 00:00:00 2001 From: Huaxin Gao Date: Wed, 9 Jun 2021 12:09:10 -0700 Subject: [PATCH 20/30] address comments --- .../read/SupportsPushDownAggregates.java | 8 +- .../datasources/v2/PushDownUtils.scala | 8 +- .../v2/V2ScanRelationPushDown.scala | 158 ++++++++---------- .../v2/parquet/ParquetScanBuilder.scala | 13 +- .../parquet/ParquetQuerySuite.scala | 24 ++- 5 files changed, 95 insertions(+), 116 deletions(-) diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/read/SupportsPushDownAggregates.java b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/read/SupportsPushDownAggregates.java index 1923826b9adaa..f0f090afa6123 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/read/SupportsPushDownAggregates.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/read/SupportsPushDownAggregates.java @@ -35,13 +35,7 @@ public interface SupportsPushDownAggregates extends ScanBuilder { * The Aggregation can be pushed down only if all the Aggregate Functions can * be pushed down. */ - void pushAggregation(Aggregation aggregation); - - /** - * Returns the aggregation that are pushed to the data source via - * {@link #pushAggregation(Aggregation aggregation)}. - */ - Aggregation pushedAggregation(); + boolean pushAggregation(Aggregation aggregation); /** * Returns the schema of the pushed down aggregates diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/PushDownUtils.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/PushDownUtils.scala index 4ef6a984c0d01..2bad04eb0d240 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/PushDownUtils.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/PushDownUtils.scala @@ -93,11 +93,11 @@ object PushDownUtils extends PredicateHelper { .translateAggregate(_, PushableColumn(false))) val translatedGroupBys = groupBy.map(columnAsString) - if (translatedAggregates.exists(_.isEmpty) || translatedGroupBys.exists(_.isEmpty)) { - Aggregation.empty + val agg = Aggregation(translatedAggregates.flatten, translatedGroupBys.flatten) + if (r.pushAggregation(agg)) { + agg } else { - r.pushAggregation(Aggregation(translatedAggregates.flatten, translatedGroupBys.flatten)) - r.pushedAggregation + Aggregation.empty } case _ => Aggregation.empty } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/V2ScanRelationPushDown.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/V2ScanRelationPushDown.scala index 40b6be7c4201d..0ce7d817153bf 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/V2ScanRelationPushDown.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/V2ScanRelationPushDown.scala @@ -58,129 +58,108 @@ object V2ScanRelationPushDown extends Rule[LogicalPlan] with AliasHelper val (pushedFilters, postScanFiltersWithoutSubquery) = PushDownUtils.pushFilters( sHolder.builder, normalizedFiltersWithoutSubquery) val postScanFilters = postScanFiltersWithoutSubquery ++ normalizedFiltersWithSubquery - val output = sHolder.output logInfo( s""" - |Pushing operators to ${sHolder.relation.asInstanceOf[DataSourceV2Relation].name} + |Pushing operators to ${sHolder.relation.name} |Pushed Filters: ${pushedFilters.mkString(", ")} |Post-Scan Filters: ${postScanFilters.mkString(",")} """.stripMargin) - val projectionOverSchema = ProjectionOverSchema(output.toStructType) - val projectionFunc = (expr: Expression) => expr transformDown { - case projectionOverSchema(newExpr) => newExpr - } - val filterCondition = postScanFilters.reduceLeftOption(And) - val newFilterCondition = filterCondition.map(projectionFunc) - newFilterCondition.map(Filter(_, sHolder)).getOrElse(sHolder) + filterCondition.map(Filter(_, sHolder)).getOrElse(sHolder) } def pushdownAggregate(plan: LogicalPlan): LogicalPlan = plan.transform { // update the scan builder with agg pushdown and return a new plan with agg pushed case aggNode @ Aggregate(groupingExpressions, resultExpressions, child) => child match { - case ScanOperation(project, filters, sHolder: ScanBuilderHolder) => + case ScanOperation(project, _, sHolder: ScanBuilderHolder) => sHolder.builder match { case r: SupportsPushDownAggregates => - if (filters.isEmpty || r.supportsPushDownAggregateWithFilter()) { + if (sHolder.builder.asInstanceOf[SupportsPushDownFilters].pushedFilters().length <= 0 + || r.supportsPushDownAggregateWithFilter()) { if (r.supportsGlobalAggregatePushDownOnly() && groupingExpressions.nonEmpty) { aggNode // return original plan node } else { val aggregates = getAggregateExpression(resultExpressions, project, sHolder) val pushedAggregates = PushDownUtils .pushAggregates(sHolder.builder, aggregates, groupingExpressions) - logInfo( - s""" - |Pushed Aggregates: ${pushedAggregates.aggregateExpressions.mkString(", ")} - """.stripMargin) - aggNode - } - } else { - aggNode - } - case _ => aggNode - } - case _ => aggNode - } - } - - def applyColumnPruning(plan: LogicalPlan): LogicalPlan = plan.transform { - case aggNode @ Aggregate(groupingExpressions, resultExpressions, child) => - child match { - case ScanOperation(project, _, sHolder: ScanBuilderHolder) => - sHolder.builder match { - case _: SupportsPushDownAggregates => - if (sHolder.builder.asInstanceOf[SupportsPushDownAggregates] - .pushedAggregation().aggregateExpressions.nonEmpty) { - val aggregates = getAggregateExpression(resultExpressions, project, sHolder) - // use the aggregate columns as the output columns - // e.g. TABLE t (c1 INT, c2 INT, c3 INT) - // SELECT min(c1), max(c1) FROM t; - // Use min(c1), max(c1) as output for DataSourceV2ScanRelation - // We want to have the following logical plan: - // == Optimized Logical Plan == - // Aggregate [min(min(c1)#21) AS min(c1)#17, max(max(c1)#22) AS max(c1)#18] - // +- RelationV2[min(c1)#21, max(c1)#22] parquet file ... - val output = aggregates.map { - case agg: AggregateExpression => - AttributeReference(toPrettySQL(agg), agg.dataType)() - } + if (pushedAggregates.aggregateExpressions.isEmpty) { + aggNode // return original plan node + } else { + // use the aggregate columns as the output columns + // e.g. TABLE t (c1 INT, c2 INT, c3 INT) + // SELECT min(c1), max(c1) FROM t; + // Use min(c1), max(c1) as output for DataSourceV2ScanRelation + // We want to have the following logical plan: + // == Optimized Logical Plan == + // Aggregate [min(min(c1)#21) AS min(c1)#17, max(max(c1)#22) AS max(c1)#18] + // +- RelationV2[min(c1)#21, max(c1)#22] parquet file ... + val output = aggregates.map { + case agg: AggregateExpression => + AttributeReference(toPrettySQL(agg), agg.dataType)() + } - // No need to do column pruning because only the aggregate columns are used as - // DataSourceV2ScanRelation output columns. All the other columns are not - // included in the output. Since PushDownUtils.pruneColumns is not called, - // ScanBuilder.requiredSchema is not pruned, but ScanBuilder.requiredSchema is - // not used anyways. The schema for aggregate columns will be built in Scan. - val scan = sHolder.builder.build() + // No need to do column pruning because only the aggregate columns are used as + // DataSourceV2ScanRelation output columns. All the other columns are not + // included in the output. Since PushDownUtils.pruneColumns is not called, + // ScanBuilder.requiredSchema is not pruned, but ScanBuilder.requiredSchema is + // not used anyways. The schema for aggregate columns will be built in Scan. + val scan = sHolder.builder.build() - logInfo( - s""" - |Output: ${output.mkString(", ")} - """.stripMargin) + logInfo( + s""" + |Pushing operators to ${sHolder.relation.name} + |Pushed Aggregate Functions: ${pushedAggregates.aggregateExpressions.mkString(", ")} + |Output: ${output.mkString(", ")} + """.stripMargin) - val scanRelation = DataSourceV2ScanRelation(sHolder.relation, scan, output) - val plan = Aggregate(groupingExpressions, resultExpressions, scanRelation) + val scanRelation = DataSourceV2ScanRelation(sHolder.relation, scan, output) + val plan = Aggregate(groupingExpressions, resultExpressions, scanRelation) - // Change the optimized logical plan to reflect the pushed down aggregate - // e.g. TABLE t (c1 INT, c2 INT, c3 INT) - // SELECT min(c1), max(c1) FROM t; - // The original logical plan is - // Aggregate [min(c1#9) AS min(c1)#17, max(c1#9) AS max(c1)#18] - // +- RelationV2[c1#9] parquet ... - // - // After change the V2ScanRelation output to [min(_1)#21, max(_1)#22] - // we have the following - // !Aggregate [min(_1#9) AS min(_1)#17, max(_1#9) AS max(_1)#18] - // +- RelationV2[min(_1)#21, max(_1)#22] parquet ... - // - // We want to change it to - // == Optimized Logical Plan == - // Aggregate [min(min(c1)#21) AS min(c1)#17, max(max(c1)#22) AS max(c1)#18] - // +- RelationV2[min(c1)#21, max(c1)#22] parquet file ... - var i = 0 - plan.transformExpressions { - case agg: AggregateExpression => - i += 1 - val aggFunction: aggregate.AggregateFunction = - agg.aggregateFunction match { - case _: aggregate.Max => aggregate.Max(output(i - 1)) - case _: aggregate.Min => aggregate.Min(output(i - 1)) - case _: aggregate.Sum => aggregate.Sum(output(i - 1)) - case _: aggregate.Count => aggregate.Sum(output(i - 1)) - case _ => agg.aggregateFunction - } - agg.copy(aggregateFunction = aggFunction, filter = None) + // Change the optimized logical plan to reflect the pushed down aggregate + // e.g. TABLE t (c1 INT, c2 INT, c3 INT) + // SELECT min(c1), max(c1) FROM t; + // The original logical plan is + // Aggregate [min(c1#9) AS min(c1)#17, max(c1#9) AS max(c1)#18] + // +- RelationV2[c1#9] parquet ... + // + // After change the V2ScanRelation output to [min(_1)#21, max(_1)#22] + // we have the following + // !Aggregate [min(_1#9) AS min(_1)#17, max(_1#9) AS max(_1)#18] + // +- RelationV2[min(_1)#21, max(_1)#22] parquet ... + // + // We want to change it to + // == Optimized Logical Plan == + // Aggregate [min(min(c1)#21) AS min(c1)#17, max(max(c1)#22) AS max(c1)#18] + // +- RelationV2[min(c1)#21, max(c1)#22] parquet file ... + var i = 0 + plan.transformExpressions { + case agg: AggregateExpression => + i += 1 + val aggFunction: aggregate.AggregateFunction = + agg.aggregateFunction match { + case _: aggregate.Max => aggregate.Max(output(i - 1)) + case _: aggregate.Min => aggregate.Min(output(i - 1)) + case _: aggregate.Sum => aggregate.Sum(output(i - 1)) + case _: aggregate.Count => aggregate.Sum(output(i - 1)) + case _ => agg.aggregateFunction + } + agg.copy(aggregateFunction = aggFunction, filter = None) + } + } } } else { aggNode } case _ => aggNode } - case _ => aggNode } + } + + def applyColumnPruning(plan: LogicalPlan): LogicalPlan = plan.transform { case ScanOperation(project, filters, sHolder: ScanBuilderHolder) => // column pruning val normalizedProjects = DataSourceStrategy @@ -225,7 +204,6 @@ object V2ScanRelationPushDown extends Rule[LogicalPlan] with AliasHelper } else { withFilter } - withProjection } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/parquet/ParquetScanBuilder.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/parquet/ParquetScanBuilder.scala index 7f12d5dbd73bf..b965693962b52 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/parquet/ParquetScanBuilder.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/parquet/ParquetScanBuilder.scala @@ -74,10 +74,10 @@ case class ParquetScanBuilder( private var pushedAggregations = Aggregation.empty - override def pushAggregation(aggregation: Aggregation): Unit = { + override def pushAggregation(aggregation: Aggregation): Boolean = { if (!sparkSession.sessionState.conf.parquetAggregatePushDown || aggregation.groupByColumns.nonEmpty) { - return + return false } aggregation.aggregateExpressions.foreach { @@ -85,25 +85,24 @@ case class ParquetScanBuilder( dataSchema.fields(dataSchema.fieldNames.toList.indexOf(col.fieldNames.head)) .dataType match { // not push down nested column - case StructType(_) | ArrayType(_, _) | MapType(_, _, _) => return + case StructType(_) | ArrayType(_, _) | MapType(_, _, _) => return false case _ => } case Min(col, _) => dataSchema.fields(dataSchema.fieldNames.toList.indexOf(col.fieldNames.head)) .dataType match { // not push down nested column - case StructType(_) | ArrayType(_, _) | MapType(_, _, _) => return + case StructType(_) | ArrayType(_, _) | MapType(_, _, _) => return false case _ => } // not push down distinct count case Count(_, _, false) => - case _ => return + case _ => return false } this.pushedAggregations = aggregation + true } - override def pushedAggregation(): Aggregation = pushedAggregations - override def supportsGlobalAggregatePushDownOnly(): Boolean = true override def supportsPushDownAggregateWithFilter(): Boolean = false diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetQuerySuite.scala index 52a3c26c698da..af1a47abc4bd7 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetQuerySuite.scala @@ -928,18 +928,26 @@ abstract class ParquetQuerySuite extends QueryTest with ParquetTest with SharedS withSQLConf( SQLConf.PARQUET_AGGREGATE_PUSHDOWN_ENABLED.key -> "true") { + val selectAgg1 = sql("SELECT min(_3) FROM t WHERE _1 > 0") + selectAgg1.queryExecution.optimizedPlan.collect { + case _: DataSourceV2ScanRelation => + val expected_plan_fragment = + "PushedFilters: [IsNotNull(_1), GreaterThan(_1,0)], PushedAggregation: []" + checkKeywordsExistsInExplain(selectAgg1, expected_plan_fragment) + } + // This is not pushed down since aggregates have arithmetic operation - val selectAgg1 = sql("SELECT min(_3 + _1), max(_3 + _1) FROM t") - checkAnswer(selectAgg1, Seq(Row(0, 19))) + val selectAgg2 = sql("SELECT min(_3 + _1), max(_3 + _1) FROM t") + checkAnswer(selectAgg2, Seq(Row(0, 19))) // sum is not pushed down - val selectAgg2 = sql("SELECT sum(_3) FROM t") - checkAnswer(selectAgg2, Seq(Row(40))) + val selectAgg3 = sql("SELECT sum(_3) FROM t") + checkAnswer(selectAgg3, Seq(Row(40))) - val selectAgg3 = sql("SELECT min(_3), min(_3), max(_3), min(_1), max(_1), max(_1)," + + val selectAgg4 = sql("SELECT min(_3), min(_3), max(_3), min(_1), max(_1), max(_1)," + " count(*), count(_1), count(_2), count(_3) FROM t") - selectAgg3.queryExecution.optimizedPlan.collect { + selectAgg4.queryExecution.optimizedPlan.collect { case _: DataSourceV2ScanRelation => val expected_plan_fragment = "PushedAggregation: [Min(_3,IntegerType), " + @@ -952,10 +960,10 @@ abstract class ParquetQuerySuite extends QueryTest with ParquetTest with SharedS "Count(_1,LongType,false), " + "Count(_2,LongType,false), " + "Count(_3,LongType,false)]" - checkKeywordsExistsInExplain(selectAgg3, expected_plan_fragment) + checkKeywordsExistsInExplain(selectAgg4, expected_plan_fragment) } - checkAnswer(selectAgg3, Seq(Row(2, 2, 19, -2, 9, 9, 6, 6, 4, 6))) + checkAnswer(selectAgg4, Seq(Row(2, 2, 19, -2, 9, 9, 6, 6, 4, 6))) } } spark.sessionState.catalog.dropTable( From 4fbe6663889ed78cb5c211c2f3944b547af4fc02 Mon Sep 17 00:00:00 2001 From: Huaxin Gao Date: Wed, 9 Jun 2021 13:22:31 -0700 Subject: [PATCH 21/30] fix lint error --- .../sql/execution/datasources/v2/V2ScanRelationPushDown.scala | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/V2ScanRelationPushDown.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/V2ScanRelationPushDown.scala index 0ce7d817153bf..cad09cd1ac8b5 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/V2ScanRelationPushDown.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/V2ScanRelationPushDown.scala @@ -111,7 +111,8 @@ object V2ScanRelationPushDown extends Rule[LogicalPlan] with AliasHelper logInfo( s""" |Pushing operators to ${sHolder.relation.name} - |Pushed Aggregate Functions: ${pushedAggregates.aggregateExpressions.mkString(", ")} + |Pushed Aggregate Functions: + | ${pushedAggregates.aggregateExpressions.mkString(", ")} |Output: ${output.mkString(", ")} """.stripMargin) From a5833ef7f551980ef48229932d9427a9e00af444 Mon Sep 17 00:00:00 2001 From: Huaxin Gao Date: Tue, 15 Jun 2021 19:42:26 -0700 Subject: [PATCH 22/30] not push down agg if it has timestamp (INT96 is sort order is undefined) --- .../datasources/parquet/ParquetUtils.scala | 52 ++---------------- .../ParquetPartitionReaderFactory.scala | 5 +- .../v2/parquet/ParquetScanBuilder.scala | 9 ++-- .../parquet/ParquetQuerySuite.scala | 54 +++++++++++++++---- 4 files changed, 55 insertions(+), 65 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetUtils.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetUtils.scala index 0edf38eb39675..f4ceccbda9ce1 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetUtils.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetUtils.scala @@ -17,7 +17,6 @@ package org.apache.spark.sql.execution.datasources.parquet import java.math.{BigDecimal, BigInteger} -import java.time.{ZoneId, ZoneOffset} import java.util import scala.collection.mutable.ArrayBuilder @@ -33,11 +32,10 @@ import org.apache.spark.SparkException import org.apache.spark.sql.SparkSession import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.SpecificInternalRow -import org.apache.spark.sql.catalyst.util.DateTimeUtils import org.apache.spark.sql.connector.expressions.{Aggregation, Count, Max, Min} import org.apache.spark.sql.execution.datasources.DataSourceUtils import org.apache.spark.sql.execution.vectorized.{OffHeapColumnVector, OnHeapColumnVector} -import org.apache.spark.sql.types.{BinaryType, ByteType, DateType, Decimal, DecimalType, IntegerType, LongType, ShortType, StringType, StructType, TimestampType} +import org.apache.spark.sql.types.{BinaryType, ByteType, DateType, Decimal, DecimalType, IntegerType, LongType, ShortType, StringType, StructType} import org.apache.spark.sql.vectorized.{ColumnarBatch, ColumnVector} import org.apache.spark.unsafe.types.UTF8String @@ -160,9 +158,7 @@ object ParquetUtils { dataSchema: StructType, aggregation: Aggregation, aggSchema: StructType, - datetimeRebaseModeInRead: String, - int96RebaseModeInRead: String, - convertTz: Option[ZoneId]): InternalRow = { + datetimeRebaseModeInRead: String): InternalRow = { val (parquetTypes, values) = ParquetUtils.getPushedDownAggResult(footer, dataSchema, aggregation) val mutableRow = new SpecificInternalRow(aggSchema.fields.map(x => x.dataType)) @@ -170,9 +166,6 @@ object ParquetUtils { val datetimeRebaseMode = DataSourceUtils.datetimeRebaseMode( footerFileMetaData.getKeyValueMetaData.get, datetimeRebaseModeInRead) - val int96RebaseMode = DataSourceUtils.int96RebaseMode( - footerFileMetaData.getKeyValueMetaData.get, - int96RebaseModeInRead) parquetTypes.zipWithIndex.foreach { case (PrimitiveType.PrimitiveTypeName.INT32, i) => aggSchema.fields(i).dataType match { @@ -200,22 +193,6 @@ object ParquetUtils { mutableRow.setDecimal(i, decimal, d.precision) case _ => throw new SparkException("Unexpected type for INT64") } - case (PrimitiveType.PrimitiveTypeName.INT96, i) => - aggSchema.fields(i).dataType match { - case LongType => - mutableRow.setLong(i, values(i).asInstanceOf[Long]) - case TimestampType => - val int96RebaseFunc = DataSourceUtils.creteTimestampRebaseFuncInRead( - int96RebaseMode, "Parquet INT96") - val julianMicros = - ParquetRowConverter.binaryToSQLTimestamp(values(i).asInstanceOf[Binary]) - val gregorianMicros = int96RebaseFunc(julianMicros) - val adjTime = - convertTz.map(DateTimeUtils.convertTz(gregorianMicros, _, ZoneOffset.UTC)) - .getOrElse(gregorianMicros) - mutableRow.setLong(i, adjTime) - case _ => throw new SparkException("Unexpected type for INT96") - } case (PrimitiveType.PrimitiveTypeName.FLOAT, i) => mutableRow.setFloat(i, values(i).asInstanceOf[Float]) case (PrimitiveType.PrimitiveTypeName.DOUBLE, i) => @@ -265,9 +242,7 @@ object ParquetUtils { aggregation: Aggregation, aggSchema: StructType, offHeap: Boolean, - datetimeRebaseModeInRead: String, - int96RebaseModeInRead: String, - convertTz: Option[ZoneId]): ColumnarBatch = { + datetimeRebaseModeInRead: String): ColumnarBatch = { val (parquetTypes, values) = ParquetUtils.getPushedDownAggResult(footer, dataSchema, aggregation) val capacity = 4 * 1024 @@ -275,9 +250,6 @@ object ParquetUtils { val datetimeRebaseMode = DataSourceUtils.datetimeRebaseMode( footerFileMetaData.getKeyValueMetaData.get, datetimeRebaseModeInRead) - val int96RebaseMode = DataSourceUtils.int96RebaseMode( - footerFileMetaData.getKeyValueMetaData.get, - int96RebaseModeInRead) val columnVectors = if (offHeap) { OffHeapColumnVector.allocateColumns(capacity, aggSchema) } else { @@ -301,22 +273,6 @@ object ParquetUtils { } case (PrimitiveType.PrimitiveTypeName.INT64, i) => columnVectors(i).appendLong(values(i).asInstanceOf[Long]) - case (PrimitiveType.PrimitiveTypeName.INT96, i) => - aggSchema.fields(i).dataType match { - case LongType => - columnVectors(i).appendLong(values(i).asInstanceOf[Long]) - case TimestampType => - val int96RebaseFunc = DataSourceUtils.creteTimestampRebaseFuncInRead( - int96RebaseMode, "Parquet INT96") - val julianMicros = - ParquetRowConverter.binaryToSQLTimestamp(values(i).asInstanceOf[Binary]) - val gregorianMicros = int96RebaseFunc(julianMicros) - val adjTime = - convertTz.map(DateTimeUtils.convertTz(gregorianMicros, _, ZoneOffset.UTC)) - .getOrElse(gregorianMicros) - columnVectors(i).appendLong(adjTime) - case _ => throw new SparkException("Unexpected type for INT96") - } case (PrimitiveType.PrimitiveTypeName.FLOAT, i) => columnVectors(i).appendFloat(values(i).asInstanceOf[Float]) case (PrimitiveType.PrimitiveTypeName.DOUBLE, i) => @@ -391,7 +347,7 @@ object ParquetUtils { } if (isCount) { valuesBuilder += rowCount - typesBuilder += PrimitiveType.PrimitiveTypeName.INT96 + typesBuilder += PrimitiveType.PrimitiveTypeName.INT64 } else { valuesBuilder += value typesBuilder += fields.get(index).asPrimitiveType.getPrimitiveTypeName diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/parquet/ParquetPartitionReaderFactory.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/parquet/ParquetPartitionReaderFactory.scala index e296027bac5a5..96b98e9c12789 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/parquet/ParquetPartitionReaderFactory.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/parquet/ParquetPartitionReaderFactory.scala @@ -147,7 +147,7 @@ case class ParquetPartitionReaderFactory( count += 1 val footer = getFooter(file) ParquetUtils.createInternalRowFromAggResult(footer, dataSchema, aggregation, aggSchema, - datetimeRebaseModeInRead, int96RebaseModeInRead, convertTz(isCreatedByParquetMr(file))) + datetimeRebaseModeInRead) } override def close(): Unit = return @@ -181,8 +181,7 @@ case class ParquetPartitionReaderFactory( count += 1 val footer = getFooter(file) ParquetUtils.createColumnarBatchFromAggResult(footer, dataSchema, aggregation, aggSchema, - enableOffHeapColumnVector, datetimeRebaseModeInRead, int96RebaseModeInRead, - convertTz(isCreatedByParquetMr(file))) + enableOffHeapColumnVector, datetimeRebaseModeInRead) } override def close(): Unit = return diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/parquet/ParquetScanBuilder.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/parquet/ParquetScanBuilder.scala index b965693962b52..562dee8a5ad34 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/parquet/ParquetScanBuilder.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/parquet/ParquetScanBuilder.scala @@ -26,7 +26,7 @@ import org.apache.spark.sql.execution.datasources.PartitioningAwareFileIndex import org.apache.spark.sql.execution.datasources.parquet.{ParquetFilters, SparkToParquetSchemaConverter} import org.apache.spark.sql.execution.datasources.v2.FileScanBuilder import org.apache.spark.sql.sources.Filter -import org.apache.spark.sql.types.{ArrayType, LongType, MapType, StructField, StructType} +import org.apache.spark.sql.types.{ArrayType, LongType, MapType, StructField, StructType, TimestampType} import org.apache.spark.sql.util.CaseInsensitiveStringMap case class ParquetScanBuilder( @@ -84,15 +84,16 @@ case class ParquetScanBuilder( case Max(col, _) => dataSchema.fields(dataSchema.fieldNames.toList.indexOf(col.fieldNames.head)) .dataType match { - // not push down nested column - case StructType(_) | ArrayType(_, _) | MapType(_, _, _) => return false + // not push down nested column and Timestamp (INT96 sort order is undefined, parquet + // doesn't return statistics for INT96) + case StructType(_) | ArrayType(_, _) | MapType(_, _, _) | TimestampType => return false case _ => } case Min(col, _) => dataSchema.fields(dataSchema.fieldNames.toList.indexOf(col.fieldNames.head)) .dataType match { // not push down nested column - case StructType(_) | ArrayType(_, _) | MapType(_, _, _) => return false + case StructType(_) | ArrayType(_, _) | MapType(_, _, _) | TimestampType => return false case _ => } // not push down distinct count diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetQuerySuite.scala index af1a47abc4bd7..70de0f2499cf7 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetQuerySuite.scala @@ -1048,11 +1048,28 @@ abstract class ParquetQuerySuite extends QueryTest with ParquetTest with SharedS val df2 = spark.read.parquet(file.getCanonicalPath) df2.createOrReplaceTempView("test") - val testMin = sql("SELECT min(StringCol), min(BooleanCol), min(ByteCol), " + + val testMinWithTS = sql("SELECT min(StringCol), min(BooleanCol), min(ByteCol), " + + "min(BinaryCol), min(ShortCol), min(IntegerCol), min(LongCol), min(FloatCol), " + + "min(DoubleCol), min(DecimalCol), min(DateCol), min(TimestampCol) FROM test") + + // INT96 (Timestamp) sort order is undefined, parquet doesn't return stats for this type + // so aggregates are not pushed down + testMinWithTS.queryExecution.optimizedPlan.collect { + case _: DataSourceV2ScanRelation => + val expected_plan_fragment = + "PushedAggregation: []" + checkKeywordsExistsInExplain(testMinWithTS, expected_plan_fragment) + } + + checkAnswer(testMinWithTS, Seq(Row("a string", false, 1.toByte, "Parquet".getBytes, + 2.toShort, 3, -9223372036854775808L, 0.15.toFloat, 0.75D, 1.23457, + ("2004-06-19").date, ("1999-08-26 10:43:59.123").ts))) + + val testMinWithOutTS = sql("SELECT min(StringCol), min(BooleanCol), min(ByteCol), " + "min(BinaryCol), min(ShortCol), min(IntegerCol), min(LongCol), min(FloatCol), " + "min(DoubleCol), min(DecimalCol), min(DateCol) FROM test") - testMin.queryExecution.optimizedPlan.collect { + testMinWithOutTS.queryExecution.optimizedPlan.collect { case _: DataSourceV2ScanRelation => val expected_plan_fragment = "PushedAggregation: [Min(StringCol,StringType), " + @@ -1066,18 +1083,35 @@ abstract class ParquetQuerySuite extends QueryTest with ParquetTest with SharedS "Min(DoubleCol,DoubleType), " + "Min(DecimalCol,DecimalType(25,5)), " + "Min(DateCol,DateType)]" - checkKeywordsExistsInExplain(testMin, expected_plan_fragment) + checkKeywordsExistsInExplain(testMinWithOutTS, expected_plan_fragment) } - checkAnswer(testMin, Seq(Row("a string", false, 1.toByte, "Parquet".getBytes, + checkAnswer(testMinWithOutTS, Seq(Row("a string", false, 1.toByte, "Parquet".getBytes, 2.toShort, 3, -9223372036854775808L, 0.15.toFloat, 0.75D, 1.23457, ("2004-06-19").date))) - val testMax = sql("SELECT max(StringCol), max(BooleanCol), max(ByteCol), " + + val testMaxWithTS = sql("SELECT max(StringCol), max(BooleanCol), max(ByteCol), " + + "max(BinaryCol), max(ShortCol), max(IntegerCol), max(LongCol), max(FloatCol), " + + "max(DoubleCol), max(DecimalCol), max(DateCol), max(TimestampCol) FROM test") + + // INT96 (Timestamp) sort order is undefined, parquet doesn't return stats for this type + // so aggregates are not pushed down + testMaxWithTS.queryExecution.optimizedPlan.collect { + case _: DataSourceV2ScanRelation => + val expected_plan_fragment = + "PushedAggregation: []" + checkKeywordsExistsInExplain(testMaxWithTS, expected_plan_fragment) + } + + checkAnswer(testMaxWithTS, Seq(Row("test string", true, 16.toByte, + "Spark SQL".getBytes, 222.toShort, 113, 9223372036854775807L, 0.25.toFloat, 0.85D, + 12345.678, ("2021-01-01").date, ("2021-01-01 23:50:59.123").ts))) + + val testMaxWithoutTS = sql("SELECT max(StringCol), max(BooleanCol), max(ByteCol), " + "max(BinaryCol), max(ShortCol), max(IntegerCol), max(LongCol), max(FloatCol), " + "max(DoubleCol), max(DecimalCol), max(DateCol) FROM test") - testMax.queryExecution.optimizedPlan.collect { + testMaxWithoutTS.queryExecution.optimizedPlan.collect { case _: DataSourceV2ScanRelation => val expected_plan_fragment = "PushedAggregation: [Max(StringCol,StringType), " + @@ -1091,12 +1125,12 @@ abstract class ParquetQuerySuite extends QueryTest with ParquetTest with SharedS "Max(DoubleCol,DoubleType), " + "Max(DecimalCol,DecimalType(25,5)), " + "Max(DateCol,DateType)]" - checkKeywordsExistsInExplain(testMax, expected_plan_fragment) + checkKeywordsExistsInExplain(testMaxWithoutTS, expected_plan_fragment) } - checkAnswer(testMax, Seq(Row("test string", true, 16.toByte, "Spark SQL".getBytes, - 222.toShort, 113, 9223372036854775807L, 0.25.toFloat, 0.85D, 12345.678, - ("2021-01-01").date))) + checkAnswer(testMaxWithoutTS, Seq(Row("test string", true, 16.toByte, + "Spark SQL".getBytes, 222.toShort, 113, 9223372036854775807L, 0.25.toFloat, 0.85D, + 12345.678, ("2021-01-01").date))) val testCount = sql("SELECT count(*), count(StringCol), count(BooleanCol)," + " count(ByteCol), count(BinaryCol), count(ShortCol), count(IntegerCol)," + From a60c9a1c90a7f2093625e7fa60fd708535e20c46 Mon Sep 17 00:00:00 2001 From: Huaxin Gao Date: Fri, 18 Jun 2021 11:10:57 -0700 Subject: [PATCH 23/30] address comments --- .../read/SupportsPushDownAggregates.java | 5 -- .../connector/expressions/aggregates.scala | 12 +-- .../datasources/DataSourceStrategy.scala | 12 +-- .../datasources/parquet/ParquetUtils.scala | 54 ++++++++----- .../v2/V2ScanRelationPushDown.scala | 5 +- .../v2/parquet/ParquetScanBuilder.scala | 81 ++++++++++++------- .../parquet/ParquetQuerySuite.scala | 6 +- 7 files changed, 98 insertions(+), 77 deletions(-) diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/read/SupportsPushDownAggregates.java b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/read/SupportsPushDownAggregates.java index f0f090afa6123..067e644c8ac9a 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/read/SupportsPushDownAggregates.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/read/SupportsPushDownAggregates.java @@ -46,9 +46,4 @@ public interface SupportsPushDownAggregates extends ScanBuilder { * Indicate if the data source only supports global aggregated push down */ boolean supportsGlobalAggregatePushDownOnly(); - - /** - * Indicate if the data source supports push down aggregates along with filters - */ - boolean supportsPushDownAggregateWithFilter(); } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/connector/expressions/aggregates.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/connector/expressions/aggregates.scala index 0e3729bf3648f..a4f1440df58f9 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/connector/expressions/aggregates.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/connector/expressions/aggregates.scala @@ -23,18 +23,18 @@ import org.apache.spark.sql.types.DataType // e.g. SELECT COUNT(EmployeeID), Max(salary), deptID FROM dept GROUP BY deptID // aggregateExpressions are (COUNT(EmployeeID), Max(salary)), groupByColumns are (deptID) case class Aggregation(aggregateExpressions: Seq[AggregateFunc], - groupByColumns: Seq[FieldReference]) + groupByColumns: Seq[Expression]) abstract class AggregateFunc -case class Min(column: FieldReference, dataType: DataType) extends AggregateFunc -case class Max(column: FieldReference, dataType: DataType) extends AggregateFunc -case class Sum(column: FieldReference, dataType: DataType, isDistinct: Boolean) +case class Min(column: Expression, dataType: DataType) extends AggregateFunc +case class Max(column: Expression, dataType: DataType) extends AggregateFunc +case class Sum(column: Expression, dataType: DataType, isDistinct: Boolean) extends AggregateFunc -case class Count(column: FieldReference, dataType: DataType, isDistinct: Boolean) +case class Count(column: Expression, dataType: DataType, isDistinct: Boolean) extends AggregateFunc object Aggregation { // Returns an empty Aggregate - def empty: Aggregation = Aggregation(Seq.empty[AggregateFunc], Seq.empty[FieldReference]) + def empty: Aggregation = Aggregation(Seq.empty[AggregateFunc], Seq.empty[Expression]) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala index 8e347436fa977..53980b25b6578 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala @@ -40,7 +40,7 @@ import org.apache.spark.sql.catalyst.rules.Rule import org.apache.spark.sql.catalyst.streaming.StreamingRelationV2 import org.apache.spark.sql.connector.catalog.SupportsRead import org.apache.spark.sql.connector.catalog.TableCapability._ -import org.apache.spark.sql.connector.expressions.{AggregateFunc, Count, FieldReference, Max, Min} +import org.apache.spark.sql.connector.expressions.{AggregateFunc, Count, FieldReference, LiteralValue, Max, Min} import org.apache.spark.sql.errors.QueryCompilationErrors import org.apache.spark.sql.execution.{RowDataSourceScanExec, SparkPlan} import org.apache.spark.sql.execution.command._ @@ -684,12 +684,14 @@ object DataSourceStrategy case max @ aggregate.Max(pushableColumn(name)) => Some(Max(FieldReference(Seq(name)), max.dataType)) case count: aggregate.Count => - val columnName = count.children.head match { + count.children.head match { // SELECT COUNT(*) FROM table is translated to SELECT 1 FROM table - case Literal(_, _) => "1" - case pushableColumn(name) => name + case Literal(_, _) => + Some(Count(LiteralValue(1, LongType), LongType, aggregates.isDistinct)) + case pushableColumn(name) => + Some(Count(FieldReference(Seq(name)), LongType, aggregates.isDistinct)) + case _ => None } - Some(Count(FieldReference(Seq(columnName)), count.dataType, aggregates.isDistinct)) case _ => None } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetUtils.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetUtils.scala index f4ceccbda9ce1..426e53a9eee93 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetUtils.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetUtils.scala @@ -32,7 +32,7 @@ import org.apache.spark.SparkException import org.apache.spark.sql.SparkSession import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.SpecificInternalRow -import org.apache.spark.sql.connector.expressions.{Aggregation, Count, Max, Min} +import org.apache.spark.sql.connector.expressions.{Aggregation, Count, FieldReference, LiteralValue, Max, Min} import org.apache.spark.sql.execution.datasources.DataSourceUtils import org.apache.spark.sql.execution.vectorized.{OffHeapColumnVector, OnHeapColumnVector} import org.apache.spark.sql.types.{BinaryType, ByteType, DateType, Decimal, DecimalType, IntegerType, LongType, ShortType, StringType, StructType} @@ -318,30 +318,40 @@ object ParquetUtils { blocks.forEach { block => val blockMetaData = block.getColumns() aggregation.aggregateExpressions(i) match { - case Max(col, _) => - index = dataSchema.fieldNames.toList.indexOf(col.fieldNames.head) - val currentMax = getCurrentBlockMaxOrMin(blockMetaData, index, true) - if (currentMax != None && - (value == None || currentMax.asInstanceOf[Comparable[Any]].compareTo(value) > 0)) { - value = currentMax - } - - case Min(col, _) => - index = dataSchema.fieldNames.toList.indexOf(col.fieldNames.head) - val currentMin = getCurrentBlockMaxOrMin(blockMetaData, index, false) - if (currentMin != None && - (value == None || currentMin.asInstanceOf[Comparable[Any]].compareTo(value) < 0)) { - value = currentMin - } - + case Max(col, _) => col match { + case ref: FieldReference => + index = dataSchema.fieldNames.toList.indexOf(ref.fieldNames.head) + val currentMax = getCurrentBlockMaxOrMin(blockMetaData, index, true) + if (currentMax != None && + (value == None || currentMax.asInstanceOf[Comparable[Any]].compareTo(value) > 0)) { + value = currentMax + } + case _ => + throw new SparkException(s"Expression $col is not currently supported.") + } + case Min(col, _) => col match { + case ref: FieldReference => + index = dataSchema.fieldNames.toList.indexOf(ref.fieldNames.head) + val currentMin = getCurrentBlockMaxOrMin(blockMetaData, index, false) + if (currentMin != None && + (value == None || currentMin.asInstanceOf[Comparable[Any]].compareTo(value) < 0)) { + value = currentMin + } + case _ => + throw new SparkException("Expression $col is not currently supported.") + } case Count(col, _, _) => - index = dataSchema.fieldNames.toList.indexOf(col.fieldNames.head) rowCount += block.getRowCount - if (!col.fieldNames.head.equals("1")) { // "1" is for count(*) - rowCount -= getNumNulls(blockMetaData, index) - } isCount = true - + col match { + case ref: FieldReference => + index = dataSchema.fieldNames.toList.indexOf(ref.fieldNames.head) + // Count(*) includes the null values, but Count (colName) doesn't. + rowCount -= getNumNulls(blockMetaData, index) + case LiteralValue(1, _) => // Count(*) + case _ => + throw new SparkException("Expression $col is not currently supported.") + } case _ => } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/V2ScanRelationPushDown.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/V2ScanRelationPushDown.scala index cad09cd1ac8b5..fa49220f02273 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/V2ScanRelationPushDown.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/V2ScanRelationPushDown.scala @@ -74,11 +74,10 @@ object V2ScanRelationPushDown extends Rule[LogicalPlan] with AliasHelper // update the scan builder with agg pushdown and return a new plan with agg pushed case aggNode @ Aggregate(groupingExpressions, resultExpressions, child) => child match { - case ScanOperation(project, _, sHolder: ScanBuilderHolder) => + case ScanOperation(project, filters, sHolder: ScanBuilderHolder) => sHolder.builder match { case r: SupportsPushDownAggregates => - if (sHolder.builder.asInstanceOf[SupportsPushDownFilters].pushedFilters().length <= 0 - || r.supportsPushDownAggregateWithFilter()) { + if (filters.length == 0) { // can't push down aggregate if postScanFilters exist if (r.supportsGlobalAggregatePushDownOnly() && groupingExpressions.nonEmpty) { aggNode // return original plan node } else { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/parquet/ParquetScanBuilder.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/parquet/ParquetScanBuilder.scala index 562dee8a5ad34..dce1bc869a880 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/parquet/ParquetScanBuilder.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/parquet/ParquetScanBuilder.scala @@ -19,8 +19,9 @@ package org.apache.spark.sql.execution.datasources.v2.parquet import scala.collection.JavaConverters._ +import org.apache.spark.SparkException import org.apache.spark.sql.SparkSession -import org.apache.spark.sql.connector.expressions.{Aggregation, Count, Max, Min} +import org.apache.spark.sql.connector.expressions.{Aggregation, Count, FieldReference, LiteralValue, Max, Min} import org.apache.spark.sql.connector.read.{Scan, SupportsPushDownAggregates, SupportsPushDownFilters} import org.apache.spark.sql.execution.datasources.PartitioningAwareFileIndex import org.apache.spark.sql.execution.datasources.parquet.{ParquetFilters, SparkToParquetSchemaConverter} @@ -76,26 +77,35 @@ case class ParquetScanBuilder( override def pushAggregation(aggregation: Aggregation): Boolean = { if (!sparkSession.sessionState.conf.parquetAggregatePushDown || - aggregation.groupByColumns.nonEmpty) { + aggregation.groupByColumns.nonEmpty || pushedParquetFilters.length > 0) { return false } aggregation.aggregateExpressions.foreach { - case Max(col, _) => - dataSchema.fields(dataSchema.fieldNames.toList.indexOf(col.fieldNames.head)) - .dataType match { - // not push down nested column and Timestamp (INT96 sort order is undefined, parquet - // doesn't return statistics for INT96) - case StructType(_) | ArrayType(_, _) | MapType(_, _, _) | TimestampType => return false - case _ => - } - case Min(col, _) => - dataSchema.fields(dataSchema.fieldNames.toList.indexOf(col.fieldNames.head)) - .dataType match { - // not push down nested column - case StructType(_) | ArrayType(_, _) | MapType(_, _, _) | TimestampType => return false - case _ => - } + case Max(col, _) => col match { + case ref: FieldReference => + dataSchema.fields(dataSchema.fieldNames.toList.indexOf(ref.fieldNames.head)) + .dataType match { + // not push down nested column and Timestamp (INT96 sort order is undefined, parquet + // doesn't return statistics for INT96) + case StructType(_) | ArrayType(_, _) | MapType(_, _, _) | TimestampType => return false + case _ => + } + case _ => + throw new SparkException("Expression $col is not currently supported.") + } + case Min(col, _) => col match { + case ref: FieldReference => + dataSchema.fields(dataSchema.fieldNames.toList.indexOf(ref.fieldNames.head)) + .dataType match { + // not push down nested column and Timestamp (INT96 sort order is undefined, parquet + // doesn't return statistics for INT96) + case StructType(_) | ArrayType(_, _) | MapType(_, _, _) | TimestampType => return false + case _ => + } + case _ => + throw new SparkException("Expression $col is not currently supported.") + } // not push down distinct count case Count(_, _, false) => case _ => return false @@ -106,24 +116,33 @@ case class ParquetScanBuilder( override def supportsGlobalAggregatePushDownOnly(): Boolean = true - override def supportsPushDownAggregateWithFilter(): Boolean = false - override def getPushDownAggSchema: StructType = { var schema = new StructType() pushedAggregations.aggregateExpressions.foreach { - case Max(col, _) => - val field = dataSchema.fields(dataSchema.fieldNames.toList.indexOf(col.fieldNames.head)) - schema = schema.add(field.copy("max(" + field.name + ")")) - case Min(col, _) => - val field = dataSchema.fields(dataSchema.fieldNames.toList.indexOf(col.fieldNames.head)) - schema = schema.add(field.copy("min(" + field.name + ")")) - case Count(col, _, _) => - if (col.fieldNames.head.equals("1")) { - schema = schema.add(new StructField("count(*)", LongType)) - } else { - schema = schema.add(new StructField("count(" + col + ")", LongType)) - } + case Max(col, _) => col match { + case ref: FieldReference => + val field = dataSchema.fields(dataSchema.fieldNames.toList.indexOf(ref.fieldNames.head)) + schema = schema.add(field.copy("max(" + field.name + ")")) + case _ => + throw new SparkException("Expression $col is not currently supported.") + } + case Min(col, _) => col match { + case ref: FieldReference => + val field = dataSchema.fields(dataSchema.fieldNames.toList.indexOf(ref.fieldNames.head)) + schema = schema.add(field.copy("min(" + field.name + ")")) + case _ => + throw new SparkException("Expression $col is not currently supported.") + } + case Count(col, _, _) => col match { + case _: FieldReference => + schema = schema.add(StructField("count(" + col + ")", LongType)) + case LiteralValue(1, _) => + schema = schema.add(StructField("count(*)", LongType)) + case _ => + throw new SparkException("Expression $col is not currently supported.") + } case _ => + throw new SparkException("Pushed down aggregate is not supported.") } schema } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetQuerySuite.scala index 70de0f2499cf7..ca38443b6bf73 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetQuerySuite.scala @@ -956,7 +956,7 @@ abstract class ParquetQuerySuite extends QueryTest with ParquetTest with SharedS "Min(_1,IntegerType), " + "Max(_1,IntegerType), " + "Max(_1,IntegerType), " + - "Count(`1`,LongType,false), " + + "Count(1,LongType,false), " + "Count(_1,LongType,false), " + "Count(_2,LongType,false), " + "Count(_3,LongType,false)]" @@ -1043,10 +1043,6 @@ abstract class ParquetQuerySuite extends QueryTest with ParquetTest with SharedS SQLConf.PARQUET_AGGREGATE_PUSHDOWN_ENABLED.key -> "true", vectorizedReaderEnabledKey -> testVectorizedReader) { withTempPath { file => - val df1 = spark.createDataFrame(rdd, schema) - df1.write.parquet(file.getCanonicalPath) - val df2 = spark.read.parquet(file.getCanonicalPath) - df2.createOrReplaceTempView("test") val testMinWithTS = sql("SELECT min(StringCol), min(BooleanCol), min(ByteCol), " + "min(BinaryCol), min(ShortCol), min(IntegerCol), min(LongCol), min(FloatCol), " + From 475c0f77f849f5c4b783affb6178b6c8367a67f8 Mon Sep 17 00:00:00 2001 From: Huaxin Gao Date: Wed, 23 Jun 2021 00:08:44 -0500 Subject: [PATCH 24/30] change SupportsPushDownAggregates interface --- .../read/SupportsPushDownAggregates.java | 41 ++++-- .../datasources/v2/PushDownUtils.scala | 2 +- .../v2/V2ScanRelationPushDown.scala | 127 +++++++++--------- .../v2/parquet/ParquetScanBuilder.scala | 21 +-- 4 files changed, 103 insertions(+), 88 deletions(-) diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/read/SupportsPushDownAggregates.java b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/read/SupportsPushDownAggregates.java index 067e644c8ac9a..40ab1b0347463 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/read/SupportsPushDownAggregates.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/read/SupportsPushDownAggregates.java @@ -23,7 +23,14 @@ /** * A mix-in interface for {@link ScanBuilder}. Data source can implement this interface to - * push down aggregates to the data source. + * push down aggregates. Depends on the data source implementation, the aggregates may not + * be able to push down, partially push down and have final aggregate at Spark, or completely + * push down. + * + * When pushing down operators, Spark pushes down filter to the data source first, then push down + * aggregates or apply column pruning. Depends on data source implementation, aggregates may or + * may not be able to be pushed down with filters. If pushed filters still need to be evaluated + * after scanning, aggregates can't be pushed down. * * @since 3.2.0 */ @@ -32,18 +39,28 @@ public interface SupportsPushDownAggregates extends ScanBuilder { /** * Pushes down Aggregation to datasource. - * The Aggregation can be pushed down only if all the Aggregate Functions can - * be pushed down. */ - boolean pushAggregation(Aggregation aggregation); + AggregatePushDownResult pushAggregation(Aggregation aggregation); - /** - * Returns the schema of the pushed down aggregates - */ - StructType getPushDownAggSchema(); + class AggregatePushDownResult { - /** - * Indicate if the data source only supports global aggregated push down - */ - boolean supportsGlobalAggregatePushDownOnly(); + // 0: aggregates not pushed down + // 1: aggregates partially pushed down, need to final aggregate in Spark + // 2: aggregates completely pushed down, doesn't need to final aggregate in Spark + int pushedDownResult = 0; + StructType pushedDownAggSchema; + + public AggregatePushDownResult(int pushedDownResult, StructType pushedDownAggSchema) { + this.pushedDownResult = pushedDownResult; + this.pushedDownAggSchema = pushedDownAggSchema; + } + + public int getPushedDownResult() { + return pushedDownResult; + } + + public StructType getPushedDownAggSchema() { + return pushedDownAggSchema; + } + } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/PushDownUtils.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/PushDownUtils.scala index 2bad04eb0d240..87d52ae8e557f 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/PushDownUtils.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/PushDownUtils.scala @@ -94,7 +94,7 @@ object PushDownUtils extends PredicateHelper { val translatedGroupBys = groupBy.map(columnAsString) val agg = Aggregation(translatedAggregates.flatten, translatedGroupBys.flatten) - if (r.pushAggregation(agg)) { + if (r.pushAggregation(agg).getPushedDownResult > 0) { agg } else { Aggregation.empty diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/V2ScanRelationPushDown.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/V2ScanRelationPushDown.scala index fa49220f02273..a022bfd2292f5 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/V2ScanRelationPushDown.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/V2ScanRelationPushDown.scala @@ -72,84 +72,81 @@ object V2ScanRelationPushDown extends Rule[LogicalPlan] with AliasHelper def pushdownAggregate(plan: LogicalPlan): LogicalPlan = plan.transform { // update the scan builder with agg pushdown and return a new plan with agg pushed - case aggNode @ Aggregate(groupingExpressions, resultExpressions, child) => + case aggNode@Aggregate(groupingExpressions, resultExpressions, child) => child match { case ScanOperation(project, filters, sHolder: ScanBuilderHolder) => sHolder.builder match { case r: SupportsPushDownAggregates => - if (filters.length == 0) { // can't push down aggregate if postScanFilters exist - if (r.supportsGlobalAggregatePushDownOnly() && groupingExpressions.nonEmpty) { + if (filters.length == 0) { // can't push down aggregate if postScanFilters exist + val aggregates = getAggregateExpression(resultExpressions, project, sHolder) + val pushedAggregates = PushDownUtils + .pushAggregates(sHolder.builder, aggregates, groupingExpressions) + if (pushedAggregates.aggregateExpressions.isEmpty) { aggNode // return original plan node } else { - val aggregates = getAggregateExpression(resultExpressions, project, sHolder) - val pushedAggregates = PushDownUtils - .pushAggregates(sHolder.builder, aggregates, groupingExpressions) - if (pushedAggregates.aggregateExpressions.isEmpty) { - aggNode // return original plan node - } else { - // use the aggregate columns as the output columns - // e.g. TABLE t (c1 INT, c2 INT, c3 INT) - // SELECT min(c1), max(c1) FROM t; - // Use min(c1), max(c1) as output for DataSourceV2ScanRelation - // We want to have the following logical plan: - // == Optimized Logical Plan == - // Aggregate [min(min(c1)#21) AS min(c1)#17, max(max(c1)#22) AS max(c1)#18] - // +- RelationV2[min(c1)#21, max(c1)#22] parquet file ... - val output = aggregates.map { - case agg: AggregateExpression => - AttributeReference(toPrettySQL(agg), agg.dataType)() - } + // use the aggregate columns as the output columns + // e.g. TABLE t (c1 INT, c2 INT, c3 INT) + // SELECT min(c1), max(c1) FROM t; + // Use min(c1), max(c1) as output for DataSourceV2ScanRelation + // We want to have the following logical plan: + // == Optimized Logical Plan == + // Aggregate [min(min(c1)#21) AS min(c1)#17, max(max(c1)#22) AS max(c1)#18] + // +- RelationV2[min(c1)#21, max(c1)#22] parquet file ... + val output = aggregates.map { + case agg: AggregateExpression => + AttributeReference(toPrettySQL(agg), agg.dataType)() + } - // No need to do column pruning because only the aggregate columns are used as - // DataSourceV2ScanRelation output columns. All the other columns are not - // included in the output. Since PushDownUtils.pruneColumns is not called, - // ScanBuilder.requiredSchema is not pruned, but ScanBuilder.requiredSchema is - // not used anyways. The schema for aggregate columns will be built in Scan. - val scan = sHolder.builder.build() + // No need to do column pruning because only the aggregate columns are used as + // DataSourceV2ScanRelation output columns. All the other columns are not + // included in the output. Since PushDownUtils.pruneColumns is not called, + // ScanBuilder.requiredSchema is not pruned, but ScanBuilder.requiredSchema is + // not used anyways. The schema for aggregate columns will be built in Scan. + val scan = sHolder.builder.build() - logInfo( - s""" - |Pushing operators to ${sHolder.relation.name} - |Pushed Aggregate Functions: - | ${pushedAggregates.aggregateExpressions.mkString(", ")} - |Output: ${output.mkString(", ")} + logInfo( + s""" + |Pushing operators to ${sHolder.relation.name} + |Pushed Aggregate Functions: + | ${pushedAggregates.aggregateExpressions.mkString(", ")} + |Output: ${output.mkString(", ")} """.stripMargin) - val scanRelation = DataSourceV2ScanRelation(sHolder.relation, scan, output) - val plan = Aggregate(groupingExpressions, resultExpressions, scanRelation) + val scanRelation = DataSourceV2ScanRelation(sHolder.relation, scan, output) + val plan = Aggregate(groupingExpressions, resultExpressions, scanRelation) - // Change the optimized logical plan to reflect the pushed down aggregate - // e.g. TABLE t (c1 INT, c2 INT, c3 INT) - // SELECT min(c1), max(c1) FROM t; - // The original logical plan is - // Aggregate [min(c1#9) AS min(c1)#17, max(c1#9) AS max(c1)#18] - // +- RelationV2[c1#9] parquet ... - // - // After change the V2ScanRelation output to [min(_1)#21, max(_1)#22] - // we have the following - // !Aggregate [min(_1#9) AS min(_1)#17, max(_1#9) AS max(_1)#18] - // +- RelationV2[min(_1)#21, max(_1)#22] parquet ... - // - // We want to change it to - // == Optimized Logical Plan == - // Aggregate [min(min(c1)#21) AS min(c1)#17, max(max(c1)#22) AS max(c1)#18] - // +- RelationV2[min(c1)#21, max(c1)#22] parquet file ... - var i = 0 - plan.transformExpressions { - case agg: AggregateExpression => - i += 1 - val aggFunction: aggregate.AggregateFunction = - agg.aggregateFunction match { - case _: aggregate.Max => aggregate.Max(output(i - 1)) - case _: aggregate.Min => aggregate.Min(output(i - 1)) - case _: aggregate.Sum => aggregate.Sum(output(i - 1)) - case _: aggregate.Count => aggregate.Sum(output(i - 1)) - case _ => agg.aggregateFunction - } - agg.copy(aggregateFunction = aggFunction, filter = None) - } + // Change the optimized logical plan to reflect the pushed down aggregate + // e.g. TABLE t (c1 INT, c2 INT, c3 INT) + // SELECT min(c1), max(c1) FROM t; + // The original logical plan is + // Aggregate [min(c1#9) AS min(c1)#17, max(c1#9) AS max(c1)#18] + // +- RelationV2[c1#9] parquet ... + // + // After change the V2ScanRelation output to [min(_1)#21, max(_1)#22] + // we have the following + // !Aggregate [min(_1#9) AS min(_1)#17, max(_1#9) AS max(_1)#18] + // +- RelationV2[min(_1)#21, max(_1)#22] parquet ... + // + // We want to change it to + // == Optimized Logical Plan == + // Aggregate [min(min(c1)#21) AS min(c1)#17, max(max(c1)#22) AS max(c1)#18] + // +- RelationV2[min(c1)#21, max(c1)#22] parquet file ... + var i = 0 + plan.transformExpressions { + case agg: AggregateExpression => + i += 1 + val aggFunction: aggregate.AggregateFunction = + agg.aggregateFunction match { + case _: aggregate.Max => aggregate.Max(output(i - 1)) + case _: aggregate.Min => aggregate.Min(output(i - 1)) + case _: aggregate.Sum => aggregate.Sum(output(i - 1)) + case _: aggregate.Count => aggregate.Sum(output(i - 1)) + case _ => agg.aggregateFunction + } + agg.copy(aggregateFunction = aggFunction, filter = None) } } + } else { aggNode } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/parquet/ParquetScanBuilder.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/parquet/ParquetScanBuilder.scala index dce1bc869a880..95642761394bf 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/parquet/ParquetScanBuilder.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/parquet/ParquetScanBuilder.scala @@ -18,10 +18,10 @@ package org.apache.spark.sql.execution.datasources.v2.parquet import scala.collection.JavaConverters._ - import org.apache.spark.SparkException import org.apache.spark.sql.SparkSession import org.apache.spark.sql.connector.expressions.{Aggregation, Count, FieldReference, LiteralValue, Max, Min} +import org.apache.spark.sql.connector.read.SupportsPushDownAggregates.AggregatePushDownResult import org.apache.spark.sql.connector.read.{Scan, SupportsPushDownAggregates, SupportsPushDownFilters} import org.apache.spark.sql.execution.datasources.PartitioningAwareFileIndex import org.apache.spark.sql.execution.datasources.parquet.{ParquetFilters, SparkToParquetSchemaConverter} @@ -75,10 +75,10 @@ case class ParquetScanBuilder( private var pushedAggregations = Aggregation.empty - override def pushAggregation(aggregation: Aggregation): Boolean = { + override def pushAggregation(aggregation: Aggregation): AggregatePushDownResult = { if (!sparkSession.sessionState.conf.parquetAggregatePushDown || aggregation.groupByColumns.nonEmpty || pushedParquetFilters.length > 0) { - return false + return new AggregatePushDownResult (0, StructType(Array.empty[StructField])) } aggregation.aggregateExpressions.foreach { @@ -88,7 +88,8 @@ case class ParquetScanBuilder( .dataType match { // not push down nested column and Timestamp (INT96 sort order is undefined, parquet // doesn't return statistics for INT96) - case StructType(_) | ArrayType(_, _) | MapType(_, _, _) | TimestampType => return false + case StructType(_) | ArrayType(_, _) | MapType(_, _, _) | TimestampType => + return new AggregatePushDownResult (0, StructType(Array.empty[StructField])) case _ => } case _ => @@ -100,7 +101,8 @@ case class ParquetScanBuilder( .dataType match { // not push down nested column and Timestamp (INT96 sort order is undefined, parquet // doesn't return statistics for INT96) - case StructType(_) | ArrayType(_, _) | MapType(_, _, _) | TimestampType => return false + case StructType(_) | ArrayType(_, _) | MapType(_, _, _) | TimestampType => + return new AggregatePushDownResult (0, StructType(Array.empty[StructField])) case _ => } case _ => @@ -108,15 +110,14 @@ case class ParquetScanBuilder( } // not push down distinct count case Count(_, _, false) => - case _ => return false + case _ => + return return new AggregatePushDownResult (0, StructType(Array.empty[StructField])) } this.pushedAggregations = aggregation - true + new AggregatePushDownResult (1, getPushDownAggSchema) } - override def supportsGlobalAggregatePushDownOnly(): Boolean = true - - override def getPushDownAggSchema: StructType = { + private def getPushDownAggSchema: StructType = { var schema = new StructType() pushedAggregations.aggregateExpressions.foreach { case Max(col, _) => col match { From 564e6de0718856f93588d0ed7350349de4269236 Mon Sep 17 00:00:00 2001 From: Huaxin Gao Date: Wed, 23 Jun 2021 00:19:49 -0500 Subject: [PATCH 25/30] fix lint-scala --- .../execution/datasources/v2/parquet/ParquetScanBuilder.scala | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/parquet/ParquetScanBuilder.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/parquet/ParquetScanBuilder.scala index 95642761394bf..6b9f49a26ffe3 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/parquet/ParquetScanBuilder.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/parquet/ParquetScanBuilder.scala @@ -18,11 +18,12 @@ package org.apache.spark.sql.execution.datasources.v2.parquet import scala.collection.JavaConverters._ + import org.apache.spark.SparkException import org.apache.spark.sql.SparkSession import org.apache.spark.sql.connector.expressions.{Aggregation, Count, FieldReference, LiteralValue, Max, Min} -import org.apache.spark.sql.connector.read.SupportsPushDownAggregates.AggregatePushDownResult import org.apache.spark.sql.connector.read.{Scan, SupportsPushDownAggregates, SupportsPushDownFilters} +import org.apache.spark.sql.connector.read.SupportsPushDownAggregates.AggregatePushDownResult import org.apache.spark.sql.execution.datasources.PartitioningAwareFileIndex import org.apache.spark.sql.execution.datasources.parquet.{ParquetFilters, SparkToParquetSchemaConverter} import org.apache.spark.sql.execution.datasources.v2.FileScanBuilder From 1648070d6fffbb5100de77f8246ca067155c631c Mon Sep 17 00:00:00 2001 From: Huaxin Gao Date: Fri, 2 Jul 2021 14:05:23 -0700 Subject: [PATCH 26/30] remove completely push down status --- .../read/SupportsPushDownAggregates.java | 5 -- .../v2/parquet/ParquetScanBuilder.scala | 59 +++++++------------ 2 files changed, 21 insertions(+), 43 deletions(-) diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/read/SupportsPushDownAggregates.java b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/read/SupportsPushDownAggregates.java index 40ab1b0347463..66a9c71d4e072 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/read/SupportsPushDownAggregates.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/read/SupportsPushDownAggregates.java @@ -46,7 +46,6 @@ class AggregatePushDownResult { // 0: aggregates not pushed down // 1: aggregates partially pushed down, need to final aggregate in Spark - // 2: aggregates completely pushed down, doesn't need to final aggregate in Spark int pushedDownResult = 0; StructType pushedDownAggSchema; @@ -58,9 +57,5 @@ public AggregatePushDownResult(int pushedDownResult, StructType pushedDownAggSch public int getPushedDownResult() { return pushedDownResult; } - - public StructType getPushedDownAggSchema() { - return pushedDownAggSchema; - } } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/parquet/ParquetScanBuilder.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/parquet/ParquetScanBuilder.scala index 6b9f49a26ffe3..741e57232de38 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/parquet/ParquetScanBuilder.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/parquet/ParquetScanBuilder.scala @@ -76,82 +76,65 @@ case class ParquetScanBuilder( private var pushedAggregations = Aggregation.empty + private var pushedAggregateSchema = new StructType() + override def pushAggregation(aggregation: Aggregation): AggregatePushDownResult = { if (!sparkSession.sessionState.conf.parquetAggregatePushDown || aggregation.groupByColumns.nonEmpty || pushedParquetFilters.length > 0) { - return new AggregatePushDownResult (0, StructType(Array.empty[StructField])) + return new AggregatePushDownResult (0, new StructType()) } aggregation.aggregateExpressions.foreach { case Max(col, _) => col match { case ref: FieldReference => - dataSchema.fields(dataSchema.fieldNames.toList.indexOf(ref.fieldNames.head)) - .dataType match { + val field = dataSchema.fields(dataSchema.fieldNames.toList.indexOf(ref.fieldNames.head)) + field.dataType match { // not push down nested column and Timestamp (INT96 sort order is undefined, parquet // doesn't return statistics for INT96) case StructType(_) | ArrayType(_, _) | MapType(_, _, _) | TimestampType => - return new AggregatePushDownResult (0, StructType(Array.empty[StructField])) + return new AggregatePushDownResult (0, new StructType()) case _ => + pushedAggregateSchema = + pushedAggregateSchema.add(field.copy("max(" + field.name + ")")) } case _ => throw new SparkException("Expression $col is not currently supported.") } case Min(col, _) => col match { case ref: FieldReference => - dataSchema.fields(dataSchema.fieldNames.toList.indexOf(ref.fieldNames.head)) - .dataType match { + val field = dataSchema.fields(dataSchema.fieldNames.toList.indexOf(ref.fieldNames.head)) + field.dataType match { // not push down nested column and Timestamp (INT96 sort order is undefined, parquet // doesn't return statistics for INT96) case StructType(_) | ArrayType(_, _) | MapType(_, _, _) | TimestampType => - return new AggregatePushDownResult (0, StructType(Array.empty[StructField])) + return new AggregatePushDownResult (0, new StructType()) case _ => + pushedAggregateSchema = + pushedAggregateSchema.add(field.copy("min(" + field.name + ")")) } case _ => throw new SparkException("Expression $col is not currently supported.") } // not push down distinct count - case Count(_, _, false) => - case _ => - return return new AggregatePushDownResult (0, StructType(Array.empty[StructField])) - } - this.pushedAggregations = aggregation - new AggregatePushDownResult (1, getPushDownAggSchema) - } - - private def getPushDownAggSchema: StructType = { - var schema = new StructType() - pushedAggregations.aggregateExpressions.foreach { - case Max(col, _) => col match { - case ref: FieldReference => - val field = dataSchema.fields(dataSchema.fieldNames.toList.indexOf(ref.fieldNames.head)) - schema = schema.add(field.copy("max(" + field.name + ")")) - case _ => - throw new SparkException("Expression $col is not currently supported.") - } - case Min(col, _) => col match { - case ref: FieldReference => - val field = dataSchema.fields(dataSchema.fieldNames.toList.indexOf(ref.fieldNames.head)) - schema = schema.add(field.copy("min(" + field.name + ")")) - case _ => - throw new SparkException("Expression $col is not currently supported.") - } - case Count(col, _, _) => col match { + case Count(col, _, false) => col match { case _: FieldReference => - schema = schema.add(StructField("count(" + col + ")", LongType)) + pushedAggregateSchema = + pushedAggregateSchema.add(StructField("count(" + col + ")", LongType)) case LiteralValue(1, _) => - schema = schema.add(StructField("count(*)", LongType)) + pushedAggregateSchema = pushedAggregateSchema.add(StructField("count(*)", LongType)) case _ => throw new SparkException("Expression $col is not currently supported.") } case _ => - throw new SparkException("Pushed down aggregate is not supported.") + return return new AggregatePushDownResult (0, new StructType()) } - schema + this.pushedAggregations = aggregation + new AggregatePushDownResult (1, pushedAggregateSchema) } override def build(): Scan = { ParquetScan(sparkSession, hadoopConf, fileIndex, dataSchema, readDataSchema(), - readPartitionSchema(), pushedParquetFilters, pushedAggregations, getPushDownAggSchema, + readPartitionSchema(), pushedParquetFilters, pushedAggregations, pushedAggregateSchema, options) } } From ccdc543da7c1befa42a3a21c2704a18b239e31d7 Mon Sep 17 00:00:00 2001 From: Huaxin Gao Date: Sun, 11 Jul 2021 17:04:51 -0700 Subject: [PATCH 27/30] address comments --- .../spark/sql/connector/read/ScanBuilder.java | 7 +-- .../read/SupportsPushDownAggregates.java | 29 ++------- .../sql/catalyst/planning/patterns.scala | 16 ++--- .../datasources/DataSourceStrategy.scala | 38 ++++++------ .../datasources/v2/PushDownUtils.scala | 7 +-- .../v2/V2ScanRelationPushDown.scala | 62 ++++++------------- .../v2/parquet/ParquetScanBuilder.scala | 31 ++++------ 7 files changed, 71 insertions(+), 119 deletions(-) diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/read/ScanBuilder.java b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/read/ScanBuilder.java index 565af7cae3ccd..956060985069b 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/read/ScanBuilder.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/read/ScanBuilder.java @@ -22,11 +22,8 @@ /** * An interface for building the {@link Scan}. Implementations can mixin SupportsPushDownXYZ * interfaces to do operator pushdown, and keep the operator pushdown result in the returned - * {@link Scan}. - * - * The operators in the Scan can be pushed down to the data source layer. - * If applicable (the operator is present and the source supports that operator), Spark pushes - * down filters to the source first, then push down aggregation and apply column pruning. + * {@link Scan}. When pushing down operators, Spark pushes down filters first, then push down + * aggregates or apply column pruning. * * @since 3.0.0 */ diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/read/SupportsPushDownAggregates.java b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/read/SupportsPushDownAggregates.java index 66a9c71d4e072..0499b5db7b67d 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/read/SupportsPushDownAggregates.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/read/SupportsPushDownAggregates.java @@ -19,15 +19,13 @@ import org.apache.spark.annotation.Evolving; import org.apache.spark.sql.connector.expressions.Aggregation; -import org.apache.spark.sql.types.StructType; /** * A mix-in interface for {@link ScanBuilder}. Data source can implement this interface to * push down aggregates. Depends on the data source implementation, the aggregates may not - * be able to push down, partially push down and have final aggregate at Spark, or completely - * push down. + * be able to push down, or partially push down and have final aggregate at Spark. * - * When pushing down operators, Spark pushes down filter to the data source first, then push down + * When pushing down operators, Spark pushes down filters to the data source first, then push down * aggregates or apply column pruning. Depends on data source implementation, aggregates may or * may not be able to be pushed down with filters. If pushed filters still need to be evaluated * after scanning, aggregates can't be pushed down. @@ -38,24 +36,9 @@ public interface SupportsPushDownAggregates extends ScanBuilder { /** - * Pushes down Aggregation to datasource. + * Pushes down Aggregation to datasource. The order of the datasource scan output is: + * grouping columns, aggregate columns (in the same order as the aggregate functions in + * the given Aggregation. */ - AggregatePushDownResult pushAggregation(Aggregation aggregation); - - class AggregatePushDownResult { - - // 0: aggregates not pushed down - // 1: aggregates partially pushed down, need to final aggregate in Spark - int pushedDownResult = 0; - StructType pushedDownAggSchema; - - public AggregatePushDownResult(int pushedDownResult, StructType pushedDownAggSchema) { - this.pushedDownResult = pushedDownResult; - this.pushedDownAggSchema = pushedDownAggSchema; - } - - public int getPushedDownResult() { - return pushedDownResult; - } - } + boolean pushAggregation(Aggregation aggregation); } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala index a53c0121d73bd..c22a874779fca 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala @@ -48,14 +48,6 @@ trait OperationHelper { .map(Alias(_, a.name)(a.exprId, a.qualifier)).getOrElse(a) } } - - protected def hasCommonNonDeterministic( - expr: Seq[Expression], - aliases: AttributeMap[Expression]): Boolean = { - expr.exists(_.collect { - case a: AttributeReference if aliases.contains(a) => aliases(a) - }.exists(!_.deterministic)) - } } /** @@ -124,6 +116,14 @@ object ScanOperation extends OperationHelper with PredicateHelper { } } + private def hasCommonNonDeterministic( + expr: Seq[Expression], + aliases: AttributeMap[Expression]): Boolean = { + expr.exists(_.collect { + case a: AttributeReference if aliases.contains(a) => aliases(a) + }.exists(!_.deterministic)) + } + private def collectProjectsAndFilters(plan: LogicalPlan): ScanReturnType = { plan match { case Project(fields, child) => diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala index 53980b25b6578..8f04c4f1df734 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala @@ -675,24 +675,26 @@ object DataSourceStrategy (nonconvertiblePredicates ++ unhandledPredicates, pushedFilters, handledFilters) } - protected[sql] def translateAggregate( - aggregates: AggregateExpression, - pushableColumn: PushableColumnBase): Option[AggregateFunc] = { - aggregates.aggregateFunction match { - case min @ aggregate.Min(pushableColumn(name)) => - Some(Min(FieldReference(Seq(name)), min.dataType)) - case max @ aggregate.Max(pushableColumn(name)) => - Some(Max(FieldReference(Seq(name)), max.dataType)) - case count: aggregate.Count => - count.children.head match { - // SELECT COUNT(*) FROM table is translated to SELECT 1 FROM table - case Literal(_, _) => - Some(Count(LiteralValue(1, LongType), LongType, aggregates.isDistinct)) - case pushableColumn(name) => - Some(Count(FieldReference(Seq(name)), LongType, aggregates.isDistinct)) - case _ => None - } - case _ => None + protected[sql] def translateAggregate(aggregates: AggregateExpression): Option[AggregateFunc] = { + if (aggregates.filter.isEmpty) { + aggregates.aggregateFunction match { + case min@aggregate.Min(PushableColumnAndNestedColumn(name)) => + Some(Min(FieldReference(Seq(name)), min.dataType)) + case max@aggregate.Max(PushableColumnAndNestedColumn(name)) => + Some(Max(FieldReference(Seq(name)), max.dataType)) + case count: aggregate.Count if count.children.length == 1 => + count.children.head match { + // SELECT COUNT(*) FROM table is translated to SELECT 1 FROM table + case Literal(_, _) => + Some(Count(LiteralValue(1L, LongType), LongType, aggregates.isDistinct)) + case PushableColumnAndNestedColumn(name) => + Some(Count(FieldReference(Seq(name)), LongType, aggregates.isDistinct)) + case _ => None + } + case _ => None + } + } else { + None } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/PushDownUtils.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/PushDownUtils.scala index 87d52ae8e557f..2b0e81e029321 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/PushDownUtils.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/PushDownUtils.scala @@ -24,7 +24,7 @@ import org.apache.spark.sql.catalyst.expressions.aggregate.AggregateExpression import org.apache.spark.sql.catalyst.util.CharVarcharUtils import org.apache.spark.sql.connector.expressions.{Aggregation, FieldReference} import org.apache.spark.sql.connector.read.{Scan, ScanBuilder, SupportsPushDownAggregates, SupportsPushDownFilters, SupportsPushDownRequiredColumns} -import org.apache.spark.sql.execution.datasources.{DataSourceStrategy, PushableColumn} +import org.apache.spark.sql.execution.datasources.DataSourceStrategy import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.sources import org.apache.spark.sql.types.StructType @@ -89,12 +89,11 @@ object PushDownUtils extends PredicateHelper { scanBuilder match { case r: SupportsPushDownAggregates => - val translatedAggregates = aggregates.map(DataSourceStrategy - .translateAggregate(_, PushableColumn(false))) + val translatedAggregates = aggregates.map(DataSourceStrategy.translateAggregate) val translatedGroupBys = groupBy.map(columnAsString) val agg = Aggregation(translatedAggregates.flatten, translatedGroupBys.flatten) - if (r.pushAggregation(agg).getPushedDownResult > 0) { + if (r.pushAggregation(agg)) { agg } else { Aggregation.empty diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/V2ScanRelationPushDown.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/V2ScanRelationPushDown.scala index a022bfd2292f5..95ba34de85c78 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/V2ScanRelationPushDown.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/V2ScanRelationPushDown.scala @@ -22,7 +22,6 @@ import org.apache.spark.sql.catalyst.expressions.aggregate.AggregateExpression import org.apache.spark.sql.catalyst.planning.{OperationHelper, ScanOperation} import org.apache.spark.sql.catalyst.plans.logical.{Aggregate, Filter, LeafNode, LogicalPlan, Project} import org.apache.spark.sql.catalyst.rules.Rule -import org.apache.spark.sql.catalyst.util.toPrettySQL import org.apache.spark.sql.connector.expressions.Aggregation import org.apache.spark.sql.connector.read.{Scan, ScanBuilder, SupportsPushDownAggregates, SupportsPushDownFilters, V1Scan} import org.apache.spark.sql.execution.datasources.DataSourceStrategy @@ -44,9 +43,8 @@ object V2ScanRelationPushDown extends Rule[LogicalPlan] with AliasHelper private def pushDownFilters(plan: LogicalPlan) = plan.transform { // update the scan builder with filter push down and return a new plan with filter pushed - case filter @ Filter(_, sHolder: ScanBuilderHolder) => - val (filters, _, _) = collectFilters(filter).get - + case Filter(condition, sHolder: ScanBuilderHolder) => + val filters = splitConjunctivePredicates(condition) val normalizedFilters = DataSourceStrategy.normalizeExprs(filters, sHolder.relation.output) val (normalizedFiltersWithSubquery, normalizedFiltersWithoutSubquery) = @@ -72,11 +70,12 @@ object V2ScanRelationPushDown extends Rule[LogicalPlan] with AliasHelper def pushdownAggregate(plan: LogicalPlan): LogicalPlan = plan.transform { // update the scan builder with agg pushdown and return a new plan with agg pushed - case aggNode@Aggregate(groupingExpressions, resultExpressions, child) => + case aggNode @ Aggregate(groupingExpressions, resultExpressions, child) => child match { - case ScanOperation(project, filters, sHolder: ScanBuilderHolder) => + case ScanOperation(project, filters, sHolder: ScanBuilderHolder) + if project.forall(_.isInstanceOf[AttributeReference]) => sHolder.builder match { - case r: SupportsPushDownAggregates => + case _: SupportsPushDownAggregates => if (filters.length == 0) { // can't push down aggregate if postScanFilters exist val aggregates = getAggregateExpression(resultExpressions, project, sHolder) val pushedAggregates = PushDownUtils @@ -92,9 +91,12 @@ object V2ScanRelationPushDown extends Rule[LogicalPlan] with AliasHelper // == Optimized Logical Plan == // Aggregate [min(min(c1)#21) AS min(c1)#17, max(max(c1)#22) AS max(c1)#18] // +- RelationV2[min(c1)#21, max(c1)#22] parquet file ... - val output = aggregates.map { - case agg: AggregateExpression => - AttributeReference(toPrettySQL(agg), agg.dataType)() + var index = 0 + val output = resultExpressions.map { + case Alias(_, name) => + index = index + 1 + AttributeReference(name, aggregates(index - 1).dataType)() + case a: AttributeReference => a } // No need to do column pruning because only the aggregate columns are used as @@ -113,7 +115,11 @@ object V2ScanRelationPushDown extends Rule[LogicalPlan] with AliasHelper """.stripMargin) val scanRelation = DataSourceV2ScanRelation(sHolder.relation, scan, output) - val plan = Aggregate(groupingExpressions, resultExpressions, scanRelation) + assert(scanRelation.output.length == + groupingExpressions.length + aggregates.length) + + val plan = Aggregate( + output.take(groupingExpressions.length), resultExpressions, scanRelation) // Change the optimized logical plan to reflect the pushed down aggregate // e.g. TABLE t (c1 INT, c2 INT, c3 INT) @@ -122,10 +128,10 @@ object V2ScanRelationPushDown extends Rule[LogicalPlan] with AliasHelper // Aggregate [min(c1#9) AS min(c1)#17, max(c1#9) AS max(c1)#18] // +- RelationV2[c1#9] parquet ... // - // After change the V2ScanRelation output to [min(_1)#21, max(_1)#22] + // After change the V2ScanRelation output to [min(c1)#21, max(c1)#22] // we have the following - // !Aggregate [min(_1#9) AS min(_1)#17, max(_1#9) AS max(_1)#18] - // +- RelationV2[min(_1)#21, max(_1)#22] parquet ... + // !Aggregate [min(c1#9) AS min(c1)#17, max(c1#9) AS max(c1)#18] + // +- RelationV2[min(c1)#21, max(c1)#22] parquet ... // // We want to change it to // == Optimized Logical Plan == @@ -146,7 +152,6 @@ object V2ScanRelationPushDown extends Rule[LogicalPlan] with AliasHelper agg.copy(aggregateFunction = aggFunction, filter = None) } } - } else { aggNode } @@ -217,33 +222,6 @@ object V2ScanRelationPushDown extends Rule[LogicalPlan] with AliasHelper DataSourceStrategy.normalizeExprs(aggregates, sHolder.relation.output) .asInstanceOf[Seq[AggregateExpression]] } - - private def collectFilters(plan: LogicalPlan): - Option[(Seq[Expression], LogicalPlan, AttributeMap[Expression])] = { - plan match { - case Filter(condition, child) => - collectFilters(child) match { - case Some((filters, other, aliases)) => - // Follow CombineFilters and only keep going if 1) the collected Filters - // and this filter are all deterministic or 2) if this filter is the first - // collected filter and doesn't have common non-deterministic expressions - // with lower Project. - val substitutedCondition = substitute(aliases)(condition) - val canCombineFilters = (filters.nonEmpty && filters.forall(_.deterministic) && - substitutedCondition.deterministic) || filters.isEmpty - if (canCombineFilters && !hasCommonNonDeterministic(Seq(condition), aliases)) { - Some((filters ++ splitConjunctivePredicates(substitutedCondition), - other, aliases)) - } else { - None - } - case None => None - } - - case other => - Some((Nil, other, AttributeMap(Seq()))) - } - } } case class ScanBuilderHolder( diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/parquet/ParquetScanBuilder.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/parquet/ParquetScanBuilder.scala index 741e57232de38..9a4b6a357a347 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/parquet/ParquetScanBuilder.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/parquet/ParquetScanBuilder.scala @@ -23,7 +23,6 @@ import org.apache.spark.SparkException import org.apache.spark.sql.SparkSession import org.apache.spark.sql.connector.expressions.{Aggregation, Count, FieldReference, LiteralValue, Max, Min} import org.apache.spark.sql.connector.read.{Scan, SupportsPushDownAggregates, SupportsPushDownFilters} -import org.apache.spark.sql.connector.read.SupportsPushDownAggregates.AggregatePushDownResult import org.apache.spark.sql.execution.datasources.PartitioningAwareFileIndex import org.apache.spark.sql.execution.datasources.parquet.{ParquetFilters, SparkToParquetSchemaConverter} import org.apache.spark.sql.execution.datasources.v2.FileScanBuilder @@ -78,10 +77,10 @@ case class ParquetScanBuilder( private var pushedAggregateSchema = new StructType() - override def pushAggregation(aggregation: Aggregation): AggregatePushDownResult = { + override def pushAggregation(aggregation: Aggregation): Boolean = { if (!sparkSession.sessionState.conf.parquetAggregatePushDown || aggregation.groupByColumns.nonEmpty || pushedParquetFilters.length > 0) { - return new AggregatePushDownResult (0, new StructType()) + return false } aggregation.aggregateExpressions.foreach { @@ -91,11 +90,9 @@ case class ParquetScanBuilder( field.dataType match { // not push down nested column and Timestamp (INT96 sort order is undefined, parquet // doesn't return statistics for INT96) - case StructType(_) | ArrayType(_, _) | MapType(_, _, _) | TimestampType => - return new AggregatePushDownResult (0, new StructType()) - case _ => - pushedAggregateSchema = - pushedAggregateSchema.add(field.copy("max(" + field.name + ")")) + case StructType(_) | ArrayType(_, _) | MapType(_, _, _) | TimestampType => return false + case _ => pushedAggregateSchema = + pushedAggregateSchema.add(field.copy("max(" + field.name + ")")) } case _ => throw new SparkException("Expression $col is not currently supported.") @@ -106,30 +103,26 @@ case class ParquetScanBuilder( field.dataType match { // not push down nested column and Timestamp (INT96 sort order is undefined, parquet // doesn't return statistics for INT96) - case StructType(_) | ArrayType(_, _) | MapType(_, _, _) | TimestampType => - return new AggregatePushDownResult (0, new StructType()) - case _ => - pushedAggregateSchema = - pushedAggregateSchema.add(field.copy("min(" + field.name + ")")) + case StructType(_) | ArrayType(_, _) | MapType(_, _, _) | TimestampType => return false + case _ => pushedAggregateSchema = + pushedAggregateSchema.add(field.copy("min(" + field.name + ")")) } case _ => throw new SparkException("Expression $col is not currently supported.") } // not push down distinct count case Count(col, _, false) => col match { - case _: FieldReference => - pushedAggregateSchema = - pushedAggregateSchema.add(StructField("count(" + col + ")", LongType)) + case _: FieldReference => pushedAggregateSchema = + pushedAggregateSchema.add(StructField("count(" + col + ")", LongType)) case LiteralValue(1, _) => pushedAggregateSchema = pushedAggregateSchema.add(StructField("count(*)", LongType)) case _ => throw new SparkException("Expression $col is not currently supported.") } - case _ => - return return new AggregatePushDownResult (0, new StructType()) + case _ => return false } this.pushedAggregations = aggregation - new AggregatePushDownResult (1, pushedAggregateSchema) + true } override def build(): Scan = { From 7540b59ec2ab65d8d3ddffc063e0d4f3a3670008 Mon Sep 17 00:00:00 2001 From: Huaxin Gao Date: Tue, 13 Jul 2021 13:45:01 -0700 Subject: [PATCH 28/30] address comments --- .../read/SupportsPushDownAggregates.java | 15 ++- .../connector/expressions/aggregates.scala | 5 - .../datasources/v2/PushDownUtils.scala | 19 ++-- .../v2/V2ScanRelationPushDown.scala | 92 +++++++++---------- .../ParquetPartitionReaderFactory.scala | 18 ++-- .../datasources/v2/parquet/ParquetScan.scala | 34 ++++++- .../v2/parquet/ParquetScanBuilder.scala | 4 +- .../org/apache/spark/sql/FileScanSuite.scala | 4 +- .../parquet/ParquetQuerySuite.scala | 51 +++++++++- 9 files changed, 155 insertions(+), 87 deletions(-) diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/read/SupportsPushDownAggregates.java b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/read/SupportsPushDownAggregates.java index 0499b5db7b67d..f1524b32fef3f 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/read/SupportsPushDownAggregates.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/read/SupportsPushDownAggregates.java @@ -23,7 +23,14 @@ /** * A mix-in interface for {@link ScanBuilder}. Data source can implement this interface to * push down aggregates. Depends on the data source implementation, the aggregates may not - * be able to push down, or partially push down and have final aggregate at Spark. + * be able to push down, or partially push down and have a final aggregate at Spark. + * For example, "SELECT min(_1) FROM t GROUP BY _2" can be pushed down to data source, + * the partially aggregated result min(_1) grouped by _2 will be returned to Spark, and + * then have a final aggregation. + * {{{ + * Aggregate [_2#10], [min(_2#10) AS min(_1)#16] + * +- RelationV2[_2#10, min(_1)#18] + * }}} * * When pushing down operators, Spark pushes down filters to the data source first, then push down * aggregates or apply column pruning. Depends on data source implementation, aggregates may or @@ -36,9 +43,9 @@ public interface SupportsPushDownAggregates extends ScanBuilder { /** - * Pushes down Aggregation to datasource. The order of the datasource scan output is: - * grouping columns, aggregate columns (in the same order as the aggregate functions in - * the given Aggregation. + * Pushes down Aggregation to datasource. The order of the datasource scan output columns should + * be: grouping columns, aggregate columns (in the same order as the aggregate functions in + * the given Aggregation). */ boolean pushAggregation(Aggregation aggregation); } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/connector/expressions/aggregates.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/connector/expressions/aggregates.scala index a4f1440df58f9..ac996d835f04b 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/connector/expressions/aggregates.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/connector/expressions/aggregates.scala @@ -33,8 +33,3 @@ case class Sum(column: Expression, dataType: DataType, isDistinct: Boolean) extends AggregateFunc case class Count(column: Expression, dataType: DataType, isDistinct: Boolean) extends AggregateFunc - -object Aggregation { - // Returns an empty Aggregate - def empty: Aggregation = Aggregation(Seq.empty[AggregateFunc], Seq.empty[Expression]) -} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/PushDownUtils.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/PushDownUtils.scala index 2b0e81e029321..9b552150fc0ac 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/PushDownUtils.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/PushDownUtils.scala @@ -80,7 +80,7 @@ object PushDownUtils extends PredicateHelper { def pushAggregates( scanBuilder: ScanBuilder, aggregates: Seq[AggregateExpression], - groupBy: Seq[Expression]): Aggregation = { + groupBy: Seq[Expression]): Option[Aggregation] = { def columnAsString(e: Expression): Option[FieldReference] = e match { case AttributeReference(name, _, _, _) => Some(FieldReference(Seq(name))) @@ -89,16 +89,21 @@ object PushDownUtils extends PredicateHelper { scanBuilder match { case r: SupportsPushDownAggregates => - val translatedAggregates = aggregates.map(DataSourceStrategy.translateAggregate) - val translatedGroupBys = groupBy.map(columnAsString) + val translatedAggregates = aggregates.map(DataSourceStrategy.translateAggregate).flatten + val translatedGroupBys = groupBy.map(columnAsString).flatten - val agg = Aggregation(translatedAggregates.flatten, translatedGroupBys.flatten) + if (translatedAggregates.length != aggregates.length || + translatedGroupBys.length != groupBy.length) { + return None + } + + val agg = Aggregation(translatedAggregates, translatedGroupBys) if (r.pushAggregation(agg)) { - agg + Some(agg) } else { - Aggregation.empty + None } - case _ => Aggregation.empty + case _ => None } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/V2ScanRelationPushDown.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/V2ScanRelationPushDown.scala index 95ba34de85c78..2ffaf729f8df2 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/V2ScanRelationPushDown.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/V2ScanRelationPushDown.scala @@ -22,7 +22,6 @@ import org.apache.spark.sql.catalyst.expressions.aggregate.AggregateExpression import org.apache.spark.sql.catalyst.planning.{OperationHelper, ScanOperation} import org.apache.spark.sql.catalyst.plans.logical.{Aggregate, Filter, LeafNode, LogicalPlan, Project} import org.apache.spark.sql.catalyst.rules.Rule -import org.apache.spark.sql.connector.expressions.Aggregation import org.apache.spark.sql.connector.read.{Scan, ScanBuilder, SupportsPushDownAggregates, SupportsPushDownFilters, V1Scan} import org.apache.spark.sql.execution.datasources.DataSourceStrategy import org.apache.spark.sql.sources @@ -77,28 +76,16 @@ object V2ScanRelationPushDown extends Rule[LogicalPlan] with AliasHelper sHolder.builder match { case _: SupportsPushDownAggregates => if (filters.length == 0) { // can't push down aggregate if postScanFilters exist - val aggregates = getAggregateExpression(resultExpressions, project, sHolder) + val aggregates = resultExpressions.flatMap { expr => + expr.collect { + case agg: AggregateExpression => agg + } + } val pushedAggregates = PushDownUtils .pushAggregates(sHolder.builder, aggregates, groupingExpressions) - if (pushedAggregates.aggregateExpressions.isEmpty) { + if (pushedAggregates.isEmpty) { aggNode // return original plan node } else { - // use the aggregate columns as the output columns - // e.g. TABLE t (c1 INT, c2 INT, c3 INT) - // SELECT min(c1), max(c1) FROM t; - // Use min(c1), max(c1) as output for DataSourceV2ScanRelation - // We want to have the following logical plan: - // == Optimized Logical Plan == - // Aggregate [min(min(c1)#21) AS min(c1)#17, max(max(c1)#22) AS max(c1)#18] - // +- RelationV2[min(c1)#21, max(c1)#22] parquet file ... - var index = 0 - val output = resultExpressions.map { - case Alias(_, name) => - index = index + 1 - AttributeReference(name, aggregates(index - 1).dataType)() - case a: AttributeReference => a - } - // No need to do column pruning because only the aggregate columns are used as // DataSourceV2ScanRelation output columns. All the other columns are not // included in the output. Since PushDownUtils.pruneColumns is not called, @@ -106,11 +93,30 @@ object V2ScanRelationPushDown extends Rule[LogicalPlan] with AliasHelper // not used anyways. The schema for aggregate columns will be built in Scan. val scan = sHolder.builder.build() + // scalastyle:off + // use the group by columns and aggregate columns as the output columns + // e.g. TABLE t (c1 INT, c2 INT, c3 INT) + // SELECT min(c1), max(c1) FROM t GROUP BY c2; + // Use c2, min(c1), max(c1) as output for DataSourceV2ScanRelation + // We want to have the following logical plan: + // == Optimized Logical Plan == + // Aggregate [c2#10], [min(min(c1)#21) AS min(c1)#17, max(max(c1)#22) AS max(c1)#18] + // +- RelationV2[c2#10, min(c1)#21, max(c1)#22] + // scalastyle:on + val newOutput = scan.readSchema().toAttributes + val groupAttrs = groupingExpressions.zip(newOutput).map { + case (a: Attribute, b: Attribute) => b.withExprId(a.exprId) + case other => other.asInstanceOf[AttributeReference] + } + val output = groupAttrs ++ newOutput.drop(groupAttrs.length) + logInfo( s""" |Pushing operators to ${sHolder.relation.name} |Pushed Aggregate Functions: - | ${pushedAggregates.aggregateExpressions.mkString(", ")} + | ${pushedAggregates.get.aggregateExpressions.mkString(", ")} + |Pushed Group by: + | ${pushedAggregates.get.groupByColumns.mkString(", ")} |Output: ${output.mkString(", ")} """.stripMargin) @@ -121,32 +127,35 @@ object V2ScanRelationPushDown extends Rule[LogicalPlan] with AliasHelper val plan = Aggregate( output.take(groupingExpressions.length), resultExpressions, scanRelation) + // scalastyle:off // Change the optimized logical plan to reflect the pushed down aggregate // e.g. TABLE t (c1 INT, c2 INT, c3 INT) - // SELECT min(c1), max(c1) FROM t; + // SELECT min(c1), max(c1) FROM t GROUP BY c2; // The original logical plan is - // Aggregate [min(c1#9) AS min(c1)#17, max(c1#9) AS max(c1)#18] - // +- RelationV2[c1#9] parquet ... + // Aggregate [c2#10],[min(c1#9) AS min(c1)#17, max(c1#9) AS max(c1)#18] + // +- RelationV2[c1#9, c2#10] ... // - // After change the V2ScanRelation output to [min(c1)#21, max(c1)#22] + // After change the V2ScanRelation output to [c2#10, min(c1)#21, max(c1)#22] // we have the following - // !Aggregate [min(c1#9) AS min(c1)#17, max(c1#9) AS max(c1)#18] - // +- RelationV2[min(c1)#21, max(c1)#22] parquet ... + // !Aggregate [c2#10], [min(c1#9) AS min(c1)#17, max(c1#9) AS max(c1)#18] + // +- RelationV2[c2#10, min(c1)#21, max(c1)#22] ... // // We want to change it to // == Optimized Logical Plan == - // Aggregate [min(min(c1)#21) AS min(c1)#17, max(max(c1)#22) AS max(c1)#18] - // +- RelationV2[min(c1)#21, max(c1)#22] parquet file ... + // Aggregate [c2#10], [min(min(c1)#21) AS min(c1)#17, max(max(c1)#22) AS max(c1)#18] + // +- RelationV2[c2#10, min(c1)#21, max(c1)#22] ... + // scalastyle:on var i = 0 + val aggOutput = output.drop(groupAttrs.length) plan.transformExpressions { case agg: AggregateExpression => i += 1 val aggFunction: aggregate.AggregateFunction = agg.aggregateFunction match { - case _: aggregate.Max => aggregate.Max(output(i - 1)) - case _: aggregate.Min => aggregate.Min(output(i - 1)) - case _: aggregate.Sum => aggregate.Sum(output(i - 1)) - case _: aggregate.Count => aggregate.Sum(output(i - 1)) + case _: aggregate.Max => aggregate.Max(aggOutput(i - 1)) + case _: aggregate.Min => aggregate.Min(aggOutput(i - 1)) + case _: aggregate.Sum => aggregate.Sum(aggOutput(i - 1)) + case _: aggregate.Count => aggregate.Sum(aggOutput(i - 1)) case _ => agg.aggregateFunction } agg.copy(aggregateFunction = aggFunction, filter = None) @@ -183,7 +192,7 @@ object V2ScanRelationPushDown extends Rule[LogicalPlan] with AliasHelper f.pushedFilters() case _ => Array.empty[sources.Filter] } - V1ScanWrapper(v1, translated, pushedFilters, Aggregation.empty) + V1ScanWrapper(v1, translated, pushedFilters) case _ => scan } @@ -208,20 +217,6 @@ object V2ScanRelationPushDown extends Rule[LogicalPlan] with AliasHelper } withProjection } - - private def getAggregateExpression( - resultExpressions: Seq[NamedExpression], - project: Seq[NamedExpression], - sHolder: ScanBuilderHolder): Seq[AggregateExpression] = { - val aggregates = resultExpressions.flatMap { expr => - expr.collect { - case agg: AggregateExpression => - replaceAlias(agg, getAliasMap(project)).asInstanceOf[AggregateExpression] - } - } - DataSourceStrategy.normalizeExprs(aggregates, sHolder.relation.output) - .asInstanceOf[Seq[AggregateExpression]] - } } case class ScanBuilderHolder( @@ -234,7 +229,6 @@ case class ScanBuilderHolder( case class V1ScanWrapper( v1Scan: V1Scan, translatedFilters: Seq[sources.Filter], - handledFilters: Seq[sources.Filter], - pushedAggregates: Aggregation) extends Scan { + handledFilters: Seq[sources.Filter]) extends Scan { override def readSchema(): StructType = v1Scan.readSchema() } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/parquet/ParquetPartitionReaderFactory.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/parquet/ParquetPartitionReaderFactory.scala index 96b98e9c12789..5be31c8367b80 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/parquet/ParquetPartitionReaderFactory.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/parquet/ParquetPartitionReaderFactory.scala @@ -67,10 +67,10 @@ case class ParquetPartitionReaderFactory( partitionSchema: StructType, aggSchema: StructType, filters: Array[Filter], - aggregation: Aggregation, + aggregation: Option[Aggregation], parquetOptions: ParquetOptions) extends FilePartitionReaderFactory with Logging { private val isCaseSensitive = sqlConf.caseSensitiveAnalysis - private val newReadDataSchema = if (aggregation.aggregateExpressions.isEmpty) { + private val newReadDataSchema = if (aggregation.isEmpty) { readDataSchema } else { aggSchema @@ -96,7 +96,7 @@ case class ParquetPartitionReaderFactory( val filePath = new Path(new URI(file.filePath)) - if (aggregation.aggregateExpressions.isEmpty) { + if (aggregation.isEmpty) { ParquetFooterReader.readFooter(conf, filePath, SKIP_ROW_GROUPS) } else { ParquetFooterReader.readFooter(conf, filePath, NO_FILTER) @@ -122,7 +122,7 @@ case class ParquetPartitionReaderFactory( } override def buildReader(file: PartitionedFile): PartitionReader[InternalRow] = { - val fileReader = if (aggregation.aggregateExpressions.isEmpty) { + val fileReader = if (aggregation.isEmpty) { val reader = if (enableVectorizedReader) { createVectorizedReader(file) @@ -146,8 +146,8 @@ case class ParquetPartitionReaderFactory( override def get(): InternalRow = { count += 1 val footer = getFooter(file) - ParquetUtils.createInternalRowFromAggResult(footer, dataSchema, aggregation, aggSchema, - datetimeRebaseModeInRead) + ParquetUtils.createInternalRowFromAggResult(footer, dataSchema, aggregation.get, + aggSchema, datetimeRebaseModeInRead) } override def close(): Unit = return @@ -159,7 +159,7 @@ case class ParquetPartitionReaderFactory( } override def buildColumnarReader(file: PartitionedFile): PartitionReader[ColumnarBatch] = { - val fileReader = if (aggregation.aggregateExpressions.isEmpty) { + val fileReader = if (aggregation.isEmpty) { val vectorizedReader = createVectorizedReader(file) vectorizedReader.enableReturningBatches() @@ -180,8 +180,8 @@ case class ParquetPartitionReaderFactory( override def get(): ColumnarBatch = { count += 1 val footer = getFooter(file) - ParquetUtils.createColumnarBatchFromAggResult(footer, dataSchema, aggregation, aggSchema, - enableOffHeapColumnVector, datetimeRebaseModeInRead) + ParquetUtils.createColumnarBatchFromAggResult(footer, dataSchema, aggregation.get, + aggSchema, enableOffHeapColumnVector, datetimeRebaseModeInRead) } override def close(): Unit = return diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/parquet/ParquetScan.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/parquet/ParquetScan.scala index 9d4149cd8da9b..8e8e8cac6e0a3 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/parquet/ParquetScan.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/parquet/ParquetScan.scala @@ -43,7 +43,7 @@ case class ParquetScan( readDataSchema: StructType, readPartitionSchema: StructType, pushedFilters: Array[Filter], - pushedAggregations: Aggregation = Aggregation.empty, + pushedAggregations: Option[Aggregation], pushedDownAggSchema: StructType, options: CaseInsensitiveStringMap, partitionFilters: Seq[Expression] = Seq.empty, @@ -96,28 +96,52 @@ case class ParquetScan( override def equals(obj: Any): Boolean = obj match { case p: ParquetScan => + val pushedDownAggEqual = if (pushedAggregations.nonEmpty) { + equivalentAggregations(pushedAggregations.get, p.pushedAggregations.get) + } else { + true + } super.equals(p) && dataSchema == p.dataSchema && options == p.options && - equivalentFilters(pushedFilters, p.pushedFilters) && - equivalentAggregations(pushedAggregations, p.pushedAggregations) + equivalentFilters(pushedFilters, p.pushedFilters) && pushedDownAggEqual case _ => false } override def hashCode(): Int = getClass.hashCode() + lazy private val pushedAggregationsStr = if (pushedAggregations.nonEmpty) { + seqToString(pushedAggregations.get.aggregateExpressions) + } else { + "[]" + } + + lazy private val pushedGroupByStr = if (pushedAggregations.nonEmpty) { + seqToString(pushedAggregations.get.groupByColumns) + } else { + "[]" + } + override def description(): String = { super.description() + ", PushedFilters: " + seqToString(pushedFilters) + - ", PushedAggregation: " + seqToString(pushedAggregations.aggregateExpressions) + ", PushedAggregation: " + pushedAggregationsStr + + ", PushedGroupBy: " + pushedGroupByStr } override def getMetaData(): Map[String, String] = { super.getMetaData() ++ Map("PushedFilters" -> seqToString(pushedFilters)) ++ - Map("PushedAggregation" -> seqToString(pushedAggregations.aggregateExpressions)) + Map("PushedAggregation" -> pushedAggregationsStr) ++ + Map("PushedGroupBy" -> pushedGroupByStr) } override def withFilters( partitionFilters: Seq[Expression], dataFilters: Seq[Expression]): FileScan = this.copy(partitionFilters = partitionFilters, dataFilters = dataFilters) + override def readSchema(): StructType = if (pushedAggregations.nonEmpty) { + pushedDownAggSchema + } else { + StructType(readDataSchema.fields ++ readPartitionSchema.fields) + } + // Returns whether the two given [[Aggregation]]s are equivalent. private def equivalentAggregations(a: Aggregation, b: Aggregation): Boolean = { a.aggregateExpressions.sortBy(_.hashCode()) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/parquet/ParquetScanBuilder.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/parquet/ParquetScanBuilder.scala index 9a4b6a357a347..f2cfff3b2a98a 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/parquet/ParquetScanBuilder.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/parquet/ParquetScanBuilder.scala @@ -73,7 +73,7 @@ case class ParquetScanBuilder( // All filters that can be converted to Parquet are pushed down. override def pushedFilters(): Array[Filter] = pushedParquetFilters - private var pushedAggregations = Aggregation.empty + private var pushedAggregations = Option.empty[Aggregation] private var pushedAggregateSchema = new StructType() @@ -121,7 +121,7 @@ case class ParquetScanBuilder( } case _ => return false } - this.pushedAggregations = aggregation + this.pushedAggregations = Some(aggregation) true } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/FileScanSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/FileScanSuite.scala index 2ac1a77fefb45..36ba853bf19e7 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/FileScanSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/FileScanSuite.scala @@ -355,8 +355,8 @@ class FileScanSuite extends FileScanSuiteBase { val scanBuilders = Seq[(String, ScanBuilder, Seq[String])]( ("ParquetScan", (s, fi, ds, rds, rps, f, o, pf, df) => - ParquetScan(s, s.sessionState.newHadoopConf(), fi, ds, rds, rps, f, Aggregation.empty, - null, o, pf, df), + ParquetScan(s, s.sessionState.newHadoopConf(), fi, ds, rds, rps, f, + Option.empty[Aggregation], null, o, pf, df), Seq.empty), ("OrcScan", (s, fi, ds, rds, rps, f, o, pf, df) => diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetQuerySuite.scala index ca38443b6bf73..955ba3875e850 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetQuerySuite.scala @@ -33,6 +33,7 @@ import org.apache.spark.sql.execution.datasources.{SchemaColumnConvertNotSupport import org.apache.spark.sql.execution.datasources.parquet.TestingUDT.{NestedStruct, NestedStructUDT, SingleElement} import org.apache.spark.sql.execution.datasources.v2.{BatchScanExec, DataSourceV2ScanRelation} import org.apache.spark.sql.execution.datasources.v2.parquet.ParquetScan +import org.apache.spark.sql.functions.min import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.test.SharedSparkSession import org.apache.spark.sql.types._ @@ -904,7 +905,7 @@ abstract class ParquetQuerySuite extends QueryTest with ParquetTest with SharedS } } - test("test aggregate push down - nested data ") { + test("test aggregate push down - nested data shouldn't be pushed down") { val data = (1 to 10).map(i => Tuple1((i, Seq(s"val_$i")))) withSQLConf( SQLConf.PARQUET_AGGREGATE_PUSHDOWN_ENABLED.key -> "true") { @@ -921,6 +922,42 @@ abstract class ParquetQuerySuite extends QueryTest with ParquetTest with SharedS } } + test("test aggregate push down alias") { + val data = Seq((-2, "abc", 2), (3, "def", 4), (6, "ghi", 2), (0, null, 19), + (9, "mno", 7), (2, null, 6)) + withParquetTable(data, "t") { + withSQLConf( + SQLConf.PARQUET_AGGREGATE_PUSHDOWN_ENABLED.key -> "true") { + + val selectAgg1 = sql("SELECT min(_1) + max(_1) as res FROM t") + + selectAgg1.queryExecution.optimizedPlan.collect { + case _: DataSourceV2ScanRelation => + val expected_plan_fragment = + "PushedAggregation: [Min(_1,IntegerType), Max(_1,IntegerType)]" + checkKeywordsExistsInExplain(selectAgg1, expected_plan_fragment) + } + + val selectAgg2 = sql("SELECT min(_1) as minValue, max(_1) as maxValue FROM t") + selectAgg2.queryExecution.optimizedPlan.collect { + case _: DataSourceV2ScanRelation => + val expected_plan_fragment = + "PushedAggregation: [Min(_1,IntegerType), Max(_1,IntegerType)]" + checkKeywordsExistsInExplain(selectAgg2, expected_plan_fragment) + } + + val df = spark.table("t") + val query = df.select($"_1".as("col1")).agg(min($"col1")) + query.queryExecution.optimizedPlan.collect { + case _: DataSourceV2ScanRelation => + val expected_plan_fragment = + "PushedAggregation: []" // aggregate alias not pushed down + checkKeywordsExistsInExplain(query, expected_plan_fragment) + } + } + } + } + test("test aggregate push down") { val data = Seq((-2, "abc", 2), (3, "def", 4), (6, "ghi", 2), (0, null, 19), (9, "mno", 7), (2, null, 6)) @@ -928,6 +965,7 @@ abstract class ParquetQuerySuite extends QueryTest with ParquetTest with SharedS withSQLConf( SQLConf.PARQUET_AGGREGATE_PUSHDOWN_ENABLED.key -> "true") { + // aggregate not pushed down if there is a filter val selectAgg1 = sql("SELECT min(_3) FROM t WHERE _1 > 0") selectAgg1.queryExecution.optimizedPlan.collect { case _: DataSourceV2ScanRelation => @@ -940,9 +978,14 @@ abstract class ParquetQuerySuite extends QueryTest with ParquetTest with SharedS val selectAgg2 = sql("SELECT min(_3 + _1), max(_3 + _1) FROM t") checkAnswer(selectAgg2, Seq(Row(0, 19))) - // sum is not pushed down - val selectAgg3 = sql("SELECT sum(_3) FROM t") - checkAnswer(selectAgg3, Seq(Row(40))) + // aggregate not pushed down if one of them can't be pushed down + val selectAgg3 = sql("SELECT min(_1), sum(_3) FROM t") + selectAgg3.queryExecution.optimizedPlan.collect { + case _: DataSourceV2ScanRelation => + val expected_plan_fragment = + "PushedAggregation: []" + checkKeywordsExistsInExplain(selectAgg3, expected_plan_fragment) + } val selectAgg4 = sql("SELECT min(_3), min(_3), max(_3), min(_1), max(_1), max(_1)," + " count(*), count(_1), count(_2), count(_3) FROM t") From 2c889c6c79d1b8a46bb2f3d5a578a720b195d509 Mon Sep 17 00:00:00 2001 From: Huaxin Gao Date: Tue, 13 Jul 2021 14:30:10 -0700 Subject: [PATCH 29/30] fix build failure: --- .../sql/execution/datasources/v2/DataSourceV2Strategy.scala | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Strategy.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Strategy.scala index 1585bc040fda1..811f41832d159 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Strategy.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Strategy.scala @@ -86,8 +86,7 @@ class DataSourceV2Strategy(session: SparkSession) extends Strategy with Predicat override def apply(plan: LogicalPlan): Seq[SparkPlan] = plan match { case PhysicalOperation(project, filters, - relation @ DataSourceV2ScanRelation(_, V1ScanWrapper(scan, translated, pushed, - aggregation), output)) => + relation @ DataSourceV2ScanRelation(_, V1ScanWrapper(scan, translated, pushed), output)) => val v1Relation = scan.toV1TableScan[BaseRelation with TableScan](session.sqlContext) if (v1Relation.schema != scan.readSchema()) { throw QueryExecutionErrors.fallbackV1RelationReportsInconsistentSchemaError( From 5c2b630a61d8ea9263116c37e154884a5cabd2ca Mon Sep 17 00:00:00 2001 From: Huaxin Gao Date: Wed, 14 Jul 2021 11:51:07 -0700 Subject: [PATCH 30/30] address comments --- .../datasources/v2/V2ScanRelationPushDown.scala | 9 ++++----- 1 file changed, 4 insertions(+), 5 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/V2ScanRelationPushDown.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/V2ScanRelationPushDown.scala index 2ffaf729f8df2..48915e91d0dcd 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/V2ScanRelationPushDown.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/V2ScanRelationPushDown.scala @@ -19,7 +19,7 @@ package org.apache.spark.sql.execution.datasources.v2 import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.aggregate.AggregateExpression -import org.apache.spark.sql.catalyst.planning.{OperationHelper, ScanOperation} +import org.apache.spark.sql.catalyst.planning.ScanOperation import org.apache.spark.sql.catalyst.plans.logical.{Aggregate, Filter, LeafNode, LogicalPlan, Project} import org.apache.spark.sql.catalyst.rules.Rule import org.apache.spark.sql.connector.read.{Scan, ScanBuilder, SupportsPushDownAggregates, SupportsPushDownFilters, V1Scan} @@ -27,8 +27,7 @@ import org.apache.spark.sql.execution.datasources.DataSourceStrategy import org.apache.spark.sql.sources import org.apache.spark.sql.types.StructType -object V2ScanRelationPushDown extends Rule[LogicalPlan] with AliasHelper - with OperationHelper with PredicateHelper { +object V2ScanRelationPushDown extends Rule[LogicalPlan] with PredicateHelper { import DataSourceV2Implicits._ def apply(plan: LogicalPlan): LogicalPlan = { @@ -106,7 +105,7 @@ object V2ScanRelationPushDown extends Rule[LogicalPlan] with AliasHelper val newOutput = scan.readSchema().toAttributes val groupAttrs = groupingExpressions.zip(newOutput).map { case (a: Attribute, b: Attribute) => b.withExprId(a.exprId) - case other => other.asInstanceOf[AttributeReference] + case (_, b) => b } val output = groupAttrs ++ newOutput.drop(groupAttrs.length) @@ -158,7 +157,7 @@ object V2ScanRelationPushDown extends Rule[LogicalPlan] with AliasHelper case _: aggregate.Count => aggregate.Sum(aggOutput(i - 1)) case _ => agg.aggregateFunction } - agg.copy(aggregateFunction = aggFunction, filter = None) + agg.copy(aggregateFunction = aggFunction) } } } else {