diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/read/ScanBuilder.java b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/read/ScanBuilder.java index cb3eea7680058..956060985069b 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/read/ScanBuilder.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/read/ScanBuilder.java @@ -22,7 +22,8 @@ /** * An interface for building the {@link Scan}. Implementations can mixin SupportsPushDownXYZ * interfaces to do operator pushdown, and keep the operator pushdown result in the returned - * {@link Scan}. + * {@link Scan}. When pushing down operators, Spark pushes down filters first, then push down + * aggregates or apply column pruning. * * @since 3.0.0 */ diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/read/SupportsPushDownAggregates.java b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/read/SupportsPushDownAggregates.java new file mode 100644 index 0000000000000..f1524b32fef3f --- /dev/null +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/read/SupportsPushDownAggregates.java @@ -0,0 +1,51 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.connector.read; + +import org.apache.spark.annotation.Evolving; +import org.apache.spark.sql.connector.expressions.Aggregation; + +/** + * A mix-in interface for {@link ScanBuilder}. Data source can implement this interface to + * push down aggregates. Depends on the data source implementation, the aggregates may not + * be able to push down, or partially push down and have a final aggregate at Spark. + * For example, "SELECT min(_1) FROM t GROUP BY _2" can be pushed down to data source, + * the partially aggregated result min(_1) grouped by _2 will be returned to Spark, and + * then have a final aggregation. + * {{{ + * Aggregate [_2#10], [min(_2#10) AS min(_1)#16] + * +- RelationV2[_2#10, min(_1)#18] + * }}} + * + * When pushing down operators, Spark pushes down filters to the data source first, then push down + * aggregates or apply column pruning. Depends on data source implementation, aggregates may or + * may not be able to be pushed down with filters. If pushed filters still need to be evaluated + * after scanning, aggregates can't be pushed down. + * + * @since 3.2.0 + */ +@Evolving +public interface SupportsPushDownAggregates extends ScanBuilder { + + /** + * Pushes down Aggregation to datasource. The order of the datasource scan output columns should + * be: grouping columns, aggregate columns (in the same order as the aggregate functions in + * the given Aggregation). + */ + boolean pushAggregation(Aggregation aggregation); +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/connector/expressions/aggregates.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/connector/expressions/aggregates.scala new file mode 100644 index 0000000000000..ac996d835f04b --- /dev/null +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/connector/expressions/aggregates.scala @@ -0,0 +1,35 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.connector.expressions + +import org.apache.spark.sql.types.DataType + +// Aggregate Functions in SQL statement. +// e.g. SELECT COUNT(EmployeeID), Max(salary), deptID FROM dept GROUP BY deptID +// aggregateExpressions are (COUNT(EmployeeID), Max(salary)), groupByColumns are (deptID) +case class Aggregation(aggregateExpressions: Seq[AggregateFunc], + groupByColumns: Seq[Expression]) + +abstract class AggregateFunc + +case class Min(column: Expression, dataType: DataType) extends AggregateFunc +case class Max(column: Expression, dataType: DataType) extends AggregateFunc +case class Sum(column: Expression, dataType: DataType, isDistinct: Boolean) + extends AggregateFunc +case class Count(column: Expression, dataType: DataType, isDistinct: Boolean) + extends AggregateFunc diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala index 04e740039f005..d979cd1ef966c 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala @@ -772,6 +772,12 @@ object SQLConf { .checkValue(threshold => threshold >= 0, "The threshold must not be negative.") .createWithDefault(10) + val PARQUET_AGGREGATE_PUSHDOWN_ENABLED = buildConf("spark.sql.parquet.aggregatePushdown") + .doc("Enables Parquet aggregate push-down optimization when set to true.") + .version("3.2.0") + .booleanConf + .createWithDefault(false) + val PARQUET_WRITE_LEGACY_FORMAT = buildConf("spark.sql.parquet.writeLegacyFormat") .doc("If true, data will be written in a way of Spark 1.4 and earlier. For example, decimal " + "values will be written in Apache Parquet's fixed-length byte array format, which other " + @@ -3423,6 +3429,8 @@ class SQLConf extends Serializable with Logging { def parquetFilterPushDownInFilterThreshold: Int = getConf(PARQUET_FILTER_PUSHDOWN_INFILTERTHRESHOLD) + def parquetAggregatePushDown: Boolean = getConf(PARQUET_AGGREGATE_PUSHDOWN_ENABLED) + def orcFilterPushDown: Boolean = getConf(ORC_FILTER_PUSHDOWN_ENABLED) def isOrcSchemaMergingEnabled: Boolean = getConf(ORC_SCHEMA_MERGING_ENABLED) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkOptimizer.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkOptimizer.scala index dde5dc2be0556..9c7a05c36a8f7 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkOptimizer.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkOptimizer.scala @@ -37,7 +37,8 @@ class SparkOptimizer( override def earlyScanPushDownRules: Seq[Rule[LogicalPlan]] = // TODO: move SchemaPruning into catalyst - SchemaPruning :: V2ScanRelationPushDown :: V2Writes :: PruneFileSourcePartitions :: Nil + SchemaPruning :: V2ScanRelationPushDown :: V2Writes :: + PruneFileSourcePartitions :: Nil override def defaultBatches: Seq[Batch] = (preOptimizationBatches ++ super.defaultBatches :+ Batch("Optimize Metadata Only Query", Once, OptimizeMetadataOnlyQuery(catalog)) :+ diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala index 58ac924a1d36c..8f04c4f1df734 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala @@ -33,12 +33,14 @@ import org.apache.spark.sql.catalyst.catalog._ import org.apache.spark.sql.catalyst.encoders.RowEncoder import org.apache.spark.sql.catalyst.expressions import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.expressions.aggregate.AggregateExpression import org.apache.spark.sql.catalyst.planning.ScanOperation import org.apache.spark.sql.catalyst.plans.logical.{InsertIntoDir, InsertIntoStatement, LogicalPlan, Project} import org.apache.spark.sql.catalyst.rules.Rule import org.apache.spark.sql.catalyst.streaming.StreamingRelationV2 import org.apache.spark.sql.connector.catalog.SupportsRead import org.apache.spark.sql.connector.catalog.TableCapability._ +import org.apache.spark.sql.connector.expressions.{AggregateFunc, Count, FieldReference, LiteralValue, Max, Min} import org.apache.spark.sql.errors.QueryCompilationErrors import org.apache.spark.sql.execution.{RowDataSourceScanExec, SparkPlan} import org.apache.spark.sql.execution.command._ @@ -673,6 +675,29 @@ object DataSourceStrategy (nonconvertiblePredicates ++ unhandledPredicates, pushedFilters, handledFilters) } + protected[sql] def translateAggregate(aggregates: AggregateExpression): Option[AggregateFunc] = { + if (aggregates.filter.isEmpty) { + aggregates.aggregateFunction match { + case min@aggregate.Min(PushableColumnAndNestedColumn(name)) => + Some(Min(FieldReference(Seq(name)), min.dataType)) + case max@aggregate.Max(PushableColumnAndNestedColumn(name)) => + Some(Max(FieldReference(Seq(name)), max.dataType)) + case count: aggregate.Count if count.children.length == 1 => + count.children.head match { + // SELECT COUNT(*) FROM table is translated to SELECT 1 FROM table + case Literal(_, _) => + Some(Count(LiteralValue(1L, LongType), LongType, aggregates.isDistinct)) + case PushableColumnAndNestedColumn(name) => + Some(Count(FieldReference(Seq(name)), LongType, aggregates.isDistinct)) + case _ => None + } + case _ => None + } + } else { + None + } + } + /** * Convert RDD of Row into RDD of InternalRow with objects in catalyst types */ diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetUtils.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetUtils.scala index b91d75c55c513..426e53a9eee93 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetUtils.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetUtils.scala @@ -16,11 +16,28 @@ */ package org.apache.spark.sql.execution.datasources.parquet +import java.math.{BigDecimal, BigInteger} +import java.util + +import scala.collection.mutable.ArrayBuilder +import scala.language.existentials + import org.apache.hadoop.fs.{FileStatus, Path} import org.apache.parquet.hadoop.ParquetFileWriter +import org.apache.parquet.hadoop.metadata.{ColumnChunkMetaData, ParquetMetadata} +import org.apache.parquet.io.api.Binary +import org.apache.parquet.schema.PrimitiveType +import org.apache.spark.SparkException import org.apache.spark.sql.SparkSession -import org.apache.spark.sql.types.StructType +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.expressions.SpecificInternalRow +import org.apache.spark.sql.connector.expressions.{Aggregation, Count, FieldReference, LiteralValue, Max, Min} +import org.apache.spark.sql.execution.datasources.DataSourceUtils +import org.apache.spark.sql.execution.vectorized.{OffHeapColumnVector, OnHeapColumnVector} +import org.apache.spark.sql.types.{BinaryType, ByteType, DateType, Decimal, DecimalType, IntegerType, LongType, ShortType, StringType, StructType} +import org.apache.spark.sql.vectorized.{ColumnarBatch, ColumnVector} +import org.apache.spark.unsafe.types.UTF8String object ParquetUtils { def inferSchema( @@ -127,4 +144,254 @@ object ParquetUtils { file.getName == ParquetFileWriter.PARQUET_COMMON_METADATA_FILE || file.getName == ParquetFileWriter.PARQUET_METADATA_FILE } + + /** + * When the partial Aggregates (Max/Min/Count) are pushed down to parquet, we don't need to + * createRowBaseReader to read data from parquet and aggregate at spark layer. Instead we want + * to calculate the partial Aggregates (Max/Min/Count) result using the statistics information + * from parquet footer file, and then construct an InternalRow from these Aggregate results. + * + * @return Aggregate results in the format of InternalRow + */ + private[sql] def createInternalRowFromAggResult( + footer: ParquetMetadata, + dataSchema: StructType, + aggregation: Aggregation, + aggSchema: StructType, + datetimeRebaseModeInRead: String): InternalRow = { + val (parquetTypes, values) = + ParquetUtils.getPushedDownAggResult(footer, dataSchema, aggregation) + val mutableRow = new SpecificInternalRow(aggSchema.fields.map(x => x.dataType)) + val footerFileMetaData = footer.getFileMetaData + val datetimeRebaseMode = DataSourceUtils.datetimeRebaseMode( + footerFileMetaData.getKeyValueMetaData.get, + datetimeRebaseModeInRead) + parquetTypes.zipWithIndex.foreach { + case (PrimitiveType.PrimitiveTypeName.INT32, i) => + aggSchema.fields(i).dataType match { + case ByteType => + mutableRow.setByte(i, values(i).asInstanceOf[Integer].toByte) + case ShortType => + mutableRow.setShort(i, values(i).asInstanceOf[Integer].toShort) + case IntegerType => + mutableRow.setInt(i, values(i).asInstanceOf[Integer]) + case DateType => + val dateRebaseFunc = DataSourceUtils.creteDateRebaseFuncInRead( + datetimeRebaseMode, "Parquet") + mutableRow.update(i, dateRebaseFunc(values(i).asInstanceOf[Integer])) + case d: DecimalType => + val decimal = Decimal(values(i).asInstanceOf[Integer].toLong, d.precision, d.scale) + mutableRow.setDecimal(i, decimal, d.precision) + case _ => throw new SparkException("Unexpected type for INT32") + } + case (PrimitiveType.PrimitiveTypeName.INT64, i) => + aggSchema.fields(i).dataType match { + case LongType => + mutableRow.setLong(i, values(i).asInstanceOf[Long]) + case d: DecimalType => + val decimal = Decimal(values(i).asInstanceOf[Long], d.precision, d.scale) + mutableRow.setDecimal(i, decimal, d.precision) + case _ => throw new SparkException("Unexpected type for INT64") + } + case (PrimitiveType.PrimitiveTypeName.FLOAT, i) => + mutableRow.setFloat(i, values(i).asInstanceOf[Float]) + case (PrimitiveType.PrimitiveTypeName.DOUBLE, i) => + mutableRow.setDouble(i, values(i).asInstanceOf[Double]) + case (PrimitiveType.PrimitiveTypeName.BOOLEAN, i) => + mutableRow.setBoolean(i, values(i).asInstanceOf[Boolean]) + case (PrimitiveType.PrimitiveTypeName.BINARY, i) => + val bytes = values(i).asInstanceOf[Binary].getBytes + aggSchema.fields(i).dataType match { + case StringType => + mutableRow.update(i, UTF8String.fromBytes(bytes)) + case BinaryType => + mutableRow.update(i, bytes) + case d: DecimalType => + val decimal = + Decimal(new BigDecimal(new BigInteger(bytes), d.scale), d.precision, d.scale) + mutableRow.setDecimal(i, decimal, d.precision) + case _ => throw new SparkException("Unexpected type for Binary") + } + case (PrimitiveType.PrimitiveTypeName.FIXED_LEN_BYTE_ARRAY, i) => + val bytes = values(i).asInstanceOf[Binary].getBytes + aggSchema.fields(i).dataType match { + case d: DecimalType => + val decimal = + Decimal(new BigDecimal(new BigInteger(bytes), d.scale), d.precision, d.scale) + mutableRow.setDecimal(i, decimal, d.precision) + case _ => throw new SparkException("Unexpected type for FIXED_LEN_BYTE_ARRAY") + } + case _ => + throw new SparkException("Unexpected parquet type name") + } + mutableRow + } + + /** + * When the Aggregates (Max/Min/Count) are pushed down to parquet, in the case of + * PARQUET_VECTORIZED_READER_ENABLED sets to true, we don't need buildColumnarReader + * to read data from parquet and aggregate at spark layer. Instead we want + * to calculate the Aggregates (Max/Min/Count) result using the statistics information + * from parquet footer file, and then construct a ColumnarBatch from these Aggregate results. + * + * @return Aggregate results in the format of ColumnarBatch + */ + private[sql] def createColumnarBatchFromAggResult( + footer: ParquetMetadata, + dataSchema: StructType, + aggregation: Aggregation, + aggSchema: StructType, + offHeap: Boolean, + datetimeRebaseModeInRead: String): ColumnarBatch = { + val (parquetTypes, values) = + ParquetUtils.getPushedDownAggResult(footer, dataSchema, aggregation) + val capacity = 4 * 1024 + val footerFileMetaData = footer.getFileMetaData + val datetimeRebaseMode = DataSourceUtils.datetimeRebaseMode( + footerFileMetaData.getKeyValueMetaData.get, + datetimeRebaseModeInRead) + val columnVectors = if (offHeap) { + OffHeapColumnVector.allocateColumns(capacity, aggSchema) + } else { + OnHeapColumnVector.allocateColumns(capacity, aggSchema) + } + + parquetTypes.zipWithIndex.foreach { + case (PrimitiveType.PrimitiveTypeName.INT32, i) => + aggSchema.fields(i).dataType match { + case ByteType => + columnVectors(i).appendByte(values(i).asInstanceOf[Integer].toByte) + case ShortType => + columnVectors(i).appendShort(values(i).asInstanceOf[Integer].toShort) + case IntegerType => + columnVectors(i).appendInt(values(i).asInstanceOf[Integer]) + case DateType => + val dateRebaseFunc = DataSourceUtils.creteDateRebaseFuncInRead( + datetimeRebaseMode, "Parquet") + columnVectors(i).appendInt(dateRebaseFunc(values(i).asInstanceOf[Integer])) + case _ => throw new SparkException("Unexpected type for INT32") + } + case (PrimitiveType.PrimitiveTypeName.INT64, i) => + columnVectors(i).appendLong(values(i).asInstanceOf[Long]) + case (PrimitiveType.PrimitiveTypeName.FLOAT, i) => + columnVectors(i).appendFloat(values(i).asInstanceOf[Float]) + case (PrimitiveType.PrimitiveTypeName.DOUBLE, i) => + columnVectors(i).appendDouble(values(i).asInstanceOf[Double]) + case (PrimitiveType.PrimitiveTypeName.BINARY, i) => + val bytes = values(i).asInstanceOf[Binary].getBytes + columnVectors(i).putByteArray(0, bytes, 0, bytes.length) + case (PrimitiveType.PrimitiveTypeName.FIXED_LEN_BYTE_ARRAY, i) => + val bytes = values(i).asInstanceOf[Binary].getBytes + columnVectors(i).putByteArray(0, bytes, 0, bytes.length) + case (PrimitiveType.PrimitiveTypeName.BOOLEAN, i) => + columnVectors(i).appendBoolean(values(i).asInstanceOf[Boolean]) + case _ => + throw new SparkException("Unexpected parquet type name") + } + new ColumnarBatch(columnVectors.asInstanceOf[Array[ColumnVector]], 1) + } + + /** + * Calculate the pushed down Aggregates (Max/Min/Count) result using the statistics + * information from parquet footer file. + * + * @return A tuple of `Array[PrimitiveType.PrimitiveTypeName]` and Array[Any]. + * The first element is the PrimitiveTypeName of the Aggregate column, + * and the second element is the aggregated value. + */ + private[sql] def getPushedDownAggResult( + footer: ParquetMetadata, + dataSchema: StructType, + aggregation: Aggregation) + : (Array[PrimitiveType.PrimitiveTypeName], Array[Any]) = { + val footerFileMetaData = footer.getFileMetaData + val fields = footerFileMetaData.getSchema.getFields + val blocks = footer.getBlocks() + val typesBuilder = ArrayBuilder.make[PrimitiveType.PrimitiveTypeName] + val valuesBuilder = ArrayBuilder.make[Any] + + aggregation.aggregateExpressions.indices.foreach { i => + var value: Any = None + var rowCount = 0L + var isCount = false + var index = 0 + blocks.forEach { block => + val blockMetaData = block.getColumns() + aggregation.aggregateExpressions(i) match { + case Max(col, _) => col match { + case ref: FieldReference => + index = dataSchema.fieldNames.toList.indexOf(ref.fieldNames.head) + val currentMax = getCurrentBlockMaxOrMin(blockMetaData, index, true) + if (currentMax != None && + (value == None || currentMax.asInstanceOf[Comparable[Any]].compareTo(value) > 0)) { + value = currentMax + } + case _ => + throw new SparkException(s"Expression $col is not currently supported.") + } + case Min(col, _) => col match { + case ref: FieldReference => + index = dataSchema.fieldNames.toList.indexOf(ref.fieldNames.head) + val currentMin = getCurrentBlockMaxOrMin(blockMetaData, index, false) + if (currentMin != None && + (value == None || currentMin.asInstanceOf[Comparable[Any]].compareTo(value) < 0)) { + value = currentMin + } + case _ => + throw new SparkException("Expression $col is not currently supported.") + } + case Count(col, _, _) => + rowCount += block.getRowCount + isCount = true + col match { + case ref: FieldReference => + index = dataSchema.fieldNames.toList.indexOf(ref.fieldNames.head) + // Count(*) includes the null values, but Count (colName) doesn't. + rowCount -= getNumNulls(blockMetaData, index) + case LiteralValue(1, _) => // Count(*) + case _ => + throw new SparkException("Expression $col is not currently supported.") + } + case _ => + } + } + if (isCount) { + valuesBuilder += rowCount + typesBuilder += PrimitiveType.PrimitiveTypeName.INT64 + } else { + valuesBuilder += value + typesBuilder += fields.get(index).asPrimitiveType.getPrimitiveTypeName + } + } + (typesBuilder.result(), valuesBuilder.result()) + } + + /** + * get the Max or Min value for ith column in the current block + * + * @return the Max or Min value + */ + private def getCurrentBlockMaxOrMin( + columnChunkMetaData: util.List[ColumnChunkMetaData], + i: Int, + isMax: Boolean): Any = { + val statistics = columnChunkMetaData.get(i).getStatistics() + if (!statistics.hasNonNullValue) { + throw new UnsupportedOperationException("No min/max found for parquet file, Set SQLConf" + + " PARQUET_AGGREGATE_PUSHDOWN_ENABLED to false and execute again") + } else { + if (isMax) statistics.genericGetMax() else statistics.genericGetMin() + } + } + + private def getNumNulls( + columnChunkMetaData: util.List[ColumnChunkMetaData], + i: Int): Long = { + val statistics = columnChunkMetaData.get(i).getStatistics() + if (!statistics.isNumNullsSet()) { + throw new UnsupportedOperationException("Number of nulls not set for parquet file." + + " Set SQLConf PARQUET_AGGREGATE_PUSHDOWN_ENABLED to false and execute again") + } + statistics.getNumNulls(); + } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/PushDownUtils.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/PushDownUtils.scala index 1f57f17911457..9b552150fc0ac 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/PushDownUtils.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/PushDownUtils.scala @@ -20,8 +20,10 @@ package org.apache.spark.sql.execution.datasources.v2 import scala.collection.mutable import org.apache.spark.sql.catalyst.expressions.{AttributeReference, AttributeSet, Expression, NamedExpression, PredicateHelper, SchemaPruning} +import org.apache.spark.sql.catalyst.expressions.aggregate.AggregateExpression import org.apache.spark.sql.catalyst.util.CharVarcharUtils -import org.apache.spark.sql.connector.read.{Scan, ScanBuilder, SupportsPushDownFilters, SupportsPushDownRequiredColumns} +import org.apache.spark.sql.connector.expressions.{Aggregation, FieldReference} +import org.apache.spark.sql.connector.read.{Scan, ScanBuilder, SupportsPushDownAggregates, SupportsPushDownFilters, SupportsPushDownRequiredColumns} import org.apache.spark.sql.execution.datasources.DataSourceStrategy import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.sources @@ -70,6 +72,41 @@ object PushDownUtils extends PredicateHelper { } } + /** + * Pushes down aggregates to the data source reader + * + * @return pushed aggregation. + */ + def pushAggregates( + scanBuilder: ScanBuilder, + aggregates: Seq[AggregateExpression], + groupBy: Seq[Expression]): Option[Aggregation] = { + + def columnAsString(e: Expression): Option[FieldReference] = e match { + case AttributeReference(name, _, _, _) => Some(FieldReference(Seq(name))) + case _ => None + } + + scanBuilder match { + case r: SupportsPushDownAggregates => + val translatedAggregates = aggregates.map(DataSourceStrategy.translateAggregate).flatten + val translatedGroupBys = groupBy.map(columnAsString).flatten + + if (translatedAggregates.length != aggregates.length || + translatedGroupBys.length != groupBy.length) { + return None + } + + val agg = Aggregation(translatedAggregates, translatedGroupBys) + if (r.pushAggregation(agg)) { + Some(agg) + } else { + None + } + case _ => None + } + } + /** * Applies column pruning to the data source, w.r.t. the references of the given expressions. * diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/V2ScanRelationPushDown.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/V2ScanRelationPushDown.scala index d2180566790ac..48915e91d0dcd 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/V2ScanRelationPushDown.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/V2ScanRelationPushDown.scala @@ -17,23 +17,34 @@ package org.apache.spark.sql.execution.datasources.v2 -import org.apache.spark.sql.catalyst.expressions.{And, Expression, NamedExpression, ProjectionOverSchema, SubqueryExpression} +import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.expressions.aggregate.AggregateExpression import org.apache.spark.sql.catalyst.planning.ScanOperation -import org.apache.spark.sql.catalyst.plans.logical.{Filter, LogicalPlan, Project} +import org.apache.spark.sql.catalyst.plans.logical.{Aggregate, Filter, LeafNode, LogicalPlan, Project} import org.apache.spark.sql.catalyst.rules.Rule -import org.apache.spark.sql.connector.read.{Scan, V1Scan} +import org.apache.spark.sql.connector.read.{Scan, ScanBuilder, SupportsPushDownAggregates, SupportsPushDownFilters, V1Scan} import org.apache.spark.sql.execution.datasources.DataSourceStrategy import org.apache.spark.sql.sources import org.apache.spark.sql.types.StructType -object V2ScanRelationPushDown extends Rule[LogicalPlan] { +object V2ScanRelationPushDown extends Rule[LogicalPlan] with PredicateHelper { import DataSourceV2Implicits._ - override def apply(plan: LogicalPlan): LogicalPlan = plan transformDown { - case ScanOperation(project, filters, relation: DataSourceV2Relation) => - val scanBuilder = relation.table.asReadable.newScanBuilder(relation.options) + def apply(plan: LogicalPlan): LogicalPlan = { + applyColumnPruning(pushdownAggregate(pushDownFilters(createScanBuilder(plan)))) + } + + private def createScanBuilder(plan: LogicalPlan) = plan.transform { + case r: DataSourceV2Relation => + ScanBuilderHolder(r.output, r, r.table.asReadable.newScanBuilder(r.options)) + } - val normalizedFilters = DataSourceStrategy.normalizeExprs(filters, relation.output) + private def pushDownFilters(plan: LogicalPlan) = plan.transform { + // update the scan builder with filter push down and return a new plan with filter pushed + case Filter(condition, sHolder: ScanBuilderHolder) => + val filters = splitConjunctivePredicates(condition) + val normalizedFilters = + DataSourceStrategy.normalizeExprs(filters, sHolder.relation.output) val (normalizedFiltersWithSubquery, normalizedFiltersWithoutSubquery) = normalizedFilters.partition(SubqueryExpression.hasSubquery) @@ -41,37 +52,157 @@ object V2ScanRelationPushDown extends Rule[LogicalPlan] { // `postScanFilters` need to be evaluated after the scan. // `postScanFilters` and `pushedFilters` can overlap, e.g. the parquet row group filter. val (pushedFilters, postScanFiltersWithoutSubquery) = PushDownUtils.pushFilters( - scanBuilder, normalizedFiltersWithoutSubquery) + sHolder.builder, normalizedFiltersWithoutSubquery) val postScanFilters = postScanFiltersWithoutSubquery ++ normalizedFiltersWithSubquery + logInfo( + s""" + |Pushing operators to ${sHolder.relation.name} + |Pushed Filters: ${pushedFilters.mkString(", ")} + |Post-Scan Filters: ${postScanFilters.mkString(",")} + """.stripMargin) + + val filterCondition = postScanFilters.reduceLeftOption(And) + filterCondition.map(Filter(_, sHolder)).getOrElse(sHolder) + } + + def pushdownAggregate(plan: LogicalPlan): LogicalPlan = plan.transform { + // update the scan builder with agg pushdown and return a new plan with agg pushed + case aggNode @ Aggregate(groupingExpressions, resultExpressions, child) => + child match { + case ScanOperation(project, filters, sHolder: ScanBuilderHolder) + if project.forall(_.isInstanceOf[AttributeReference]) => + sHolder.builder match { + case _: SupportsPushDownAggregates => + if (filters.length == 0) { // can't push down aggregate if postScanFilters exist + val aggregates = resultExpressions.flatMap { expr => + expr.collect { + case agg: AggregateExpression => agg + } + } + val pushedAggregates = PushDownUtils + .pushAggregates(sHolder.builder, aggregates, groupingExpressions) + if (pushedAggregates.isEmpty) { + aggNode // return original plan node + } else { + // No need to do column pruning because only the aggregate columns are used as + // DataSourceV2ScanRelation output columns. All the other columns are not + // included in the output. Since PushDownUtils.pruneColumns is not called, + // ScanBuilder.requiredSchema is not pruned, but ScanBuilder.requiredSchema is + // not used anyways. The schema for aggregate columns will be built in Scan. + val scan = sHolder.builder.build() + + // scalastyle:off + // use the group by columns and aggregate columns as the output columns + // e.g. TABLE t (c1 INT, c2 INT, c3 INT) + // SELECT min(c1), max(c1) FROM t GROUP BY c2; + // Use c2, min(c1), max(c1) as output for DataSourceV2ScanRelation + // We want to have the following logical plan: + // == Optimized Logical Plan == + // Aggregate [c2#10], [min(min(c1)#21) AS min(c1)#17, max(max(c1)#22) AS max(c1)#18] + // +- RelationV2[c2#10, min(c1)#21, max(c1)#22] + // scalastyle:on + val newOutput = scan.readSchema().toAttributes + val groupAttrs = groupingExpressions.zip(newOutput).map { + case (a: Attribute, b: Attribute) => b.withExprId(a.exprId) + case (_, b) => b + } + val output = groupAttrs ++ newOutput.drop(groupAttrs.length) + + logInfo( + s""" + |Pushing operators to ${sHolder.relation.name} + |Pushed Aggregate Functions: + | ${pushedAggregates.get.aggregateExpressions.mkString(", ")} + |Pushed Group by: + | ${pushedAggregates.get.groupByColumns.mkString(", ")} + |Output: ${output.mkString(", ")} + """.stripMargin) + + val scanRelation = DataSourceV2ScanRelation(sHolder.relation, scan, output) + assert(scanRelation.output.length == + groupingExpressions.length + aggregates.length) + + val plan = Aggregate( + output.take(groupingExpressions.length), resultExpressions, scanRelation) + + // scalastyle:off + // Change the optimized logical plan to reflect the pushed down aggregate + // e.g. TABLE t (c1 INT, c2 INT, c3 INT) + // SELECT min(c1), max(c1) FROM t GROUP BY c2; + // The original logical plan is + // Aggregate [c2#10],[min(c1#9) AS min(c1)#17, max(c1#9) AS max(c1)#18] + // +- RelationV2[c1#9, c2#10] ... + // + // After change the V2ScanRelation output to [c2#10, min(c1)#21, max(c1)#22] + // we have the following + // !Aggregate [c2#10], [min(c1#9) AS min(c1)#17, max(c1#9) AS max(c1)#18] + // +- RelationV2[c2#10, min(c1)#21, max(c1)#22] ... + // + // We want to change it to + // == Optimized Logical Plan == + // Aggregate [c2#10], [min(min(c1)#21) AS min(c1)#17, max(max(c1)#22) AS max(c1)#18] + // +- RelationV2[c2#10, min(c1)#21, max(c1)#22] ... + // scalastyle:on + var i = 0 + val aggOutput = output.drop(groupAttrs.length) + plan.transformExpressions { + case agg: AggregateExpression => + i += 1 + val aggFunction: aggregate.AggregateFunction = + agg.aggregateFunction match { + case _: aggregate.Max => aggregate.Max(aggOutput(i - 1)) + case _: aggregate.Min => aggregate.Min(aggOutput(i - 1)) + case _: aggregate.Sum => aggregate.Sum(aggOutput(i - 1)) + case _: aggregate.Count => aggregate.Sum(aggOutput(i - 1)) + case _ => agg.aggregateFunction + } + agg.copy(aggregateFunction = aggFunction) + } + } + } else { + aggNode + } + case _ => aggNode + } + case _ => aggNode + } + } + + def applyColumnPruning(plan: LogicalPlan): LogicalPlan = plan.transform { + case ScanOperation(project, filters, sHolder: ScanBuilderHolder) => + // column pruning val normalizedProjects = DataSourceStrategy - .normalizeExprs(project, relation.output) + .normalizeExprs(project, sHolder.output) .asInstanceOf[Seq[NamedExpression]] val (scan, output) = PushDownUtils.pruneColumns( - scanBuilder, relation, normalizedProjects, postScanFilters) + sHolder.builder, sHolder.relation, normalizedProjects, filters) + logInfo( s""" - |Pushing operators to ${relation.name} - |Pushed Filters: ${pushedFilters.mkString(", ")} - |Post-Scan Filters: ${postScanFilters.mkString(",")} |Output: ${output.mkString(", ")} """.stripMargin) val wrappedScan = scan match { case v1: V1Scan => val translated = filters.flatMap(DataSourceStrategy.translateFilter(_, true)) + val pushedFilters = sHolder.builder match { + case f: SupportsPushDownFilters => + f.pushedFilters() + case _ => Array.empty[sources.Filter] + } V1ScanWrapper(v1, translated, pushedFilters) case _ => scan } - val scanRelation = DataSourceV2ScanRelation(relation, wrappedScan, output) + val scanRelation = DataSourceV2ScanRelation(sHolder.relation, wrappedScan, output) val projectionOverSchema = ProjectionOverSchema(output.toStructType) val projectionFunc = (expr: Expression) => expr transformDown { case projectionOverSchema(newExpr) => newExpr } - val filterCondition = postScanFilters.reduceLeftOption(And) + val filterCondition = filters.reduceLeftOption(And) val newFilterCondition = filterCondition.map(projectionFunc) val withFilter = newFilterCondition.map(Filter(_, scanRelation)).getOrElse(scanRelation) @@ -83,11 +214,15 @@ object V2ScanRelationPushDown extends Rule[LogicalPlan] { } else { withFilter } - withProjection } } +case class ScanBuilderHolder( + output: Seq[AttributeReference], + relation: DataSourceV2Relation, + builder: ScanBuilder) extends LeafNode + // A wrapper for v1 scan to carry the translated filters and the handled ones. This is required by // the physical v1 scan node. case class V1ScanWrapper( diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/parquet/ParquetPartitionReaderFactory.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/parquet/ParquetPartitionReaderFactory.scala index 78076040e7cf5..5be31c8367b80 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/parquet/ParquetPartitionReaderFactory.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/parquet/ParquetPartitionReaderFactory.scala @@ -25,14 +25,16 @@ import org.apache.hadoop.mapreduce._ import org.apache.hadoop.mapreduce.task.TaskAttemptContextImpl import org.apache.parquet.filter2.compat.FilterCompat import org.apache.parquet.filter2.predicate.{FilterApi, FilterPredicate} -import org.apache.parquet.format.converter.ParquetMetadataConverter.SKIP_ROW_GROUPS +import org.apache.parquet.format.converter.ParquetMetadataConverter.{NO_FILTER, SKIP_ROW_GROUPS} import org.apache.parquet.hadoop.{ParquetInputFormat, ParquetRecordReader} +import org.apache.parquet.hadoop.metadata.ParquetMetadata import org.apache.spark.TaskContext import org.apache.spark.broadcast.Broadcast import org.apache.spark.internal.Logging import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.util.DateTimeUtils +import org.apache.spark.sql.connector.expressions.Aggregation import org.apache.spark.sql.connector.read.{InputPartition, PartitionReader} import org.apache.spark.sql.execution.datasources.{DataSourceUtils, PartitionedFile, RecordReaderIterator} import org.apache.spark.sql.execution.datasources.parquet._ @@ -52,7 +54,9 @@ import org.apache.spark.util.SerializableConfiguration * @param dataSchema Schema of Parquet files. * @param readDataSchema Required schema of Parquet files. * @param partitionSchema Schema of partitions. + * @aggSchema Schema of the pushed down aggregation. * @param filters Filters to be pushed down in the batch scan. + * @param aggregation Aggregation to be pushed down in the batch scan. * @param parquetOptions The options of Parquet datasource that are set for the read. */ case class ParquetPartitionReaderFactory( @@ -61,10 +65,17 @@ case class ParquetPartitionReaderFactory( dataSchema: StructType, readDataSchema: StructType, partitionSchema: StructType, + aggSchema: StructType, filters: Array[Filter], + aggregation: Option[Aggregation], parquetOptions: ParquetOptions) extends FilePartitionReaderFactory with Logging { private val isCaseSensitive = sqlConf.caseSensitiveAnalysis - private val resultSchema = StructType(partitionSchema.fields ++ readDataSchema.fields) + private val newReadDataSchema = if (aggregation.isEmpty) { + readDataSchema + } else { + aggSchema + } + private val resultSchema = StructType(partitionSchema.fields ++ newReadDataSchema.fields) private val enableOffHeapColumnVector = sqlConf.offHeapColumnVectorEnabled private val enableVectorizedReader: Boolean = sqlConf.parquetVectorizedReaderEnabled && resultSchema.forall(_.dataType.isInstanceOf[AtomicType]) @@ -80,6 +91,30 @@ case class ParquetPartitionReaderFactory( private val datetimeRebaseModeInRead = parquetOptions.datetimeRebaseModeInRead private val int96RebaseModeInRead = parquetOptions.int96RebaseModeInRead + private def getFooter(file: PartitionedFile): ParquetMetadata = { + val conf = broadcastedConf.value.value + + val filePath = new Path(new URI(file.filePath)) + + if (aggregation.isEmpty) { + ParquetFooterReader.readFooter(conf, filePath, SKIP_ROW_GROUPS) + } else { + ParquetFooterReader.readFooter(conf, filePath, NO_FILTER) + } + } + + // Define isCreatedByParquetMr as function to avoid unnecessary parquet footer reads. + private def isCreatedByParquetMr(file: PartitionedFile): Boolean = + getFooter(file).getFileMetaData.getCreatedBy().startsWith("parquet-mr") + + private def convertTz(isCreatedByParquetMr: Boolean): Option[ZoneId] = + if (timestampConversion && !isCreatedByParquetMr) { + Some(DateTimeUtils + .getZoneId(broadcastedConf.value.value.get(SQLConf.SESSION_LOCAL_TIMEZONE.key))) + } else { + None + } + override def supportColumnarReads(partition: InputPartition): Boolean = { sqlConf.parquetVectorizedReaderEnabled && sqlConf.wholeStageEnabled && resultSchema.length <= sqlConf.wholeStageMaxNumFields && @@ -87,36 +122,72 @@ case class ParquetPartitionReaderFactory( } override def buildReader(file: PartitionedFile): PartitionReader[InternalRow] = { - val reader = if (enableVectorizedReader) { - createVectorizedReader(file) + val fileReader = if (aggregation.isEmpty) { + + val reader = if (enableVectorizedReader) { + createVectorizedReader(file) + } else { + createRowBaseReader(file) + } + + new PartitionReader[InternalRow] { + override def next(): Boolean = reader.nextKeyValue() + + override def get(): InternalRow = reader.getCurrentValue.asInstanceOf[InternalRow] + + override def close(): Unit = reader.close() + } } else { - createRowBaseReader(file) - } + new PartitionReader[InternalRow] { + var count = 0 - val fileReader = new PartitionReader[InternalRow] { - override def next(): Boolean = reader.nextKeyValue() + override def next(): Boolean = if (count == 0) true else false - override def get(): InternalRow = reader.getCurrentValue.asInstanceOf[InternalRow] + override def get(): InternalRow = { + count += 1 + val footer = getFooter(file) + ParquetUtils.createInternalRowFromAggResult(footer, dataSchema, aggregation.get, + aggSchema, datetimeRebaseModeInRead) + } - override def close(): Unit = reader.close() + override def close(): Unit = return + } } - new PartitionReaderWithPartitionValues(fileReader, readDataSchema, + new PartitionReaderWithPartitionValues(fileReader, newReadDataSchema, partitionSchema, file.partitionValues) } override def buildColumnarReader(file: PartitionedFile): PartitionReader[ColumnarBatch] = { - val vectorizedReader = createVectorizedReader(file) - vectorizedReader.enableReturningBatches() + val fileReader = if (aggregation.isEmpty) { + val vectorizedReader = createVectorizedReader(file) + vectorizedReader.enableReturningBatches() + + new PartitionReader[ColumnarBatch] { + override def next(): Boolean = vectorizedReader.nextKeyValue() + + override def get(): ColumnarBatch = + vectorizedReader.getCurrentValue.asInstanceOf[ColumnarBatch] - new PartitionReader[ColumnarBatch] { - override def next(): Boolean = vectorizedReader.nextKeyValue() + override def close(): Unit = vectorizedReader.close() + } + } else { + new PartitionReader[ColumnarBatch] { + var count = 0 - override def get(): ColumnarBatch = - vectorizedReader.getCurrentValue.asInstanceOf[ColumnarBatch] + override def next(): Boolean = if (count == 0) true else false - override def close(): Unit = vectorizedReader.close() + override def get(): ColumnarBatch = { + count += 1 + val footer = getFooter(file) + ParquetUtils.createColumnarBatchFromAggResult(footer, dataSchema, aggregation.get, + aggSchema, enableOffHeapColumnVector, datetimeRebaseModeInRead) + } + + override def close(): Unit = return + } } + fileReader } private def buildReaderBase[T]( @@ -131,8 +202,7 @@ case class ParquetPartitionReaderFactory( val filePath = new Path(new URI(file.filePath)) val split = new FileSplit(filePath, file.start, file.length, Array.empty[String]) - lazy val footerFileMetaData = - ParquetFooterReader.readFooter(conf, filePath, SKIP_ROW_GROUPS).getFileMetaData + lazy val footerFileMetaData = getFooter(file).getFileMetaData // Try to push down filters when filter push-down is enabled. val pushed = if (enableParquetFilterPushDown) { val parquetSchema = footerFileMetaData.getSchema @@ -151,16 +221,6 @@ case class ParquetPartitionReaderFactory( // *only* if the file was created by something other than "parquet-mr", so check the actual // writer here for this file. We have to do this per-file, as each file in the table may // have different writers. - // Define isCreatedByParquetMr as function to avoid unnecessary parquet footer reads. - def isCreatedByParquetMr: Boolean = - footerFileMetaData.getCreatedBy().startsWith("parquet-mr") - - val convertTz = - if (timestampConversion && !isCreatedByParquetMr) { - Some(DateTimeUtils.getZoneId(conf.get(SQLConf.SESSION_LOCAL_TIMEZONE.key))) - } else { - None - } val attemptId = new TaskAttemptID(new TaskID(new JobID(), TaskType.MAP, 0), 0) val hadoopAttemptContext = new TaskAttemptContextImpl(conf, attemptId) @@ -181,7 +241,7 @@ case class ParquetPartitionReaderFactory( file.partitionValues, hadoopAttemptContext, pushed, - convertTz, + convertTz(isCreatedByParquetMr(file)), datetimeRebaseMode, int96RebaseMode) reader.initialize(split, hadoopAttemptContext) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/parquet/ParquetScan.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/parquet/ParquetScan.scala index 60573ba10ccb6..8e8e8cac6e0a3 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/parquet/ParquetScan.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/parquet/ParquetScan.scala @@ -24,6 +24,7 @@ import org.apache.parquet.hadoop.ParquetInputFormat import org.apache.spark.sql.SparkSession import org.apache.spark.sql.catalyst.expressions.Expression +import org.apache.spark.sql.connector.expressions.Aggregation import org.apache.spark.sql.connector.read.PartitionReaderFactory import org.apache.spark.sql.execution.datasources.PartitioningAwareFileIndex import org.apache.spark.sql.execution.datasources.parquet.{ParquetOptions, ParquetReadSupport, ParquetWriteSupport} @@ -42,6 +43,8 @@ case class ParquetScan( readDataSchema: StructType, readPartitionSchema: StructType, pushedFilters: Array[Filter], + pushedAggregations: Option[Aggregation], + pushedDownAggSchema: StructType, options: CaseInsensitiveStringMap, partitionFilters: Seq[Expression] = Seq.empty, dataFilters: Seq[Expression] = Seq.empty) extends FileScan { @@ -85,28 +88,63 @@ case class ParquetScan( dataSchema, readDataSchema, readPartitionSchema, + pushedDownAggSchema, pushedFilters, + pushedAggregations, new ParquetOptions(options.asCaseSensitiveMap.asScala.toMap, sqlConf)) } override def equals(obj: Any): Boolean = obj match { case p: ParquetScan => + val pushedDownAggEqual = if (pushedAggregations.nonEmpty) { + equivalentAggregations(pushedAggregations.get, p.pushedAggregations.get) + } else { + true + } super.equals(p) && dataSchema == p.dataSchema && options == p.options && - equivalentFilters(pushedFilters, p.pushedFilters) + equivalentFilters(pushedFilters, p.pushedFilters) && pushedDownAggEqual case _ => false } override def hashCode(): Int = getClass.hashCode() + lazy private val pushedAggregationsStr = if (pushedAggregations.nonEmpty) { + seqToString(pushedAggregations.get.aggregateExpressions) + } else { + "[]" + } + + lazy private val pushedGroupByStr = if (pushedAggregations.nonEmpty) { + seqToString(pushedAggregations.get.groupByColumns) + } else { + "[]" + } + override def description(): String = { - super.description() + ", PushedFilters: " + seqToString(pushedFilters) + super.description() + ", PushedFilters: " + seqToString(pushedFilters) + + ", PushedAggregation: " + pushedAggregationsStr + + ", PushedGroupBy: " + pushedGroupByStr } override def getMetaData(): Map[String, String] = { - super.getMetaData() ++ Map("PushedFilters" -> seqToString(pushedFilters)) + super.getMetaData() ++ Map("PushedFilters" -> seqToString(pushedFilters)) ++ + Map("PushedAggregation" -> pushedAggregationsStr) ++ + Map("PushedGroupBy" -> pushedGroupByStr) } override def withFilters( partitionFilters: Seq[Expression], dataFilters: Seq[Expression]): FileScan = this.copy(partitionFilters = partitionFilters, dataFilters = dataFilters) + + override def readSchema(): StructType = if (pushedAggregations.nonEmpty) { + pushedDownAggSchema + } else { + StructType(readDataSchema.fields ++ readPartitionSchema.fields) + } + + // Returns whether the two given [[Aggregation]]s are equivalent. + private def equivalentAggregations(a: Aggregation, b: Aggregation): Boolean = { + a.aggregateExpressions.sortBy(_.hashCode()) + .sameElements(b.aggregateExpressions.sortBy(_.hashCode())) + } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/parquet/ParquetScanBuilder.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/parquet/ParquetScanBuilder.scala index 44053830defe5..f2cfff3b2a98a 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/parquet/ParquetScanBuilder.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/parquet/ParquetScanBuilder.scala @@ -19,13 +19,15 @@ package org.apache.spark.sql.execution.datasources.v2.parquet import scala.collection.JavaConverters._ +import org.apache.spark.SparkException import org.apache.spark.sql.SparkSession -import org.apache.spark.sql.connector.read.{Scan, SupportsPushDownFilters} +import org.apache.spark.sql.connector.expressions.{Aggregation, Count, FieldReference, LiteralValue, Max, Min} +import org.apache.spark.sql.connector.read.{Scan, SupportsPushDownAggregates, SupportsPushDownFilters} import org.apache.spark.sql.execution.datasources.PartitioningAwareFileIndex import org.apache.spark.sql.execution.datasources.parquet.{ParquetFilters, SparkToParquetSchemaConverter} import org.apache.spark.sql.execution.datasources.v2.FileScanBuilder import org.apache.spark.sql.sources.Filter -import org.apache.spark.sql.types.StructType +import org.apache.spark.sql.types.{ArrayType, LongType, MapType, StructField, StructType, TimestampType} import org.apache.spark.sql.util.CaseInsensitiveStringMap case class ParquetScanBuilder( @@ -34,7 +36,8 @@ case class ParquetScanBuilder( schema: StructType, dataSchema: StructType, options: CaseInsensitiveStringMap) - extends FileScanBuilder(sparkSession, fileIndex, dataSchema) with SupportsPushDownFilters { + extends FileScanBuilder(sparkSession, fileIndex, dataSchema) with SupportsPushDownFilters + with SupportsPushDownAggregates{ lazy val hadoopConf = { val caseSensitiveMap = options.asCaseSensitiveMap.asScala.toMap // Hadoop Configurations are case sensitive. @@ -70,8 +73,61 @@ case class ParquetScanBuilder( // All filters that can be converted to Parquet are pushed down. override def pushedFilters(): Array[Filter] = pushedParquetFilters + private var pushedAggregations = Option.empty[Aggregation] + + private var pushedAggregateSchema = new StructType() + + override def pushAggregation(aggregation: Aggregation): Boolean = { + if (!sparkSession.sessionState.conf.parquetAggregatePushDown || + aggregation.groupByColumns.nonEmpty || pushedParquetFilters.length > 0) { + return false + } + + aggregation.aggregateExpressions.foreach { + case Max(col, _) => col match { + case ref: FieldReference => + val field = dataSchema.fields(dataSchema.fieldNames.toList.indexOf(ref.fieldNames.head)) + field.dataType match { + // not push down nested column and Timestamp (INT96 sort order is undefined, parquet + // doesn't return statistics for INT96) + case StructType(_) | ArrayType(_, _) | MapType(_, _, _) | TimestampType => return false + case _ => pushedAggregateSchema = + pushedAggregateSchema.add(field.copy("max(" + field.name + ")")) + } + case _ => + throw new SparkException("Expression $col is not currently supported.") + } + case Min(col, _) => col match { + case ref: FieldReference => + val field = dataSchema.fields(dataSchema.fieldNames.toList.indexOf(ref.fieldNames.head)) + field.dataType match { + // not push down nested column and Timestamp (INT96 sort order is undefined, parquet + // doesn't return statistics for INT96) + case StructType(_) | ArrayType(_, _) | MapType(_, _, _) | TimestampType => return false + case _ => pushedAggregateSchema = + pushedAggregateSchema.add(field.copy("min(" + field.name + ")")) + } + case _ => + throw new SparkException("Expression $col is not currently supported.") + } + // not push down distinct count + case Count(col, _, false) => col match { + case _: FieldReference => pushedAggregateSchema = + pushedAggregateSchema.add(StructField("count(" + col + ")", LongType)) + case LiteralValue(1, _) => + pushedAggregateSchema = pushedAggregateSchema.add(StructField("count(*)", LongType)) + case _ => + throw new SparkException("Expression $col is not currently supported.") + } + case _ => return false + } + this.pushedAggregations = Some(aggregation) + true + } + override def build(): Scan = { ParquetScan(sparkSession, hadoopConf, fileIndex, dataSchema, readDataSchema(), - readPartitionSchema(), pushedParquetFilters, options) + readPartitionSchema(), pushedParquetFilters, pushedAggregations, pushedAggregateSchema, + options) } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/parquet/ParquetTable.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/parquet/ParquetTable.scala index e9f9bf8df35e6..261774a3e1e49 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/parquet/ParquetTable.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/parquet/ParquetTable.scala @@ -38,7 +38,7 @@ case class ParquetTable( extends FileTable(sparkSession, options, paths, userSpecifiedSchema) { override def newScanBuilder(options: CaseInsensitiveStringMap): ParquetScanBuilder = - new ParquetScanBuilder(sparkSession, fileIndex, schema, dataSchema, options) + ParquetScanBuilder(sparkSession, fileIndex, schema, dataSchema, options) override def inferSchema(files: Seq[FileStatus]): Option[StructType] = ParquetUtils.inferSchema(sparkSession, options.asScala.toMap, files) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/FileScanSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/FileScanSuite.scala index 4e7fe8455ff93..36ba853bf19e7 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/FileScanSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/FileScanSuite.scala @@ -24,6 +24,7 @@ import org.apache.hadoop.fs.{FileStatus, Path} import org.apache.spark.sql.catalyst.dsl.expressions._ import org.apache.spark.sql.catalyst.expressions.{And, Expression, IsNull, LessThan} +import org.apache.spark.sql.connector.expressions.Aggregation import org.apache.spark.sql.execution.datasources.{PartitioningAwareFileIndex, PartitionSpec} import org.apache.spark.sql.execution.datasources.v2.FileScan import org.apache.spark.sql.execution.datasources.v2.csv.CSVScan @@ -354,7 +355,8 @@ class FileScanSuite extends FileScanSuiteBase { val scanBuilders = Seq[(String, ScanBuilder, Seq[String])]( ("ParquetScan", (s, fi, ds, rds, rps, f, o, pf, df) => - ParquetScan(s, s.sessionState.newHadoopConf(), fi, ds, rds, rps, f, o, pf, df), + ParquetScan(s, s.sessionState.newHadoopConf(), fi, ds, rds, rps, f, + Option.empty[Aggregation], null, o, pf, df), Seq.empty), ("OrcScan", (s, fi, ds, rds, rps, f, o, pf, df) => diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetQuerySuite.scala index 9ef43995467c6..955ba3875e850 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetQuerySuite.scala @@ -18,6 +18,7 @@ package org.apache.spark.sql.execution.datasources.parquet import java.io.File +import java.sql.{Date, Timestamp} import java.util.concurrent.TimeUnit import org.apache.hadoop.fs.{FileSystem, Path} @@ -30,8 +31,9 @@ import org.apache.spark.sql.catalyst.expressions.SpecificInternalRow import org.apache.spark.sql.execution.FileSourceScanExec import org.apache.spark.sql.execution.datasources.{SchemaColumnConvertNotSupportedException, SQLHadoopMapReduceCommitProtocol} import org.apache.spark.sql.execution.datasources.parquet.TestingUDT.{NestedStruct, NestedStructUDT, SingleElement} -import org.apache.spark.sql.execution.datasources.v2.BatchScanExec +import org.apache.spark.sql.execution.datasources.v2.{BatchScanExec, DataSourceV2ScanRelation} import org.apache.spark.sql.execution.datasources.v2.parquet.ParquetScan +import org.apache.spark.sql.functions.min import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.test.SharedSparkSession import org.apache.spark.sql.types._ @@ -40,7 +42,8 @@ import org.apache.spark.util.Utils /** * A test suite that tests various Parquet queries. */ -abstract class ParquetQuerySuite extends QueryTest with ParquetTest with SharedSparkSession { +abstract class ParquetQuerySuite extends QueryTest with ParquetTest with SharedSparkSession + with ExplainSuiteHelper { import testImplicits._ test("simple select queries") { @@ -901,6 +904,302 @@ abstract class ParquetQuerySuite extends QueryTest with ParquetTest with SharedS } } } + + test("test aggregate push down - nested data shouldn't be pushed down") { + val data = (1 to 10).map(i => Tuple1((i, Seq(s"val_$i")))) + withSQLConf( + SQLConf.PARQUET_AGGREGATE_PUSHDOWN_ENABLED.key -> "true") { + withParquetTable(data, "t") { + val count = sql("SELECT Count(_1) FROM t") + count.queryExecution.optimizedPlan.collect { + case _: DataSourceV2ScanRelation => + val expected_plan_fragment = + "PushedAggregation: [Count(_1,LongType,false)]" + checkKeywordsExistsInExplain(count, expected_plan_fragment) + } + checkAnswer(count, Seq(Row(10))) + } + } + } + + test("test aggregate push down alias") { + val data = Seq((-2, "abc", 2), (3, "def", 4), (6, "ghi", 2), (0, null, 19), + (9, "mno", 7), (2, null, 6)) + withParquetTable(data, "t") { + withSQLConf( + SQLConf.PARQUET_AGGREGATE_PUSHDOWN_ENABLED.key -> "true") { + + val selectAgg1 = sql("SELECT min(_1) + max(_1) as res FROM t") + + selectAgg1.queryExecution.optimizedPlan.collect { + case _: DataSourceV2ScanRelation => + val expected_plan_fragment = + "PushedAggregation: [Min(_1,IntegerType), Max(_1,IntegerType)]" + checkKeywordsExistsInExplain(selectAgg1, expected_plan_fragment) + } + + val selectAgg2 = sql("SELECT min(_1) as minValue, max(_1) as maxValue FROM t") + selectAgg2.queryExecution.optimizedPlan.collect { + case _: DataSourceV2ScanRelation => + val expected_plan_fragment = + "PushedAggregation: [Min(_1,IntegerType), Max(_1,IntegerType)]" + checkKeywordsExistsInExplain(selectAgg2, expected_plan_fragment) + } + + val df = spark.table("t") + val query = df.select($"_1".as("col1")).agg(min($"col1")) + query.queryExecution.optimizedPlan.collect { + case _: DataSourceV2ScanRelation => + val expected_plan_fragment = + "PushedAggregation: []" // aggregate alias not pushed down + checkKeywordsExistsInExplain(query, expected_plan_fragment) + } + } + } + } + + test("test aggregate push down") { + val data = Seq((-2, "abc", 2), (3, "def", 4), (6, "ghi", 2), (0, null, 19), + (9, "mno", 7), (2, null, 6)) + withParquetTable(data, "t") { + withSQLConf( + SQLConf.PARQUET_AGGREGATE_PUSHDOWN_ENABLED.key -> "true") { + + // aggregate not pushed down if there is a filter + val selectAgg1 = sql("SELECT min(_3) FROM t WHERE _1 > 0") + selectAgg1.queryExecution.optimizedPlan.collect { + case _: DataSourceV2ScanRelation => + val expected_plan_fragment = + "PushedFilters: [IsNotNull(_1), GreaterThan(_1,0)], PushedAggregation: []" + checkKeywordsExistsInExplain(selectAgg1, expected_plan_fragment) + } + + // This is not pushed down since aggregates have arithmetic operation + val selectAgg2 = sql("SELECT min(_3 + _1), max(_3 + _1) FROM t") + checkAnswer(selectAgg2, Seq(Row(0, 19))) + + // aggregate not pushed down if one of them can't be pushed down + val selectAgg3 = sql("SELECT min(_1), sum(_3) FROM t") + selectAgg3.queryExecution.optimizedPlan.collect { + case _: DataSourceV2ScanRelation => + val expected_plan_fragment = + "PushedAggregation: []" + checkKeywordsExistsInExplain(selectAgg3, expected_plan_fragment) + } + + val selectAgg4 = sql("SELECT min(_3), min(_3), max(_3), min(_1), max(_1), max(_1)," + + " count(*), count(_1), count(_2), count(_3) FROM t") + + selectAgg4.queryExecution.optimizedPlan.collect { + case _: DataSourceV2ScanRelation => + val expected_plan_fragment = + "PushedAggregation: [Min(_3,IntegerType), " + + "Min(_3,IntegerType), " + + "Max(_3,IntegerType), " + + "Min(_1,IntegerType), " + + "Max(_1,IntegerType), " + + "Max(_1,IntegerType), " + + "Count(1,LongType,false), " + + "Count(_1,LongType,false), " + + "Count(_2,LongType,false), " + + "Count(_3,LongType,false)]" + checkKeywordsExistsInExplain(selectAgg4, expected_plan_fragment) + } + + checkAnswer(selectAgg4, Seq(Row(2, 2, 19, -2, 9, 9, 6, 6, 4, 6))) + } + } + spark.sessionState.catalog.dropTable( + TableIdentifier("tmp"), ignoreIfNotExists = true, purge = false) + } + + test("aggregate pushdown for different data types") { + implicit class StringToDate(s: String) { + def date: Date = Date.valueOf(s) + } + + implicit class StringToTs(s: String) { + def ts: Timestamp = Timestamp.valueOf(s) + } + + val rows = + Seq( + Row( + "a string", + true, + 10.toByte, + "Spark SQL".getBytes, + 12.toShort, + 3, + Long.MaxValue, + 0.15.toFloat, + 0.75D, + Decimal("12.345678"), + ("2021-01-01").date, + ("2015-01-01 23:50:59.123").ts), + Row( + "test string", + false, + 1.toByte, + "Parquet".getBytes, + 2.toShort, + null, + Long.MinValue, + 0.25.toFloat, + 0.85D, + Decimal("1.2345678"), + ("2015-01-01").date, + ("2021-01-01 23:50:59.123").ts), + Row( + null, + true, + 10000.toByte, + "Spark ML".getBytes, + 222.toShort, + 113, + 11111111L, + 0.25.toFloat, + 0.75D, + Decimal("12345.678"), + ("2004-06-19").date, + ("1999-08-26 10:43:59.123").ts) + ) + + val schema = StructType(List(StructField("StringCol", StringType, true), + StructField("BooleanCol", BooleanType, false), + StructField("ByteCol", ByteType, false), + StructField("BinaryCol", BinaryType, false), + StructField("ShortCol", ShortType, false), + StructField("IntegerCol", IntegerType, true), + StructField("LongCol", LongType, false), + StructField("FloatCol", FloatType, false), + StructField("DoubleCol", DoubleType, false), + StructField("DecimalCol", DecimalType(25, 5), true), + StructField("DateCol", DateType, false), + StructField("TimestampCol", TimestampType, false)).toArray) + + val rdd = sparkContext.parallelize(rows) + spark.createDataFrame(rdd, schema).createOrReplaceTempView("test") + val enableVectorizedReader = Seq("false", "true") + for (testVectorizedReader <- enableVectorizedReader) { + withSQLConf( + SQLConf.PARQUET_AGGREGATE_PUSHDOWN_ENABLED.key -> "true", + vectorizedReaderEnabledKey -> testVectorizedReader) { + withTempPath { file => + + val testMinWithTS = sql("SELECT min(StringCol), min(BooleanCol), min(ByteCol), " + + "min(BinaryCol), min(ShortCol), min(IntegerCol), min(LongCol), min(FloatCol), " + + "min(DoubleCol), min(DecimalCol), min(DateCol), min(TimestampCol) FROM test") + + // INT96 (Timestamp) sort order is undefined, parquet doesn't return stats for this type + // so aggregates are not pushed down + testMinWithTS.queryExecution.optimizedPlan.collect { + case _: DataSourceV2ScanRelation => + val expected_plan_fragment = + "PushedAggregation: []" + checkKeywordsExistsInExplain(testMinWithTS, expected_plan_fragment) + } + + checkAnswer(testMinWithTS, Seq(Row("a string", false, 1.toByte, "Parquet".getBytes, + 2.toShort, 3, -9223372036854775808L, 0.15.toFloat, 0.75D, 1.23457, + ("2004-06-19").date, ("1999-08-26 10:43:59.123").ts))) + + val testMinWithOutTS = sql("SELECT min(StringCol), min(BooleanCol), min(ByteCol), " + + "min(BinaryCol), min(ShortCol), min(IntegerCol), min(LongCol), min(FloatCol), " + + "min(DoubleCol), min(DecimalCol), min(DateCol) FROM test") + + testMinWithOutTS.queryExecution.optimizedPlan.collect { + case _: DataSourceV2ScanRelation => + val expected_plan_fragment = + "PushedAggregation: [Min(StringCol,StringType), " + + "Min(BooleanCol,BooleanType), " + + "Min(ByteCol,ByteType), " + + "Min(BinaryCol,BinaryType), " + + "Min(ShortCol,ShortType), " + + "Min(IntegerCol,IntegerType), " + + "Min(LongCol,LongType), " + + "Min(FloatCol,FloatType), " + + "Min(DoubleCol,DoubleType), " + + "Min(DecimalCol,DecimalType(25,5)), " + + "Min(DateCol,DateType)]" + checkKeywordsExistsInExplain(testMinWithOutTS, expected_plan_fragment) + } + + checkAnswer(testMinWithOutTS, Seq(Row("a string", false, 1.toByte, "Parquet".getBytes, + 2.toShort, 3, -9223372036854775808L, 0.15.toFloat, 0.75D, 1.23457, + ("2004-06-19").date))) + + val testMaxWithTS = sql("SELECT max(StringCol), max(BooleanCol), max(ByteCol), " + + "max(BinaryCol), max(ShortCol), max(IntegerCol), max(LongCol), max(FloatCol), " + + "max(DoubleCol), max(DecimalCol), max(DateCol), max(TimestampCol) FROM test") + + // INT96 (Timestamp) sort order is undefined, parquet doesn't return stats for this type + // so aggregates are not pushed down + testMaxWithTS.queryExecution.optimizedPlan.collect { + case _: DataSourceV2ScanRelation => + val expected_plan_fragment = + "PushedAggregation: []" + checkKeywordsExistsInExplain(testMaxWithTS, expected_plan_fragment) + } + + checkAnswer(testMaxWithTS, Seq(Row("test string", true, 16.toByte, + "Spark SQL".getBytes, 222.toShort, 113, 9223372036854775807L, 0.25.toFloat, 0.85D, + 12345.678, ("2021-01-01").date, ("2021-01-01 23:50:59.123").ts))) + + val testMaxWithoutTS = sql("SELECT max(StringCol), max(BooleanCol), max(ByteCol), " + + "max(BinaryCol), max(ShortCol), max(IntegerCol), max(LongCol), max(FloatCol), " + + "max(DoubleCol), max(DecimalCol), max(DateCol) FROM test") + + testMaxWithoutTS.queryExecution.optimizedPlan.collect { + case _: DataSourceV2ScanRelation => + val expected_plan_fragment = + "PushedAggregation: [Max(StringCol,StringType), " + + "Max(BooleanCol,BooleanType), " + + "Max(ByteCol,ByteType), " + + "Max(BinaryCol,BinaryType), " + + "Max(ShortCol,ShortType), " + + "Max(IntegerCol,IntegerType), " + + "Max(LongCol,LongType), " + + "Max(FloatCol,FloatType), " + + "Max(DoubleCol,DoubleType), " + + "Max(DecimalCol,DecimalType(25,5)), " + + "Max(DateCol,DateType)]" + checkKeywordsExistsInExplain(testMaxWithoutTS, expected_plan_fragment) + } + + checkAnswer(testMaxWithoutTS, Seq(Row("test string", true, 16.toByte, + "Spark SQL".getBytes, 222.toShort, 113, 9223372036854775807L, 0.25.toFloat, 0.85D, + 12345.678, ("2021-01-01").date))) + + val testCount = sql("SELECT count(*), count(StringCol), count(BooleanCol)," + + " count(ByteCol), count(BinaryCol), count(ShortCol), count(IntegerCol)," + + " count(LongCol), count(FloatCol), count(DoubleCol)," + + " count(DecimalCol), count(DateCol), count(TimestampCol) FROM test") + + testCount.queryExecution.optimizedPlan.collect { + case _: DataSourceV2ScanRelation => + val expected_plan_fragment = + "PushedAggregation: [Count(`1`,LongType,false), " + + "Count(StringCol,LongType,false), " + + "Count(BooleanCol,LongType,false), " + + "Count(ByteCol,LongType,false), " + + "Count(BinaryCol,LongType,false), " + + "Count(ShortCol,LongType,false), " + + "Count(IntegerCol,LongType,false), " + + "Count(LongCol,LongType,false), " + + "Count(FloatCol,LongType,false), " + + "Count(DoubleCol,LongType,false), " + + "Count(DecimalCol,LongType,false), " + + "Count(DateCol,LongType,false), " + + "Count(TimestampCol,LongType,false)]" + checkKeywordsExistsInExplain(testCount, expected_plan_fragment) + } + + checkAnswer(testCount, Seq(Row(3, 2, 3, 3, 3, 3, 2, 3, 3, 3, 3, 3, 3))) + } + } + } + } } class ParquetV1QuerySuite extends ParquetQuerySuite {