diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/interface.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/interface.scala index f653bf41c1624..96e8e7852529b 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/interface.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/interface.scala @@ -610,7 +610,9 @@ case class HiveTableRelation( tableMeta: CatalogTable, dataCols: Seq[AttributeReference], partitionCols: Seq[AttributeReference], - tableStats: Option[Statistics] = None) extends LeafNode with MultiInstanceRelation { + tableStats: Option[Statistics] = None, + @transient prunedPartitions: Option[Seq[CatalogTablePartition]] = None) + extends LeafNode with MultiInstanceRelation { assert(tableMeta.identifier.database.isDefined) assert(tableMeta.partitionSchema.sameType(partitionCols.toStructType)) assert(tableMeta.dataSchema.sameType(dataCols.toStructType)) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/benchmark/TPCDSQueryBenchmark.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/benchmark/TPCDSQueryBenchmark.scala index 5ff33b9cfbfc9..e8ce7aef75ff1 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/benchmark/TPCDSQueryBenchmark.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/benchmark/TPCDSQueryBenchmark.scala @@ -84,7 +84,7 @@ object TPCDSQueryBenchmark extends SqlBasedBenchmark { queryRelations.add(alias.identifier) case LogicalRelation(_, _, Some(catalogTable), _) => queryRelations.add(catalogTable.identifier.table) - case HiveTableRelation(tableMeta, _, _, _) => + case HiveTableRelation(tableMeta, _, _, _, _) => queryRelations.add(tableMeta.identifier.table) case _ => } diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveSessionStateBuilder.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveSessionStateBuilder.scala index 3df77fec20993..1d71f47708997 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveSessionStateBuilder.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveSessionStateBuilder.scala @@ -21,9 +21,10 @@ import org.apache.spark.annotation.Unstable import org.apache.spark.sql._ import org.apache.spark.sql.catalyst.analysis.{Analyzer, ResolveSessionCatalog} import org.apache.spark.sql.catalyst.catalog.ExternalCatalogWithListener +import org.apache.spark.sql.catalyst.optimizer.Optimizer import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan import org.apache.spark.sql.catalyst.rules.Rule -import org.apache.spark.sql.execution.SparkPlanner +import org.apache.spark.sql.execution.{SparkOptimizer, SparkPlanner} import org.apache.spark.sql.execution.analysis.DetectAmbiguousSelfJoin import org.apache.spark.sql.execution.datasources._ import org.apache.spark.sql.execution.datasources.v2.TableCapabilityCheck @@ -93,6 +94,20 @@ class HiveSessionStateBuilder(session: SparkSession, parentState: Option[Session customCheckRules } + /** + * Logical query plan optimizer that takes into account Hive. + */ + override protected def optimizer: Optimizer = { + new SparkOptimizer(catalogManager, catalog, experimentalMethods) { + override def postHocOptimizationBatches: Seq[Batch] = Seq( + Batch("Prune Hive Table Partitions", Once, PruneHiveTablePartitions(session)) + ) + + override def extendedOperatorOptimizationRules: Seq[Rule[LogicalPlan]] = + super.extendedOperatorOptimizationRules ++ customOperatorOptimizationRules + } + } + /** * Planner that takes into account Hive-specific strategies. */ diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveStrategies.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveStrategies.scala index 33ca1889e944d..cf806cda4119e 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveStrategies.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveStrategies.scala @@ -21,15 +21,16 @@ import java.io.IOException import java.util.Locale import org.apache.hadoop.fs.{FileSystem, Path} +import org.apache.hadoop.hive.common.StatsSetupConst import org.apache.spark.sql._ import org.apache.spark.sql.catalyst.catalog._ import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.planning._ -import org.apache.spark.sql.catalyst.plans.logical.{InsertIntoDir, InsertIntoStatement, LogicalPlan, ScriptTransformation, Statistics} +import org.apache.spark.sql.catalyst.plans.logical.{Filter, InsertIntoDir, InsertIntoStatement, LogicalPlan, Project, ScriptTransformation, Statistics} import org.apache.spark.sql.catalyst.rules.Rule import org.apache.spark.sql.execution._ -import org.apache.spark.sql.execution.command.{CreateTableCommand, DDLUtils} +import org.apache.spark.sql.execution.command.{CommandUtils, CreateTableCommand, DDLUtils} import org.apache.spark.sql.execution.datasources.CreateTable import org.apache.spark.sql.hive.execution._ import org.apache.spark.sql.internal.{HiveSerDe, SQLConf} @@ -231,6 +232,68 @@ case class RelationConversions( } } +/** + * TODO: merge this with PruneFileSourcePartitions after we completely make hive as a data source. + */ +case class PruneHiveTablePartitions( + session: SparkSession) extends Rule[LogicalPlan] with PredicateHelper { + override def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperators { + case op @ PhysicalOperation(projections, predicates, relation: HiveTableRelation) + if predicates.nonEmpty && relation.isPartitioned && relation.prunedPartitions.isEmpty => + val normalizedFilters = predicates.map { e => + e transform { + case a: AttributeReference => + a.withName(relation.output.find(_.semanticEquals(a)).get.name) + } + } + val partitionSet = AttributeSet(relation.partitionCols) + val pruningPredicates = normalizedFilters.filter { predicate => + !predicate.references.isEmpty && predicate.references.subsetOf(partitionSet) + } + // SPARK-24085: scalar subquery should be skipped for partition pruning + val hasScalarSubquery = pruningPredicates.exists(SubqueryExpression.hasSubquery) + val conf = session.sessionState.conf + if (conf.metastorePartitionPruning && pruningPredicates.nonEmpty && !hasScalarSubquery) { + val prunedPartitions = session.sharedState.externalCatalog.listPartitionsByFilter( + relation.tableMeta.database, + relation.tableMeta.identifier.table, + pruningPredicates, + conf.sessionLocalTimeZone) + val sizeInBytes = try { + val sizeOfPartitions = prunedPartitions.map { part => + val rawDataSize = part.parameters.get(StatsSetupConst.RAW_DATA_SIZE).map(_.toLong) + val totalSize = part.parameters.get(StatsSetupConst.TOTAL_SIZE).map(_.toLong) + if (rawDataSize.isDefined && rawDataSize.get > 0) { + rawDataSize.get + } else if (totalSize.isDefined && totalSize.get > 0L) { + totalSize.get + } else if (conf.fallBackToHdfsForStatsEnabled) { + CommandUtils.calculateLocationSize( + session.sessionState, relation.tableMeta.identifier, part.storage.locationUri) + } else { // we cannot get any size statics here. Use 0 as the default size to sum up. + 0L + } + }.sum + // If size of partitions is zero fall back to the default size. + if (sizeOfPartitions == 0L) conf.defaultSizeInBytes else sizeOfPartitions + } catch { + case e: IOException => + logWarning("Failed to get table size from HDFS.", e) + conf.defaultSizeInBytes + } + val withStats = relation.tableMeta.copy( + stats = Some(CatalogStatistics(sizeInBytes = BigInt(sizeInBytes)))) + val prunedHiveTableRelation = + relation.copy(tableMeta = withStats, prunedPartitions = Some(prunedPartitions)) + val filterExpression = predicates.reduceLeft(And) + val filter = Filter(filterExpression, prunedHiveTableRelation) + Project(projections, filter) + } else { + op + } + } +} + private[hive] trait HiveStrategies { // Possibly being too clever with types here... or not clever enough. self: SparkPlanner => diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/HiveTableScanExec.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/HiveTableScanExec.scala index 5b00e2ebafa43..d353d5c546ba4 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/HiveTableScanExec.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/HiveTableScanExec.scala @@ -166,14 +166,14 @@ case class HiveTableScanExec( @transient lazy val rawPartitions = { val prunedPartitions = if (sparkSession.sessionState.conf.metastorePartitionPruning && - partitionPruningPred.size > 0) { + partitionPruningPred.nonEmpty) { // Retrieve the original attributes based on expression ID so that capitalization matches. val normalizedFilters = partitionPruningPred.map(_.transform { case a: AttributeReference => originalAttributes(a) }) - sparkSession.sessionState.catalog.listPartitionsByFilter( - relation.tableMeta.identifier, - normalizedFilters) + relation.prunedPartitions.getOrElse( + sparkSession.sessionState.catalog.listPartitionsByFilter( + relation.tableMeta.identifier, normalizedFilters)) } else { sparkSession.sessionState.catalog.listPartitions(relation.tableMeta.identifier) } diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/StatisticsSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/StatisticsSuite.scala index 40581066c62bb..8b483441970f4 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/StatisticsSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/StatisticsSuite.scala @@ -1514,4 +1514,35 @@ class StatisticsSuite extends StatisticsCollectionTestBase with TestHiveSingleto } } } + + test("Broadcast join can by inferred if partitioned table can be pruned under threshold") { + withTempView("tempTbl", "largeTbl") { + withTable("partTbl") { + spark.range(0, 1000, 1, 2).selectExpr("id as col1", "id as col2") + .createOrReplaceTempView("tempTbl") + spark.range(0, 100000, 1, 2).selectExpr("id as col1", "id as col2") + .createOrReplaceTempView("largeTbl") + sql("CREATE TABLE partTbl (col1 INT, col2 STRING) " + + "PARTITIONED BY (part1 STRING, part2 INT) STORED AS textfile") + for (part1 <- Seq("a", "b", "c", "d"); part2 <- Seq(1, 2)) { + sql( + s""" + |INSERT OVERWRITE TABLE partTbl PARTITION (part1='$part1',part2='$part2') + |select col1, col2 from tempTbl + """.stripMargin) + } + val query = "select * from largeTbl join partTbl on (largeTbl.col1 = partTbl.col1 " + + "and partTbl.part1 = 'a' and partTbl.part2 = 1)" + Seq(true, false).foreach { partitionPruning => + withSQLConf(SQLConf.ENABLE_FALL_BACK_TO_HDFS_FOR_STATS.key -> "true", + SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "8001", + SQLConf.HIVE_METASTORE_PARTITION_PRUNING.key -> s"$partitionPruning") { + val broadcastJoins = + sql(query).queryExecution.sparkPlan.collect { case j: BroadcastHashJoinExec => j } + assert(broadcastJoins.nonEmpty === partitionPruning) + } + } + } + } + } } diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveTableScanSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveTableScanSuite.scala index 3f9bb8de42e09..e97ee83bf7669 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveTableScanSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveTableScanSuite.scala @@ -128,8 +128,12 @@ class HiveTableScanSuite extends HiveComparisonTest with SQLTestUtils with TestH // If the pruning predicate is used, getHiveQlPartitions should only return the // qualified partition; Otherwise, it return all the partitions. val expectedNumPartitions = if (hivePruning == "true") 1 else 2 - checkNumScannedPartitions( - stmt = s"SELECT id, p2 FROM $table WHERE p2 <= 'b'", expectedNumPartitions) + val stmt = s"SELECT id, p2 FROM $table WHERE p2 <= 'b'" + checkNumScannedPartitions(stmt = stmt, expectedNumPartitions) + // prunedPartitions are held in HiveTableRelation + val prunedNumPartitions = if (hivePruning == "true") 1 else 0 + assert( + getHiveTableScanExec(stmt).relation.prunedPartitions.size === prunedNumPartitions) } } @@ -137,8 +141,10 @@ class HiveTableScanSuite extends HiveComparisonTest with SQLTestUtils with TestH withSQLConf(SQLConf.HIVE_METASTORE_PARTITION_PRUNING.key -> hivePruning) { // If the pruning predicate does not exist, getHiveQlPartitions should always // return all the partitions. - checkNumScannedPartitions( - stmt = s"SELECT id, p2 FROM $table WHERE id <= 3", expectedNumParts = 2) + val stmt = s"SELECT id, p2 FROM $table WHERE id <= 3" + checkNumScannedPartitions(stmt = stmt, expectedNumParts = 2) + // no pruning is triggered, no partitions are held in HiveTableRelation + assert(getHiveTableScanExec(stmt).relation.prunedPartitions.isEmpty) } } }