Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -23,9 +23,11 @@ import org.apache.hadoop.fs.Path
import org.apache.parquet.hadoop.ParquetInputFormat

import org.apache.spark.sql.SparkSession
import org.apache.spark.sql.catalyst.expressions.Expression
import org.apache.spark.sql.catalyst.expressions.{AttributeReference, Expression, InSet}
import org.apache.spark.sql.connector.expressions.{FieldReference, LiteralValue, NamedReference}
import org.apache.spark.sql.connector.expressions.aggregate.Aggregation
import org.apache.spark.sql.connector.read.PartitionReaderFactory
import org.apache.spark.sql.connector.expressions.filter.Predicate
import org.apache.spark.sql.connector.read.{PartitionReaderFactory, SupportsRuntimeV2Filtering}
import org.apache.spark.sql.execution.datasources.{AggregatePushDownUtils, PartitioningAwareFileIndex}
import org.apache.spark.sql.execution.datasources.parquet.{ParquetOptions, ParquetReadSupport, ParquetWriteSupport}
import org.apache.spark.sql.execution.datasources.v2.FileScan
Expand All @@ -46,14 +48,22 @@ case class ParquetScan(
pushedFilters: Array[Filter],
options: CaseInsensitiveStringMap,
pushedAggregate: Option[Aggregation] = None,
partitionFilters: Seq[Expression] = Seq.empty,
dataFilters: Seq[Expression] = Seq.empty) extends FileScan {
originPartitionFilters: Seq[Expression] = Seq.empty,
dataFilters: Seq[Expression] = Seq.empty)
extends FileScan
with SupportsRuntimeV2Filtering {
override def isSplitable(path: Path): Boolean = {
// If aggregate is pushed down, only the file footer will be read once,
// so file should not be split across multiple tasks.
pushedAggregate.isEmpty
}

private var dppPartitionFilters: Seq[Expression] = Seq.empty;

override def partitionFilters: Seq[Expression] = {
originPartitionFilters ++ dppPartitionFilters
}

override def readSchema(): StructType = {
// If aggregate is pushed down, schema has already been pruned in `ParquetScanBuilder`
// and no need to call super.readSchema()
Expand Down Expand Up @@ -132,4 +142,29 @@ case class ParquetScan(
Map("PushedAggregation" -> pushedAggregationsStr) ++
Map("PushedGroupBy" -> pushedGroupByStr)
}

override def filterAttributes(): Array[NamedReference] = {
val scanFields = readSchema.fields.map(_.name).toSet
readPartitionSchema.fields
.map(_.name)
.filter(ref => scanFields.contains(ref))
.map(f => FieldReference(f))
}

override def filter(predicates: Array[Predicate]): Unit = {
predicates.foreach {
case p: Predicate if p.name().equals("IN") =>
if (p.children().length > 1 && p.children()(0).isInstanceOf[FieldReference]
&& p.children().tail.forall(_.isInstanceOf[LiteralValue[_]])) {
val values = p.children().drop(1)
val filterRef = p.children()(0).asInstanceOf[FieldReference].references.head
val sets = values.map(_.asInstanceOf[LiteralValue[_]].value).toSet[Any]
dppPartitionFilters = dppPartitionFilters :+ InSet(
AttributeReference(
filterRef.toString,
values(0).asInstanceOf[LiteralValue[_]].dataType)(),
sets)
}
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -67,11 +67,20 @@ private[sql] trait FileBasedDataSourceTest extends SQLTestUtils {
* Writes `data` to a data source file, which is then passed to `f` and will be deleted after `f`
* returns.
*/
protected def withDataSourceFile[T <: Product : ClassTag : TypeTag]
(data: Seq[T])
(f: String => Unit): Unit = {
protected def withDataSourceFile[T <: Product: ClassTag: TypeTag](
data: Seq[T],
partitionNames: Seq[String] = Seq.empty)(f: String => Unit): Unit = {
withTempPath { file =>
spark.createDataFrame(data).write.format(dataSourceName).save(file.getCanonicalPath)
if (partitionNames.isEmpty) {
spark.createDataFrame(data).write.format(dataSourceName).save(file.getCanonicalPath)
} else {
spark
.createDataFrame(data)
.write
.format(dataSourceName)
.partitionBy(partitionNames: _*)
.save(file.getCanonicalPath)
}
f(file.getCanonicalPath)
}
}
Expand All @@ -80,21 +89,24 @@ private[sql] trait FileBasedDataSourceTest extends SQLTestUtils {
* Writes `data` to a data source file and reads it back as a [[DataFrame]],
* which is then passed to `f`. The file will be deleted after `f` returns.
*/
protected def withDataSourceDataFrame[T <: Product : ClassTag : TypeTag]
(data: Seq[T], testVectorized: Boolean = true)
(f: DataFrame => Unit): Unit = {
withDataSourceFile(data)(path => readFile(path.toString, testVectorized)(f))
protected def withDataSourceDataFrame[T <: Product: ClassTag: TypeTag](
data: Seq[T],
testVectorized: Boolean = true,
partitionNames: Seq[String] = Seq.empty)(f: DataFrame => Unit): Unit = {
withDataSourceFile(data, partitionNames)(path => readFile(path.toString, testVectorized)(f))
}

/**
* Writes `data` to a data source file, reads it back as a [[DataFrame]] and registers it as a
* temporary table named `tableName`, then call `f`. The temporary table together with the
* data file will be dropped/deleted after `f` returns.
*/
protected def withDataSourceTable[T <: Product : ClassTag : TypeTag]
(data: Seq[T], tableName: String, testVectorized: Boolean = true)
(f: => Unit): Unit = {
withDataSourceDataFrame(data, testVectorized) { df =>
protected def withDataSourceTable[T <: Product: ClassTag: TypeTag](
data: Seq[T],
tableName: String,
testVectorized: Boolean = true,
partitionNames: Seq[String] = Seq.empty)(f: => Unit): Unit = {
withDataSourceDataFrame(data, testVectorized, partitionNames) { df =>
df.createOrReplaceTempView(tableName)
withTempView(tableName)(f)
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ import org.apache.spark.sql._
import org.apache.spark.sql.catalyst.{InternalRow, TableIdentifier}
import org.apache.spark.sql.catalyst.expressions.SpecificInternalRow
import org.apache.spark.sql.catalyst.util.ArrayData
import org.apache.spark.sql.execution.FileSourceScanExec
import org.apache.spark.sql.execution.{ExplainMode, FileSourceScanExec}
import org.apache.spark.sql.execution.datasources.{SchemaColumnConvertNotSupportedException, SQLHadoopMapReduceCommitProtocol}
import org.apache.spark.sql.execution.datasources.parquet.TestingUDT._
import org.apache.spark.sql.execution.datasources.v2.BatchScanExec
Expand Down Expand Up @@ -1180,6 +1180,72 @@ abstract class ParquetQuerySuite extends QueryTest with ParquetTest with SharedS
checkAnswer(sql("select * from tbl"), expected)
}
}

test("dynamic partition pruning") {
// 4 rows, cells of column 1 of row 2 and row 4 are null
val fact = (0 to 99).map { i =>
(i, i + 1, (i + 2).toByte, (i + 3).toShort, (i * 20) % 100, (i + 1).toString)
}

val dim = (0 to 9).map { i =>
(i, i + 1, (i + 2).toByte, (i + 3).toShort, (i * 10), (i + 1).toString)
}

withParquetTable(fact, "fact", true, Seq.apply("_1", "_2", "_3")) {
withParquetTable(dim, "dim") {
val df = sql("""
|SELECT f._1, f._2, f._3, f._4 FROM fact f
|JOIN dim d
|ON (f._2 = d._2)
|WHERE d._5 > 80
""".stripMargin)
val explainDF = df.queryExecution.explainString(
ExplainMode
.fromString("extended"))
assert(explainDF.contains("dynamicpruningexpression"))
checkAnswer(df, Row(9, 10, 11, 12) :: Nil)

// reuse a single Byte key
val dfByte = sql("""
|SELECT f._1, f._2, f._3, f._4 FROM fact f
|JOIN dim d
|ON (f._3 = d._3)
|WHERE d._5 > 80
""".stripMargin)
val explainDFByte = dfByte.queryExecution.explainString(
ExplainMode
.fromString("extended"))
assert(explainDFByte.contains("dynamicpruningexpression"))
checkAnswer(dfByte, Row(9, 10, 11, 12) :: Nil)

// reuse a single String key
val dfStr = sql("""
|SELECT f._1, f._2, f._3, f._4 FROM fact f
|JOIN dim d
|ON (f._3 = d._3)
|WHERE d._5 > 80
""".stripMargin)
val explainDFStr = dfStr.queryExecution.explainString(
ExplainMode
.fromString("extended"))
assert(explainDFStr.contains("dynamicpruningexpression"))
checkAnswer(dfStr, Row(9, 10, 11, 12) :: Nil)

// mult results
val dfMultStr = sql("""
|SELECT f._1, f._2, f._3, f._4 FROM fact f
|JOIN dim d
|ON (f._3 = d._3)
|WHERE d._5 > 70
""".stripMargin)
val explainDFMultStr = dfMultStr.queryExecution.explainString(
ExplainMode
.fromString("extended"))
assert(explainDFMultStr.contains("dynamicpruningexpression"))
checkAnswer(dfMultStr, Seq(Row(8, 9, 10, 11), Row(9, 10, 11, 12)))
}
}
}
}

class ParquetV1QuerySuite extends ParquetQuerySuite {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -83,9 +83,12 @@ private[sql] trait ParquetTest extends FileBasedDataSourceTest {
* temporary table named `tableName`, then call `f`. The temporary table together with the
* Parquet file will be dropped/deleted after `f` returns.
*/
protected def withParquetTable[T <: Product: ClassTag: TypeTag]
(data: Seq[T], tableName: String, testVectorized: Boolean = true)
(f: => Unit): Unit = withDataSourceTable(data, tableName, testVectorized)(f)
protected def withParquetTable[T <: Product: ClassTag: TypeTag](
data: Seq[T],
tableName: String,
testVectorized: Boolean = true,
partitionNames: Seq[String] = Seq.empty)(f: => Unit): Unit =
withDataSourceTable(data, tableName, testVectorized, partitionNames)(f)

protected def makeParquetFile[T <: Product: ClassTag: TypeTag](
data: Seq[T], path: File): Unit = makeDataSourceFile(data, path)
Expand Down