From 11287f610fac162679bd0d5050a7fba9770458dd Mon Sep 17 00:00:00 2001 From: Gengliang Wang Date: Mon, 6 Jan 2020 13:40:21 -0800 Subject: [PATCH 1/8] support partition pruning in file source V2 --- .../apache/spark/sql/v2/avro/AvroScan.scala | 8 ++- .../spark/sql/execution/SparkOptimizer.scala | 2 +- .../PruneFileSourcePartitions.scala | 65 ++++++++++++++----- .../execution/datasources/v2/FileScan.scala | 38 ++++++++--- .../datasources/v2/TextBasedFileScan.scala | 6 +- .../datasources/v2/csv/CSVScan.scala | 12 ++-- .../datasources/v2/json/JsonScan.scala | 12 ++-- .../datasources/v2/orc/OrcScan.scala | 15 +++-- .../datasources/v2/parquet/ParquetScan.scala | 13 ++-- .../datasources/v2/text/TextScan.scala | 11 +++- 10 files changed, 130 insertions(+), 52 deletions(-) diff --git a/external/avro/src/main/scala/org/apache/spark/sql/v2/avro/AvroScan.scala b/external/avro/src/main/scala/org/apache/spark/sql/v2/avro/AvroScan.scala index e1268ac2ce581..f9d358bfe74c9 100644 --- a/external/avro/src/main/scala/org/apache/spark/sql/v2/avro/AvroScan.scala +++ b/external/avro/src/main/scala/org/apache/spark/sql/v2/avro/AvroScan.scala @@ -21,6 +21,7 @@ import scala.collection.JavaConverters._ import org.apache.hadoop.fs.Path import org.apache.spark.sql.SparkSession +import org.apache.spark.sql.catalyst.expressions.Expression import org.apache.spark.sql.connector.read.PartitionReaderFactory import org.apache.spark.sql.execution.datasources.PartitioningAwareFileIndex import org.apache.spark.sql.execution.datasources.v2.FileScan @@ -34,8 +35,8 @@ case class AvroScan( dataSchema: StructType, readDataSchema: StructType, readPartitionSchema: StructType, - options: CaseInsensitiveStringMap) - extends FileScan(sparkSession, fileIndex, readDataSchema, readPartitionSchema) { + options: CaseInsensitiveStringMap, + partitionFilters: Seq[Expression] = Seq.empty) extends FileScan { override def isSplitable(path: Path): Boolean = true override def createReaderFactory(): PartitionReaderFactory = { @@ -49,4 +50,7 @@ case class AvroScan( AvroPartitionReaderFactory(sparkSession.sessionState.conf, broadcastedConf, dataSchema, readDataSchema, readPartitionSchema, caseSensitiveMap) } + + override def withPartitionFilters(partitionFilters: Seq[Expression]): FileScan = + this.copy(partitionFilters = partitionFilters) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkOptimizer.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkOptimizer.scala index e65faefad5b9e..013d94768a2a8 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkOptimizer.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkOptimizer.scala @@ -37,7 +37,7 @@ class SparkOptimizer( override def earlyScanPushDownRules: Seq[Rule[LogicalPlan]] = // TODO: move SchemaPruning into catalyst - SchemaPruning :: PruneFileSourcePartitions :: V2ScanRelationPushDown :: Nil + SchemaPruning :: V2ScanRelationPushDown :: PruneFileSourcePartitions :: Nil override def defaultBatches: Seq[Batch] = (preOptimizationBatches ++ super.defaultBatches :+ Batch("Optimize Metadata Only Query", Once, OptimizeMetadataOnlyQuery(catalog)) :+ diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/PruneFileSourcePartitions.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/PruneFileSourcePartitions.scala index 02d629721327d..d0beab99d65da 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/PruneFileSourcePartitions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/PruneFileSourcePartitions.scala @@ -17,13 +17,40 @@ package org.apache.spark.sql.execution.datasources +import org.apache.spark.sql.SparkSession import org.apache.spark.sql.catalyst.catalog.CatalogStatistics import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.planning.PhysicalOperation -import org.apache.spark.sql.catalyst.plans.logical.{Filter, LogicalPlan, Project} +import org.apache.spark.sql.catalyst.plans.logical.{Filter, LeafNode, LogicalPlan, Project} import org.apache.spark.sql.catalyst.rules.Rule +import org.apache.spark.sql.execution.datasources.v2.{DataSourceV2ScanRelation, FileScan, FileTable} +import org.apache.spark.sql.types.StructType private[sql] object PruneFileSourcePartitions extends Rule[LogicalPlan] { + + private def getPartitionKeyFilters( + sparkSession: SparkSession, + relation: LeafNode, + partitionSchema: StructType, + normalizedFilters: Seq[Expression]): ExpressionSet = { + val partitionColumns = + relation.resolve(partitionSchema, sparkSession.sessionState.analyzer.resolver) + val partitionSet = AttributeSet(partitionColumns) + ExpressionSet(normalizedFilters.filter { f => + f.references.subsetOf(partitionSet) + }) + } + + private def rebuildPhysicalOperation( + projects: Seq[NamedExpression], + filters: Seq[Expression], + relation: LeafNode): Project = { + // Keep partition-pruning predicates so that they are visible in physical planning + val filterExpression = filters.reduceLeft(And) + val filter = Filter(filterExpression, relation) + Project(projects, filter) + } + override def apply(plan: LogicalPlan): LogicalPlan = plan transformDown { case op @ PhysicalOperation(projects, filters, logicalRelation @ @@ -41,29 +68,35 @@ private[sql] object PruneFileSourcePartitions extends Rule[LogicalPlan] { if filters.nonEmpty && fsRelation.partitionSchemaOption.isDefined => val normalizedFilters = DataSourceStrategy.normalizeExprs( filters.filterNot(SubqueryExpression.hasSubquery), logicalRelation.output) - - val sparkSession = fsRelation.sparkSession - val partitionColumns = - logicalRelation.resolve( - partitionSchema, sparkSession.sessionState.analyzer.resolver) - val partitionSet = AttributeSet(partitionColumns) - val partitionKeyFilters = ExpressionSet(normalizedFilters.filter { f => - f.references.subsetOf(partitionSet) - }) - + val partitionKeyFilters = getPartitionKeyFilters( + fsRelation.sparkSession, logicalRelation, partitionSchema, normalizedFilters) if (partitionKeyFilters.nonEmpty) { val prunedFileIndex = catalogFileIndex.filterPartitions(partitionKeyFilters.toSeq) val prunedFsRelation = - fsRelation.copy(location = prunedFileIndex)(sparkSession) + fsRelation.copy(location = prunedFileIndex)(fsRelation.sparkSession) // Change table stats based on the sizeInBytes of pruned files val withStats = logicalRelation.catalogTable.map(_.copy( stats = Some(CatalogStatistics(sizeInBytes = BigInt(prunedFileIndex.sizeInBytes))))) val prunedLogicalRelation = logicalRelation.copy( relation = prunedFsRelation, catalogTable = withStats) - // Keep partition-pruning predicates so that they are visible in physical planning - val filterExpression = filters.reduceLeft(And) - val filter = Filter(filterExpression, prunedLogicalRelation) - Project(projects, filter) + rebuildPhysicalOperation(projects, filters, prunedLogicalRelation) + } else { + op + } + + case op @ PhysicalOperation(projects, filters, + v2Relation @ DataSourceV2ScanRelation(_, scan: FileScan, output)) + if filters.nonEmpty && scan.readDataSchema.nonEmpty => + val normalizedFilters = DataSourceStrategy.normalizeExprs( + filters.filter(_.deterministic), output) + val partitionKeyFilters = getPartitionKeyFilters(scan.sparkSession, + v2Relation, scan.readPartitionSchema, normalizedFilters) + if (partitionKeyFilters.nonEmpty) { + val prunedV2Relation = + v2Relation.copy(scan = scan.withPartitionFilters(partitionKeyFilters.toSeq)) + val afterScanFilters = + ExpressionSet(filters) -- partitionKeyFilters.filter(_.references.nonEmpty) + rebuildPhysicalOperation(projects, afterScanFilters.toSeq, prunedV2Relation) } else { op } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/FileScan.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/FileScan.scala index 55104a2b21deb..06903ed57e277 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/FileScan.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/FileScan.scala @@ -24,6 +24,7 @@ import org.apache.hadoop.fs.Path import org.apache.spark.internal.Logging import org.apache.spark.internal.config.IO_WARNING_LARGEFILETHRESHOLD import org.apache.spark.sql.{AnalysisException, SparkSession} +import org.apache.spark.sql.catalyst.expressions.Expression import org.apache.spark.sql.catalyst.expressions.codegen.GenerateUnsafeProjection import org.apache.spark.sql.connector.read.{Batch, InputPartition, Scan, Statistics, SupportsReportStatistics} import org.apache.spark.sql.execution.PartitionedFileUtil @@ -32,13 +33,7 @@ import org.apache.spark.sql.sources.Filter import org.apache.spark.sql.types.StructType import org.apache.spark.util.Utils -abstract class FileScan( - sparkSession: SparkSession, - fileIndex: PartitioningAwareFileIndex, - readDataSchema: StructType, - readPartitionSchema: StructType) - extends Scan - with Batch with SupportsReportStatistics with Logging { +trait FileScan extends Scan with Batch with SupportsReportStatistics with Logging { /** * Returns whether a file with `path` could be split or not. */ @@ -46,6 +41,30 @@ abstract class FileScan( false } + def sparkSession: SparkSession + + def fileIndex: PartitioningAwareFileIndex + + /** + * Returns the required data schema + */ + def readDataSchema: StructType + + /** + * Returns the required partition schema + */ + def readPartitionSchema: StructType + + /** + * Returns the filters that can be use for partition pruning + */ + def partitionFilters: Seq[Expression] + + /** + * Create a new `FileScan` instance from the current one with different `partitionFilters`. + */ + def withPartitionFilters(partitionFilters: Seq[Expression]): FileScan + /** * If a file with `path` is unsplittable, return the unsplittable reason, * otherwise return `None`. @@ -55,11 +74,14 @@ abstract class FileScan( "undefined" } + protected def seqToString(seq: Seq[Any]): String = seq.mkString("[", ", ", "]") + override def description(): String = { val locationDesc = fileIndex.getClass.getSimpleName + fileIndex.rootPaths.mkString("[", ", ", "]") val metadata: Map[String, String] = Map( "ReadSchema" -> readDataSchema.catalogString, + "PartitionFilters" -> seqToString(partitionFilters), "Location" -> locationDesc) val metadataStr = metadata.toSeq.sorted.map { case (key, value) => @@ -71,7 +93,7 @@ abstract class FileScan( } protected def partitions: Seq[FilePartition] = { - val selectedPartitions = fileIndex.listFiles(Seq.empty, Seq.empty) + val selectedPartitions = fileIndex.listFiles(partitionFilters, Seq.empty) val maxSplitBytes = FilePartition.maxSplitBytes(sparkSession, selectedPartitions) val partitionAttributes = fileIndex.partitionSchema.toAttributes val attributeMap = partitionAttributes.map(a => normalizeName(a.name) -> a).toMap diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/TextBasedFileScan.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/TextBasedFileScan.scala index 7ddd99a0293b1..1ca3fd42c0597 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/TextBasedFileScan.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/TextBasedFileScan.scala @@ -29,11 +29,7 @@ import org.apache.spark.util.Utils abstract class TextBasedFileScan( sparkSession: SparkSession, - fileIndex: PartitioningAwareFileIndex, - readDataSchema: StructType, - readPartitionSchema: StructType, - options: CaseInsensitiveStringMap) - extends FileScan(sparkSession, fileIndex, readDataSchema, readPartitionSchema) { + options: CaseInsensitiveStringMap) extends FileScan { @transient private lazy val codecFactory: CompressionCodecFactory = new CompressionCodecFactory( sparkSession.sessionState.newHadoopConfWithOptions(options.asScala.toMap)) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/csv/CSVScan.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/csv/CSVScan.scala index 5125de9313a4c..1b2ddf6257f46 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/csv/CSVScan.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/csv/CSVScan.scala @@ -22,11 +22,11 @@ import org.apache.hadoop.fs.Path import org.apache.spark.sql.{AnalysisException, SparkSession} import org.apache.spark.sql.catalyst.csv.CSVOptions -import org.apache.spark.sql.catalyst.expressions.ExprUtils +import org.apache.spark.sql.catalyst.expressions.{Expression, ExprUtils} import org.apache.spark.sql.connector.read.PartitionReaderFactory import org.apache.spark.sql.execution.datasources.PartitioningAwareFileIndex import org.apache.spark.sql.execution.datasources.csv.CSVDataSource -import org.apache.spark.sql.execution.datasources.v2.TextBasedFileScan +import org.apache.spark.sql.execution.datasources.v2.{FileScan, TextBasedFileScan} import org.apache.spark.sql.types.StructType import org.apache.spark.sql.util.CaseInsensitiveStringMap import org.apache.spark.util.SerializableConfiguration @@ -37,8 +37,9 @@ case class CSVScan( dataSchema: StructType, readDataSchema: StructType, readPartitionSchema: StructType, - options: CaseInsensitiveStringMap) - extends TextBasedFileScan(sparkSession, fileIndex, readDataSchema, readPartitionSchema, options) { + options: CaseInsensitiveStringMap, + partitionFilters: Seq[Expression] = Seq.empty) + extends TextBasedFileScan(sparkSession, options) { private lazy val parsedOptions: CSVOptions = new CSVOptions( options.asScala.toMap, @@ -87,4 +88,7 @@ case class CSVScan( CSVPartitionReaderFactory(sparkSession.sessionState.conf, broadcastedConf, dataSchema, readDataSchema, readPartitionSchema, parsedOptions) } + + override def withPartitionFilters(partitionFilters: Seq[Expression]): FileScan = + this.copy(partitionFilters = partitionFilters) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/json/JsonScan.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/json/JsonScan.scala index a64b78d3c8305..be5af0f03d063 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/json/JsonScan.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/json/JsonScan.scala @@ -21,13 +21,13 @@ import scala.collection.JavaConverters._ import org.apache.hadoop.fs.Path import org.apache.spark.sql.{AnalysisException, SparkSession} -import org.apache.spark.sql.catalyst.expressions.ExprUtils +import org.apache.spark.sql.catalyst.expressions.{Expression, ExprUtils} import org.apache.spark.sql.catalyst.json.JSONOptionsInRead import org.apache.spark.sql.catalyst.util.CaseInsensitiveMap import org.apache.spark.sql.connector.read.PartitionReaderFactory import org.apache.spark.sql.execution.datasources.PartitioningAwareFileIndex import org.apache.spark.sql.execution.datasources.json.JsonDataSource -import org.apache.spark.sql.execution.datasources.v2.TextBasedFileScan +import org.apache.spark.sql.execution.datasources.v2.{FileScan, TextBasedFileScan} import org.apache.spark.sql.types.StructType import org.apache.spark.sql.util.CaseInsensitiveStringMap import org.apache.spark.util.SerializableConfiguration @@ -38,8 +38,9 @@ case class JsonScan( dataSchema: StructType, readDataSchema: StructType, readPartitionSchema: StructType, - options: CaseInsensitiveStringMap) - extends TextBasedFileScan(sparkSession, fileIndex, readDataSchema, readPartitionSchema, options) { + options: CaseInsensitiveStringMap, + partitionFilters: Seq[Expression] = Seq.empty) + extends TextBasedFileScan(sparkSession, options) { private val parsedOptions = new JSONOptionsInRead( CaseInsensitiveMap(options.asScala.toMap), @@ -86,4 +87,7 @@ case class JsonScan( JsonPartitionReaderFactory(sparkSession.sessionState.conf, broadcastedConf, dataSchema, readDataSchema, readPartitionSchema, parsedOptions) } + + override def withPartitionFilters(partitionFilters: Seq[Expression]): FileScan = + this.copy(partitionFilters = partitionFilters) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/orc/OrcScan.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/orc/OrcScan.scala index 40784516a6f34..6f71e2e7b7eb0 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/orc/OrcScan.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/orc/OrcScan.scala @@ -20,6 +20,7 @@ import org.apache.hadoop.conf.Configuration import org.apache.hadoop.fs.Path import org.apache.spark.sql.SparkSession +import org.apache.spark.sql.catalyst.expressions.Expression import org.apache.spark.sql.connector.read.PartitionReaderFactory import org.apache.spark.sql.execution.datasources.PartitioningAwareFileIndex import org.apache.spark.sql.execution.datasources.v2.FileScan @@ -36,8 +37,8 @@ case class OrcScan( readDataSchema: StructType, readPartitionSchema: StructType, options: CaseInsensitiveStringMap, - pushedFilters: Array[Filter]) - extends FileScan(sparkSession, fileIndex, readDataSchema, readPartitionSchema) { + pushedFilters: Array[Filter], + partitionFilters: Seq[Expression] = Seq.empty) extends FileScan { override def isSplitable(path: Path): Boolean = true override def createReaderFactory(): PartitionReaderFactory = { @@ -52,14 +53,18 @@ case class OrcScan( override def equals(obj: Any): Boolean = obj match { case o: OrcScan => fileIndex == o.fileIndex && dataSchema == o.dataSchema && - readDataSchema == o.readDataSchema && readPartitionSchema == o.readPartitionSchema && - options == o.options && equivalentFilters(pushedFilters, o.pushedFilters) + readDataSchema == o.readDataSchema && readPartitionSchema == o.readPartitionSchema && + options == o.options && equivalentFilters(pushedFilters, o.pushedFilters) && + Set(partitionFilters.map(_.canonicalized)) == Set(o.partitionFilters.map(_.canonicalized)) case _ => false } override def hashCode(): Int = getClass.hashCode() override def description(): String = { - super.description() + ", PushedFilters: " + pushedFilters.mkString("[", ", ", "]") + super.description() + ", PushedFilters: " + seqToString(pushedFilters) } + + override def withPartitionFilters(partitionFilters: Seq[Expression]): FileScan = + this.copy(partitionFilters = partitionFilters) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/parquet/ParquetScan.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/parquet/ParquetScan.scala index cf16a174d9e22..ab0ead5f7fea9 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/parquet/ParquetScan.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/parquet/ParquetScan.scala @@ -21,6 +21,7 @@ import org.apache.hadoop.fs.Path import org.apache.parquet.hadoop.ParquetInputFormat import org.apache.spark.sql.SparkSession +import org.apache.spark.sql.catalyst.expressions.Expression import org.apache.spark.sql.connector.read.PartitionReaderFactory import org.apache.spark.sql.execution.datasources.PartitioningAwareFileIndex import org.apache.spark.sql.execution.datasources.parquet.{ParquetReadSupport, ParquetWriteSupport} @@ -39,8 +40,8 @@ case class ParquetScan( readDataSchema: StructType, readPartitionSchema: StructType, pushedFilters: Array[Filter], - options: CaseInsensitiveStringMap) - extends FileScan(sparkSession, fileIndex, readDataSchema, readPartitionSchema) { + options: CaseInsensitiveStringMap, + partitionFilters: Seq[Expression] = Seq.empty) extends FileScan { override def isSplitable(path: Path): Boolean = true override def createReaderFactory(): PartitionReaderFactory = { @@ -82,13 +83,17 @@ case class ParquetScan( case p: ParquetScan => fileIndex == p.fileIndex && dataSchema == p.dataSchema && readDataSchema == p.readDataSchema && readPartitionSchema == p.readPartitionSchema && - options == p.options && equivalentFilters(pushedFilters, p.pushedFilters) + options == p.options && equivalentFilters(pushedFilters, p.pushedFilters) && + Set(partitionFilters.map(_.canonicalized)) == Set(p.partitionFilters.map(_.canonicalized)) case _ => false } override def hashCode(): Int = getClass.hashCode() override def description(): String = { - super.description() + ", PushedFilters: " + pushedFilters.mkString("[", ", ", "]") + super.description() + ", PushedFilters: " + seqToString(pushedFilters) } + + override def withPartitionFilters(partitionFilters: Seq[Expression]): FileScan = + this.copy(partitionFilters = partitionFilters) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/text/TextScan.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/text/TextScan.scala index a2c42db59d7fd..25fad8c16b65e 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/text/TextScan.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/text/TextScan.scala @@ -21,10 +21,11 @@ import scala.collection.JavaConverters._ import org.apache.hadoop.fs.Path import org.apache.spark.sql.SparkSession +import org.apache.spark.sql.catalyst.expressions.Expression import org.apache.spark.sql.connector.read.PartitionReaderFactory import org.apache.spark.sql.execution.datasources.PartitioningAwareFileIndex import org.apache.spark.sql.execution.datasources.text.TextOptions -import org.apache.spark.sql.execution.datasources.v2.TextBasedFileScan +import org.apache.spark.sql.execution.datasources.v2.{FileScan, TextBasedFileScan} import org.apache.spark.sql.types.StructType import org.apache.spark.sql.util.CaseInsensitiveStringMap import org.apache.spark.util.SerializableConfiguration @@ -34,8 +35,9 @@ case class TextScan( fileIndex: PartitioningAwareFileIndex, readDataSchema: StructType, readPartitionSchema: StructType, - options: CaseInsensitiveStringMap) - extends TextBasedFileScan(sparkSession, fileIndex, readDataSchema, readPartitionSchema, options) { + options: CaseInsensitiveStringMap, + partitionFilters: Seq[Expression] = Seq.empty) + extends TextBasedFileScan(sparkSession, options) { private val optionsAsScala = options.asScala.toMap private lazy val textOptions: TextOptions = new TextOptions(optionsAsScala) @@ -67,4 +69,7 @@ case class TextScan( TextPartitionReaderFactory(sparkSession.sessionState.conf, broadcastedConf, readDataSchema, readPartitionSchema, textOptions) } + + override def withPartitionFilters(partitionFilters: Seq[Expression]): FileScan = + this.copy(partitionFilters = partitionFilters) } From acfe4f1d614a8c81f16a011081159a2e1c018821 Mon Sep 17 00:00:00 2001 From: Gengliang Wang Date: Mon, 6 Jan 2020 16:24:24 -0800 Subject: [PATCH 2/8] fix compiling error --- .../datasources/PruneFileSourcePartitions.scala | 10 +++++++--- .../spark/sql/execution/datasources/orc/OrcTest.scala | 10 +++++----- .../sql/execution/datasources/orc/OrcFilterSuite.scala | 8 ++++---- 3 files changed, 16 insertions(+), 12 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/PruneFileSourcePartitions.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/PruneFileSourcePartitions.scala index d0beab99d65da..0edce8b34725f 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/PruneFileSourcePartitions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/PruneFileSourcePartitions.scala @@ -46,9 +46,13 @@ private[sql] object PruneFileSourcePartitions extends Rule[LogicalPlan] { filters: Seq[Expression], relation: LeafNode): Project = { // Keep partition-pruning predicates so that they are visible in physical planning - val filterExpression = filters.reduceLeft(And) - val filter = Filter(filterExpression, relation) - Project(projects, filter) + val withFilter = if (filters.nonEmpty) { + val filterExpression = filters.reduceLeft(And) + Filter(filterExpression, relation) + } else { + relation + } + Project(projects, withFilter) } override def apply(plan: LogicalPlan): LogicalPlan = plan transformDown { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/orc/OrcTest.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/orc/OrcTest.scala index 528c3474a17c5..388744bd0fd6e 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/orc/OrcTest.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/orc/OrcTest.scala @@ -119,14 +119,14 @@ abstract class OrcTest extends QueryTest with FileBasedDataSourceTest with Befor query.queryExecution.optimizedPlan match { case PhysicalOperation(_, filters, - DataSourceV2ScanRelation(_, OrcScan(_, _, _, _, _, _, _, pushedFilters), _)) => + DataSourceV2ScanRelation(_, o: OrcScan, _)) => assert(filters.nonEmpty, "No filter is analyzed from the given query") if (noneSupported) { - assert(pushedFilters.isEmpty, "Unsupported filters should not show in pushed filters") + assert(o.pushedFilters.isEmpty, "Unsupported filters should not show in pushed filters") } else { - assert(pushedFilters.nonEmpty, "No filter is pushed down") - val maybeFilter = OrcFilters.createFilter(query.schema, pushedFilters) - assert(maybeFilter.isEmpty, s"Couldn't generate filter predicate for $pushedFilters") + assert(o.pushedFilters.nonEmpty, "No filter is pushed down") + val maybeFilter = OrcFilters.createFilter(query.schema, o.pushedFilters) + assert(maybeFilter.isEmpty, s"Couldn't generate filter predicate for ${o.pushedFilters}") } case _ => diff --git a/sql/core/v2.3/src/test/scala/org/apache/spark/sql/execution/datasources/orc/OrcFilterSuite.scala b/sql/core/v2.3/src/test/scala/org/apache/spark/sql/execution/datasources/orc/OrcFilterSuite.scala index b95a32ef85ddf..d6afa2d24bdfa 100644 --- a/sql/core/v2.3/src/test/scala/org/apache/spark/sql/execution/datasources/orc/OrcFilterSuite.scala +++ b/sql/core/v2.3/src/test/scala/org/apache/spark/sql/execution/datasources/orc/OrcFilterSuite.scala @@ -55,11 +55,11 @@ class OrcFilterSuite extends OrcTest with SharedSparkSession { query.queryExecution.optimizedPlan match { case PhysicalOperation(_, filters, - DataSourceV2ScanRelation(_, OrcScan(_, _, _, _, _, _, _, pushedFilters), _)) => + DataSourceV2ScanRelation(_, o: OrcScan, _)) => assert(filters.nonEmpty, "No filter is analyzed from the given query") - assert(pushedFilters.nonEmpty, "No filter is pushed down") - val maybeFilter = OrcFilters.createFilter(query.schema, pushedFilters) - assert(maybeFilter.isDefined, s"Couldn't generate filter predicate for $pushedFilters") + assert(o.pushedFilters.nonEmpty, "No filter is pushed down") + val maybeFilter = OrcFilters.createFilter(query.schema, o.pushedFilters) + assert(maybeFilter.isDefined, s"Couldn't generate filter predicate for ${o.pushedFilters}") checker(maybeFilter.get) case _ => From 51a42a03e2c6fb7991c4b16e5c9a5f377d406105 Mon Sep 17 00:00:00 2001 From: Gengliang Wang Date: Mon, 6 Jan 2020 17:20:28 -0800 Subject: [PATCH 3/8] add test case --- .../PruneFileSourcePartitions.scala | 13 +++++----- .../spark/sql/FileBasedDataSourceSuite.scala | 25 ++++++++++++++++++- 2 files changed, 30 insertions(+), 8 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/PruneFileSourcePartitions.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/PruneFileSourcePartitions.scala index 0edce8b34725f..9e1c3eed926ae 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/PruneFileSourcePartitions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/PruneFileSourcePartitions.scala @@ -32,7 +32,10 @@ private[sql] object PruneFileSourcePartitions extends Rule[LogicalPlan] { sparkSession: SparkSession, relation: LeafNode, partitionSchema: StructType, - normalizedFilters: Seq[Expression]): ExpressionSet = { + filters: Seq[Expression], + output: Seq[AttributeReference]): ExpressionSet = { + val normalizedFilters = DataSourceStrategy.normalizeExprs( + filters.filter(f => f.deterministic && !SubqueryExpression.hasSubquery(f)), output) val partitionColumns = relation.resolve(partitionSchema, sparkSession.sessionState.analyzer.resolver) val partitionSet = AttributeSet(partitionColumns) @@ -70,10 +73,8 @@ private[sql] object PruneFileSourcePartitions extends Rule[LogicalPlan] { _, _)) if filters.nonEmpty && fsRelation.partitionSchemaOption.isDefined => - val normalizedFilters = DataSourceStrategy.normalizeExprs( - filters.filterNot(SubqueryExpression.hasSubquery), logicalRelation.output) val partitionKeyFilters = getPartitionKeyFilters( - fsRelation.sparkSession, logicalRelation, partitionSchema, normalizedFilters) + fsRelation.sparkSession, logicalRelation, partitionSchema, filters, logicalRelation.output) if (partitionKeyFilters.nonEmpty) { val prunedFileIndex = catalogFileIndex.filterPartitions(partitionKeyFilters.toSeq) val prunedFsRelation = @@ -91,10 +92,8 @@ private[sql] object PruneFileSourcePartitions extends Rule[LogicalPlan] { case op @ PhysicalOperation(projects, filters, v2Relation @ DataSourceV2ScanRelation(_, scan: FileScan, output)) if filters.nonEmpty && scan.readDataSchema.nonEmpty => - val normalizedFilters = DataSourceStrategy.normalizeExprs( - filters.filter(_.deterministic), output) val partitionKeyFilters = getPartitionKeyFilters(scan.sparkSession, - v2Relation, scan.readPartitionSchema, normalizedFilters) + v2Relation, scan.readPartitionSchema, filters, output) if (partitionKeyFilters.nonEmpty) { val prunedV2Relation = v2Relation.copy(scan = scan.withPartitionFilters(partitionKeyFilters.toSeq)) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/FileBasedDataSourceSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/FileBasedDataSourceSuite.scala index d4f76858af95f..0a19b769c6638 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/FileBasedDataSourceSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/FileBasedDataSourceSuite.scala @@ -29,7 +29,8 @@ import org.apache.spark.SparkException import org.apache.spark.scheduler.{SparkListener, SparkListenerTaskEnd} import org.apache.spark.sql.TestingUDT.{IntervalUDT, NullData, NullUDT} import org.apache.spark.sql.catalyst.planning.PhysicalOperation -import org.apache.spark.sql.execution.datasources.v2.DataSourceV2ScanRelation +import org.apache.spark.sql.execution.datasources.FilePartition +import org.apache.spark.sql.execution.datasources.v2.{BatchScanExec, DataSourceV2ScanRelation, FileScan} import org.apache.spark.sql.execution.datasources.v2.parquet.ParquetTable import org.apache.spark.sql.execution.joins.{BroadcastHashJoinExec, SortMergeJoinExec} import org.apache.spark.sql.functions._ @@ -726,6 +727,28 @@ class FileBasedDataSourceSuite extends QueryTest with SharedSparkSession { } } + test("File source v2: support partition pruning") { + withSQLConf(SQLConf.USE_V1_SOURCE_LIST.key -> "") { + allFileBasedDataSources.foreach { format => + withTempPath { dir => + Seq(("a", 1), ("b", 2)).toDF("v", "p").write.format(format) + .partitionBy("p").save(dir.getCanonicalPath) + val df = spark.read.format(format).load(dir.getCanonicalPath).where("p = 1") + val fileScan = df.queryExecution.executedPlan collectFirst { + case BatchScanExec(_, f: FileScan) => f + } + assert(fileScan.nonEmpty) + assert(fileScan.get.partitionFilters.nonEmpty) + assert(fileScan.get.planInputPartitions().forall { partition => + partition.asInstanceOf[FilePartition].files.forall { file => + file.filePath.contains("p=1") + } + }) + } + } + } + } + test("File table location should include both values of option `path` and `paths`") { withSQLConf(SQLConf.USE_V1_SOURCE_LIST.key -> "") { withTempPaths(3) { paths => From 65200b6ffc823b8848111a00e2d43f82b9a7604b Mon Sep 17 00:00:00 2001 From: Gengliang Wang Date: Tue, 7 Jan 2020 11:10:03 -0800 Subject: [PATCH 4/8] fix sameResult method --- .../apache/spark/sql/v2/avro/AvroScan.scala | 34 ++++++++++++------- .../execution/datasources/v2/FileScan.scala | 12 ++++++- .../datasources/v2/csv/CSVScan.scala | 8 +++++ .../datasources/v2/json/JsonScan.scala | 8 +++++ .../datasources/v2/orc/OrcScan.scala | 7 ++-- .../datasources/v2/parquet/ParquetScan.scala | 6 ++-- .../datasources/v2/text/TextScan.scala | 8 +++++ 7 files changed, 61 insertions(+), 22 deletions(-) diff --git a/external/avro/src/main/scala/org/apache/spark/sql/v2/avro/AvroScan.scala b/external/avro/src/main/scala/org/apache/spark/sql/v2/avro/AvroScan.scala index f9d358bfe74c9..bb840e69d99a3 100644 --- a/external/avro/src/main/scala/org/apache/spark/sql/v2/avro/AvroScan.scala +++ b/external/avro/src/main/scala/org/apache/spark/sql/v2/avro/AvroScan.scala @@ -37,20 +37,28 @@ case class AvroScan( readPartitionSchema: StructType, options: CaseInsensitiveStringMap, partitionFilters: Seq[Expression] = Seq.empty) extends FileScan { - override def isSplitable(path: Path): Boolean = true - - override def createReaderFactory(): PartitionReaderFactory = { - val caseSensitiveMap = options.asCaseSensitiveMap.asScala.toMap - // Hadoop Configurations are case sensitive. - val hadoopConf = sparkSession.sessionState.newHadoopConfWithOptions(caseSensitiveMap) - val broadcastedConf = sparkSession.sparkContext.broadcast( - new SerializableConfiguration(hadoopConf)) - // The partition values are already truncated in `FileScan.partitions`. - // We should use `readPartitionSchema` as the partition schema here. - AvroPartitionReaderFactory(sparkSession.sessionState.conf, broadcastedConf, - dataSchema, readDataSchema, readPartitionSchema, caseSensitiveMap) - } + override def isSplitable(path: Path): Boolean = true + + override def createReaderFactory(): PartitionReaderFactory = { + val caseSensitiveMap = options.asCaseSensitiveMap.asScala.toMap + // Hadoop Configurations are case sensitive. + val hadoopConf = sparkSession.sessionState.newHadoopConfWithOptions(caseSensitiveMap) + val broadcastedConf = sparkSession.sparkContext.broadcast( + new SerializableConfiguration(hadoopConf)) + // The partition values are already truncated in `FileScan.partitions`. + // We should use `readPartitionSchema` as the partition schema here. + AvroPartitionReaderFactory(sparkSession.sessionState.conf, broadcastedConf, + dataSchema, readDataSchema, readPartitionSchema, caseSensitiveMap) + } override def withPartitionFilters(partitionFilters: Seq[Expression]): FileScan = this.copy(partitionFilters = partitionFilters) + + override def equals(obj: Any): Boolean = obj match { + case a: AvroScan => super.equals(a) && dataSchema == a.dataSchema && options == a.options + + case _ => false } + + override def hashCode(): Int = super.hashCode() +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/FileScan.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/FileScan.scala index 06903ed57e277..a22e1ccfe4515 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/FileScan.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/FileScan.scala @@ -24,7 +24,7 @@ import org.apache.hadoop.fs.Path import org.apache.spark.internal.Logging import org.apache.spark.internal.config.IO_WARNING_LARGEFILETHRESHOLD import org.apache.spark.sql.{AnalysisException, SparkSession} -import org.apache.spark.sql.catalyst.expressions.Expression +import org.apache.spark.sql.catalyst.expressions.{Expression, ExpressionSet} import org.apache.spark.sql.catalyst.expressions.codegen.GenerateUnsafeProjection import org.apache.spark.sql.connector.read.{Batch, InputPartition, Scan, Statistics, SupportsReportStatistics} import org.apache.spark.sql.execution.PartitionedFileUtil @@ -76,6 +76,16 @@ trait FileScan extends Scan with Batch with SupportsReportStatistics with Loggin protected def seqToString(seq: Seq[Any]): String = seq.mkString("[", ", ", "]") + override def equals(obj: Any): Boolean = obj match { + case f: FileScan => + fileIndex == f.fileIndex && readSchema == f.readSchema + ExpressionSet(partitionFilters) == ExpressionSet(f.partitionFilters) + + case _ => false + } + + override def hashCode(): Int = getClass.hashCode() + override def description(): String = { val locationDesc = fileIndex.getClass.getSimpleName + fileIndex.rootPaths.mkString("[", ", ", "]") diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/csv/CSVScan.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/csv/CSVScan.scala index 1b2ddf6257f46..78b04aa811e09 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/csv/CSVScan.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/csv/CSVScan.scala @@ -91,4 +91,12 @@ case class CSVScan( override def withPartitionFilters(partitionFilters: Seq[Expression]): FileScan = this.copy(partitionFilters = partitionFilters) + + override def equals(obj: Any): Boolean = obj match { + case c: CSVScan => super.equals(c) && dataSchema == c.dataSchema && options == c.options + + case _ => false + } + + override def hashCode(): Int = super.hashCode() } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/json/JsonScan.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/json/JsonScan.scala index be5af0f03d063..153b402476c40 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/json/JsonScan.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/json/JsonScan.scala @@ -90,4 +90,12 @@ case class JsonScan( override def withPartitionFilters(partitionFilters: Seq[Expression]): FileScan = this.copy(partitionFilters = partitionFilters) + + override def equals(obj: Any): Boolean = obj match { + case j: JsonScan => super.equals(j) && dataSchema == j.dataSchema && options == j.options + + case _ => false + } + + override def hashCode(): Int = super.hashCode() } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/orc/OrcScan.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/orc/OrcScan.scala index 6f71e2e7b7eb0..f0595cb6d09c3 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/orc/OrcScan.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/orc/OrcScan.scala @@ -52,10 +52,9 @@ case class OrcScan( override def equals(obj: Any): Boolean = obj match { case o: OrcScan => - fileIndex == o.fileIndex && dataSchema == o.dataSchema && - readDataSchema == o.readDataSchema && readPartitionSchema == o.readPartitionSchema && - options == o.options && equivalentFilters(pushedFilters, o.pushedFilters) && - Set(partitionFilters.map(_.canonicalized)) == Set(o.partitionFilters.map(_.canonicalized)) + super.equals(o) && dataSchema == o.dataSchema && options == o.options && + equivalentFilters(pushedFilters, o.pushedFilters) + case _ => false } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/parquet/ParquetScan.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/parquet/ParquetScan.scala index ab0ead5f7fea9..44179e2e42a4c 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/parquet/ParquetScan.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/parquet/ParquetScan.scala @@ -81,10 +81,8 @@ case class ParquetScan( override def equals(obj: Any): Boolean = obj match { case p: ParquetScan => - fileIndex == p.fileIndex && dataSchema == p.dataSchema && - readDataSchema == p.readDataSchema && readPartitionSchema == p.readPartitionSchema && - options == p.options && equivalentFilters(pushedFilters, p.pushedFilters) && - Set(partitionFilters.map(_.canonicalized)) == Set(p.partitionFilters.map(_.canonicalized)) + super.equals(p) && dataSchema == p.dataSchema && options == p.options && + equivalentFilters(pushedFilters, p.pushedFilters) case _ => false } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/text/TextScan.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/text/TextScan.scala index 25fad8c16b65e..cf6595e5c126c 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/text/TextScan.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/text/TextScan.scala @@ -72,4 +72,12 @@ case class TextScan( override def withPartitionFilters(partitionFilters: Seq[Expression]): FileScan = this.copy(partitionFilters = partitionFilters) + + override def equals(obj: Any): Boolean = obj match { + case t: TextScan => super.equals(t) && options == t.options + + case _ => false + } + + override def hashCode(): Int = super.hashCode() } From 13e95359b30f6e25ba1ed5253a844de28608b1e7 Mon Sep 17 00:00:00 2001 From: Gengliang Wang Date: Tue, 7 Jan 2020 13:29:35 -0800 Subject: [PATCH 5/8] update tests --- .../org/apache/spark/sql/avro/AvroSuite.scala | 62 ++++++++++++++++++- .../spark/sql/FileBasedDataSourceSuite.scala | 19 ++++-- 2 files changed, 76 insertions(+), 5 deletions(-) diff --git a/external/avro/src/test/scala/org/apache/spark/sql/avro/AvroSuite.scala b/external/avro/src/test/scala/org/apache/spark/sql/avro/AvroSuite.scala index dc60cfe41ca7a..f9e3c20265c0a 100644 --- a/external/avro/src/test/scala/org/apache/spark/sql/avro/AvroSuite.scala +++ b/external/avro/src/test/scala/org/apache/spark/sql/avro/AvroSuite.scala @@ -36,10 +36,13 @@ import org.apache.commons.io.FileUtils import org.apache.spark.{SparkConf, SparkException} import org.apache.spark.sql._ import org.apache.spark.sql.TestingUDT.{IntervalData, NullData, NullUDT} -import org.apache.spark.sql.execution.datasources.DataSource +import org.apache.spark.sql.execution.SparkPlan +import org.apache.spark.sql.execution.datasources.{DataSource, FilePartition} +import org.apache.spark.sql.execution.datasources.v2.BatchScanExec import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.test.SharedSparkSession import org.apache.spark.sql.types._ +import org.apache.spark.sql.v2.avro.AvroScan import org.apache.spark.util.Utils abstract class AvroSuite extends QueryTest with SharedSparkSession { @@ -1502,8 +1505,65 @@ class AvroV1Suite extends AvroSuite { } class AvroV2Suite extends AvroSuite { + import testImplicits._ + override protected def sparkConf: SparkConf = super .sparkConf .set(SQLConf.USE_V1_SOURCE_LIST, "") + + test("Avro source v2: support partition pruning") { + withTempPath { dir => + Seq(("a", 1, 2), ("b", 1, 2), ("c", 2, 1)) + .toDF("value", "p1", "p2") + .write + .format("avro") + .partitionBy("p1", "p2") + .option("header", true) + .save(dir.getCanonicalPath) + val df = spark + .read + .format("avro") + .option("header", true) + .load(dir.getCanonicalPath) + .where("p1 = 1 and p2 = 2 and value != \"a\"") + val fileScan = df.queryExecution.executedPlan collectFirst { + case BatchScanExec(_, f: AvroScan) => f + } + assert(fileScan.nonEmpty) + assert(fileScan.get.partitionFilters.nonEmpty) + assert(fileScan.get.planInputPartitions().forall { partition => + partition.asInstanceOf[FilePartition].files.forall { file => + file.filePath.contains("p1=1") && file.filePath.contains("p2=2") + } + }) + checkAnswer(df, Row("b", 1, 2)) + } + } + + private def getBatchScanExec(plan: SparkPlan): BatchScanExec = { + plan.find(_.isInstanceOf[BatchScanExec]).get.asInstanceOf[BatchScanExec] + } + + test("Avro source v2: same result with different orders of data filters and partition filters") { + withTempPath { path => + val tmpDir = path.getCanonicalPath + spark + .range(10) + .selectExpr("id as a", "id + 1 as b", "id + 2 as c", "id + 3 as d") + .write + .partitionBy("a", "b") + .format("avro") + .save(tmpDir) + val df = spark.read.format("avro").load(tmpDir) + // partition filters: a > 1 AND b < 9 + // data filters: c > 1 AND d < 9 + val plan1 = df.where("a > 1 AND b < 9 AND c > 1 AND d < 9").queryExecution.sparkPlan + val plan2 = df.where("b < 9 AND a > 1 AND d < 9 AND c > 1").queryExecution.sparkPlan + assert(plan1.sameResult(plan2)) + val scan1 = getBatchScanExec(plan1) + val scan2 = getBatchScanExec(plan2) + assert(scan1.sameResult(scan2)) + } + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/FileBasedDataSourceSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/FileBasedDataSourceSuite.scala index 0a19b769c6638..64a08343e7cfc 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/FileBasedDataSourceSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/FileBasedDataSourceSuite.scala @@ -731,9 +731,19 @@ class FileBasedDataSourceSuite extends QueryTest with SharedSparkSession { withSQLConf(SQLConf.USE_V1_SOURCE_LIST.key -> "") { allFileBasedDataSources.foreach { format => withTempPath { dir => - Seq(("a", 1), ("b", 2)).toDF("v", "p").write.format(format) - .partitionBy("p").save(dir.getCanonicalPath) - val df = spark.read.format(format).load(dir.getCanonicalPath).where("p = 1") + Seq(("a", 1, 2), ("b", 1, 2), ("c", 2, 1)) + .toDF("value", "p1", "p2") + .write + .format(format) + .partitionBy("p1", "p2") + .option("header", true) + .save(dir.getCanonicalPath) + val df = spark + .read + .format(format) + .option("header", true) + .load(dir.getCanonicalPath) + .where("p1 = 1 and p2 = 2 and value != \"a\"") val fileScan = df.queryExecution.executedPlan collectFirst { case BatchScanExec(_, f: FileScan) => f } @@ -741,9 +751,10 @@ class FileBasedDataSourceSuite extends QueryTest with SharedSparkSession { assert(fileScan.get.partitionFilters.nonEmpty) assert(fileScan.get.planInputPartitions().forall { partition => partition.asInstanceOf[FilePartition].files.forall { file => - file.filePath.contains("p=1") + file.filePath.contains("p1=1") && file.filePath.contains("p2=2") } }) + checkAnswer(df, Row("b", 1, 2)) } } } From 31c8c14810d6cbe7fbe9166909b2a7f5b92b6dd7 Mon Sep 17 00:00:00 2001 From: Gengliang Wang Date: Tue, 7 Jan 2020 14:22:40 -0800 Subject: [PATCH 6/8] revise comments --- .../execution/datasources/PruneFileSourcePartitions.scala | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/PruneFileSourcePartitions.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/PruneFileSourcePartitions.scala index 9e1c3eed926ae..fdf4967e0339b 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/PruneFileSourcePartitions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/PruneFileSourcePartitions.scala @@ -48,7 +48,6 @@ private[sql] object PruneFileSourcePartitions extends Rule[LogicalPlan] { projects: Seq[NamedExpression], filters: Seq[Expression], relation: LeafNode): Project = { - // Keep partition-pruning predicates so that they are visible in physical planning val withFilter = if (filters.nonEmpty) { val filterExpression = filters.reduceLeft(And) Filter(filterExpression, relation) @@ -84,6 +83,7 @@ private[sql] object PruneFileSourcePartitions extends Rule[LogicalPlan] { stats = Some(CatalogStatistics(sizeInBytes = BigInt(prunedFileIndex.sizeInBytes))))) val prunedLogicalRelation = logicalRelation.copy( relation = prunedFsRelation, catalogTable = withStats) + // Keep partition-pruning predicates so that they are visible in physical planning rebuildPhysicalOperation(projects, filters, prunedLogicalRelation) } else { op @@ -91,12 +91,13 @@ private[sql] object PruneFileSourcePartitions extends Rule[LogicalPlan] { case op @ PhysicalOperation(projects, filters, v2Relation @ DataSourceV2ScanRelation(_, scan: FileScan, output)) - if filters.nonEmpty && scan.readDataSchema.nonEmpty => + if filters.nonEmpty && scan.readDataSchema.nonEmpty => val partitionKeyFilters = getPartitionKeyFilters(scan.sparkSession, v2Relation, scan.readPartitionSchema, filters, output) if (partitionKeyFilters.nonEmpty) { val prunedV2Relation = v2Relation.copy(scan = scan.withPartitionFilters(partitionKeyFilters.toSeq)) + // The pushed down partition filters don't need to be evaluated again. val afterScanFilters = ExpressionSet(filters) -- partitionKeyFilters.filter(_.references.nonEmpty) rebuildPhysicalOperation(projects, afterScanFilters.toSeq, prunedV2Relation) From 7652c351f0abf4fabfed21a0173a4e4325edf2b4 Mon Sep 17 00:00:00 2001 From: Gengliang Wang Date: Tue, 7 Jan 2020 14:55:33 -0800 Subject: [PATCH 7/8] check filters in test cases --- .../scala/org/apache/spark/sql/avro/AvroSuite.scala | 12 ++++++++++++ .../datasources/PruneFileSourcePartitions.scala | 2 +- .../apache/spark/sql/FileBasedDataSourceSuite.scala | 12 ++++++++++++ 3 files changed, 25 insertions(+), 1 deletion(-) diff --git a/external/avro/src/test/scala/org/apache/spark/sql/avro/AvroSuite.scala b/external/avro/src/test/scala/org/apache/spark/sql/avro/AvroSuite.scala index f9e3c20265c0a..3f2744014c199 100644 --- a/external/avro/src/test/scala/org/apache/spark/sql/avro/AvroSuite.scala +++ b/external/avro/src/test/scala/org/apache/spark/sql/avro/AvroSuite.scala @@ -36,6 +36,8 @@ import org.apache.commons.io.FileUtils import org.apache.spark.{SparkConf, SparkException} import org.apache.spark.sql._ import org.apache.spark.sql.TestingUDT.{IntervalData, NullData, NullUDT} +import org.apache.spark.sql.catalyst.expressions.AttributeReference +import org.apache.spark.sql.catalyst.plans.logical.Filter import org.apache.spark.sql.execution.SparkPlan import org.apache.spark.sql.execution.datasources.{DataSource, FilePartition} import org.apache.spark.sql.execution.datasources.v2.BatchScanExec @@ -1527,6 +1529,16 @@ class AvroV2Suite extends AvroSuite { .option("header", true) .load(dir.getCanonicalPath) .where("p1 = 1 and p2 = 2 and value != \"a\"") + + val filterCondition = df.queryExecution.optimizedPlan.collectFirst { + case f: Filter => f.condition + } + assert(filterCondition.isDefined) + // The partitions filters should be pushed down and no need to be reevaluated. + assert(filterCondition.get.collectFirst { + case a: AttributeReference if a.name == "p1" || a.name == "p2" => a + }.isEmpty) + val fileScan = df.queryExecution.executedPlan collectFirst { case BatchScanExec(_, f: AvroScan) => f } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/PruneFileSourcePartitions.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/PruneFileSourcePartitions.scala index fdf4967e0339b..7fd154ccac445 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/PruneFileSourcePartitions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/PruneFileSourcePartitions.scala @@ -97,7 +97,7 @@ private[sql] object PruneFileSourcePartitions extends Rule[LogicalPlan] { if (partitionKeyFilters.nonEmpty) { val prunedV2Relation = v2Relation.copy(scan = scan.withPartitionFilters(partitionKeyFilters.toSeq)) - // The pushed down partition filters don't need to be evaluated again. + // The pushed down partition filters don't need to be reevaluated. val afterScanFilters = ExpressionSet(filters) -- partitionKeyFilters.filter(_.references.nonEmpty) rebuildPhysicalOperation(projects, afterScanFilters.toSeq, prunedV2Relation) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/FileBasedDataSourceSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/FileBasedDataSourceSuite.scala index 64a08343e7cfc..b8b27b52c67f7 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/FileBasedDataSourceSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/FileBasedDataSourceSuite.scala @@ -28,7 +28,9 @@ import org.apache.hadoop.fs.Path import org.apache.spark.SparkException import org.apache.spark.scheduler.{SparkListener, SparkListenerTaskEnd} import org.apache.spark.sql.TestingUDT.{IntervalUDT, NullData, NullUDT} +import org.apache.spark.sql.catalyst.expressions.AttributeReference import org.apache.spark.sql.catalyst.planning.PhysicalOperation +import org.apache.spark.sql.catalyst.plans.logical.Filter import org.apache.spark.sql.execution.datasources.FilePartition import org.apache.spark.sql.execution.datasources.v2.{BatchScanExec, DataSourceV2ScanRelation, FileScan} import org.apache.spark.sql.execution.datasources.v2.parquet.ParquetTable @@ -744,6 +746,16 @@ class FileBasedDataSourceSuite extends QueryTest with SharedSparkSession { .option("header", true) .load(dir.getCanonicalPath) .where("p1 = 1 and p2 = 2 and value != \"a\"") + + val filterCondition = df.queryExecution.optimizedPlan.collectFirst { + case f: Filter => f.condition + } + assert(filterCondition.isDefined) + // The partitions filters should be pushed down and no need to be reevaluated. + assert(filterCondition.get.collectFirst { + case a: AttributeReference if a.name == "p1" || a.name == "p2" => a + }.isEmpty) + val fileScan = df.queryExecution.executedPlan collectFirst { case BatchScanExec(_, f: FileScan) => f } From 58a4a07c50fd0ef54ca55fe1883d228eeb8464b2 Mon Sep 17 00:00:00 2001 From: Gengliang Wang Date: Tue, 7 Jan 2020 16:34:22 -0800 Subject: [PATCH 8/8] fix v1.2 compiling --- .../sql/execution/datasources/orc/OrcFilterSuite.scala | 9 ++++----- .../sql/execution/datasources/orc/OrcFilterSuite.scala | 3 +-- 2 files changed, 5 insertions(+), 7 deletions(-) diff --git a/sql/core/v1.2/src/test/scala/org/apache/spark/sql/execution/datasources/orc/OrcFilterSuite.scala b/sql/core/v1.2/src/test/scala/org/apache/spark/sql/execution/datasources/orc/OrcFilterSuite.scala index d09236a934337..526ce5cb70856 100644 --- a/sql/core/v1.2/src/test/scala/org/apache/spark/sql/execution/datasources/orc/OrcFilterSuite.scala +++ b/sql/core/v1.2/src/test/scala/org/apache/spark/sql/execution/datasources/orc/OrcFilterSuite.scala @@ -53,12 +53,11 @@ class OrcFilterSuite extends OrcTest with SharedSparkSession { .where(Column(predicate)) query.queryExecution.optimizedPlan match { - case PhysicalOperation(_, filters, - DataSourceV2ScanRelation(_, OrcScan(_, _, _, _, _, _, _, pushedFilters), _)) => + case PhysicalOperation(_, filters, DataSourceV2ScanRelation(_, o: OrcScan, _)) => assert(filters.nonEmpty, "No filter is analyzed from the given query") - assert(pushedFilters.nonEmpty, "No filter is pushed down") - val maybeFilter = OrcFilters.createFilter(query.schema, pushedFilters) - assert(maybeFilter.isDefined, s"Couldn't generate filter predicate for $pushedFilters") + assert(o.pushedFilters.nonEmpty, "No filter is pushed down") + val maybeFilter = OrcFilters.createFilter(query.schema, o.pushedFilters) + assert(maybeFilter.isDefined, s"Couldn't generate filter predicate for ${o.pushedFilters}") checker(maybeFilter.get) case _ => diff --git a/sql/core/v2.3/src/test/scala/org/apache/spark/sql/execution/datasources/orc/OrcFilterSuite.scala b/sql/core/v2.3/src/test/scala/org/apache/spark/sql/execution/datasources/orc/OrcFilterSuite.scala index d6afa2d24bdfa..f88fec7ed4d65 100644 --- a/sql/core/v2.3/src/test/scala/org/apache/spark/sql/execution/datasources/orc/OrcFilterSuite.scala +++ b/sql/core/v2.3/src/test/scala/org/apache/spark/sql/execution/datasources/orc/OrcFilterSuite.scala @@ -54,8 +54,7 @@ class OrcFilterSuite extends OrcTest with SharedSparkSession { .where(Column(predicate)) query.queryExecution.optimizedPlan match { - case PhysicalOperation(_, filters, - DataSourceV2ScanRelation(_, o: OrcScan, _)) => + case PhysicalOperation(_, filters, DataSourceV2ScanRelation(_, o: OrcScan, _)) => assert(filters.nonEmpty, "No filter is analyzed from the given query") assert(o.pushedFilters.nonEmpty, "No filter is pushed down") val maybeFilter = OrcFilters.createFilter(query.schema, o.pushedFilters)